pycityagent 2.0.0a52__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a54__cp311-cp311-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- pycityagent/agent/agent.py +83 -62
- pycityagent/agent/agent_base.py +81 -54
- pycityagent/cityagent/bankagent.py +5 -7
- pycityagent/cityagent/blocks/__init__.py +0 -2
- pycityagent/cityagent/blocks/cognition_block.py +149 -172
- pycityagent/cityagent/blocks/economy_block.py +90 -129
- pycityagent/cityagent/blocks/mobility_block.py +56 -29
- pycityagent/cityagent/blocks/needs_block.py +163 -145
- pycityagent/cityagent/blocks/other_block.py +17 -9
- pycityagent/cityagent/blocks/plan_block.py +45 -57
- pycityagent/cityagent/blocks/social_block.py +70 -51
- pycityagent/cityagent/blocks/utils.py +2 -0
- pycityagent/cityagent/firmagent.py +6 -7
- pycityagent/cityagent/governmentagent.py +7 -9
- pycityagent/cityagent/memory_config.py +48 -48
- pycityagent/cityagent/message_intercept.py +99 -0
- pycityagent/cityagent/nbsagent.py +6 -29
- pycityagent/cityagent/societyagent.py +325 -127
- pycityagent/cli/wrapper.py +4 -0
- pycityagent/economy/econ_client.py +0 -2
- pycityagent/environment/__init__.py +7 -1
- pycityagent/environment/sim/client.py +10 -1
- pycityagent/environment/sim/clock_service.py +2 -2
- pycityagent/environment/sim/pause_service.py +61 -0
- pycityagent/environment/sim/sim_env.py +34 -46
- pycityagent/environment/simulator.py +18 -14
- pycityagent/llm/embeddings.py +0 -24
- pycityagent/llm/llm.py +18 -10
- pycityagent/memory/faiss_query.py +29 -26
- pycityagent/memory/memory.py +733 -247
- pycityagent/message/__init__.py +8 -1
- pycityagent/message/message_interceptor.py +322 -0
- pycityagent/message/messager.py +42 -11
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +137 -96
- pycityagent/simulation/simulation.py +184 -38
- pycityagent/simulation/storage/pg.py +2 -2
- pycityagent/tools/tool.py +7 -9
- pycityagent/utils/__init__.py +7 -2
- pycityagent/utils/pg_query.py +1 -0
- pycityagent/utils/survey_util.py +26 -23
- pycityagent/workflow/block.py +14 -7
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/METADATA +2 -2
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/RECORD +48 -46
- pycityagent/cityagent/blocks/time_block.py +0 -116
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/top_level.txt +0 -0
@@ -18,13 +18,15 @@ from ..cityagent import (BankAgent, FirmAgent, GovernmentAgent, NBSAgent,
|
|
18
18
|
memory_config_government, memory_config_nbs,
|
19
19
|
memory_config_societyagent)
|
20
20
|
from ..cityagent.initial import bind_agent_info, initialize_social_network
|
21
|
-
from ..environment
|
22
|
-
from ..llm import SimpleEmbedding
|
21
|
+
from ..environment import Simulator
|
22
|
+
from ..llm import LLM, LLMConfig, SimpleEmbedding
|
23
23
|
from ..memory import Memory
|
24
|
-
from ..message
|
24
|
+
from ..message import (MessageBlockBase, MessageBlockListenerBase,
|
25
|
+
MessageInterceptor, Messager)
|
25
26
|
from ..metrics import init_mlflow_connection
|
27
|
+
from ..metrics.mlflow_client import MlflowClient
|
26
28
|
from ..survey import Survey
|
27
|
-
from ..utils import TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
29
|
+
from ..utils import SURVEY_SENDER_UUID, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
28
30
|
from .agentgroup import AgentGroup
|
29
31
|
from .storage.pg import PgWriter, create_pg_tables
|
30
32
|
|
@@ -39,6 +41,7 @@ class AgentSimulation:
|
|
39
41
|
config: dict,
|
40
42
|
agent_class: Union[None, type[Agent], list[type[Agent]]] = None,
|
41
43
|
agent_config_file: Optional[dict] = None,
|
44
|
+
metric_extractor: Optional[list[tuple[int, Callable]]] = None,
|
42
45
|
enable_economy: bool = True,
|
43
46
|
agent_prefix: str = "agent_",
|
44
47
|
exp_name: str = "default_experiment",
|
@@ -80,6 +83,15 @@ class AgentSimulation:
|
|
80
83
|
self.config = config
|
81
84
|
self.exp_name = exp_name
|
82
85
|
self._simulator = Simulator(config["simulator_request"])
|
86
|
+
if enable_economy:
|
87
|
+
self._economy_env = self._simulator._sim_env
|
88
|
+
_req_dict = self.config["simulator_request"]
|
89
|
+
if "economy" in _req_dict:
|
90
|
+
_req_dict["economy"]["server"] = self._economy_env.sim_addr
|
91
|
+
else:
|
92
|
+
_req_dict["economy"] = {
|
93
|
+
"server": self._economy_env.sim_addr,
|
94
|
+
}
|
83
95
|
self.agent_prefix = agent_prefix
|
84
96
|
self._groups: dict[str, AgentGroup] = {} # type:ignore
|
85
97
|
self._agent_uuid2group: dict[str, AgentGroup] = {} # type:ignore
|
@@ -89,6 +101,8 @@ class AgentSimulation:
|
|
89
101
|
self._user_survey_topics: dict[str, str] = {}
|
90
102
|
self._user_interview_topics: dict[str, str] = {}
|
91
103
|
self._loop = asyncio.get_event_loop()
|
104
|
+
self._total_steps = 0
|
105
|
+
self._simulator_day = 0
|
92
106
|
# self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
93
107
|
|
94
108
|
self._messager = Messager.remote(
|
@@ -102,6 +116,7 @@ class AgentSimulation:
|
|
102
116
|
_storage_config: dict[str, Any] = config.get("storage", {})
|
103
117
|
if _storage_config is None:
|
104
118
|
_storage_config = {}
|
119
|
+
|
105
120
|
# avro
|
106
121
|
_avro_config: dict[str, Any] = _storage_config.get("avro", {})
|
107
122
|
self._enable_avro = _avro_config.get("enabled", False)
|
@@ -112,6 +127,27 @@ class AgentSimulation:
|
|
112
127
|
self._avro_path = Path(_avro_config["path"]) / f"{self.exp_id}"
|
113
128
|
self._avro_path.mkdir(parents=True, exist_ok=True)
|
114
129
|
|
130
|
+
# mlflow
|
131
|
+
_mlflow_config: dict[str, Any] = config.get("metric_request", {}).get("mlflow")
|
132
|
+
mlflow_run_id, _ = init_mlflow_connection(
|
133
|
+
config=_mlflow_config,
|
134
|
+
mlflow_run_name=f"EXP_{self.exp_name}_{1000*int(time.time())}",
|
135
|
+
experiment_name=self.exp_name,
|
136
|
+
)
|
137
|
+
if _mlflow_config:
|
138
|
+
logger.info(f"-----Creating Mlflow client...")
|
139
|
+
self.mlflow_client = MlflowClient(
|
140
|
+
config=_mlflow_config,
|
141
|
+
mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
|
142
|
+
experiment_name=exp_name,
|
143
|
+
run_id=mlflow_run_id,
|
144
|
+
)
|
145
|
+
self.metric_extractor = metric_extractor
|
146
|
+
else:
|
147
|
+
logger.warning("Mlflow is not enabled, NO MLFLOW STORAGE")
|
148
|
+
self.mlflow_client = None
|
149
|
+
self.metric_extractor = None
|
150
|
+
|
115
151
|
# pg
|
116
152
|
_pgsql_config: dict[str, Any] = _storage_config.get("pgsql", {})
|
117
153
|
self._enable_pgsql = _pgsql_config.get("enabled", False)
|
@@ -167,11 +203,11 @@ class AgentSimulation:
|
|
167
203
|
- workflow:
|
168
204
|
- list[Step]
|
169
205
|
- Step:
|
170
|
-
- type: str, "step", "run", "interview", "survey", "intervene"
|
206
|
+
- type: str, "step", "run", "interview", "survey", "intervene", "pause", "resume"
|
171
207
|
- day: int if type is "run", else None
|
172
|
-
-
|
208
|
+
- times: int if type is "step", else None
|
173
209
|
- description: Optional[str], description of the step
|
174
|
-
-
|
210
|
+
- func: Optional[Callable[AgentSimulation, None]], only used when type is "interview", "survey" and "intervene"
|
175
211
|
- logging_level: Optional[int]
|
176
212
|
- exp_name: Optional[str]
|
177
213
|
"""
|
@@ -201,6 +237,7 @@ class AgentSimulation:
|
|
201
237
|
agent_count.append(config["agent_config"]["number_of_government"])
|
202
238
|
agent_count.append(config["agent_config"]["number_of_bank"])
|
203
239
|
agent_count.append(config["agent_config"]["number_of_nbs"])
|
240
|
+
# TODO(yanjunbo): support MessageInterceptor
|
204
241
|
await simulation.init_agents(
|
205
242
|
agent_count=agent_count,
|
206
243
|
group_size=config["agent_config"].get("group_size", 10000),
|
@@ -224,10 +261,15 @@ class AgentSimulation:
|
|
224
261
|
if step["type"] == "run":
|
225
262
|
await simulation.run(step.get("day", 1))
|
226
263
|
elif step["type"] == "step":
|
227
|
-
|
228
|
-
|
264
|
+
times = step.get("times", 1)
|
265
|
+
for _ in range(times):
|
266
|
+
await simulation.step()
|
267
|
+
elif step["type"] == "pause":
|
268
|
+
await simulation.pause_simulator()
|
269
|
+
elif step["type"] == "resume":
|
270
|
+
await simulation.resume_simulator()
|
229
271
|
else:
|
230
|
-
await step["
|
272
|
+
await step["func"](simulation)
|
231
273
|
logger.info("Simulation finished")
|
232
274
|
|
233
275
|
@property
|
@@ -261,9 +303,13 @@ class AgentSimulation:
|
|
261
303
|
return self._agent_uuid2group
|
262
304
|
|
263
305
|
@property
|
264
|
-
def messager(self):
|
306
|
+
def messager(self) -> ray.ObjectRef:
|
265
307
|
return self._messager
|
266
308
|
|
309
|
+
@property
|
310
|
+
def message_interceptor(self) -> ray.ObjectRef:
|
311
|
+
return self._message_interceptors[0] # type:ignore
|
312
|
+
|
267
313
|
async def _save_exp_info(self) -> None:
|
268
314
|
"""异步保存实验信息到YAML文件"""
|
269
315
|
try:
|
@@ -331,11 +377,21 @@ class AgentSimulation:
|
|
331
377
|
# 如果没有发生异常且状态不是错误,则更新为完成
|
332
378
|
await self._update_exp_status(2)
|
333
379
|
|
380
|
+
async def pause_simulator(self):
|
381
|
+
await self._simulator.pause()
|
382
|
+
|
383
|
+
async def resume_simulator(self):
|
384
|
+
await self._simulator.resume()
|
385
|
+
|
334
386
|
async def init_agents(
|
335
387
|
self,
|
336
388
|
agent_count: Union[int, list[int]],
|
337
389
|
group_size: int = 10000,
|
338
390
|
pg_sql_writers: int = 32,
|
391
|
+
message_interceptors: int = 1,
|
392
|
+
message_interceptor_blocks: Optional[list[MessageBlockBase]] = None,
|
393
|
+
social_black_list: Optional[list[tuple[str, str]]] = None,
|
394
|
+
message_listener: Optional[MessageBlockListenerBase] = None,
|
339
395
|
embedding_model: Embeddings = SimpleEmbedding(),
|
340
396
|
memory_config_func: Optional[Union[Callable, list[Callable]]] = None,
|
341
397
|
) -> None:
|
@@ -344,6 +400,8 @@ class AgentSimulation:
|
|
344
400
|
Args:
|
345
401
|
agent_count: 要创建的总智能体数量, 如果为列表,则每个元素表示一个智能体类创建的智能体数量
|
346
402
|
group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
|
403
|
+
pg_sql_writers: 独立的PgSQL writer数量
|
404
|
+
message_interceptors: message拦截器数量
|
347
405
|
memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组, 如果为列表,则每个元素表示一个智能体类创建的Memory配置函数
|
348
406
|
"""
|
349
407
|
if not isinstance(agent_count, list):
|
@@ -353,12 +411,12 @@ class AgentSimulation:
|
|
353
411
|
raise ValueError("agent_class和agent_count的长度不一致")
|
354
412
|
|
355
413
|
if memory_config_func is None:
|
356
|
-
logger.warning(
|
357
|
-
"memory_config_func is None, using default memory config function"
|
358
|
-
)
|
359
414
|
memory_config_func = self.default_memory_config_func
|
360
415
|
|
361
416
|
elif not isinstance(memory_config_func, list):
|
417
|
+
logger.warning(
|
418
|
+
"memory_config_func is not a list, using specific memory config function"
|
419
|
+
)
|
362
420
|
memory_config_func = [memory_config_func]
|
363
421
|
|
364
422
|
if len(memory_config_func) != len(agent_count):
|
@@ -464,13 +522,12 @@ class AgentSimulation:
|
|
464
522
|
config_files,
|
465
523
|
)
|
466
524
|
)
|
467
|
-
|
468
525
|
# 初始化mlflow连接
|
469
526
|
_mlflow_config = self.config.get("metric_request", {}).get("mlflow")
|
470
527
|
if _mlflow_config:
|
471
528
|
mlflow_run_id, _ = init_mlflow_connection(
|
472
529
|
config=_mlflow_config,
|
473
|
-
mlflow_run_name=f"
|
530
|
+
mlflow_run_name=f"{self.exp_name}_{1000*int(time.time())}",
|
474
531
|
experiment_name=self.exp_name,
|
475
532
|
)
|
476
533
|
else:
|
@@ -489,7 +546,31 @@ class AgentSimulation:
|
|
489
546
|
else:
|
490
547
|
_num_workers = 1
|
491
548
|
self._pgsql_writers = _workers = [None for _ in range(_num_workers)]
|
492
|
-
|
549
|
+
# message interceptor
|
550
|
+
if message_listener is not None:
|
551
|
+
self._message_abort_listening_queue = _queue = ray.util.queue.Queue() # type: ignore
|
552
|
+
await message_listener.set_queue(_queue)
|
553
|
+
else:
|
554
|
+
self._message_abort_listening_queue = _queue = None
|
555
|
+
_interceptor_blocks = message_interceptor_blocks
|
556
|
+
_black_list = [] if social_black_list is None else social_black_list
|
557
|
+
_llm_config = self.config.get("llm_request", {})
|
558
|
+
if message_interceptor_blocks is not None:
|
559
|
+
_num_interceptors = min(1, message_interceptors)
|
560
|
+
self._message_interceptors = _interceptors = [
|
561
|
+
MessageInterceptor.remote(
|
562
|
+
_interceptor_blocks, # type:ignore
|
563
|
+
_black_list,
|
564
|
+
_llm_config,
|
565
|
+
_queue,
|
566
|
+
)
|
567
|
+
for _ in range(_num_interceptors)
|
568
|
+
]
|
569
|
+
else:
|
570
|
+
_num_interceptors = 1
|
571
|
+
self._message_interceptors = _interceptors = [
|
572
|
+
None for _ in range(_num_interceptors)
|
573
|
+
]
|
493
574
|
creation_tasks = []
|
494
575
|
for i, (
|
495
576
|
agent_class,
|
@@ -504,13 +585,14 @@ class AgentSimulation:
|
|
504
585
|
number_of_agents,
|
505
586
|
memory_config_function_group,
|
506
587
|
self.config,
|
507
|
-
self.exp_id,
|
508
588
|
self.exp_name,
|
589
|
+
self.exp_id,
|
509
590
|
self.enable_avro,
|
510
591
|
self.avro_path,
|
511
592
|
self.enable_pgsql,
|
512
593
|
_workers[i % _num_workers], # type:ignore
|
513
|
-
|
594
|
+
self.message_interceptor,
|
595
|
+
mlflow_run_id,
|
514
596
|
embedding_model,
|
515
597
|
self.logging_level,
|
516
598
|
config_file,
|
@@ -536,16 +618,24 @@ class AgentSimulation:
|
|
536
618
|
self._type2group[agent_type].append(group)
|
537
619
|
|
538
620
|
# 并行初始化所有组的agents
|
621
|
+
await self.resume_simulator()
|
539
622
|
init_tasks = []
|
540
623
|
for group in self._groups.values():
|
541
624
|
init_tasks.append(group.init_agents.remote())
|
542
625
|
ray.get(init_tasks)
|
626
|
+
await self.messager.connect.remote() # type:ignore
|
627
|
+
await self.messager.subscribe.remote( # type:ignore
|
628
|
+
[(f"exps/{self.exp_id}/user_payback", 1)], [self.exp_id]
|
629
|
+
)
|
630
|
+
await self.messager.start_listening.remote() # type:ignore
|
543
631
|
|
544
|
-
async def gather(
|
632
|
+
async def gather(
|
633
|
+
self, content: str, target_agent_uuids: Optional[list[str]] = None
|
634
|
+
):
|
545
635
|
"""收集智能体的特定信息"""
|
546
636
|
gather_tasks = []
|
547
637
|
for group in self._groups.values():
|
548
|
-
gather_tasks.append(group.gather.remote(content))
|
638
|
+
gather_tasks.append(group.gather.remote(content, target_agent_uuids))
|
549
639
|
return await asyncio.gather(*gather_tasks)
|
550
640
|
|
551
641
|
async def filter(
|
@@ -585,13 +675,13 @@ class AgentSimulation:
|
|
585
675
|
self, survey: Survey, agent_uuids: Optional[list[str]] = None
|
586
676
|
):
|
587
677
|
"""发送问卷"""
|
588
|
-
await self.messager.connect()
|
678
|
+
await self.messager.connect.remote() # type:ignore
|
589
679
|
survey_dict = survey.to_dict()
|
590
680
|
if agent_uuids is None:
|
591
681
|
agent_uuids = self._agent_uuids
|
592
682
|
_date_time = datetime.now(timezone.utc)
|
593
683
|
payload = {
|
594
|
-
"from":
|
684
|
+
"from": SURVEY_SENDER_UUID,
|
595
685
|
"survey_id": survey_dict["id"],
|
596
686
|
"timestamp": int(_date_time.timestamp() * 1000),
|
597
687
|
"data": survey_dict,
|
@@ -599,16 +689,23 @@ class AgentSimulation:
|
|
599
689
|
}
|
600
690
|
for uuid in agent_uuids:
|
601
691
|
topic = self._user_survey_topics[uuid]
|
602
|
-
await self.messager.send_message(topic, payload)
|
692
|
+
await self.messager.send_message.remote(topic, payload) # type:ignore
|
693
|
+
remain_payback = len(agent_uuids)
|
694
|
+
while True:
|
695
|
+
messages = await self.messager.fetch_messages.remote() # type:ignore
|
696
|
+
logger.info(f"Received {len(messages)} payback messages [survey]")
|
697
|
+
remain_payback -= len(messages)
|
698
|
+
if remain_payback <= 0:
|
699
|
+
break
|
700
|
+
await asyncio.sleep(3)
|
603
701
|
|
604
702
|
async def send_interview_message(
|
605
703
|
self, content: str, agent_uuids: Union[str, list[str]]
|
606
704
|
):
|
607
705
|
"""发送采访消息"""
|
608
|
-
await self.messager.connect()
|
609
706
|
_date_time = datetime.now(timezone.utc)
|
610
707
|
payload = {
|
611
|
-
"from":
|
708
|
+
"from": SURVEY_SENDER_UUID,
|
612
709
|
"content": content,
|
613
710
|
"timestamp": int(_date_time.timestamp() * 1000),
|
614
711
|
"_date_time": _date_time,
|
@@ -617,24 +714,72 @@ class AgentSimulation:
|
|
617
714
|
agent_uuids = [agent_uuids]
|
618
715
|
for uuid in agent_uuids:
|
619
716
|
topic = self._user_chat_topics[uuid]
|
620
|
-
await self.messager.send_message(topic, payload)
|
717
|
+
await self.messager.send_message.remote(topic, payload) # type:ignore
|
718
|
+
remain_payback = len(agent_uuids)
|
719
|
+
while True:
|
720
|
+
messages = await self.messager.fetch_messages.remote() # type:ignore
|
721
|
+
logger.info(f"Received {len(messages)} payback messages [interview]")
|
722
|
+
remain_payback -= len(messages)
|
723
|
+
if remain_payback <= 0:
|
724
|
+
break
|
725
|
+
await asyncio.sleep(3)
|
726
|
+
|
727
|
+
async def extract_metric(self, metric_extractors: list[Callable]):
|
728
|
+
"""提取指标"""
|
729
|
+
for metric_extractor in metric_extractors:
|
730
|
+
await metric_extractor(self)
|
621
731
|
|
622
732
|
async def step(self):
|
623
|
-
"""
|
733
|
+
"""Run one step, each agent execute one forward"""
|
624
734
|
try:
|
735
|
+
# check whether insert agents
|
736
|
+
simulator_day = await self._simulator.get_simulator_day()
|
737
|
+
print(
|
738
|
+
f"simulator_day: {simulator_day}, self._simulator_day: {self._simulator_day}"
|
739
|
+
)
|
740
|
+
need_insert_agents = False
|
741
|
+
if simulator_day > self._simulator_day:
|
742
|
+
need_insert_agents = True
|
743
|
+
self._simulator_day = simulator_day
|
744
|
+
if need_insert_agents:
|
745
|
+
await self.resume_simulator()
|
746
|
+
insert_tasks = []
|
747
|
+
for group in self._groups.values():
|
748
|
+
insert_tasks.append(group.insert_agents.remote())
|
749
|
+
await asyncio.gather(*insert_tasks)
|
750
|
+
|
751
|
+
# step
|
625
752
|
tasks = []
|
626
753
|
for group in self._groups.values():
|
627
754
|
tasks.append(group.step.remote())
|
628
755
|
await asyncio.gather(*tasks)
|
756
|
+
# save
|
757
|
+
simulator_day = await self._simulator.get_simulator_day()
|
758
|
+
simulator_time = int(await self._simulator.get_time())
|
759
|
+
save_tasks = []
|
760
|
+
for group in self._groups.values():
|
761
|
+
save_tasks.append(group.save.remote(simulator_day, simulator_time))
|
762
|
+
await asyncio.gather(*save_tasks)
|
763
|
+
self._total_steps += 1
|
764
|
+
if self.metric_extractor is not None:
|
765
|
+
print(f"total_steps: {self._total_steps}, excute metric")
|
766
|
+
to_excute_metric = [
|
767
|
+
metric[1]
|
768
|
+
for metric in self.metric_extractor
|
769
|
+
if self._total_steps % metric[0] == 0
|
770
|
+
]
|
771
|
+
await self.extract_metric(to_excute_metric)
|
629
772
|
except Exception as e:
|
630
|
-
|
631
|
-
|
773
|
+
import traceback
|
774
|
+
|
775
|
+
logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
|
776
|
+
raise RuntimeError(str(e)) from e
|
632
777
|
|
633
778
|
async def run(
|
634
779
|
self,
|
635
780
|
day: int = 1,
|
636
781
|
):
|
637
|
-
"""
|
782
|
+
"""Run the simulation by days"""
|
638
783
|
try:
|
639
784
|
self._exp_info["num_day"] += day
|
640
785
|
await self._update_exp_status(1) # 更新状态为运行中
|
@@ -645,13 +790,14 @@ class AgentSimulation:
|
|
645
790
|
monitor_task = asyncio.create_task(self._monitor_exp_status(stop_event))
|
646
791
|
|
647
792
|
try:
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
793
|
+
end_time = (
|
794
|
+
await self._simulator.get_time() + day * 24 * 3600
|
795
|
+
) # type:ignore
|
796
|
+
while True:
|
797
|
+
current_time = await self._simulator.get_time()
|
798
|
+
if current_time >= end_time: # type:ignore
|
799
|
+
break
|
800
|
+
await self.step()
|
655
801
|
finally:
|
656
802
|
# 设置停止事件
|
657
803
|
stop_event.set()
|
@@ -71,11 +71,11 @@ class PgWriter:
|
|
71
71
|
|
72
72
|
@lock_decorator
|
73
73
|
async def async_write_status(self, rows: list[tuple]):
|
74
|
-
_tuple_types = [str, int, float, float, float, int, str, str, None]
|
74
|
+
_tuple_types = [str, int, float, float, float, int, list, str, str, None]
|
75
75
|
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_status"
|
76
76
|
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
77
77
|
copy_sql = psycopg.sql.SQL(
|
78
|
-
"COPY {} (id, day, t, lng, lat, parent_id, action, status, created_at) FROM STDIN"
|
78
|
+
"COPY {} (id, day, t, lng, lat, parent_id, friend_ids, action, status, created_at) FROM STDIN"
|
79
79
|
).format(psycopg.sql.Identifier(table_name))
|
80
80
|
_rows: list[Any] = []
|
81
81
|
async with aconn.cursor() as cur:
|
pycityagent/tools/tool.py
CHANGED
@@ -119,7 +119,7 @@ class SencePOI(Tool):
|
|
119
119
|
if agent.memory is None or agent.simulator is None:
|
120
120
|
raise ValueError("Memory or Simulator is not set.")
|
121
121
|
if radius is None and category_prefix is None:
|
122
|
-
position = await agent.
|
122
|
+
position = await agent.status.get("position")
|
123
123
|
resp = []
|
124
124
|
for prefix in self.category_prefix:
|
125
125
|
resp += agent.simulator.map.query_pois(
|
@@ -146,17 +146,15 @@ class UpdateWithSimulator(Tool):
|
|
146
146
|
agent = self.agent
|
147
147
|
if agent._simulator is None:
|
148
148
|
return
|
149
|
-
if not agent._has_bound_to_simulator:
|
150
|
-
await agent._bind_to_simulator() # type: ignore
|
151
149
|
simulator = agent.simulator
|
152
|
-
|
153
|
-
person_id = await
|
150
|
+
status = agent.status
|
151
|
+
person_id = await status.get("id")
|
154
152
|
resp = await simulator.get_person(person_id)
|
155
153
|
resp_dict = resp["person"]
|
156
154
|
for k, v in resp_dict.get("motion", {}).items():
|
157
155
|
try:
|
158
|
-
await
|
159
|
-
await
|
156
|
+
await status.get(k)
|
157
|
+
await status.update(
|
160
158
|
k, v, mode="replace", protect_llm_read_only_fields=False
|
161
159
|
)
|
162
160
|
except KeyError as e:
|
@@ -183,9 +181,9 @@ class ResetAgentPosition(Tool):
|
|
183
181
|
s: Optional[float] = None,
|
184
182
|
):
|
185
183
|
agent = self.agent
|
186
|
-
|
184
|
+
status = agent.status
|
187
185
|
await agent.simulator.reset_person_position(
|
188
|
-
person_id=await
|
186
|
+
person_id=await status.get("id"),
|
189
187
|
aoi_id=aoi_id,
|
190
188
|
poi_id=poi_id,
|
191
189
|
lane_id=lane_id,
|
pycityagent/utils/__init__.py
CHANGED
@@ -1,11 +1,16 @@
|
|
1
1
|
from .avro_schema import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA,
|
2
2
|
PROFILE_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA)
|
3
3
|
from .pg_query import PGSQL_DICT, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
4
|
-
from .survey_util import process_survey_for_llm
|
4
|
+
from .survey_util import SURVEY_SENDER_UUID, process_survey_for_llm
|
5
5
|
|
6
6
|
__all__ = [
|
7
|
-
"PROFILE_SCHEMA",
|
7
|
+
"PROFILE_SCHEMA",
|
8
|
+
"DIALOG_SCHEMA",
|
9
|
+
"STATUS_SCHEMA",
|
10
|
+
"SURVEY_SCHEMA",
|
11
|
+
"INSTITUTION_STATUS_SCHEMA",
|
8
12
|
"process_survey_for_llm",
|
9
13
|
"TO_UPDATE_EXP_INFO_KEYS_AND_TYPES",
|
10
14
|
"PGSQL_DICT",
|
15
|
+
"SURVEY_SENDER_UUID",
|
11
16
|
]
|
pycityagent/utils/pg_query.py
CHANGED
pycityagent/utils/survey_util.py
CHANGED
@@ -8,40 +8,40 @@ Survey Description: {survey_dict['description']}
|
|
8
8
|
Please answer each question in the following format:
|
9
9
|
|
10
10
|
"""
|
11
|
-
|
11
|
+
|
12
12
|
question_count = 1
|
13
|
-
for page in survey_dict[
|
14
|
-
for question in page[
|
13
|
+
for page in survey_dict["pages"]:
|
14
|
+
for question in page["elements"]:
|
15
15
|
prompt += f"Question {question_count}: {question['title']}\n"
|
16
|
-
|
16
|
+
|
17
17
|
# 根据不同类型的问题生成不同的提示
|
18
|
-
if question[
|
19
|
-
prompt += "Options: " + ", ".join(question[
|
18
|
+
if question["type"] == "radiogroup":
|
19
|
+
prompt += "Options: " + ", ".join(question["choices"]) + "\n"
|
20
20
|
prompt += "Please select ONE option\n"
|
21
|
-
|
22
|
-
elif question[
|
23
|
-
prompt += "Options: " + ", ".join(question[
|
21
|
+
|
22
|
+
elif question["type"] == "checkbox":
|
23
|
+
prompt += "Options: " + ", ".join(question["choices"]) + "\n"
|
24
24
|
prompt += "You can select MULTIPLE options\n"
|
25
|
-
|
26
|
-
elif question[
|
25
|
+
|
26
|
+
elif question["type"] == "rating":
|
27
27
|
prompt += f"Rating range: {question.get('min_rating', 1)} - {question.get('max_rating', 5)}\n"
|
28
28
|
prompt += "Please provide a rating within the range\n"
|
29
|
-
|
30
|
-
elif question[
|
31
|
-
prompt += "Rows: " + ", ".join(question[
|
32
|
-
prompt += "Columns: " + ", ".join(question[
|
29
|
+
|
30
|
+
elif question["type"] == "matrix":
|
31
|
+
prompt += "Rows: " + ", ".join(question["rows"]) + "\n"
|
32
|
+
prompt += "Columns: " + ", ".join(question["columns"]) + "\n"
|
33
33
|
prompt += "Please select ONE column option for EACH row\n"
|
34
|
-
|
35
|
-
elif question[
|
34
|
+
|
35
|
+
elif question["type"] == "text":
|
36
36
|
prompt += "Please provide a text response\n"
|
37
|
-
|
38
|
-
elif question[
|
37
|
+
|
38
|
+
elif question["type"] == "boolean":
|
39
39
|
prompt += "Options: Yes, No\n"
|
40
40
|
prompt += "Please select either Yes or No\n"
|
41
|
-
|
41
|
+
|
42
42
|
prompt += "\nAnswer: [Your response here]\n\n---\n\n"
|
43
43
|
question_count += 1
|
44
|
-
|
44
|
+
|
45
45
|
# 添加总结提示
|
46
46
|
prompt += """Please ensure:
|
47
47
|
1. All required questions are answered
|
@@ -49,5 +49,8 @@ Please answer each question in the following format:
|
|
49
49
|
3. Answers are clear and specific
|
50
50
|
|
51
51
|
Format your responses exactly as requested above."""
|
52
|
-
|
53
|
-
return prompt
|
52
|
+
|
53
|
+
return prompt
|
54
|
+
|
55
|
+
|
56
|
+
SURVEY_SENDER_UUID = "none"
|
pycityagent/workflow/block.py
CHANGED
@@ -146,6 +146,7 @@ def trigger_class():
|
|
146
146
|
class Block:
|
147
147
|
configurable_fields: list[str] = []
|
148
148
|
default_values: dict[str, Any] = {}
|
149
|
+
fields_description: dict[str, str] = {}
|
149
150
|
|
150
151
|
def __init__(
|
151
152
|
self,
|
@@ -173,14 +174,20 @@ class Block:
|
|
173
174
|
|
174
175
|
@classmethod
|
175
176
|
def export_class_config(cls) -> dict[str, str]:
|
176
|
-
return
|
177
|
-
|
178
|
-
|
179
|
-
|
177
|
+
return (
|
178
|
+
{
|
179
|
+
field: cls.default_values.get(field, "default_value")
|
180
|
+
for field in cls.configurable_fields
|
181
|
+
},
|
182
|
+
{
|
183
|
+
field: cls.fields_description.get(field, "")
|
184
|
+
for field in cls.configurable_fields
|
185
|
+
}
|
186
|
+
)
|
180
187
|
|
181
188
|
@classmethod
|
182
189
|
def import_config(cls, config: dict[str, Union[str, dict]]) -> Block:
|
183
|
-
instance = cls(name=config["name"])
|
190
|
+
instance = cls(name=config["name"]) # type: ignore
|
184
191
|
assert isinstance(config["config"], dict)
|
185
192
|
for field, value in config["config"].items():
|
186
193
|
if field in cls.configurable_fields:
|
@@ -188,12 +195,12 @@ class Block:
|
|
188
195
|
|
189
196
|
# 递归创建子Block
|
190
197
|
for child_config in config.get("children", []):
|
191
|
-
child_block = Block.import_config(child_config)
|
198
|
+
child_block = Block.import_config(child_config) # type: ignore
|
192
199
|
setattr(instance, child_block.name.lower(), child_block)
|
193
200
|
|
194
201
|
return instance
|
195
202
|
|
196
|
-
def load_from_config(self, config: dict[str, list[
|
203
|
+
def load_from_config(self, config: dict[str, list[dict]]) -> None:
|
197
204
|
"""
|
198
205
|
使用配置更新当前Block实例的参数,并递归更新子Block。
|
199
206
|
"""
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: pycityagent
|
3
|
-
Version: 2.0.
|
3
|
+
Version: 2.0.0a54
|
4
4
|
Summary: LLM-based city environment agent building library
|
5
5
|
Author-email: Yuwei Yan <pinkgranite86@gmail.com>, Junbo Yan <yanjb20thu@gmali.com>, Jun Zhang <zhangjun990222@gmali.com>
|
6
6
|
License: MIT License
|
@@ -45,7 +45,7 @@ Requires-Dist: openai>=1.58.1
|
|
45
45
|
Requires-Dist: Pillow<12.0.0,>=11.0.0
|
46
46
|
Requires-Dist: protobuf<5.0.0,<=4.24.0
|
47
47
|
Requires-Dist: pycitydata>=1.0.3
|
48
|
-
Requires-Dist: pycityproto>=2.
|
48
|
+
Requires-Dist: pycityproto>=2.2.0
|
49
49
|
Requires-Dist: requests>=2.32.3
|
50
50
|
Requires-Dist: Shapely>=2.0.6
|
51
51
|
Requires-Dist: PyYAML>=6.0.2
|