pycityagent 2.0.0a15__py3-none-any.whl → 2.0.0a17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pycityagent/agent.py CHANGED
@@ -9,7 +9,7 @@ from enum import Enum
9
9
  import logging
10
10
  import random
11
11
  import uuid
12
- from typing import Dict, List, Optional
12
+ from typing import Dict, List, Optional,Any
13
13
 
14
14
  import fastavro
15
15
 
@@ -80,6 +80,7 @@ class Agent(ABC):
80
80
  self._simulator = simulator
81
81
  self._memory = memory
82
82
  self._exp_id = -1
83
+ self._agent_id = -1
83
84
  self._has_bound_to_simulator = False
84
85
  self._has_bound_to_economy = False
85
86
  self._blocked = False
@@ -224,8 +225,8 @@ class Agent(ABC):
224
225
  return
225
226
  response_to_avro = [{
226
227
  "id": self._uuid,
227
- "day": await self._simulator.get_simulator_day(),
228
- "t": await self._simulator.get_simulator_second_from_start_of_day(),
228
+ "day": await self.simulator.get_simulator_day(),
229
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
229
230
  "survey_id": survey["id"],
230
231
  "result": survey_response,
231
232
  "created_at": int(datetime.now().timestamp() * 1000),
@@ -273,8 +274,8 @@ class Agent(ABC):
273
274
  async def _process_interview(self, payload: dict):
274
275
  auros = [{
275
276
  "id": self._uuid,
276
- "day": await self._simulator.get_simulator_day(),
277
- "t": await self._simulator.get_simulator_second_from_start_of_day(),
277
+ "day": await self.simulator.get_simulator_day(),
278
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
278
279
  "type": 2,
279
280
  "speaker": "user",
280
281
  "content": payload["content"],
@@ -284,8 +285,8 @@ class Agent(ABC):
284
285
  response = await self.generate_user_chat_response(question)
285
286
  auros.append({
286
287
  "id": self._uuid,
287
- "day": await self._simulator.get_simulator_day(),
288
- "t": await self._simulator.get_simulator_second_from_start_of_day(),
288
+ "day": await self.simulator.get_simulator_day(),
289
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
289
290
  "type": 2,
290
291
  "speaker": "",
291
292
  "content": response,
@@ -297,7 +298,9 @@ class Agent(ABC):
297
298
  fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
298
299
 
299
300
  async def process_agent_chat_response(self, payload: dict) -> str:
300
- logger.info(f"Agent {self._uuid} received agent chat response: {payload}")
301
+ resp = f"Agent {self._uuid} received agent chat response: {payload}"
302
+ logger.info(resp)
303
+ return resp
301
304
 
302
305
  async def _process_agent_chat(self, payload: dict):
303
306
  auros = [{
@@ -334,7 +337,7 @@ class Agent(ABC):
334
337
  logger.info(f"Agent {self._uuid} received user survey message: {payload}")
335
338
  asyncio.create_task(self._process_survey(payload["data"]))
336
339
 
337
- async def handle_gather_message(self, payload: str):
340
+ async def handle_gather_message(self, payload: Any):
338
341
  raise NotImplementedError
339
342
 
340
343
  # MQTT send message
@@ -357,14 +360,14 @@ class Agent(ABC):
357
360
  "from": self._uuid,
358
361
  "content": content,
359
362
  "timestamp": int(datetime.now().timestamp() * 1000),
360
- "day": await self._simulator.get_simulator_day(),
361
- "t": await self._simulator.get_simulator_second_from_start_of_day(),
363
+ "day": await self.simulator.get_simulator_day(),
364
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
362
365
  }
363
366
  await self._send_message(to_agent_uuid, payload, "agent-chat")
364
367
  auros = [{
365
368
  "id": self._uuid,
366
- "day": await self._simulator.get_simulator_day(),
367
- "t": await self._simulator.get_simulator_second_from_start_of_day(),
369
+ "day": await self.simulator.get_simulator_day(),
370
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
368
371
  "type": 1,
369
372
  "speaker": self._uuid,
370
373
  "content": content,
@@ -440,8 +443,8 @@ class CitizenAgent(Agent):
440
443
  "pedestrian_attribute",
441
444
  "bike_attribute",
442
445
  }
443
- simulator = self._simulator
444
- memory = self._memory
446
+ simulator = self.simulator
447
+ memory = self.memory
445
448
  person_id = await memory.get("id")
446
449
  # ATTENTION:模拟器分配的id从0开始
447
450
  if person_id >= 0:
@@ -478,11 +481,11 @@ class CitizenAgent(Agent):
478
481
  await self._economy_client.remove_agents([self._agent_id])
479
482
  except:
480
483
  pass
481
- person_id = await self._memory.get("id")
484
+ person_id = await self.memory.get("id")
482
485
  await self._economy_client.add_agents(
483
486
  {
484
487
  "id": person_id,
485
- "currency": await self._memory.get("currency"),
488
+ "currency": await self.memory.get("currency"),
486
489
  }
487
490
  )
488
491
  self._has_bound_to_economy = True
@@ -543,62 +546,63 @@ class InstitutionAgent(Agent):
543
546
  # TODO: More general id generation
544
547
  _id = random.randint(100000, 999999)
545
548
  self._agent_id = _id
546
- await self._memory.update("id", _id, protect_llm_read_only_fields=False)
549
+ await self.memory.update("id", _id, protect_llm_read_only_fields=False)
547
550
  try:
548
551
  await self._economy_client.remove_orgs([self._agent_id])
549
552
  except:
550
553
  pass
551
554
  try:
552
- id = await self._memory.get("id")
553
- type = await self._memory.get("type")
555
+ _memory = self.memory
556
+ _id = await _memory.get("id")
557
+ _type = await _memory.get("type")
554
558
  try:
555
- nominal_gdp = await self._memory.get("nominal_gdp")
559
+ nominal_gdp = await _memory.get("nominal_gdp")
556
560
  except:
557
561
  nominal_gdp = []
558
562
  try:
559
- real_gdp = await self._memory.get("real_gdp")
563
+ real_gdp = await _memory.get("real_gdp")
560
564
  except:
561
565
  real_gdp = []
562
566
  try:
563
- unemployment = await self._memory.get("unemployment")
567
+ unemployment = await _memory.get("unemployment")
564
568
  except:
565
569
  unemployment = []
566
570
  try:
567
- wages = await self._memory.get("wages")
571
+ wages = await _memory.get("wages")
568
572
  except:
569
573
  wages = []
570
574
  try:
571
- prices = await self._memory.get("prices")
575
+ prices = await _memory.get("prices")
572
576
  except:
573
577
  prices = []
574
578
  try:
575
- inventory = await self._memory.get("inventory")
579
+ inventory = await _memory.get("inventory")
576
580
  except:
577
581
  inventory = 0
578
582
  try:
579
- price = await self._memory.get("price")
583
+ price = await _memory.get("price")
580
584
  except:
581
585
  price = 0
582
586
  try:
583
- currency = await self._memory.get("currency")
587
+ currency = await _memory.get("currency")
584
588
  except:
585
589
  currency = 0.0
586
590
  try:
587
- interest_rate = await self._memory.get("interest_rate")
591
+ interest_rate = await _memory.get("interest_rate")
588
592
  except:
589
593
  interest_rate = 0.0
590
594
  try:
591
- bracket_cutoffs = await self._memory.get("bracket_cutoffs")
595
+ bracket_cutoffs = await _memory.get("bracket_cutoffs")
592
596
  except:
593
597
  bracket_cutoffs = []
594
598
  try:
595
- bracket_rates = await self._memory.get("bracket_rates")
599
+ bracket_rates = await _memory.get("bracket_rates")
596
600
  except:
597
601
  bracket_rates = []
598
602
  await self._economy_client.add_orgs(
599
603
  {
600
- "id": id,
601
- "type": type,
604
+ "id": _id,
605
+ "type": _type,
602
606
  "nominal_gdp": nominal_gdp,
603
607
  "real_gdp": real_gdp,
604
608
  "unemployment": unemployment,
@@ -307,3 +307,12 @@ class EconomyClient:
307
307
  )
308
308
  # current agent ids and org ids
309
309
  return (list(response.agent_ids), list(response.org_ids))
310
+
311
+ async def get_org_entity_ids(self, org_type: economyv2.OrgType)->list[int]:
312
+ request = org_service.GetOrgEntityIdsRequest(
313
+ type=org_type,
314
+ )
315
+ response: org_service.GetOrgEntityIdsResponse = (
316
+ await self._aio_stub.GetOrgEntityIds(request)
317
+ )
318
+ return list(response.org_ids)
@@ -20,7 +20,7 @@ logger = logging.getLogger("pycityagent")
20
20
 
21
21
  @ray.remote
22
22
  class AgentGroup:
23
- def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, enable_avro: bool, avro_path: Path, logging_level: int = logging.WARNING):
23
+ def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, enable_avro: bool, avro_path: Path, logging_level: int):
24
24
  logger.setLevel(logging_level)
25
25
  self._uuid = str(uuid.uuid4())
26
26
  self.agents = agents
@@ -64,20 +64,20 @@ class AgentGroup:
64
64
  self.economy_client = None
65
65
 
66
66
  for agent in self.agents:
67
- agent.set_exp_id(self.exp_id)
67
+ agent.set_exp_id(self.exp_id) # type: ignore
68
68
  agent.set_llm_client(self.llm)
69
69
  agent.set_simulator(self.simulator)
70
70
  if self.economy_client is not None:
71
71
  agent.set_economy_client(self.economy_client)
72
72
  agent.set_messager(self.messager)
73
73
  if self.enable_avro:
74
- agent.set_avro_file(self.avro_file)
74
+ agent.set_avro_file(self.avro_file) # type: ignore
75
75
 
76
76
  async def init_agents(self):
77
77
  logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
78
78
  logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
79
79
  for agent in self.agents:
80
- await agent.bind_to_simulator()
80
+ await agent.bind_to_simulator() # type: ignore
81
81
  self.id2agent = {agent._uuid: agent for agent in self.agents}
82
82
  logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
83
83
  await self.messager.connect()
@@ -2,19 +2,22 @@ import asyncio
2
2
  import json
3
3
  import logging
4
4
  import os
5
- from pathlib import Path
5
+ import random
6
6
  import uuid
7
+ from collections.abc import Sequence
8
+ from concurrent.futures import ThreadPoolExecutor
7
9
  from datetime import datetime, timezone
8
- import random
9
- from typing import Dict, List, Optional, Callable, Union,Any
10
- from mosstool.map._map_util.const import AOI_START_ID
10
+ from pathlib import Path
11
+ from typing import Any, Callable, Dict, List, Optional, Union
12
+
11
13
  import pycityproto.city.economy.v2.economy_pb2 as economyv2
14
+ import yaml
15
+ from mosstool.map._map_util.const import AOI_START_ID
16
+
12
17
  from pycityagent.environment.simulator import Simulator
13
18
  from pycityagent.memory.memory import Memory
14
19
  from pycityagent.message.messager import Messager
15
20
  from pycityagent.survey import Survey
16
- import yaml
17
- from concurrent.futures import ThreadPoolExecutor
18
21
 
19
22
  from ..agent import Agent, InstitutionAgent
20
23
  from .agentgroup import AgentGroup
@@ -31,7 +34,7 @@ class AgentSimulation:
31
34
  config: dict,
32
35
  agent_prefix: str = "agent_",
33
36
  exp_name: str = "default_experiment",
34
- logging_level: int = logging.WARNING
37
+ logging_level: int = logging.WARNING,
35
38
  ):
36
39
  """
37
40
  Args:
@@ -50,8 +53,8 @@ class AgentSimulation:
50
53
  self._simulator = Simulator(config["simulator_request"])
51
54
  self.agent_prefix = agent_prefix
52
55
  self._agents: Dict[uuid.UUID, Agent] = {}
53
- self._groups: Dict[str, AgentGroup] = {}
54
- self._agent_uuid2group: Dict[uuid.UUID, AgentGroup] = {}
56
+ self._groups: Dict[str, AgentGroup] = {} # type:ignore
57
+ self._agent_uuid2group: Dict[uuid.UUID, AgentGroup] = {} # type:ignore
55
58
  self._agent_uuids: List[uuid.UUID] = []
56
59
  self._user_chat_topics: Dict[uuid.UUID, str] = {}
57
60
  self._user_survey_topics: Dict[uuid.UUID, str] = {}
@@ -89,33 +92,44 @@ class AgentSimulation:
89
92
  "cur_t": 0.0,
90
93
  "config": json.dumps(config),
91
94
  "error": "",
92
- "created_at": datetime.now(timezone.utc).isoformat()
95
+ "created_at": datetime.now(timezone.utc).isoformat(),
93
96
  }
94
-
97
+
95
98
  # 创建异步任务保存实验信息
96
99
  self._exp_info_file = self._avro_path / "experiment_info.yaml"
97
- with open(self._exp_info_file, 'w') as f:
100
+ with open(self._exp_info_file, "w") as f:
98
101
  yaml.dump(self._exp_info, f)
99
102
 
100
103
  @property
101
104
  def agents(self):
102
105
  return self._agents
103
-
106
+
104
107
  @property
105
108
  def groups(self):
106
109
  return self._groups
107
-
110
+
108
111
  @property
109
112
  def agent_uuids(self):
110
113
  return self._agent_uuids
111
-
114
+
112
115
  @property
113
116
  def agent_uuid2group(self):
114
117
  return self._agent_uuid2group
115
-
116
- def create_remote_group(self, group_name: str, agents: list[Agent], config: dict, exp_id: str, enable_avro: bool, avro_path: Path, logging_level: int = logging.WARNING):
118
+
119
+ def create_remote_group(
120
+ self,
121
+ group_name: str,
122
+ agents: list[Agent],
123
+ config: dict,
124
+ exp_id: str,
125
+ enable_avro: bool,
126
+ avro_path: Path,
127
+ logging_level: int = logging.WARNING,
128
+ ):
117
129
  """创建远程组"""
118
- group = AgentGroup.remote(agents, config, exp_id, enable_avro, avro_path, logging_level)
130
+ group = AgentGroup.remote(
131
+ agents, config, exp_id, enable_avro, avro_path, logging_level
132
+ )
119
133
  return group_name, group, agents
120
134
 
121
135
  async def init_agents(
@@ -164,7 +178,7 @@ class AgentSimulation:
164
178
  # 使用线程池并行创建 AgentGroup
165
179
  group_creation_params = []
166
180
  class_init_index = 0
167
-
181
+
168
182
  # 首先收集所有需要创建的组的参数
169
183
  for i in range(len(self.agent_class)):
170
184
  agent_class = self.agent_class[i]
@@ -175,9 +189,7 @@ class AgentSimulation:
175
189
 
176
190
  # 获取Memory配置
177
191
  extra_attributes, profile, base = memory_config_func_i()
178
- memory = Memory(
179
- config=extra_attributes, profile=profile, base=base
180
- )
192
+ memory = Memory(config=extra_attributes, profile=profile, base=base)
181
193
 
182
194
  # 创建智能体时传入Memory配置
183
195
  agent = agent_class(
@@ -190,34 +202,36 @@ class AgentSimulation:
190
202
 
191
203
  # 计算需要的组数,向上取整以处理不足一组的情况
192
204
  num_group = (agent_count_i + group_size - 1) // group_size
193
-
205
+
194
206
  for k in range(num_group):
195
207
  start_idx = class_init_index + k * group_size
196
208
  end_idx = min(
197
209
  class_init_index + (k + 1) * group_size, # 修正了索引计算
198
- class_init_index + agent_count_i
210
+ class_init_index + agent_count_i,
199
211
  )
200
-
212
+
201
213
  agents = list(self._agents.values())[start_idx:end_idx]
202
214
  group_name = f"AgentType_{i}_Group_{k}"
203
-
215
+
204
216
  # 收集创建参数
205
- group_creation_params.append((
206
- group_name,
207
- agents
208
- ))
209
-
217
+ group_creation_params.append((group_name, agents))
218
+
210
219
  class_init_index += agent_count_i
211
220
 
212
221
  # 收集所有创建组的参数
213
222
  creation_tasks = []
214
223
  for group_name, agents in group_creation_params:
215
224
  # 直接创建异步任务
216
- group = AgentGroup.remote(agents, self.config, self.exp_id,
217
- self._enable_avro, self._avro_path,
218
- self.logging_level)
225
+ group = AgentGroup.remote(
226
+ agents,
227
+ self.config,
228
+ self.exp_id,
229
+ self._enable_avro,
230
+ self._avro_path,
231
+ self.logging_level,
232
+ )
219
233
  creation_tasks.append((group_name, group, agents))
220
-
234
+
221
235
  # 更新数据结构
222
236
  for group_name, group, agents in creation_tasks:
223
237
  self._groups[group_name] = group
@@ -233,7 +247,9 @@ class AgentSimulation:
233
247
  # 设置用户主题
234
248
  for uuid, agent in self._agents.items():
235
249
  self._user_chat_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-chat"
236
- self._user_survey_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-survey"
250
+ self._user_survey_topics[uuid] = (
251
+ f"exps/{self.exp_id}/agents/{uuid}/user-survey"
252
+ )
237
253
 
238
254
  async def gather(self, content: str):
239
255
  """收集智能体的特定信息"""
@@ -250,7 +266,18 @@ class AgentSimulation:
250
266
  def default_memory_config_institution(self):
251
267
  """默认的Memory配置函数"""
252
268
  EXTRA_ATTRIBUTES = {
253
- "type": (int, random.choice([economyv2.ORG_TYPE_BANK, economyv2.ORG_TYPE_GOVERNMENT, economyv2.ORG_TYPE_FIRM, economyv2.ORG_TYPE_NBS, economyv2.ORG_TYPE_UNSPECIFIED])),
269
+ "type": (
270
+ int,
271
+ random.choice(
272
+ [
273
+ economyv2.ORG_TYPE_BANK,
274
+ economyv2.ORG_TYPE_GOVERNMENT,
275
+ economyv2.ORG_TYPE_FIRM,
276
+ economyv2.ORG_TYPE_NBS,
277
+ economyv2.ORG_TYPE_UNSPECIFIED,
278
+ ]
279
+ ),
280
+ ),
254
281
  "nominal_gdp": (list, [], True),
255
282
  "real_gdp": (list, [], True),
256
283
  "unemployment": (list, [], True),
@@ -364,29 +391,35 @@ class AgentSimulation:
364
391
  }
365
392
 
366
393
  return EXTRA_ATTRIBUTES, PROFILE, BASE
367
-
368
- async def send_survey(self, survey: Survey, agent_uuids: Optional[List[uuid.UUID]] = None):
394
+
395
+ async def send_survey(
396
+ self, survey: Survey, agent_uuids: Optional[List[uuid.UUID]] = None
397
+ ):
369
398
  """发送问卷"""
370
- survey = survey.to_dict()
399
+ survey_dict = survey.to_dict()
371
400
  if agent_uuids is None:
372
401
  agent_uuids = self._agent_uuids
373
402
  payload = {
374
403
  "from": "none",
375
- "survey_id": survey["id"],
404
+ "survey_id": survey_dict["id"],
376
405
  "timestamp": int(datetime.now().timestamp() * 1000),
377
- "data": survey,
406
+ "data": survey_dict,
378
407
  }
379
408
  for uuid in agent_uuids:
380
409
  topic = self._user_survey_topics[uuid]
381
410
  await self._messager.send_message(topic, payload)
382
411
 
383
- async def send_interview_message(self, content: str, agent_uuids: Union[uuid.UUID, List[uuid.UUID]]):
412
+ async def send_interview_message(
413
+ self, content: str, agent_uuids: Union[uuid.UUID, List[uuid.UUID]]
414
+ ):
384
415
  """发送面试消息"""
385
416
  payload = {
386
417
  "from": "none",
387
418
  "content": content,
388
419
  "timestamp": int(datetime.now().timestamp() * 1000),
389
420
  }
421
+ if not isinstance(agent_uuids, Sequence):
422
+ agent_uuids = [agent_uuids]
390
423
  for uuid in agent_uuids:
391
424
  topic = self._user_chat_topics[uuid]
392
425
  await self._messager.send_message(topic, payload)
@@ -405,7 +438,7 @@ class AgentSimulation:
405
438
  async def _save_exp_info(self) -> None:
406
439
  """异步保存实验信息到YAML文件"""
407
440
  try:
408
- with open(self._exp_info_file, 'w') as f:
441
+ with open(self._exp_info_file, "w") as f:
409
442
  yaml.dump(self._exp_info, f)
410
443
  except Exception as e:
411
444
  logger.error(f"保存实验信息失败: {str(e)}")
@@ -418,20 +451,22 @@ class AgentSimulation:
418
451
 
419
452
  async def _monitor_exp_status(self, stop_event: asyncio.Event):
420
453
  """监控实验状态并更新
421
-
454
+
422
455
  Args:
423
456
  stop_event: 用于通知监控任务停止的事件
424
457
  """
425
458
  try:
426
- while not stop_event.is_set():
459
+ while not stop_event.is_set():
427
460
  # 更新实验状态
428
461
  # 假设所有group的cur_day和cur_t是同步的,取第一个即可
429
462
  self._exp_info["cur_day"] = await self._simulator.get_simulator_day()
430
- self._exp_info["cur_t"] = await self._simulator.get_simulator_second_from_start_of_day()
463
+ self._exp_info["cur_t"] = (
464
+ await self._simulator.get_simulator_second_from_start_of_day()
465
+ )
431
466
  await self._save_exp_info()
432
-
467
+
433
468
  await asyncio.sleep(1) # 避免过于频繁的更新
434
- except asyncio.CancelError:
469
+ except asyncio.CancelledError:
435
470
  # 正常取消,不需要特殊处理
436
471
  pass
437
472
  except Exception as e:
@@ -446,12 +481,12 @@ class AgentSimulation:
446
481
  try:
447
482
  self._exp_info["num_day"] += day
448
483
  await self._update_exp_status(1) # 更新状态为运行中
449
-
484
+
450
485
  # 创建停止事件
451
486
  stop_event = asyncio.Event()
452
487
  # 创建监控任务
453
488
  monitor_task = asyncio.create_task(self._monitor_exp_status(stop_event))
454
-
489
+
455
490
  try:
456
491
  tasks = []
457
492
  for group in self._groups.values():
@@ -459,13 +494,13 @@ class AgentSimulation:
459
494
 
460
495
  # 等待所有group运行完成
461
496
  await asyncio.gather(*tasks)
462
-
497
+
463
498
  finally:
464
499
  # 设置停止事件
465
500
  stop_event.set()
466
501
  # 等待监控任务结束
467
502
  await monitor_task
468
-
503
+
469
504
  # 运行成功后更新状态
470
505
  await self._update_exp_status(2)
471
506
 
@@ -1,17 +1,16 @@
1
- from typing import List, Dict, Optional
2
- from datetime import datetime
3
- import uuid
4
1
  import json
5
- from .models import Survey, Question, QuestionType, Page
2
+ import uuid
3
+ from datetime import datetime
4
+ from typing import Optional
5
+
6
+ from .models import Page, Question, QuestionType, Survey
6
7
 
7
8
 
8
9
  class SurveyManager:
9
10
  def __init__(self):
10
- self._surveys: Dict[str, Survey] = {}
11
+ self._surveys: dict[str, Survey] = {}
11
12
 
12
- def create_survey(
13
- self, title: str, description: str, pages: List[dict]
14
- ) -> Survey:
13
+ def create_survey(self, title: str, description: str, pages: list[dict]) -> Survey:
15
14
  """创建新问卷"""
16
15
  survey_id = uuid.uuid4()
17
16
 
@@ -32,11 +31,8 @@ class SurveyManager:
32
31
  max_rating=q.get("max_rating", 5),
33
32
  )
34
33
  questions.append(question)
35
-
36
- page = Page(
37
- name=page_data["name"],
38
- elements=questions
39
- )
34
+
35
+ page = Page(name=page_data["name"], elements=questions)
40
36
  survey_pages.append(page)
41
37
 
42
38
  survey = Survey(
@@ -53,6 +49,6 @@ class SurveyManager:
53
49
  """获取指定问卷"""
54
50
  return self._surveys.get(survey_id)
55
51
 
56
- def get_all_surveys(self) -> List[Survey]:
52
+ def get_all_surveys(self) -> list[Survey]:
57
53
  """获取所有问卷"""
58
54
  return list(self._surveys.values())
@@ -1,9 +1,9 @@
1
+ import json
2
+ import uuid
1
3
  from dataclasses import dataclass, field
2
- from typing import List, Dict, Optional
3
4
  from datetime import datetime
4
5
  from enum import Enum
5
- import uuid
6
- import json
6
+ from typing import Any
7
7
 
8
8
 
9
9
  class QuestionType(Enum):
@@ -20,19 +20,20 @@ class Question:
20
20
  name: str
21
21
  title: str
22
22
  type: QuestionType
23
- choices: List[str] = field(default_factory=list)
24
- columns: List[str] = field(default_factory=list)
25
- rows: List[str] = field(default_factory=list)
23
+ choices: list[str] = field(default_factory=list)
24
+ columns: list[str] = field(default_factory=list)
25
+ rows: list[str] = field(default_factory=list)
26
+ required: bool = True
26
27
  min_rating: int = 1
27
28
  max_rating: int = 5
28
29
 
29
30
  def to_dict(self) -> dict:
30
- base_dict = {
31
+ base_dict: dict[str, Any] = {
31
32
  "type": self.type.value,
32
33
  "name": self.name,
33
34
  "title": self.title,
34
35
  }
35
-
36
+
36
37
  if self.type in [QuestionType.RADIO, QuestionType.CHECKBOX]:
37
38
  base_dict["choices"] = self.choices
38
39
  elif self.type == QuestionType.MATRIX:
@@ -41,20 +42,17 @@ class Question:
41
42
  elif self.type == QuestionType.RATING:
42
43
  base_dict["min_rating"] = self.min_rating
43
44
  base_dict["max_rating"] = self.max_rating
44
-
45
+
45
46
  return base_dict
46
47
 
47
48
 
48
49
  @dataclass
49
50
  class Page:
50
51
  name: str
51
- elements: List[Question]
52
+ elements: list[Question]
52
53
 
53
54
  def to_dict(self) -> dict:
54
- return {
55
- "name": self.name,
56
- "elements": [q.to_dict() for q in self.elements]
57
- }
55
+ return {"name": self.name, "elements": [q.to_dict() for q in self.elements]}
58
56
 
59
57
 
60
58
  @dataclass
@@ -62,8 +60,8 @@ class Survey:
62
60
  id: uuid.UUID
63
61
  title: str
64
62
  description: str
65
- pages: List[Page]
66
- responses: Dict[str, dict] = field(default_factory=dict)
63
+ pages: list[Page]
64
+ responses: dict[str, dict] = field(default_factory=dict)
67
65
  created_at: datetime = field(default_factory=datetime.now)
68
66
 
69
67
  def to_dict(self) -> dict:
@@ -83,12 +81,12 @@ class Survey:
83
81
  "description": self.description,
84
82
  "pages": [p.to_dict() for p in self.pages],
85
83
  "responses": self.responses,
86
- "created_at": self.created_at.isoformat()
84
+ "created_at": self.created_at.isoformat(),
87
85
  }
88
86
  return json.dumps(survey_dict)
89
87
 
90
88
  @classmethod
91
- def from_json(cls, json_str: str) -> 'Survey':
89
+ def from_json(cls, json_str: str) -> "Survey":
92
90
  """Create a Survey instance from a JSON string"""
93
91
  data = json.loads(json_str)
94
92
  pages = [
@@ -104,17 +102,19 @@ class Survey:
104
102
  columns=q.get("columns", []),
105
103
  rows=q.get("rows", []),
106
104
  min_rating=q.get("min_rating", 1),
107
- max_rating=q.get("max_rating", 5)
108
- ) for q in p["elements"]
109
- ]
110
- ) for p in data["pages"]
105
+ max_rating=q.get("max_rating", 5),
106
+ )
107
+ for q in p["elements"]
108
+ ],
109
+ )
110
+ for p in data["pages"]
111
111
  ]
112
-
112
+
113
113
  return cls(
114
114
  id=uuid.UUID(data["id"]),
115
115
  title=data["title"],
116
116
  description=data["description"],
117
117
  pages=pages,
118
118
  responses=data.get("responses", {}),
119
- created_at=datetime.fromisoformat(data["created_at"])
119
+ created_at=datetime.fromisoformat(data["created_at"]),
120
120
  )
@@ -139,7 +139,7 @@ class UpdateWithSimulator(Tool):
139
139
  if agent._simulator is None:
140
140
  return
141
141
  if not agent._has_bound_to_simulator:
142
- await self._bind_to_simulator()
142
+ await agent._bind_to_simulator() # type: ignore
143
143
  simulator = agent.simulator
144
144
  memory = agent.memory
145
145
  person_id = await memory.get("id")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pycityagent
3
- Version: 2.0.0a15
3
+ Version: 2.0.0a17
4
4
  Summary: LLM-based城市环境agent构建库
5
5
  License: MIT
6
6
  Author: Yuwei Yan
@@ -24,6 +24,7 @@ Requires-Dist: fastavro (>=1.10.0,<2.0.0)
24
24
  Requires-Dist: geojson (==3.1.0)
25
25
  Requires-Dist: gradio (>=5.7.1,<6.0.0)
26
26
  Requires-Dist: grpcio (==1.67.1)
27
+ Requires-Dist: langchain-core (>=0.3.28,<0.4.0)
27
28
  Requires-Dist: matplotlib (==3.8.3)
28
29
  Requires-Dist: mosstool (==1.0.24)
29
30
  Requires-Dist: networkx (==3.2.1)
@@ -33,7 +34,7 @@ Requires-Dist: pandavro (>=1.8.0,<2.0.0)
33
34
  Requires-Dist: poetry (>=1.2.2)
34
35
  Requires-Dist: protobuf (<=4.24.0)
35
36
  Requires-Dist: pycitydata (==1.0.0)
36
- Requires-Dist: pycityproto (==2.0.7)
37
+ Requires-Dist: pycityproto (>=2.1.4,<3.0.0)
37
38
  Requires-Dist: pyyaml (>=6.0.2,<7.0.0)
38
39
  Requires-Dist: ray (>=2.40.0,<3.0.0)
39
40
  Requires-Dist: sidecar (==0.7.0)
@@ -1,7 +1,7 @@
1
1
  pycityagent/__init__.py,sha256=EDxt3Su3lH1IMh9suNw7GeGL7UrXeWiZTw5KWNznDzc,637
2
- pycityagent/agent.py,sha256=fcuKX6FtMzjNP8lVep9pG-9KHzHQwJ8IymJbmLKMfu0,23109
2
+ pycityagent/agent.py,sha256=t9W9sKxtQ0EkMxL78kAjAu-rXigEK6gyLY0IEA4DbnQ,23143
3
3
  pycityagent/economy/__init__.py,sha256=aonY4WHnx-6EGJ4WKrx4S-2jAkYNLtqUA04jp6q8B7w,75
4
- pycityagent/economy/econ_client.py,sha256=wcuNtcpkSijJwNkt2mXw3SshYy4SBy6qbvJ0VQ7Aovo,10854
4
+ pycityagent/economy/econ_client.py,sha256=DE11Ng_NO_foW65A-LxFW0VED-HLrnn4GwUf_Xn-Tlg,11189
5
5
  pycityagent/environment/__init__.py,sha256=awHxlOud-btWbk0FCS4RmGJ13W84oVCkbGfcrhKqihA,240
6
6
  pycityagent/environment/interact/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  pycityagent/environment/interact/interact.py,sha256=ifxPPzuHeqLHIZ_6zvfXMoBOnBsXNIP4bYp7OJ7pnEQ,6588
@@ -46,11 +46,11 @@ pycityagent/memory/utils.py,sha256=wLNlNlZ-AY9VB8kbUIy0UQSYh26FOQABbhmKQkit5o8,8
46
46
  pycityagent/message/__init__.py,sha256=TCjazxqb5DVwbTu1fF0sNvaH_EPXVuj2XQ0p6W-QCLU,55
47
47
  pycityagent/message/messager.py,sha256=W_OVlNGcreHSBf6v-DrEnfNCXExB78ySr0w26MSncfU,2541
48
48
  pycityagent/simulation/__init__.py,sha256=jYaqaNpzM5M_e_ykISS_M-mIyYdzJXJWhgpfBpA6l5k,111
49
- pycityagent/simulation/agentgroup.py,sha256=JwfssUtVrOgSnJCan4jcIcSHLjWBCwYxqOPT-AXA2sE,12514
50
- pycityagent/simulation/simulation.py,sha256=G68P1EJ3JceA3zID2O6AGd_KdhhYy5XVZVUgkfJHypc,18897
49
+ pycityagent/simulation/agentgroup.py,sha256=M19XWJRWyjMAYS0_RIOBQ2C7I1MuVYIaX3DgehGZL2Y,12541
50
+ pycityagent/simulation/simulation.py,sha256=TndrMZSm6qe_wgfv9h6mL9oAiAEHbq6KHBwdwUGG_3k,19261
51
51
  pycityagent/survey/__init__.py,sha256=rxwou8U9KeFSP7rMzXtmtp2fVFZxK4Trzi-psx9LPIs,153
52
- pycityagent/survey/manager.py,sha256=N4-Q8vve4L-PLaFikAgVB4Z8BNyFDEd6WjEk3AyMTgs,1764
53
- pycityagent/survey/models.py,sha256=Z2gHEazQRj0TkTz5qbh4Uy_JrU_FZGWpOLwjN0RoUrY,3547
52
+ pycityagent/survey/manager.py,sha256=S5IkwTdelsdtZETChRcfCEczzwSrry_Fly9MY4s3rbk,1681
53
+ pycityagent/survey/models.py,sha256=YE50UUt5qJ0O_lIUsSY6XFCGUTkJVNu_L1gAhaCJ2fs,3546
54
54
  pycityagent/utils/__init__.py,sha256=xXEMhVfFeOJUXjczaHv9DJqYNp57rc6FibtS7CfrVbA,305
55
55
  pycityagent/utils/avro_schema.py,sha256=DHM3bOo8m0dJf8oSwyOWzVeXrH6OERmzA_a5vS4So4M,4255
56
56
  pycityagent/utils/decorators.py,sha256=Gk3r41hfk6awui40tbwpq3C7wC7jHaRmLRlcJFlLQCE,3160
@@ -62,8 +62,8 @@ pycityagent/utils/survey_util.py,sha256=Be9nptmu2JtesFNemPgORh_2GsN7rcDYGQS9Zfvc
62
62
  pycityagent/workflow/__init__.py,sha256=EyCcjB6LyBim-5iAOPe4m2qfvghEPqu1ZdGfy4KPeZ8,551
63
63
  pycityagent/workflow/block.py,sha256=6EmiRMLdOZC1wMlmLMIjfrp9TuiI7Gw4s3nnXVMbrnw,6031
64
64
  pycityagent/workflow/prompt.py,sha256=tY69nDO8fgYfF_dOA-iceR8pAhkYmCqoox8uRPqEuGY,2956
65
- pycityagent/workflow/tool.py,sha256=zMvz3BV4QBs5TqyQ3ziJxj4pCfL2uqUI3A1FbT1gd3Q,6626
65
+ pycityagent/workflow/tool.py,sha256=_bCluIX8HTC8ZW6a-wrMB3Uhx2yzD8sM8XFDI3vd0MM,6642
66
66
  pycityagent/workflow/trigger.py,sha256=t5X_i0WtL32bipZSsq_E3UUyYYudYLxQUpvxbgClp2s,5683
67
- pycityagent-2.0.0a15.dist-info/METADATA,sha256=8ONKHaTIPOGja6FKJDE9tZ-pLZL2aeXQufB-LWkElR8,7705
68
- pycityagent-2.0.0a15.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
69
- pycityagent-2.0.0a15.dist-info/RECORD,,
67
+ pycityagent-2.0.0a17.dist-info/METADATA,sha256=fxmJlP11NGtkwMyZr8QB1U2_0lgynLasG6z7GOEZCus,7760
68
+ pycityagent-2.0.0a17.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
69
+ pycityagent-2.0.0a17.dist-info/RECORD,,