pycityagent 2.0.0a48__cp39-cp39-macosx_11_0_arm64.whl → 2.0.0a50__cp39-cp39-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,9 +1,9 @@
1
1
  import asyncio
2
- from collections.abc import Callable
3
2
  import json
4
3
  import logging
5
4
  import time
6
5
  import uuid
6
+ from collections.abc import Callable
7
7
  from datetime import datetime, timezone
8
8
  from pathlib import Path
9
9
  from typing import Any, Optional, Type, Union
@@ -34,7 +34,10 @@ class AgentGroup:
34
34
  self,
35
35
  agent_class: Union[type[Agent], list[type[Agent]]],
36
36
  number_of_agents: Union[int, list[int]],
37
- memory_config_function_group: Union[Callable[[], tuple[dict, dict, dict]], list[Callable[[], tuple[dict, dict, dict]]]],
37
+ memory_config_function_group: Union[
38
+ Callable[[], tuple[dict, dict, dict]],
39
+ list[Callable[[], tuple[dict, dict, dict]]],
40
+ ],
38
41
  config: dict,
39
42
  exp_id: str | UUID,
40
43
  exp_name: str,
@@ -45,7 +48,7 @@ class AgentGroup:
45
48
  mlflow_run_id: str,
46
49
  embedding_model: Embeddings,
47
50
  logging_level: int,
48
- agent_config_file: Union[str, list[str]] = None,
51
+ agent_config_file: Optional[Union[str, list[str]]] = None,
49
52
  ):
50
53
  logger.setLevel(logging_level)
51
54
  self._uuid = str(uuid.uuid4())
@@ -81,14 +84,14 @@ class AgentGroup:
81
84
  # prepare Messager
82
85
  if "mqtt" in config["simulator_request"]:
83
86
  self.messager = Messager.remote(
84
- hostname=config["simulator_request"]["mqtt"]["server"],
87
+ hostname=config["simulator_request"]["mqtt"]["server"], # type:ignore
85
88
  port=config["simulator_request"]["mqtt"]["port"],
86
89
  username=config["simulator_request"]["mqtt"].get("username", None),
87
90
  password=config["simulator_request"]["mqtt"].get("password", None),
88
91
  )
89
92
  else:
90
93
  self.messager = None
91
-
94
+
92
95
  self.message_dispatch_task = None
93
96
  self._pgsql_writer = pgsql_writer
94
97
  self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
@@ -160,7 +163,7 @@ class AgentGroup:
160
163
  agent.memory.set_faiss_query(self.faiss_query)
161
164
  if self.embedding_model is not None:
162
165
  agent.memory.set_embedding_model(self.embedding_model)
163
- if self.agent_config_file[i]:
166
+ if self.agent_config_file is not None and self.agent_config_file[i]:
164
167
  agent.load_from_file(self.agent_config_file[i])
165
168
  self.agents.append(agent)
166
169
  self.id2agent[agent._uuid] = agent
@@ -168,21 +171,21 @@ class AgentGroup:
168
171
  @property
169
172
  def agent_count(self):
170
173
  return self.number_of_agents
171
-
174
+
172
175
  @property
173
176
  def agent_uuids(self):
174
177
  return list(self.id2agent.keys())
175
-
178
+
176
179
  @property
177
180
  def agent_type(self):
178
181
  return self.agent_class
179
-
182
+
180
183
  def get_agent_count(self):
181
184
  return self.agent_count
182
-
185
+
183
186
  def get_agent_uuids(self):
184
187
  return self.agent_uuids
185
-
188
+
186
189
  def get_agent_type(self):
187
190
  return self.agent_type
188
191
 
@@ -190,10 +193,6 @@ class AgentGroup:
190
193
  self.message_dispatch_task.cancel() # type: ignore
191
194
  await asyncio.gather(self.message_dispatch_task, return_exceptions=True) # type: ignore
192
195
 
193
- async def __aexit__(self, exc_type, exc_value, traceback):
194
- self.message_dispatch_task.cancel() # type: ignore
195
- await asyncio.gather(self.message_dispatch_task, return_exceptions=True) # type: ignore
196
-
197
196
  async def init_agents(self):
198
197
  logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
199
198
  logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
@@ -201,6 +200,7 @@ class AgentGroup:
201
200
  await agent.bind_to_simulator() # type: ignore
202
201
  self.id2agent = {agent._uuid: agent for agent in self.agents}
203
202
  logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
203
+ assert self.messager is not None
204
204
  await self.messager.connect.remote()
205
205
  if await self.messager.is_connected.remote():
206
206
  await self.messager.start_listening.remote()
@@ -293,10 +293,12 @@ class AgentGroup:
293
293
  self.initialized = True
294
294
  logger.debug(f"-----AgentGroup {self._uuid} initialized")
295
295
 
296
- async def filter(self,
297
- types: Optional[list[Type[Agent]]] = None,
298
- keys: Optional[list[str]] = None,
299
- values: Optional[list[Any]] = None) -> list[str]:
296
+ async def filter(
297
+ self,
298
+ types: Optional[list[Type[Agent]]] = None,
299
+ keys: Optional[list[str]] = None,
300
+ values: Optional[list[Any]] = None,
301
+ ) -> list[str]:
300
302
  filtered_uuids = []
301
303
  for agent in self.agents:
302
304
  add = True
@@ -304,6 +306,7 @@ class AgentGroup:
304
306
  if agent.__class__ in types:
305
307
  if keys:
306
308
  for key in keys:
309
+ assert values is not None
307
310
  if not agent.memory.get(key) == values[keys.index(key)]:
308
311
  add = False
309
312
  break
@@ -311,6 +314,7 @@ class AgentGroup:
311
314
  filtered_uuids.append(agent._uuid)
312
315
  elif keys:
313
316
  for key in keys:
317
+ assert values is not None
314
318
  if not agent.memory.get(key) == values[keys.index(key)]:
315
319
  add = False
316
320
  break
@@ -335,6 +339,7 @@ class AgentGroup:
335
339
  async def message_dispatch(self):
336
340
  logger.debug(f"-----Starting message dispatch for group {self._uuid}")
337
341
  while True:
342
+ assert self.messager is not None
338
343
  if not await self.messager.is_connected.remote():
339
344
  logger.warning(
340
345
  "Messager is not connected. Skipping message processing."
@@ -13,6 +13,11 @@ import yaml
13
13
  from langchain_core.embeddings import Embeddings
14
14
 
15
15
  from ..agent import Agent, InstitutionAgent
16
+ from ..cityagent import (BankAgent, FirmAgent, GovernmentAgent, NBSAgent,
17
+ SocietyAgent, memory_config_bank, memory_config_firm,
18
+ memory_config_government, memory_config_nbs,
19
+ memory_config_societyagent)
20
+ from ..cityagent.initial import bind_agent_info, initialize_social_network
16
21
  from ..environment.simulator import Simulator
17
22
  from ..llm import SimpleEmbedding
18
23
  from ..memory import Memory
@@ -22,11 +27,10 @@ from ..survey import Survey
22
27
  from ..utils import TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
23
28
  from .agentgroup import AgentGroup
24
29
  from .storage.pg import PgWriter, create_pg_tables
25
- from ..cityagent import SocietyAgent, FirmAgent, BankAgent, NBSAgent, GovernmentAgent, memory_config_societyagent, memory_config_government, memory_config_firm, memory_config_bank, memory_config_nbs
26
- from ..cityagent.initial import bind_agent_info, initialize_social_network
27
30
 
28
31
  logger = logging.getLogger("pycityagent")
29
32
 
33
+
30
34
  class AgentSimulation:
31
35
  """城市智能体模拟器"""
32
36
 
@@ -52,7 +56,13 @@ class AgentSimulation:
52
56
  self.agent_class = agent_class
53
57
  elif agent_class is None:
54
58
  if enable_economy:
55
- self.agent_class = [SocietyAgent, FirmAgent, BankAgent, NBSAgent, GovernmentAgent]
59
+ self.agent_class = [
60
+ SocietyAgent,
61
+ FirmAgent,
62
+ BankAgent,
63
+ NBSAgent,
64
+ GovernmentAgent,
65
+ ]
56
66
  self.default_memory_config_func = [
57
67
  memory_config_societyagent,
58
68
  memory_config_firm,
@@ -82,7 +92,7 @@ class AgentSimulation:
82
92
  # self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
83
93
 
84
94
  self._messager = Messager.remote(
85
- hostname=config["simulator_request"]["mqtt"]["server"], # type:ignore
95
+ hostname=config["simulator_request"]["mqtt"]["server"], # type:ignore
86
96
  port=config["simulator_request"]["mqtt"]["port"],
87
97
  username=config["simulator_request"]["mqtt"].get("username", None),
88
98
  password=config["simulator_request"]["mqtt"].get("password", None),
@@ -143,7 +153,7 @@ class AgentSimulation:
143
153
  """Directly run from config file
144
154
  Basic config file should contain:
145
155
  - simulation_config: file_path
146
- - agent_config:
156
+ - agent_config:
147
157
  - agent_config_file: Optional[dict]
148
158
  - memory_config_func: Optional[Union[Callable, list[Callable]]]
149
159
  - init_func: Optional[list[Callable[AgentSimulation, None]]]
@@ -173,13 +183,14 @@ class AgentSimulation:
173
183
  if "workflow" not in config:
174
184
  raise ValueError("workflow is required")
175
185
  import yaml
186
+
176
187
  logger.info("Loading config file...")
177
188
  with open(config["simulation_config"], "r") as f:
178
189
  simulation_config = yaml.safe_load(f)
179
190
  logger.info("Creating AgentSimulation Task...")
180
191
  simulation = cls(
181
- config=simulation_config,
182
- agent_config_file=config["agent_config"].get("agent_config_file", None),
192
+ config=simulation_config,
193
+ agent_config_file=config["agent_config"].get("agent_config_file", None),
183
194
  exp_name=config.get("exp_name", "default_experiment"),
184
195
  logging_level=config.get("logging_level", logging.WARNING),
185
196
  )
@@ -193,21 +204,28 @@ class AgentSimulation:
193
204
  await simulation.init_agents(
194
205
  agent_count=agent_count,
195
206
  group_size=config["agent_config"].get("group_size", 10000),
196
- embedding_model=config["agent_config"].get("embedding_model", SimpleEmbedding()),
207
+ embedding_model=config["agent_config"].get(
208
+ "embedding_model", SimpleEmbedding()
209
+ ),
197
210
  memory_config_func=config["agent_config"].get("memory_config_func", None),
198
211
  )
199
212
  logger.info("Running Init Functions...")
200
- for init_func in config["agent_config"].get("init_func", [bind_agent_info, initialize_social_network]):
213
+ for init_func in config["agent_config"].get(
214
+ "init_func", [bind_agent_info, initialize_social_network]
215
+ ):
201
216
  await init_func(simulation)
202
217
  logger.info("Starting Simulation...")
203
218
  for step in config["workflow"]:
204
- logger.info(f"Running step: type: {step['type']} - description: {step.get('description', 'no description')}")
219
+ logger.info(
220
+ f"Running step: type: {step['type']} - description: {step.get('description', 'no description')}"
221
+ )
205
222
  if step["type"] not in ["run", "step", "interview", "survey", "intervene"]:
206
223
  raise ValueError(f"Invalid step type: {step['type']}")
207
224
  if step["type"] == "run":
208
225
  await simulation.run(step.get("day", 1))
209
226
  elif step["type"] == "step":
210
- await simulation.step(step.get("time", 1))
227
+ # await simulation.step(step.get("time", 1))
228
+ await simulation.step()
211
229
  else:
212
230
  await step["step_func"](simulation)
213
231
  logger.info("Simulation finished")
@@ -241,11 +259,11 @@ class AgentSimulation:
241
259
  @property
242
260
  def agent_uuid2group(self):
243
261
  return self._agent_uuid2group
244
-
262
+
245
263
  @property
246
264
  def messager(self):
247
265
  return self._messager
248
-
266
+
249
267
  async def _save_exp_info(self) -> None:
250
268
  """异步保存实验信息到YAML文件"""
251
269
  try:
@@ -354,38 +372,44 @@ class AgentSimulation:
354
372
  # 分别处理机构智能体和普通智能体
355
373
  institution_params = []
356
374
  citizen_params = []
357
-
375
+
358
376
  # 收集所有参数
359
377
  for i in range(len(self.agent_class)):
360
378
  agent_class = self.agent_class[i]
361
379
  agent_count_i = agent_count[i]
362
380
  memory_config_func_i = memory_config_func[i]
363
-
381
+
364
382
  if self.agent_config_file is not None:
365
- config_file = self.agent_config_file.get(agent_class, None)
383
+ config_file = self.agent_config_file.get(agent_class, None)
366
384
  else:
367
385
  config_file = None
368
-
386
+
369
387
  if issubclass(agent_class, InstitutionAgent):
370
- institution_params.append((agent_class, agent_count_i, memory_config_func_i, config_file))
388
+ institution_params.append(
389
+ (agent_class, agent_count_i, memory_config_func_i, config_file)
390
+ )
371
391
  else:
372
- citizen_params.append((agent_class, agent_count_i, memory_config_func_i, config_file))
392
+ citizen_params.append(
393
+ (agent_class, agent_count_i, memory_config_func_i, config_file)
394
+ )
373
395
 
374
396
  # 处理机构智能体组
375
397
  if institution_params:
376
398
  total_institution_count = sum(p[1] for p in institution_params)
377
- num_institution_groups = (total_institution_count + group_size - 1) // group_size
378
-
399
+ num_institution_groups = (
400
+ total_institution_count + group_size - 1
401
+ ) // group_size
402
+
379
403
  for k in range(num_institution_groups):
380
404
  start_idx = k * group_size
381
405
  remaining = total_institution_count - start_idx
382
406
  number_of_agents = min(remaining, group_size)
383
-
407
+
384
408
  agent_classes = []
385
409
  agent_counts = []
386
410
  memory_config_funcs = []
387
411
  config_files = []
388
-
412
+
389
413
  # 分配每种类型的机构智能体到当前组
390
414
  curr_start = start_idx
391
415
  for agent_class, count, mem_func, conf_file in institution_params:
@@ -395,30 +419,32 @@ class AgentSimulation:
395
419
  memory_config_funcs.append(mem_func)
396
420
  config_files.append(conf_file)
397
421
  curr_start = max(0, curr_start - count)
398
-
399
- group_creation_params.append((
400
- agent_classes,
401
- agent_counts,
402
- memory_config_funcs,
403
- f"InstitutionGroup_{k}",
404
- config_files
405
- ))
422
+
423
+ group_creation_params.append(
424
+ (
425
+ agent_classes,
426
+ agent_counts,
427
+ memory_config_funcs,
428
+ f"InstitutionGroup_{k}",
429
+ config_files,
430
+ )
431
+ )
406
432
 
407
433
  # 处理普通智能体组
408
434
  if citizen_params:
409
435
  total_citizen_count = sum(p[1] for p in citizen_params)
410
436
  num_citizen_groups = (total_citizen_count + group_size - 1) // group_size
411
-
437
+
412
438
  for k in range(num_citizen_groups):
413
439
  start_idx = k * group_size
414
440
  remaining = total_citizen_count - start_idx
415
441
  number_of_agents = min(remaining, group_size)
416
-
442
+
417
443
  agent_classes = []
418
444
  agent_counts = []
419
445
  memory_config_funcs = []
420
446
  config_files = []
421
-
447
+
422
448
  # 分配每种类型的普通智能体到当前组
423
449
  curr_start = start_idx
424
450
  for agent_class, count, mem_func, conf_file in citizen_params:
@@ -428,14 +454,16 @@ class AgentSimulation:
428
454
  memory_config_funcs.append(mem_func)
429
455
  config_files.append(conf_file)
430
456
  curr_start = max(0, curr_start - count)
431
-
432
- group_creation_params.append((
433
- agent_classes,
434
- agent_counts,
435
- memory_config_funcs,
436
- f"CitizenGroup_{k}",
437
- config_files
438
- ))
457
+
458
+ group_creation_params.append(
459
+ (
460
+ agent_classes,
461
+ agent_counts,
462
+ memory_config_funcs,
463
+ f"CitizenGroup_{k}",
464
+ config_files,
465
+ )
466
+ )
439
467
 
440
468
  # 初始化mlflow连接
441
469
  _mlflow_config = self.config.get("metric_request", {}).get("mlflow")
@@ -463,7 +491,13 @@ class AgentSimulation:
463
491
  self._pgsql_writers = _workers = [None for _ in range(_num_workers)]
464
492
 
465
493
  creation_tasks = []
466
- for i, (agent_class, number_of_agents, memory_config_function_group, group_name, config_file) in enumerate(group_creation_params):
494
+ for i, (
495
+ agent_class,
496
+ number_of_agents,
497
+ memory_config_function_group,
498
+ group_name,
499
+ config_file,
500
+ ) in enumerate(group_creation_params):
467
501
  # 直接创建异步任务
468
502
  group = AgentGroup.remote(
469
503
  agent_class,
@@ -489,7 +523,9 @@ class AgentSimulation:
489
523
  group_agent_uuids = ray.get(group.get_agent_uuids.remote())
490
524
  for agent_uuid in group_agent_uuids:
491
525
  self._agent_uuid2group[agent_uuid] = group
492
- self._user_chat_topics[agent_uuid] = f"exps/{self.exp_id}/agents/{agent_uuid}/user-chat"
526
+ self._user_chat_topics[agent_uuid] = (
527
+ f"exps/{self.exp_id}/agents/{agent_uuid}/user-chat"
528
+ )
493
529
  self._user_survey_topics[agent_uuid] = (
494
530
  f"exps/{self.exp_id}/agents/{agent_uuid}/user-survey"
495
531
  )
@@ -511,23 +547,26 @@ class AgentSimulation:
511
547
  for group in self._groups.values():
512
548
  gather_tasks.append(group.gather.remote(content))
513
549
  return await asyncio.gather(*gather_tasks)
514
-
515
- async def filter(self,
516
- types: Optional[list[Type[Agent]]] = None,
517
- keys: Optional[list[str]] = None,
518
- values: Optional[list[Any]] = None) -> list[str]:
550
+
551
+ async def filter(
552
+ self,
553
+ types: Optional[list[Type[Agent]]] = None,
554
+ keys: Optional[list[str]] = None,
555
+ values: Optional[list[Any]] = None,
556
+ ) -> list[str]:
519
557
  """过滤出指定类型的智能体"""
520
558
  if not types and not keys and not values:
521
559
  return self._agent_uuids
522
560
  group_to_filter = []
523
- for t in types:
524
- if t in self._type2group:
525
- group_to_filter.extend(self._type2group[t])
526
- else:
527
- raise ValueError(f"type {t} not found in simulation")
561
+ if types is not None:
562
+ for t in types:
563
+ if t in self._type2group:
564
+ group_to_filter.extend(self._type2group[t])
565
+ else:
566
+ raise ValueError(f"type {t} not found in simulation")
528
567
  filtered_uuids = []
529
568
  if keys:
530
- if len(keys) != len(values):
569
+ if values is None or len(keys) != len(values):
531
570
  raise ValueError("the length of key and value does not match")
532
571
  for group in group_to_filter:
533
572
  filtered_uuids.extend(await group.filter.remote(types, keys, values))
@@ -0,0 +1,9 @@
1
+ from .tool import ExportMlflowMetrics, GetMap, SencePOI, Tool, UpdateWithSimulator
2
+
3
+ __all__ = [
4
+ "SencePOI",
5
+ "Tool",
6
+ "ExportMlflowMetrics",
7
+ "GetMap",
8
+ "UpdateWithSimulator",
9
+ ]
@@ -1,10 +1,16 @@
1
- from typing import Any, Optional, Union
1
+ import asyncio
2
+ import time
2
3
  from collections import defaultdict
3
4
  from collections.abc import Callable, Sequence
5
+ from typing import Any, Optional, Union
6
+
4
7
  from mlflow.entities import Metric
5
- import time
6
8
 
7
- from ..environment import LEVEL_ONE_PRE, POI_TYPE_DICT
9
+ from ..agent import Agent
10
+ from ..environment import (LEVEL_ONE_PRE, POI_TYPE_DICT, AoiService,
11
+ PersonService)
12
+ from ..utils.decorators import lock_decorator
13
+ from ..workflow import Block
8
14
 
9
15
 
10
16
  class Tool:
@@ -34,31 +40,23 @@ class Tool:
34
40
  raise NotImplementedError
35
41
 
36
42
  @property
37
- def agent(self):
43
+ def agent(self) -> Agent:
38
44
  instance = self._instance # type:ignore
39
- if not isinstance(instance, self._get_agent_class()):
45
+ if not isinstance(instance, Agent):
40
46
  raise RuntimeError(
41
47
  f"Tool bind to object `{type(instance).__name__}`, not an `Agent` object!"
42
48
  )
43
49
  return instance
44
50
 
45
51
  @property
46
- def block(self):
52
+ def block(self) -> Block:
47
53
  instance = self._instance # type:ignore
48
- if not isinstance(instance, self._get_block_class()):
54
+ if not isinstance(instance, Block):
49
55
  raise RuntimeError(
50
56
  f"Tool bind to object `{type(instance).__name__}`, not an `Block` object!"
51
57
  )
52
58
  return instance
53
59
 
54
- def _get_agent_class(self):
55
- from ..agent import Agent
56
- return Agent
57
-
58
- def _get_block_class(self):
59
- from ..workflow import Block
60
- return Block
61
-
62
60
 
63
61
  class GetMap(Tool):
64
62
  """Retrieve the map from the simulator. Can be bound only to an `Agent` instance."""
@@ -140,7 +138,7 @@ class SencePOI(Tool):
140
138
 
141
139
  class UpdateWithSimulator(Tool):
142
140
  def __init__(self) -> None:
143
- pass
141
+ self._lock = asyncio.Lock()
144
142
 
145
143
  async def _update_motion_with_sim(
146
144
  self,
@@ -164,6 +162,7 @@ class UpdateWithSimulator(Tool):
164
162
  except KeyError as e:
165
163
  continue
166
164
 
165
+ @lock_decorator
167
166
  async def __call__(
168
167
  self,
169
168
  ):
@@ -173,8 +172,9 @@ class UpdateWithSimulator(Tool):
173
172
 
174
173
  class ResetAgentPosition(Tool):
175
174
  def __init__(self) -> None:
176
- pass
175
+ self._lock = asyncio.Lock()
177
176
 
177
+ @lock_decorator
178
178
  async def __call__(
179
179
  self,
180
180
  aoi_id: Optional[int] = None,
@@ -198,7 +198,9 @@ class ExportMlflowMetrics(Tool):
198
198
  self._log_batch_size = log_batch_size
199
199
  # TODO: support other log types
200
200
  self.metric_log_cache: dict[str, list[Metric]] = defaultdict(list)
201
+ self._lock = asyncio.Lock()
201
202
 
203
+ @lock_decorator
202
204
  async def __call__(
203
205
  self,
204
206
  metric: Union[Sequence[Union[Metric, dict]], Union[Metric, dict]],
@@ -231,6 +233,7 @@ class ExportMlflowMetrics(Tool):
231
233
  if clear_cache:
232
234
  await self._clear_cache()
233
235
 
236
+ @lock_decorator
234
237
  async def _clear_cache(
235
238
  self,
236
239
  ):
@@ -7,14 +7,9 @@ This module contains classes for creating blocks and running workflows.
7
7
  from .block import (Block, log_and_check, log_and_check_with_memory,
8
8
  trigger_class)
9
9
  from .prompt import FormatPrompt
10
- from .tool import ExportMlflowMetrics, GetMap, SencePOI, Tool
11
10
  from .trigger import EventTrigger, MemoryChangeTrigger, TimeTrigger
12
11
 
13
12
  __all__ = [
14
- "SencePOI",
15
- "Tool",
16
- "ExportMlflowMetrics",
17
- "GetMap",
18
13
  "MemoryChangeTrigger",
19
14
  "TimeTrigger",
20
15
  "EventTrigger",
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
+
2
3
  import asyncio
3
4
  import functools
4
5
  import inspect
5
- from collections.abc import Awaitable, Callable, Coroutine
6
6
  import json
7
- from typing import Any, List, Optional, Union
7
+ from collections.abc import Awaitable, Callable, Coroutine
8
+ from typing import Any, Optional, Union
8
9
 
9
10
  from pyparsing import Dict
10
11
 
@@ -143,7 +144,7 @@ def trigger_class():
143
144
 
144
145
  # Define a Block, similar to a layer in PyTorch
145
146
  class Block:
146
- configurable_fields: List[str] = []
147
+ configurable_fields: list[str] = []
147
148
  default_values: dict[str, Any] = {}
148
149
 
149
150
  def __init__(
@@ -164,22 +165,23 @@ class Block:
164
165
  trigger.initialize() # 立即初始化trigger
165
166
  self.trigger = trigger
166
167
 
167
- def export_config(self) -> Dict[str, Optional[str]]:
168
+ def export_config(self) -> dict[str, Optional[str]]:
168
169
  return {
169
170
  field: self.default_values.get(field, "default_value")
170
171
  for field in self.configurable_fields
171
172
  }
172
173
 
173
174
  @classmethod
174
- def export_class_config(cls) -> Dict[str, str]:
175
+ def export_class_config(cls) -> dict[str, str]:
175
176
  return {
176
177
  field: cls.default_values.get(field, "default_value")
177
178
  for field in cls.configurable_fields
178
179
  }
179
180
 
180
181
  @classmethod
181
- def import_config(cls, config: Dict[str, str]) -> "Block":
182
+ def import_config(cls, config: dict[str, Union[str, dict]]) -> Block:
182
183
  instance = cls(name=config["name"])
184
+ assert isinstance(config["config"], dict)
183
185
  for field, value in config["config"].items():
184
186
  if field in cls.configurable_fields:
185
187
  setattr(instance, field, value)
@@ -190,8 +192,8 @@ class Block:
190
192
  setattr(instance, child_block.name.lower(), child_block)
191
193
 
192
194
  return instance
193
-
194
- def load_from_config(self, config: Dict[str, List[Dict]]) -> None:
195
+
196
+ def load_from_config(self, config: dict[str, list[Dict]]) -> None:
195
197
  """
196
198
  使用配置更新当前Block实例的参数,并递归更新子Block。
197
199
  """
@@ -201,8 +203,8 @@ class Block:
201
203
  if config["config"][field] != "default_value":
202
204
  setattr(self, field, config["config"][field])
203
205
 
204
- def build_or_update_block(block_data: Dict) -> Block:
205
- block_name = block_data["name"].lower()
206
+ def build_or_update_block(block_data: dict) -> Block:
207
+ block_name = block_data["name"].lower() # type:ignore
206
208
  existing_block = getattr(self, block_name, None)
207
209
 
208
210
  if existing_block:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pycityagent
3
- Version: 2.0.0a48
3
+ Version: 2.0.0a50
4
4
  Summary: LLM-based city environment agent building library
5
5
  Author-email: Yuwei Yan <pinkgranite86@gmail.com>, Junbo Yan <yanjb20thu@gmali.com>, Jun Zhang <zhangjun990222@gmali.com>
6
6
  License: MIT License
@@ -50,7 +50,6 @@ Requires-Dist: requests>=2.32.3
50
50
  Requires-Dist: Shapely>=2.0.6
51
51
  Requires-Dist: PyYAML>=6.0.2
52
52
  Requires-Dist: zhipuai>=2.1.5.20230904
53
- Requires-Dist: gradio>=5.7.1
54
53
  Requires-Dist: mosstool>=1.3.0
55
54
  Requires-Dist: ray>=2.40.0
56
55
  Requires-Dist: aiomqtt>=2.3.0