pycityagent 2.0.0a52__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a54__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent/agent.py +83 -62
- pycityagent/agent/agent_base.py +81 -54
- pycityagent/cityagent/bankagent.py +5 -7
- pycityagent/cityagent/blocks/__init__.py +0 -2
- pycityagent/cityagent/blocks/cognition_block.py +149 -172
- pycityagent/cityagent/blocks/economy_block.py +90 -129
- pycityagent/cityagent/blocks/mobility_block.py +56 -29
- pycityagent/cityagent/blocks/needs_block.py +163 -145
- pycityagent/cityagent/blocks/other_block.py +17 -9
- pycityagent/cityagent/blocks/plan_block.py +45 -57
- pycityagent/cityagent/blocks/social_block.py +70 -51
- pycityagent/cityagent/blocks/utils.py +2 -0
- pycityagent/cityagent/firmagent.py +6 -7
- pycityagent/cityagent/governmentagent.py +7 -9
- pycityagent/cityagent/memory_config.py +48 -48
- pycityagent/cityagent/message_intercept.py +99 -0
- pycityagent/cityagent/nbsagent.py +6 -29
- pycityagent/cityagent/societyagent.py +325 -127
- pycityagent/cli/wrapper.py +4 -0
- pycityagent/economy/econ_client.py +0 -2
- pycityagent/environment/__init__.py +7 -1
- pycityagent/environment/sim/client.py +10 -1
- pycityagent/environment/sim/clock_service.py +2 -2
- pycityagent/environment/sim/pause_service.py +61 -0
- pycityagent/environment/sim/sim_env.py +34 -46
- pycityagent/environment/simulator.py +18 -14
- pycityagent/llm/embeddings.py +0 -24
- pycityagent/llm/llm.py +18 -10
- pycityagent/memory/faiss_query.py +29 -26
- pycityagent/memory/memory.py +733 -247
- pycityagent/message/__init__.py +8 -1
- pycityagent/message/message_interceptor.py +322 -0
- pycityagent/message/messager.py +42 -11
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +137 -96
- pycityagent/simulation/simulation.py +184 -38
- pycityagent/simulation/storage/pg.py +2 -2
- pycityagent/tools/tool.py +7 -9
- pycityagent/utils/__init__.py +7 -2
- pycityagent/utils/pg_query.py +1 -0
- pycityagent/utils/survey_util.py +26 -23
- pycityagent/workflow/block.py +14 -7
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/METADATA +2 -2
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/RECORD +48 -46
- pycityagent/cityagent/blocks/time_block.py +0 -116
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/top_level.txt +0 -0
@@ -18,13 +18,15 @@ from ..cityagent import (BankAgent, FirmAgent, GovernmentAgent, NBSAgent,
|
|
18
18
|
memory_config_government, memory_config_nbs,
|
19
19
|
memory_config_societyagent)
|
20
20
|
from ..cityagent.initial import bind_agent_info, initialize_social_network
|
21
|
-
from ..environment
|
22
|
-
from ..llm import SimpleEmbedding
|
21
|
+
from ..environment import Simulator
|
22
|
+
from ..llm import LLM, LLMConfig, SimpleEmbedding
|
23
23
|
from ..memory import Memory
|
24
|
-
from ..message
|
24
|
+
from ..message import (MessageBlockBase, MessageBlockListenerBase,
|
25
|
+
MessageInterceptor, Messager)
|
25
26
|
from ..metrics import init_mlflow_connection
|
27
|
+
from ..metrics.mlflow_client import MlflowClient
|
26
28
|
from ..survey import Survey
|
27
|
-
from ..utils import TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
29
|
+
from ..utils import SURVEY_SENDER_UUID, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
28
30
|
from .agentgroup import AgentGroup
|
29
31
|
from .storage.pg import PgWriter, create_pg_tables
|
30
32
|
|
@@ -39,6 +41,7 @@ class AgentSimulation:
|
|
39
41
|
config: dict,
|
40
42
|
agent_class: Union[None, type[Agent], list[type[Agent]]] = None,
|
41
43
|
agent_config_file: Optional[dict] = None,
|
44
|
+
metric_extractor: Optional[list[tuple[int, Callable]]] = None,
|
42
45
|
enable_economy: bool = True,
|
43
46
|
agent_prefix: str = "agent_",
|
44
47
|
exp_name: str = "default_experiment",
|
@@ -80,6 +83,15 @@ class AgentSimulation:
|
|
80
83
|
self.config = config
|
81
84
|
self.exp_name = exp_name
|
82
85
|
self._simulator = Simulator(config["simulator_request"])
|
86
|
+
if enable_economy:
|
87
|
+
self._economy_env = self._simulator._sim_env
|
88
|
+
_req_dict = self.config["simulator_request"]
|
89
|
+
if "economy" in _req_dict:
|
90
|
+
_req_dict["economy"]["server"] = self._economy_env.sim_addr
|
91
|
+
else:
|
92
|
+
_req_dict["economy"] = {
|
93
|
+
"server": self._economy_env.sim_addr,
|
94
|
+
}
|
83
95
|
self.agent_prefix = agent_prefix
|
84
96
|
self._groups: dict[str, AgentGroup] = {} # type:ignore
|
85
97
|
self._agent_uuid2group: dict[str, AgentGroup] = {} # type:ignore
|
@@ -89,6 +101,8 @@ class AgentSimulation:
|
|
89
101
|
self._user_survey_topics: dict[str, str] = {}
|
90
102
|
self._user_interview_topics: dict[str, str] = {}
|
91
103
|
self._loop = asyncio.get_event_loop()
|
104
|
+
self._total_steps = 0
|
105
|
+
self._simulator_day = 0
|
92
106
|
# self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
93
107
|
|
94
108
|
self._messager = Messager.remote(
|
@@ -102,6 +116,7 @@ class AgentSimulation:
|
|
102
116
|
_storage_config: dict[str, Any] = config.get("storage", {})
|
103
117
|
if _storage_config is None:
|
104
118
|
_storage_config = {}
|
119
|
+
|
105
120
|
# avro
|
106
121
|
_avro_config: dict[str, Any] = _storage_config.get("avro", {})
|
107
122
|
self._enable_avro = _avro_config.get("enabled", False)
|
@@ -112,6 +127,27 @@ class AgentSimulation:
|
|
112
127
|
self._avro_path = Path(_avro_config["path"]) / f"{self.exp_id}"
|
113
128
|
self._avro_path.mkdir(parents=True, exist_ok=True)
|
114
129
|
|
130
|
+
# mlflow
|
131
|
+
_mlflow_config: dict[str, Any] = config.get("metric_request", {}).get("mlflow")
|
132
|
+
mlflow_run_id, _ = init_mlflow_connection(
|
133
|
+
config=_mlflow_config,
|
134
|
+
mlflow_run_name=f"EXP_{self.exp_name}_{1000*int(time.time())}",
|
135
|
+
experiment_name=self.exp_name,
|
136
|
+
)
|
137
|
+
if _mlflow_config:
|
138
|
+
logger.info(f"-----Creating Mlflow client...")
|
139
|
+
self.mlflow_client = MlflowClient(
|
140
|
+
config=_mlflow_config,
|
141
|
+
mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
|
142
|
+
experiment_name=exp_name,
|
143
|
+
run_id=mlflow_run_id,
|
144
|
+
)
|
145
|
+
self.metric_extractor = metric_extractor
|
146
|
+
else:
|
147
|
+
logger.warning("Mlflow is not enabled, NO MLFLOW STORAGE")
|
148
|
+
self.mlflow_client = None
|
149
|
+
self.metric_extractor = None
|
150
|
+
|
115
151
|
# pg
|
116
152
|
_pgsql_config: dict[str, Any] = _storage_config.get("pgsql", {})
|
117
153
|
self._enable_pgsql = _pgsql_config.get("enabled", False)
|
@@ -167,11 +203,11 @@ class AgentSimulation:
|
|
167
203
|
- workflow:
|
168
204
|
- list[Step]
|
169
205
|
- Step:
|
170
|
-
- type: str, "step", "run", "interview", "survey", "intervene"
|
206
|
+
- type: str, "step", "run", "interview", "survey", "intervene", "pause", "resume"
|
171
207
|
- day: int if type is "run", else None
|
172
|
-
-
|
208
|
+
- times: int if type is "step", else None
|
173
209
|
- description: Optional[str], description of the step
|
174
|
-
-
|
210
|
+
- func: Optional[Callable[AgentSimulation, None]], only used when type is "interview", "survey" and "intervene"
|
175
211
|
- logging_level: Optional[int]
|
176
212
|
- exp_name: Optional[str]
|
177
213
|
"""
|
@@ -201,6 +237,7 @@ class AgentSimulation:
|
|
201
237
|
agent_count.append(config["agent_config"]["number_of_government"])
|
202
238
|
agent_count.append(config["agent_config"]["number_of_bank"])
|
203
239
|
agent_count.append(config["agent_config"]["number_of_nbs"])
|
240
|
+
# TODO(yanjunbo): support MessageInterceptor
|
204
241
|
await simulation.init_agents(
|
205
242
|
agent_count=agent_count,
|
206
243
|
group_size=config["agent_config"].get("group_size", 10000),
|
@@ -224,10 +261,15 @@ class AgentSimulation:
|
|
224
261
|
if step["type"] == "run":
|
225
262
|
await simulation.run(step.get("day", 1))
|
226
263
|
elif step["type"] == "step":
|
227
|
-
|
228
|
-
|
264
|
+
times = step.get("times", 1)
|
265
|
+
for _ in range(times):
|
266
|
+
await simulation.step()
|
267
|
+
elif step["type"] == "pause":
|
268
|
+
await simulation.pause_simulator()
|
269
|
+
elif step["type"] == "resume":
|
270
|
+
await simulation.resume_simulator()
|
229
271
|
else:
|
230
|
-
await step["
|
272
|
+
await step["func"](simulation)
|
231
273
|
logger.info("Simulation finished")
|
232
274
|
|
233
275
|
@property
|
@@ -261,9 +303,13 @@ class AgentSimulation:
|
|
261
303
|
return self._agent_uuid2group
|
262
304
|
|
263
305
|
@property
|
264
|
-
def messager(self):
|
306
|
+
def messager(self) -> ray.ObjectRef:
|
265
307
|
return self._messager
|
266
308
|
|
309
|
+
@property
|
310
|
+
def message_interceptor(self) -> ray.ObjectRef:
|
311
|
+
return self._message_interceptors[0] # type:ignore
|
312
|
+
|
267
313
|
async def _save_exp_info(self) -> None:
|
268
314
|
"""异步保存实验信息到YAML文件"""
|
269
315
|
try:
|
@@ -331,11 +377,21 @@ class AgentSimulation:
|
|
331
377
|
# 如果没有发生异常且状态不是错误,则更新为完成
|
332
378
|
await self._update_exp_status(2)
|
333
379
|
|
380
|
+
async def pause_simulator(self):
|
381
|
+
await self._simulator.pause()
|
382
|
+
|
383
|
+
async def resume_simulator(self):
|
384
|
+
await self._simulator.resume()
|
385
|
+
|
334
386
|
async def init_agents(
|
335
387
|
self,
|
336
388
|
agent_count: Union[int, list[int]],
|
337
389
|
group_size: int = 10000,
|
338
390
|
pg_sql_writers: int = 32,
|
391
|
+
message_interceptors: int = 1,
|
392
|
+
message_interceptor_blocks: Optional[list[MessageBlockBase]] = None,
|
393
|
+
social_black_list: Optional[list[tuple[str, str]]] = None,
|
394
|
+
message_listener: Optional[MessageBlockListenerBase] = None,
|
339
395
|
embedding_model: Embeddings = SimpleEmbedding(),
|
340
396
|
memory_config_func: Optional[Union[Callable, list[Callable]]] = None,
|
341
397
|
) -> None:
|
@@ -344,6 +400,8 @@ class AgentSimulation:
|
|
344
400
|
Args:
|
345
401
|
agent_count: 要创建的总智能体数量, 如果为列表,则每个元素表示一个智能体类创建的智能体数量
|
346
402
|
group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
|
403
|
+
pg_sql_writers: 独立的PgSQL writer数量
|
404
|
+
message_interceptors: message拦截器数量
|
347
405
|
memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组, 如果为列表,则每个元素表示一个智能体类创建的Memory配置函数
|
348
406
|
"""
|
349
407
|
if not isinstance(agent_count, list):
|
@@ -353,12 +411,12 @@ class AgentSimulation:
|
|
353
411
|
raise ValueError("agent_class和agent_count的长度不一致")
|
354
412
|
|
355
413
|
if memory_config_func is None:
|
356
|
-
logger.warning(
|
357
|
-
"memory_config_func is None, using default memory config function"
|
358
|
-
)
|
359
414
|
memory_config_func = self.default_memory_config_func
|
360
415
|
|
361
416
|
elif not isinstance(memory_config_func, list):
|
417
|
+
logger.warning(
|
418
|
+
"memory_config_func is not a list, using specific memory config function"
|
419
|
+
)
|
362
420
|
memory_config_func = [memory_config_func]
|
363
421
|
|
364
422
|
if len(memory_config_func) != len(agent_count):
|
@@ -464,13 +522,12 @@ class AgentSimulation:
|
|
464
522
|
config_files,
|
465
523
|
)
|
466
524
|
)
|
467
|
-
|
468
525
|
# 初始化mlflow连接
|
469
526
|
_mlflow_config = self.config.get("metric_request", {}).get("mlflow")
|
470
527
|
if _mlflow_config:
|
471
528
|
mlflow_run_id, _ = init_mlflow_connection(
|
472
529
|
config=_mlflow_config,
|
473
|
-
mlflow_run_name=f"
|
530
|
+
mlflow_run_name=f"{self.exp_name}_{1000*int(time.time())}",
|
474
531
|
experiment_name=self.exp_name,
|
475
532
|
)
|
476
533
|
else:
|
@@ -489,7 +546,31 @@ class AgentSimulation:
|
|
489
546
|
else:
|
490
547
|
_num_workers = 1
|
491
548
|
self._pgsql_writers = _workers = [None for _ in range(_num_workers)]
|
492
|
-
|
549
|
+
# message interceptor
|
550
|
+
if message_listener is not None:
|
551
|
+
self._message_abort_listening_queue = _queue = ray.util.queue.Queue() # type: ignore
|
552
|
+
await message_listener.set_queue(_queue)
|
553
|
+
else:
|
554
|
+
self._message_abort_listening_queue = _queue = None
|
555
|
+
_interceptor_blocks = message_interceptor_blocks
|
556
|
+
_black_list = [] if social_black_list is None else social_black_list
|
557
|
+
_llm_config = self.config.get("llm_request", {})
|
558
|
+
if message_interceptor_blocks is not None:
|
559
|
+
_num_interceptors = min(1, message_interceptors)
|
560
|
+
self._message_interceptors = _interceptors = [
|
561
|
+
MessageInterceptor.remote(
|
562
|
+
_interceptor_blocks, # type:ignore
|
563
|
+
_black_list,
|
564
|
+
_llm_config,
|
565
|
+
_queue,
|
566
|
+
)
|
567
|
+
for _ in range(_num_interceptors)
|
568
|
+
]
|
569
|
+
else:
|
570
|
+
_num_interceptors = 1
|
571
|
+
self._message_interceptors = _interceptors = [
|
572
|
+
None for _ in range(_num_interceptors)
|
573
|
+
]
|
493
574
|
creation_tasks = []
|
494
575
|
for i, (
|
495
576
|
agent_class,
|
@@ -504,13 +585,14 @@ class AgentSimulation:
|
|
504
585
|
number_of_agents,
|
505
586
|
memory_config_function_group,
|
506
587
|
self.config,
|
507
|
-
self.exp_id,
|
508
588
|
self.exp_name,
|
589
|
+
self.exp_id,
|
509
590
|
self.enable_avro,
|
510
591
|
self.avro_path,
|
511
592
|
self.enable_pgsql,
|
512
593
|
_workers[i % _num_workers], # type:ignore
|
513
|
-
|
594
|
+
self.message_interceptor,
|
595
|
+
mlflow_run_id,
|
514
596
|
embedding_model,
|
515
597
|
self.logging_level,
|
516
598
|
config_file,
|
@@ -536,16 +618,24 @@ class AgentSimulation:
|
|
536
618
|
self._type2group[agent_type].append(group)
|
537
619
|
|
538
620
|
# 并行初始化所有组的agents
|
621
|
+
await self.resume_simulator()
|
539
622
|
init_tasks = []
|
540
623
|
for group in self._groups.values():
|
541
624
|
init_tasks.append(group.init_agents.remote())
|
542
625
|
ray.get(init_tasks)
|
626
|
+
await self.messager.connect.remote() # type:ignore
|
627
|
+
await self.messager.subscribe.remote( # type:ignore
|
628
|
+
[(f"exps/{self.exp_id}/user_payback", 1)], [self.exp_id]
|
629
|
+
)
|
630
|
+
await self.messager.start_listening.remote() # type:ignore
|
543
631
|
|
544
|
-
async def gather(
|
632
|
+
async def gather(
|
633
|
+
self, content: str, target_agent_uuids: Optional[list[str]] = None
|
634
|
+
):
|
545
635
|
"""收集智能体的特定信息"""
|
546
636
|
gather_tasks = []
|
547
637
|
for group in self._groups.values():
|
548
|
-
gather_tasks.append(group.gather.remote(content))
|
638
|
+
gather_tasks.append(group.gather.remote(content, target_agent_uuids))
|
549
639
|
return await asyncio.gather(*gather_tasks)
|
550
640
|
|
551
641
|
async def filter(
|
@@ -585,13 +675,13 @@ class AgentSimulation:
|
|
585
675
|
self, survey: Survey, agent_uuids: Optional[list[str]] = None
|
586
676
|
):
|
587
677
|
"""发送问卷"""
|
588
|
-
await self.messager.connect()
|
678
|
+
await self.messager.connect.remote() # type:ignore
|
589
679
|
survey_dict = survey.to_dict()
|
590
680
|
if agent_uuids is None:
|
591
681
|
agent_uuids = self._agent_uuids
|
592
682
|
_date_time = datetime.now(timezone.utc)
|
593
683
|
payload = {
|
594
|
-
"from":
|
684
|
+
"from": SURVEY_SENDER_UUID,
|
595
685
|
"survey_id": survey_dict["id"],
|
596
686
|
"timestamp": int(_date_time.timestamp() * 1000),
|
597
687
|
"data": survey_dict,
|
@@ -599,16 +689,23 @@ class AgentSimulation:
|
|
599
689
|
}
|
600
690
|
for uuid in agent_uuids:
|
601
691
|
topic = self._user_survey_topics[uuid]
|
602
|
-
await self.messager.send_message(topic, payload)
|
692
|
+
await self.messager.send_message.remote(topic, payload) # type:ignore
|
693
|
+
remain_payback = len(agent_uuids)
|
694
|
+
while True:
|
695
|
+
messages = await self.messager.fetch_messages.remote() # type:ignore
|
696
|
+
logger.info(f"Received {len(messages)} payback messages [survey]")
|
697
|
+
remain_payback -= len(messages)
|
698
|
+
if remain_payback <= 0:
|
699
|
+
break
|
700
|
+
await asyncio.sleep(3)
|
603
701
|
|
604
702
|
async def send_interview_message(
|
605
703
|
self, content: str, agent_uuids: Union[str, list[str]]
|
606
704
|
):
|
607
705
|
"""发送采访消息"""
|
608
|
-
await self.messager.connect()
|
609
706
|
_date_time = datetime.now(timezone.utc)
|
610
707
|
payload = {
|
611
|
-
"from":
|
708
|
+
"from": SURVEY_SENDER_UUID,
|
612
709
|
"content": content,
|
613
710
|
"timestamp": int(_date_time.timestamp() * 1000),
|
614
711
|
"_date_time": _date_time,
|
@@ -617,24 +714,72 @@ class AgentSimulation:
|
|
617
714
|
agent_uuids = [agent_uuids]
|
618
715
|
for uuid in agent_uuids:
|
619
716
|
topic = self._user_chat_topics[uuid]
|
620
|
-
await self.messager.send_message(topic, payload)
|
717
|
+
await self.messager.send_message.remote(topic, payload) # type:ignore
|
718
|
+
remain_payback = len(agent_uuids)
|
719
|
+
while True:
|
720
|
+
messages = await self.messager.fetch_messages.remote() # type:ignore
|
721
|
+
logger.info(f"Received {len(messages)} payback messages [interview]")
|
722
|
+
remain_payback -= len(messages)
|
723
|
+
if remain_payback <= 0:
|
724
|
+
break
|
725
|
+
await asyncio.sleep(3)
|
726
|
+
|
727
|
+
async def extract_metric(self, metric_extractors: list[Callable]):
|
728
|
+
"""提取指标"""
|
729
|
+
for metric_extractor in metric_extractors:
|
730
|
+
await metric_extractor(self)
|
621
731
|
|
622
732
|
async def step(self):
|
623
|
-
"""
|
733
|
+
"""Run one step, each agent execute one forward"""
|
624
734
|
try:
|
735
|
+
# check whether insert agents
|
736
|
+
simulator_day = await self._simulator.get_simulator_day()
|
737
|
+
print(
|
738
|
+
f"simulator_day: {simulator_day}, self._simulator_day: {self._simulator_day}"
|
739
|
+
)
|
740
|
+
need_insert_agents = False
|
741
|
+
if simulator_day > self._simulator_day:
|
742
|
+
need_insert_agents = True
|
743
|
+
self._simulator_day = simulator_day
|
744
|
+
if need_insert_agents:
|
745
|
+
await self.resume_simulator()
|
746
|
+
insert_tasks = []
|
747
|
+
for group in self._groups.values():
|
748
|
+
insert_tasks.append(group.insert_agents.remote())
|
749
|
+
await asyncio.gather(*insert_tasks)
|
750
|
+
|
751
|
+
# step
|
625
752
|
tasks = []
|
626
753
|
for group in self._groups.values():
|
627
754
|
tasks.append(group.step.remote())
|
628
755
|
await asyncio.gather(*tasks)
|
756
|
+
# save
|
757
|
+
simulator_day = await self._simulator.get_simulator_day()
|
758
|
+
simulator_time = int(await self._simulator.get_time())
|
759
|
+
save_tasks = []
|
760
|
+
for group in self._groups.values():
|
761
|
+
save_tasks.append(group.save.remote(simulator_day, simulator_time))
|
762
|
+
await asyncio.gather(*save_tasks)
|
763
|
+
self._total_steps += 1
|
764
|
+
if self.metric_extractor is not None:
|
765
|
+
print(f"total_steps: {self._total_steps}, excute metric")
|
766
|
+
to_excute_metric = [
|
767
|
+
metric[1]
|
768
|
+
for metric in self.metric_extractor
|
769
|
+
if self._total_steps % metric[0] == 0
|
770
|
+
]
|
771
|
+
await self.extract_metric(to_excute_metric)
|
629
772
|
except Exception as e:
|
630
|
-
|
631
|
-
|
773
|
+
import traceback
|
774
|
+
|
775
|
+
logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
|
776
|
+
raise RuntimeError(str(e)) from e
|
632
777
|
|
633
778
|
async def run(
|
634
779
|
self,
|
635
780
|
day: int = 1,
|
636
781
|
):
|
637
|
-
"""
|
782
|
+
"""Run the simulation by days"""
|
638
783
|
try:
|
639
784
|
self._exp_info["num_day"] += day
|
640
785
|
await self._update_exp_status(1) # 更新状态为运行中
|
@@ -645,13 +790,14 @@ class AgentSimulation:
|
|
645
790
|
monitor_task = asyncio.create_task(self._monitor_exp_status(stop_event))
|
646
791
|
|
647
792
|
try:
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
793
|
+
end_time = (
|
794
|
+
await self._simulator.get_time() + day * 24 * 3600
|
795
|
+
) # type:ignore
|
796
|
+
while True:
|
797
|
+
current_time = await self._simulator.get_time()
|
798
|
+
if current_time >= end_time: # type:ignore
|
799
|
+
break
|
800
|
+
await self.step()
|
655
801
|
finally:
|
656
802
|
# 设置停止事件
|
657
803
|
stop_event.set()
|
@@ -71,11 +71,11 @@ class PgWriter:
|
|
71
71
|
|
72
72
|
@lock_decorator
|
73
73
|
async def async_write_status(self, rows: list[tuple]):
|
74
|
-
_tuple_types = [str, int, float, float, float, int, str, str, None]
|
74
|
+
_tuple_types = [str, int, float, float, float, int, list, str, str, None]
|
75
75
|
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_status"
|
76
76
|
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
77
77
|
copy_sql = psycopg.sql.SQL(
|
78
|
-
"COPY {} (id, day, t, lng, lat, parent_id, action, status, created_at) FROM STDIN"
|
78
|
+
"COPY {} (id, day, t, lng, lat, parent_id, friend_ids, action, status, created_at) FROM STDIN"
|
79
79
|
).format(psycopg.sql.Identifier(table_name))
|
80
80
|
_rows: list[Any] = []
|
81
81
|
async with aconn.cursor() as cur:
|
pycityagent/tools/tool.py
CHANGED
@@ -119,7 +119,7 @@ class SencePOI(Tool):
|
|
119
119
|
if agent.memory is None or agent.simulator is None:
|
120
120
|
raise ValueError("Memory or Simulator is not set.")
|
121
121
|
if radius is None and category_prefix is None:
|
122
|
-
position = await agent.
|
122
|
+
position = await agent.status.get("position")
|
123
123
|
resp = []
|
124
124
|
for prefix in self.category_prefix:
|
125
125
|
resp += agent.simulator.map.query_pois(
|
@@ -146,17 +146,15 @@ class UpdateWithSimulator(Tool):
|
|
146
146
|
agent = self.agent
|
147
147
|
if agent._simulator is None:
|
148
148
|
return
|
149
|
-
if not agent._has_bound_to_simulator:
|
150
|
-
await agent._bind_to_simulator() # type: ignore
|
151
149
|
simulator = agent.simulator
|
152
|
-
|
153
|
-
person_id = await
|
150
|
+
status = agent.status
|
151
|
+
person_id = await status.get("id")
|
154
152
|
resp = await simulator.get_person(person_id)
|
155
153
|
resp_dict = resp["person"]
|
156
154
|
for k, v in resp_dict.get("motion", {}).items():
|
157
155
|
try:
|
158
|
-
await
|
159
|
-
await
|
156
|
+
await status.get(k)
|
157
|
+
await status.update(
|
160
158
|
k, v, mode="replace", protect_llm_read_only_fields=False
|
161
159
|
)
|
162
160
|
except KeyError as e:
|
@@ -183,9 +181,9 @@ class ResetAgentPosition(Tool):
|
|
183
181
|
s: Optional[float] = None,
|
184
182
|
):
|
185
183
|
agent = self.agent
|
186
|
-
|
184
|
+
status = agent.status
|
187
185
|
await agent.simulator.reset_person_position(
|
188
|
-
person_id=await
|
186
|
+
person_id=await status.get("id"),
|
189
187
|
aoi_id=aoi_id,
|
190
188
|
poi_id=poi_id,
|
191
189
|
lane_id=lane_id,
|
pycityagent/utils/__init__.py
CHANGED
@@ -1,11 +1,16 @@
|
|
1
1
|
from .avro_schema import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA,
|
2
2
|
PROFILE_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA)
|
3
3
|
from .pg_query import PGSQL_DICT, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
4
|
-
from .survey_util import process_survey_for_llm
|
4
|
+
from .survey_util import SURVEY_SENDER_UUID, process_survey_for_llm
|
5
5
|
|
6
6
|
__all__ = [
|
7
|
-
"PROFILE_SCHEMA",
|
7
|
+
"PROFILE_SCHEMA",
|
8
|
+
"DIALOG_SCHEMA",
|
9
|
+
"STATUS_SCHEMA",
|
10
|
+
"SURVEY_SCHEMA",
|
11
|
+
"INSTITUTION_STATUS_SCHEMA",
|
8
12
|
"process_survey_for_llm",
|
9
13
|
"TO_UPDATE_EXP_INFO_KEYS_AND_TYPES",
|
10
14
|
"PGSQL_DICT",
|
15
|
+
"SURVEY_SENDER_UUID",
|
11
16
|
]
|
pycityagent/utils/pg_query.py
CHANGED
pycityagent/utils/survey_util.py
CHANGED
@@ -8,40 +8,40 @@ Survey Description: {survey_dict['description']}
|
|
8
8
|
Please answer each question in the following format:
|
9
9
|
|
10
10
|
"""
|
11
|
-
|
11
|
+
|
12
12
|
question_count = 1
|
13
|
-
for page in survey_dict[
|
14
|
-
for question in page[
|
13
|
+
for page in survey_dict["pages"]:
|
14
|
+
for question in page["elements"]:
|
15
15
|
prompt += f"Question {question_count}: {question['title']}\n"
|
16
|
-
|
16
|
+
|
17
17
|
# 根据不同类型的问题生成不同的提示
|
18
|
-
if question[
|
19
|
-
prompt += "Options: " + ", ".join(question[
|
18
|
+
if question["type"] == "radiogroup":
|
19
|
+
prompt += "Options: " + ", ".join(question["choices"]) + "\n"
|
20
20
|
prompt += "Please select ONE option\n"
|
21
|
-
|
22
|
-
elif question[
|
23
|
-
prompt += "Options: " + ", ".join(question[
|
21
|
+
|
22
|
+
elif question["type"] == "checkbox":
|
23
|
+
prompt += "Options: " + ", ".join(question["choices"]) + "\n"
|
24
24
|
prompt += "You can select MULTIPLE options\n"
|
25
|
-
|
26
|
-
elif question[
|
25
|
+
|
26
|
+
elif question["type"] == "rating":
|
27
27
|
prompt += f"Rating range: {question.get('min_rating', 1)} - {question.get('max_rating', 5)}\n"
|
28
28
|
prompt += "Please provide a rating within the range\n"
|
29
|
-
|
30
|
-
elif question[
|
31
|
-
prompt += "Rows: " + ", ".join(question[
|
32
|
-
prompt += "Columns: " + ", ".join(question[
|
29
|
+
|
30
|
+
elif question["type"] == "matrix":
|
31
|
+
prompt += "Rows: " + ", ".join(question["rows"]) + "\n"
|
32
|
+
prompt += "Columns: " + ", ".join(question["columns"]) + "\n"
|
33
33
|
prompt += "Please select ONE column option for EACH row\n"
|
34
|
-
|
35
|
-
elif question[
|
34
|
+
|
35
|
+
elif question["type"] == "text":
|
36
36
|
prompt += "Please provide a text response\n"
|
37
|
-
|
38
|
-
elif question[
|
37
|
+
|
38
|
+
elif question["type"] == "boolean":
|
39
39
|
prompt += "Options: Yes, No\n"
|
40
40
|
prompt += "Please select either Yes or No\n"
|
41
|
-
|
41
|
+
|
42
42
|
prompt += "\nAnswer: [Your response here]\n\n---\n\n"
|
43
43
|
question_count += 1
|
44
|
-
|
44
|
+
|
45
45
|
# 添加总结提示
|
46
46
|
prompt += """Please ensure:
|
47
47
|
1. All required questions are answered
|
@@ -49,5 +49,8 @@ Please answer each question in the following format:
|
|
49
49
|
3. Answers are clear and specific
|
50
50
|
|
51
51
|
Format your responses exactly as requested above."""
|
52
|
-
|
53
|
-
return prompt
|
52
|
+
|
53
|
+
return prompt
|
54
|
+
|
55
|
+
|
56
|
+
SURVEY_SENDER_UUID = "none"
|
pycityagent/workflow/block.py
CHANGED
@@ -146,6 +146,7 @@ def trigger_class():
|
|
146
146
|
class Block:
|
147
147
|
configurable_fields: list[str] = []
|
148
148
|
default_values: dict[str, Any] = {}
|
149
|
+
fields_description: dict[str, str] = {}
|
149
150
|
|
150
151
|
def __init__(
|
151
152
|
self,
|
@@ -173,14 +174,20 @@ class Block:
|
|
173
174
|
|
174
175
|
@classmethod
|
175
176
|
def export_class_config(cls) -> dict[str, str]:
|
176
|
-
return
|
177
|
-
|
178
|
-
|
179
|
-
|
177
|
+
return (
|
178
|
+
{
|
179
|
+
field: cls.default_values.get(field, "default_value")
|
180
|
+
for field in cls.configurable_fields
|
181
|
+
},
|
182
|
+
{
|
183
|
+
field: cls.fields_description.get(field, "")
|
184
|
+
for field in cls.configurable_fields
|
185
|
+
}
|
186
|
+
)
|
180
187
|
|
181
188
|
@classmethod
|
182
189
|
def import_config(cls, config: dict[str, Union[str, dict]]) -> Block:
|
183
|
-
instance = cls(name=config["name"])
|
190
|
+
instance = cls(name=config["name"]) # type: ignore
|
184
191
|
assert isinstance(config["config"], dict)
|
185
192
|
for field, value in config["config"].items():
|
186
193
|
if field in cls.configurable_fields:
|
@@ -188,12 +195,12 @@ class Block:
|
|
188
195
|
|
189
196
|
# 递归创建子Block
|
190
197
|
for child_config in config.get("children", []):
|
191
|
-
child_block = Block.import_config(child_config)
|
198
|
+
child_block = Block.import_config(child_config) # type: ignore
|
192
199
|
setattr(instance, child_block.name.lower(), child_block)
|
193
200
|
|
194
201
|
return instance
|
195
202
|
|
196
|
-
def load_from_config(self, config: dict[str, list[
|
203
|
+
def load_from_config(self, config: dict[str, list[dict]]) -> None:
|
197
204
|
"""
|
198
205
|
使用配置更新当前Block实例的参数,并递归更新子Block。
|
199
206
|
"""
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: pycityagent
|
3
|
-
Version: 2.0.
|
3
|
+
Version: 2.0.0a54
|
4
4
|
Summary: LLM-based city environment agent building library
|
5
5
|
Author-email: Yuwei Yan <pinkgranite86@gmail.com>, Junbo Yan <yanjb20thu@gmali.com>, Jun Zhang <zhangjun990222@gmali.com>
|
6
6
|
License: MIT License
|
@@ -45,7 +45,7 @@ Requires-Dist: openai>=1.58.1
|
|
45
45
|
Requires-Dist: Pillow<12.0.0,>=11.0.0
|
46
46
|
Requires-Dist: protobuf<5.0.0,<=4.24.0
|
47
47
|
Requires-Dist: pycitydata>=1.0.3
|
48
|
-
Requires-Dist: pycityproto>=2.
|
48
|
+
Requires-Dist: pycityproto>=2.2.0
|
49
49
|
Requires-Dist: requests>=2.32.3
|
50
50
|
Requires-Dist: Shapely>=2.0.6
|
51
51
|
Requires-Dist: PyYAML>=6.0.2
|