pycityagent 2.0.0a48__cp310-cp310-macosx_11_0_arm64.whl → 2.0.0a50__cp310-cp310-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- pycityagent/__init__.py +12 -3
- pycityagent/agent/__init__.py +9 -0
- pycityagent/agent/agent.py +324 -0
- pycityagent/{agent.py → agent/agent_base.py} +43 -347
- pycityagent/cityagent/bankagent.py +28 -16
- pycityagent/cityagent/firmagent.py +63 -25
- pycityagent/cityagent/governmentagent.py +35 -19
- pycityagent/cityagent/initial.py +38 -28
- pycityagent/cityagent/memory_config.py +240 -128
- pycityagent/cityagent/nbsagent.py +81 -35
- pycityagent/cityagent/societyagent.py +155 -72
- pycityagent/simulation/agentgroup.py +24 -19
- pycityagent/simulation/simulation.py +94 -55
- pycityagent/tools/__init__.py +9 -0
- pycityagent/{workflow → tools}/tool.py +20 -17
- pycityagent/workflow/__init__.py +0 -5
- pycityagent/workflow/block.py +12 -10
- {pycityagent-2.0.0a48.dist-info → pycityagent-2.0.0a50.dist-info}/METADATA +1 -2
- {pycityagent-2.0.0a48.dist-info → pycityagent-2.0.0a50.dist-info}/RECORD +23 -20
- {pycityagent-2.0.0a48.dist-info → pycityagent-2.0.0a50.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a48.dist-info → pycityagent-2.0.0a50.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a48.dist-info → pycityagent-2.0.0a50.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a48.dist-info → pycityagent-2.0.0a50.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
|
|
1
1
|
import asyncio
|
2
|
-
from collections.abc import Callable
|
3
2
|
import json
|
4
3
|
import logging
|
5
4
|
import time
|
6
5
|
import uuid
|
6
|
+
from collections.abc import Callable
|
7
7
|
from datetime import datetime, timezone
|
8
8
|
from pathlib import Path
|
9
9
|
from typing import Any, Optional, Type, Union
|
@@ -34,7 +34,10 @@ class AgentGroup:
|
|
34
34
|
self,
|
35
35
|
agent_class: Union[type[Agent], list[type[Agent]]],
|
36
36
|
number_of_agents: Union[int, list[int]],
|
37
|
-
memory_config_function_group: Union[
|
37
|
+
memory_config_function_group: Union[
|
38
|
+
Callable[[], tuple[dict, dict, dict]],
|
39
|
+
list[Callable[[], tuple[dict, dict, dict]]],
|
40
|
+
],
|
38
41
|
config: dict,
|
39
42
|
exp_id: str | UUID,
|
40
43
|
exp_name: str,
|
@@ -45,7 +48,7 @@ class AgentGroup:
|
|
45
48
|
mlflow_run_id: str,
|
46
49
|
embedding_model: Embeddings,
|
47
50
|
logging_level: int,
|
48
|
-
agent_config_file: Union[str, list[str]] = None,
|
51
|
+
agent_config_file: Optional[Union[str, list[str]]] = None,
|
49
52
|
):
|
50
53
|
logger.setLevel(logging_level)
|
51
54
|
self._uuid = str(uuid.uuid4())
|
@@ -81,14 +84,14 @@ class AgentGroup:
|
|
81
84
|
# prepare Messager
|
82
85
|
if "mqtt" in config["simulator_request"]:
|
83
86
|
self.messager = Messager.remote(
|
84
|
-
hostname=config["simulator_request"]["mqtt"]["server"],
|
87
|
+
hostname=config["simulator_request"]["mqtt"]["server"], # type:ignore
|
85
88
|
port=config["simulator_request"]["mqtt"]["port"],
|
86
89
|
username=config["simulator_request"]["mqtt"].get("username", None),
|
87
90
|
password=config["simulator_request"]["mqtt"].get("password", None),
|
88
91
|
)
|
89
92
|
else:
|
90
93
|
self.messager = None
|
91
|
-
|
94
|
+
|
92
95
|
self.message_dispatch_task = None
|
93
96
|
self._pgsql_writer = pgsql_writer
|
94
97
|
self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
@@ -160,7 +163,7 @@ class AgentGroup:
|
|
160
163
|
agent.memory.set_faiss_query(self.faiss_query)
|
161
164
|
if self.embedding_model is not None:
|
162
165
|
agent.memory.set_embedding_model(self.embedding_model)
|
163
|
-
if self.agent_config_file[i]:
|
166
|
+
if self.agent_config_file is not None and self.agent_config_file[i]:
|
164
167
|
agent.load_from_file(self.agent_config_file[i])
|
165
168
|
self.agents.append(agent)
|
166
169
|
self.id2agent[agent._uuid] = agent
|
@@ -168,21 +171,21 @@ class AgentGroup:
|
|
168
171
|
@property
|
169
172
|
def agent_count(self):
|
170
173
|
return self.number_of_agents
|
171
|
-
|
174
|
+
|
172
175
|
@property
|
173
176
|
def agent_uuids(self):
|
174
177
|
return list(self.id2agent.keys())
|
175
|
-
|
178
|
+
|
176
179
|
@property
|
177
180
|
def agent_type(self):
|
178
181
|
return self.agent_class
|
179
|
-
|
182
|
+
|
180
183
|
def get_agent_count(self):
|
181
184
|
return self.agent_count
|
182
|
-
|
185
|
+
|
183
186
|
def get_agent_uuids(self):
|
184
187
|
return self.agent_uuids
|
185
|
-
|
188
|
+
|
186
189
|
def get_agent_type(self):
|
187
190
|
return self.agent_type
|
188
191
|
|
@@ -190,10 +193,6 @@ class AgentGroup:
|
|
190
193
|
self.message_dispatch_task.cancel() # type: ignore
|
191
194
|
await asyncio.gather(self.message_dispatch_task, return_exceptions=True) # type: ignore
|
192
195
|
|
193
|
-
async def __aexit__(self, exc_type, exc_value, traceback):
|
194
|
-
self.message_dispatch_task.cancel() # type: ignore
|
195
|
-
await asyncio.gather(self.message_dispatch_task, return_exceptions=True) # type: ignore
|
196
|
-
|
197
196
|
async def init_agents(self):
|
198
197
|
logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
|
199
198
|
logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
|
@@ -201,6 +200,7 @@ class AgentGroup:
|
|
201
200
|
await agent.bind_to_simulator() # type: ignore
|
202
201
|
self.id2agent = {agent._uuid: agent for agent in self.agents}
|
203
202
|
logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
|
203
|
+
assert self.messager is not None
|
204
204
|
await self.messager.connect.remote()
|
205
205
|
if await self.messager.is_connected.remote():
|
206
206
|
await self.messager.start_listening.remote()
|
@@ -293,10 +293,12 @@ class AgentGroup:
|
|
293
293
|
self.initialized = True
|
294
294
|
logger.debug(f"-----AgentGroup {self._uuid} initialized")
|
295
295
|
|
296
|
-
async def filter(
|
297
|
-
|
298
|
-
|
299
|
-
|
296
|
+
async def filter(
|
297
|
+
self,
|
298
|
+
types: Optional[list[Type[Agent]]] = None,
|
299
|
+
keys: Optional[list[str]] = None,
|
300
|
+
values: Optional[list[Any]] = None,
|
301
|
+
) -> list[str]:
|
300
302
|
filtered_uuids = []
|
301
303
|
for agent in self.agents:
|
302
304
|
add = True
|
@@ -304,6 +306,7 @@ class AgentGroup:
|
|
304
306
|
if agent.__class__ in types:
|
305
307
|
if keys:
|
306
308
|
for key in keys:
|
309
|
+
assert values is not None
|
307
310
|
if not agent.memory.get(key) == values[keys.index(key)]:
|
308
311
|
add = False
|
309
312
|
break
|
@@ -311,6 +314,7 @@ class AgentGroup:
|
|
311
314
|
filtered_uuids.append(agent._uuid)
|
312
315
|
elif keys:
|
313
316
|
for key in keys:
|
317
|
+
assert values is not None
|
314
318
|
if not agent.memory.get(key) == values[keys.index(key)]:
|
315
319
|
add = False
|
316
320
|
break
|
@@ -335,6 +339,7 @@ class AgentGroup:
|
|
335
339
|
async def message_dispatch(self):
|
336
340
|
logger.debug(f"-----Starting message dispatch for group {self._uuid}")
|
337
341
|
while True:
|
342
|
+
assert self.messager is not None
|
338
343
|
if not await self.messager.is_connected.remote():
|
339
344
|
logger.warning(
|
340
345
|
"Messager is not connected. Skipping message processing."
|
@@ -13,6 +13,11 @@ import yaml
|
|
13
13
|
from langchain_core.embeddings import Embeddings
|
14
14
|
|
15
15
|
from ..agent import Agent, InstitutionAgent
|
16
|
+
from ..cityagent import (BankAgent, FirmAgent, GovernmentAgent, NBSAgent,
|
17
|
+
SocietyAgent, memory_config_bank, memory_config_firm,
|
18
|
+
memory_config_government, memory_config_nbs,
|
19
|
+
memory_config_societyagent)
|
20
|
+
from ..cityagent.initial import bind_agent_info, initialize_social_network
|
16
21
|
from ..environment.simulator import Simulator
|
17
22
|
from ..llm import SimpleEmbedding
|
18
23
|
from ..memory import Memory
|
@@ -22,11 +27,10 @@ from ..survey import Survey
|
|
22
27
|
from ..utils import TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
23
28
|
from .agentgroup import AgentGroup
|
24
29
|
from .storage.pg import PgWriter, create_pg_tables
|
25
|
-
from ..cityagent import SocietyAgent, FirmAgent, BankAgent, NBSAgent, GovernmentAgent, memory_config_societyagent, memory_config_government, memory_config_firm, memory_config_bank, memory_config_nbs
|
26
|
-
from ..cityagent.initial import bind_agent_info, initialize_social_network
|
27
30
|
|
28
31
|
logger = logging.getLogger("pycityagent")
|
29
32
|
|
33
|
+
|
30
34
|
class AgentSimulation:
|
31
35
|
"""城市智能体模拟器"""
|
32
36
|
|
@@ -52,7 +56,13 @@ class AgentSimulation:
|
|
52
56
|
self.agent_class = agent_class
|
53
57
|
elif agent_class is None:
|
54
58
|
if enable_economy:
|
55
|
-
self.agent_class = [
|
59
|
+
self.agent_class = [
|
60
|
+
SocietyAgent,
|
61
|
+
FirmAgent,
|
62
|
+
BankAgent,
|
63
|
+
NBSAgent,
|
64
|
+
GovernmentAgent,
|
65
|
+
]
|
56
66
|
self.default_memory_config_func = [
|
57
67
|
memory_config_societyagent,
|
58
68
|
memory_config_firm,
|
@@ -82,7 +92,7 @@ class AgentSimulation:
|
|
82
92
|
# self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
83
93
|
|
84
94
|
self._messager = Messager.remote(
|
85
|
-
hostname=config["simulator_request"]["mqtt"]["server"],
|
95
|
+
hostname=config["simulator_request"]["mqtt"]["server"], # type:ignore
|
86
96
|
port=config["simulator_request"]["mqtt"]["port"],
|
87
97
|
username=config["simulator_request"]["mqtt"].get("username", None),
|
88
98
|
password=config["simulator_request"]["mqtt"].get("password", None),
|
@@ -143,7 +153,7 @@ class AgentSimulation:
|
|
143
153
|
"""Directly run from config file
|
144
154
|
Basic config file should contain:
|
145
155
|
- simulation_config: file_path
|
146
|
-
- agent_config:
|
156
|
+
- agent_config:
|
147
157
|
- agent_config_file: Optional[dict]
|
148
158
|
- memory_config_func: Optional[Union[Callable, list[Callable]]]
|
149
159
|
- init_func: Optional[list[Callable[AgentSimulation, None]]]
|
@@ -173,13 +183,14 @@ class AgentSimulation:
|
|
173
183
|
if "workflow" not in config:
|
174
184
|
raise ValueError("workflow is required")
|
175
185
|
import yaml
|
186
|
+
|
176
187
|
logger.info("Loading config file...")
|
177
188
|
with open(config["simulation_config"], "r") as f:
|
178
189
|
simulation_config = yaml.safe_load(f)
|
179
190
|
logger.info("Creating AgentSimulation Task...")
|
180
191
|
simulation = cls(
|
181
|
-
config=simulation_config,
|
182
|
-
agent_config_file=config["agent_config"].get("agent_config_file", None),
|
192
|
+
config=simulation_config,
|
193
|
+
agent_config_file=config["agent_config"].get("agent_config_file", None),
|
183
194
|
exp_name=config.get("exp_name", "default_experiment"),
|
184
195
|
logging_level=config.get("logging_level", logging.WARNING),
|
185
196
|
)
|
@@ -193,21 +204,28 @@ class AgentSimulation:
|
|
193
204
|
await simulation.init_agents(
|
194
205
|
agent_count=agent_count,
|
195
206
|
group_size=config["agent_config"].get("group_size", 10000),
|
196
|
-
embedding_model=config["agent_config"].get(
|
207
|
+
embedding_model=config["agent_config"].get(
|
208
|
+
"embedding_model", SimpleEmbedding()
|
209
|
+
),
|
197
210
|
memory_config_func=config["agent_config"].get("memory_config_func", None),
|
198
211
|
)
|
199
212
|
logger.info("Running Init Functions...")
|
200
|
-
for init_func in config["agent_config"].get(
|
213
|
+
for init_func in config["agent_config"].get(
|
214
|
+
"init_func", [bind_agent_info, initialize_social_network]
|
215
|
+
):
|
201
216
|
await init_func(simulation)
|
202
217
|
logger.info("Starting Simulation...")
|
203
218
|
for step in config["workflow"]:
|
204
|
-
logger.info(
|
219
|
+
logger.info(
|
220
|
+
f"Running step: type: {step['type']} - description: {step.get('description', 'no description')}"
|
221
|
+
)
|
205
222
|
if step["type"] not in ["run", "step", "interview", "survey", "intervene"]:
|
206
223
|
raise ValueError(f"Invalid step type: {step['type']}")
|
207
224
|
if step["type"] == "run":
|
208
225
|
await simulation.run(step.get("day", 1))
|
209
226
|
elif step["type"] == "step":
|
210
|
-
await simulation.step(step.get("time", 1))
|
227
|
+
# await simulation.step(step.get("time", 1))
|
228
|
+
await simulation.step()
|
211
229
|
else:
|
212
230
|
await step["step_func"](simulation)
|
213
231
|
logger.info("Simulation finished")
|
@@ -241,11 +259,11 @@ class AgentSimulation:
|
|
241
259
|
@property
|
242
260
|
def agent_uuid2group(self):
|
243
261
|
return self._agent_uuid2group
|
244
|
-
|
262
|
+
|
245
263
|
@property
|
246
264
|
def messager(self):
|
247
265
|
return self._messager
|
248
|
-
|
266
|
+
|
249
267
|
async def _save_exp_info(self) -> None:
|
250
268
|
"""异步保存实验信息到YAML文件"""
|
251
269
|
try:
|
@@ -354,38 +372,44 @@ class AgentSimulation:
|
|
354
372
|
# 分别处理机构智能体和普通智能体
|
355
373
|
institution_params = []
|
356
374
|
citizen_params = []
|
357
|
-
|
375
|
+
|
358
376
|
# 收集所有参数
|
359
377
|
for i in range(len(self.agent_class)):
|
360
378
|
agent_class = self.agent_class[i]
|
361
379
|
agent_count_i = agent_count[i]
|
362
380
|
memory_config_func_i = memory_config_func[i]
|
363
|
-
|
381
|
+
|
364
382
|
if self.agent_config_file is not None:
|
365
|
-
config_file = self.agent_config_file.get(agent_class, None)
|
383
|
+
config_file = self.agent_config_file.get(agent_class, None)
|
366
384
|
else:
|
367
385
|
config_file = None
|
368
|
-
|
386
|
+
|
369
387
|
if issubclass(agent_class, InstitutionAgent):
|
370
|
-
institution_params.append(
|
388
|
+
institution_params.append(
|
389
|
+
(agent_class, agent_count_i, memory_config_func_i, config_file)
|
390
|
+
)
|
371
391
|
else:
|
372
|
-
citizen_params.append(
|
392
|
+
citizen_params.append(
|
393
|
+
(agent_class, agent_count_i, memory_config_func_i, config_file)
|
394
|
+
)
|
373
395
|
|
374
396
|
# 处理机构智能体组
|
375
397
|
if institution_params:
|
376
398
|
total_institution_count = sum(p[1] for p in institution_params)
|
377
|
-
num_institution_groups = (
|
378
|
-
|
399
|
+
num_institution_groups = (
|
400
|
+
total_institution_count + group_size - 1
|
401
|
+
) // group_size
|
402
|
+
|
379
403
|
for k in range(num_institution_groups):
|
380
404
|
start_idx = k * group_size
|
381
405
|
remaining = total_institution_count - start_idx
|
382
406
|
number_of_agents = min(remaining, group_size)
|
383
|
-
|
407
|
+
|
384
408
|
agent_classes = []
|
385
409
|
agent_counts = []
|
386
410
|
memory_config_funcs = []
|
387
411
|
config_files = []
|
388
|
-
|
412
|
+
|
389
413
|
# 分配每种类型的机构智能体到当前组
|
390
414
|
curr_start = start_idx
|
391
415
|
for agent_class, count, mem_func, conf_file in institution_params:
|
@@ -395,30 +419,32 @@ class AgentSimulation:
|
|
395
419
|
memory_config_funcs.append(mem_func)
|
396
420
|
config_files.append(conf_file)
|
397
421
|
curr_start = max(0, curr_start - count)
|
398
|
-
|
399
|
-
group_creation_params.append(
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
422
|
+
|
423
|
+
group_creation_params.append(
|
424
|
+
(
|
425
|
+
agent_classes,
|
426
|
+
agent_counts,
|
427
|
+
memory_config_funcs,
|
428
|
+
f"InstitutionGroup_{k}",
|
429
|
+
config_files,
|
430
|
+
)
|
431
|
+
)
|
406
432
|
|
407
433
|
# 处理普通智能体组
|
408
434
|
if citizen_params:
|
409
435
|
total_citizen_count = sum(p[1] for p in citizen_params)
|
410
436
|
num_citizen_groups = (total_citizen_count + group_size - 1) // group_size
|
411
|
-
|
437
|
+
|
412
438
|
for k in range(num_citizen_groups):
|
413
439
|
start_idx = k * group_size
|
414
440
|
remaining = total_citizen_count - start_idx
|
415
441
|
number_of_agents = min(remaining, group_size)
|
416
|
-
|
442
|
+
|
417
443
|
agent_classes = []
|
418
444
|
agent_counts = []
|
419
445
|
memory_config_funcs = []
|
420
446
|
config_files = []
|
421
|
-
|
447
|
+
|
422
448
|
# 分配每种类型的普通智能体到当前组
|
423
449
|
curr_start = start_idx
|
424
450
|
for agent_class, count, mem_func, conf_file in citizen_params:
|
@@ -428,14 +454,16 @@ class AgentSimulation:
|
|
428
454
|
memory_config_funcs.append(mem_func)
|
429
455
|
config_files.append(conf_file)
|
430
456
|
curr_start = max(0, curr_start - count)
|
431
|
-
|
432
|
-
group_creation_params.append(
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
457
|
+
|
458
|
+
group_creation_params.append(
|
459
|
+
(
|
460
|
+
agent_classes,
|
461
|
+
agent_counts,
|
462
|
+
memory_config_funcs,
|
463
|
+
f"CitizenGroup_{k}",
|
464
|
+
config_files,
|
465
|
+
)
|
466
|
+
)
|
439
467
|
|
440
468
|
# 初始化mlflow连接
|
441
469
|
_mlflow_config = self.config.get("metric_request", {}).get("mlflow")
|
@@ -463,7 +491,13 @@ class AgentSimulation:
|
|
463
491
|
self._pgsql_writers = _workers = [None for _ in range(_num_workers)]
|
464
492
|
|
465
493
|
creation_tasks = []
|
466
|
-
for i, (
|
494
|
+
for i, (
|
495
|
+
agent_class,
|
496
|
+
number_of_agents,
|
497
|
+
memory_config_function_group,
|
498
|
+
group_name,
|
499
|
+
config_file,
|
500
|
+
) in enumerate(group_creation_params):
|
467
501
|
# 直接创建异步任务
|
468
502
|
group = AgentGroup.remote(
|
469
503
|
agent_class,
|
@@ -489,7 +523,9 @@ class AgentSimulation:
|
|
489
523
|
group_agent_uuids = ray.get(group.get_agent_uuids.remote())
|
490
524
|
for agent_uuid in group_agent_uuids:
|
491
525
|
self._agent_uuid2group[agent_uuid] = group
|
492
|
-
self._user_chat_topics[agent_uuid] =
|
526
|
+
self._user_chat_topics[agent_uuid] = (
|
527
|
+
f"exps/{self.exp_id}/agents/{agent_uuid}/user-chat"
|
528
|
+
)
|
493
529
|
self._user_survey_topics[agent_uuid] = (
|
494
530
|
f"exps/{self.exp_id}/agents/{agent_uuid}/user-survey"
|
495
531
|
)
|
@@ -511,23 +547,26 @@ class AgentSimulation:
|
|
511
547
|
for group in self._groups.values():
|
512
548
|
gather_tasks.append(group.gather.remote(content))
|
513
549
|
return await asyncio.gather(*gather_tasks)
|
514
|
-
|
515
|
-
async def filter(
|
516
|
-
|
517
|
-
|
518
|
-
|
550
|
+
|
551
|
+
async def filter(
|
552
|
+
self,
|
553
|
+
types: Optional[list[Type[Agent]]] = None,
|
554
|
+
keys: Optional[list[str]] = None,
|
555
|
+
values: Optional[list[Any]] = None,
|
556
|
+
) -> list[str]:
|
519
557
|
"""过滤出指定类型的智能体"""
|
520
558
|
if not types and not keys and not values:
|
521
559
|
return self._agent_uuids
|
522
560
|
group_to_filter = []
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
561
|
+
if types is not None:
|
562
|
+
for t in types:
|
563
|
+
if t in self._type2group:
|
564
|
+
group_to_filter.extend(self._type2group[t])
|
565
|
+
else:
|
566
|
+
raise ValueError(f"type {t} not found in simulation")
|
528
567
|
filtered_uuids = []
|
529
568
|
if keys:
|
530
|
-
if len(keys) != len(values):
|
569
|
+
if values is None or len(keys) != len(values):
|
531
570
|
raise ValueError("the length of key and value does not match")
|
532
571
|
for group in group_to_filter:
|
533
572
|
filtered_uuids.extend(await group.filter.remote(types, keys, values))
|
@@ -1,10 +1,16 @@
|
|
1
|
-
|
1
|
+
import asyncio
|
2
|
+
import time
|
2
3
|
from collections import defaultdict
|
3
4
|
from collections.abc import Callable, Sequence
|
5
|
+
from typing import Any, Optional, Union
|
6
|
+
|
4
7
|
from mlflow.entities import Metric
|
5
|
-
import time
|
6
8
|
|
7
|
-
from ..
|
9
|
+
from ..agent import Agent
|
10
|
+
from ..environment import (LEVEL_ONE_PRE, POI_TYPE_DICT, AoiService,
|
11
|
+
PersonService)
|
12
|
+
from ..utils.decorators import lock_decorator
|
13
|
+
from ..workflow import Block
|
8
14
|
|
9
15
|
|
10
16
|
class Tool:
|
@@ -34,31 +40,23 @@ class Tool:
|
|
34
40
|
raise NotImplementedError
|
35
41
|
|
36
42
|
@property
|
37
|
-
def agent(self):
|
43
|
+
def agent(self) -> Agent:
|
38
44
|
instance = self._instance # type:ignore
|
39
|
-
if not isinstance(instance,
|
45
|
+
if not isinstance(instance, Agent):
|
40
46
|
raise RuntimeError(
|
41
47
|
f"Tool bind to object `{type(instance).__name__}`, not an `Agent` object!"
|
42
48
|
)
|
43
49
|
return instance
|
44
50
|
|
45
51
|
@property
|
46
|
-
def block(self):
|
52
|
+
def block(self) -> Block:
|
47
53
|
instance = self._instance # type:ignore
|
48
|
-
if not isinstance(instance,
|
54
|
+
if not isinstance(instance, Block):
|
49
55
|
raise RuntimeError(
|
50
56
|
f"Tool bind to object `{type(instance).__name__}`, not an `Block` object!"
|
51
57
|
)
|
52
58
|
return instance
|
53
59
|
|
54
|
-
def _get_agent_class(self):
|
55
|
-
from ..agent import Agent
|
56
|
-
return Agent
|
57
|
-
|
58
|
-
def _get_block_class(self):
|
59
|
-
from ..workflow import Block
|
60
|
-
return Block
|
61
|
-
|
62
60
|
|
63
61
|
class GetMap(Tool):
|
64
62
|
"""Retrieve the map from the simulator. Can be bound only to an `Agent` instance."""
|
@@ -140,7 +138,7 @@ class SencePOI(Tool):
|
|
140
138
|
|
141
139
|
class UpdateWithSimulator(Tool):
|
142
140
|
def __init__(self) -> None:
|
143
|
-
|
141
|
+
self._lock = asyncio.Lock()
|
144
142
|
|
145
143
|
async def _update_motion_with_sim(
|
146
144
|
self,
|
@@ -164,6 +162,7 @@ class UpdateWithSimulator(Tool):
|
|
164
162
|
except KeyError as e:
|
165
163
|
continue
|
166
164
|
|
165
|
+
@lock_decorator
|
167
166
|
async def __call__(
|
168
167
|
self,
|
169
168
|
):
|
@@ -173,8 +172,9 @@ class UpdateWithSimulator(Tool):
|
|
173
172
|
|
174
173
|
class ResetAgentPosition(Tool):
|
175
174
|
def __init__(self) -> None:
|
176
|
-
|
175
|
+
self._lock = asyncio.Lock()
|
177
176
|
|
177
|
+
@lock_decorator
|
178
178
|
async def __call__(
|
179
179
|
self,
|
180
180
|
aoi_id: Optional[int] = None,
|
@@ -198,7 +198,9 @@ class ExportMlflowMetrics(Tool):
|
|
198
198
|
self._log_batch_size = log_batch_size
|
199
199
|
# TODO: support other log types
|
200
200
|
self.metric_log_cache: dict[str, list[Metric]] = defaultdict(list)
|
201
|
+
self._lock = asyncio.Lock()
|
201
202
|
|
203
|
+
@lock_decorator
|
202
204
|
async def __call__(
|
203
205
|
self,
|
204
206
|
metric: Union[Sequence[Union[Metric, dict]], Union[Metric, dict]],
|
@@ -231,6 +233,7 @@ class ExportMlflowMetrics(Tool):
|
|
231
233
|
if clear_cache:
|
232
234
|
await self._clear_cache()
|
233
235
|
|
236
|
+
@lock_decorator
|
234
237
|
async def _clear_cache(
|
235
238
|
self,
|
236
239
|
):
|
pycityagent/workflow/__init__.py
CHANGED
@@ -7,14 +7,9 @@ This module contains classes for creating blocks and running workflows.
|
|
7
7
|
from .block import (Block, log_and_check, log_and_check_with_memory,
|
8
8
|
trigger_class)
|
9
9
|
from .prompt import FormatPrompt
|
10
|
-
from .tool import ExportMlflowMetrics, GetMap, SencePOI, Tool
|
11
10
|
from .trigger import EventTrigger, MemoryChangeTrigger, TimeTrigger
|
12
11
|
|
13
12
|
__all__ = [
|
14
|
-
"SencePOI",
|
15
|
-
"Tool",
|
16
|
-
"ExportMlflowMetrics",
|
17
|
-
"GetMap",
|
18
13
|
"MemoryChangeTrigger",
|
19
14
|
"TimeTrigger",
|
20
15
|
"EventTrigger",
|
pycityagent/workflow/block.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import functools
|
4
5
|
import inspect
|
5
|
-
from collections.abc import Awaitable, Callable, Coroutine
|
6
6
|
import json
|
7
|
-
from
|
7
|
+
from collections.abc import Awaitable, Callable, Coroutine
|
8
|
+
from typing import Any, Optional, Union
|
8
9
|
|
9
10
|
from pyparsing import Dict
|
10
11
|
|
@@ -143,7 +144,7 @@ def trigger_class():
|
|
143
144
|
|
144
145
|
# Define a Block, similar to a layer in PyTorch
|
145
146
|
class Block:
|
146
|
-
configurable_fields:
|
147
|
+
configurable_fields: list[str] = []
|
147
148
|
default_values: dict[str, Any] = {}
|
148
149
|
|
149
150
|
def __init__(
|
@@ -164,22 +165,23 @@ class Block:
|
|
164
165
|
trigger.initialize() # 立即初始化trigger
|
165
166
|
self.trigger = trigger
|
166
167
|
|
167
|
-
def export_config(self) ->
|
168
|
+
def export_config(self) -> dict[str, Optional[str]]:
|
168
169
|
return {
|
169
170
|
field: self.default_values.get(field, "default_value")
|
170
171
|
for field in self.configurable_fields
|
171
172
|
}
|
172
173
|
|
173
174
|
@classmethod
|
174
|
-
def export_class_config(cls) ->
|
175
|
+
def export_class_config(cls) -> dict[str, str]:
|
175
176
|
return {
|
176
177
|
field: cls.default_values.get(field, "default_value")
|
177
178
|
for field in cls.configurable_fields
|
178
179
|
}
|
179
180
|
|
180
181
|
@classmethod
|
181
|
-
def import_config(cls, config:
|
182
|
+
def import_config(cls, config: dict[str, Union[str, dict]]) -> Block:
|
182
183
|
instance = cls(name=config["name"])
|
184
|
+
assert isinstance(config["config"], dict)
|
183
185
|
for field, value in config["config"].items():
|
184
186
|
if field in cls.configurable_fields:
|
185
187
|
setattr(instance, field, value)
|
@@ -190,8 +192,8 @@ class Block:
|
|
190
192
|
setattr(instance, child_block.name.lower(), child_block)
|
191
193
|
|
192
194
|
return instance
|
193
|
-
|
194
|
-
def load_from_config(self, config:
|
195
|
+
|
196
|
+
def load_from_config(self, config: dict[str, list[Dict]]) -> None:
|
195
197
|
"""
|
196
198
|
使用配置更新当前Block实例的参数,并递归更新子Block。
|
197
199
|
"""
|
@@ -201,8 +203,8 @@ class Block:
|
|
201
203
|
if config["config"][field] != "default_value":
|
202
204
|
setattr(self, field, config["config"][field])
|
203
205
|
|
204
|
-
def build_or_update_block(block_data:
|
205
|
-
block_name = block_data["name"].lower()
|
206
|
+
def build_or_update_block(block_data: dict) -> Block:
|
207
|
+
block_name = block_data["name"].lower() # type:ignore
|
206
208
|
existing_block = getattr(self, block_name, None)
|
207
209
|
|
208
210
|
if existing_block:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: pycityagent
|
3
|
-
Version: 2.0.
|
3
|
+
Version: 2.0.0a50
|
4
4
|
Summary: LLM-based city environment agent building library
|
5
5
|
Author-email: Yuwei Yan <pinkgranite86@gmail.com>, Junbo Yan <yanjb20thu@gmali.com>, Jun Zhang <zhangjun990222@gmali.com>
|
6
6
|
License: MIT License
|
@@ -50,7 +50,6 @@ Requires-Dist: requests>=2.32.3
|
|
50
50
|
Requires-Dist: Shapely>=2.0.6
|
51
51
|
Requires-Dist: PyYAML>=6.0.2
|
52
52
|
Requires-Dist: zhipuai>=2.1.5.20230904
|
53
|
-
Requires-Dist: gradio>=5.7.1
|
54
53
|
Requires-Dist: mosstool>=1.3.0
|
55
54
|
Requires-Dist: ray>=2.40.0
|
56
55
|
Requires-Dist: aiomqtt>=2.3.0
|