pycityagent 2.0.0a10__py3-none-any.whl → 2.0.0a12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pycityagent/agent.py CHANGED
@@ -2,12 +2,15 @@
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  import asyncio
5
+ import json
5
6
  from uuid import UUID
6
7
  from copy import deepcopy
7
8
  from datetime import datetime
9
+ import time
8
10
  from enum import Enum
9
11
  import logging
10
12
  import random
13
+ import uuid
11
14
  from typing import Dict, List, Optional
12
15
 
13
16
  from pycityagent.environment.sim.person_service import PersonService
@@ -64,6 +67,7 @@ class Agent(ABC):
64
67
  """
65
68
  self._name = name
66
69
  self._type = type
70
+ self._uuid = uuid.uuid4()
67
71
  self._llm_client = llm_client
68
72
  self._economy_client = economy_client
69
73
  self._messager = messager
@@ -118,76 +122,15 @@ class Agent(ABC):
118
122
  """
119
123
  self._exp_id = exp_id
120
124
 
121
- async def _bind_to_simulator(self):
122
- """
123
- Bind Agent to Simulator
124
-
125
- Args:
126
- person_template (dict, optional): The person template in dict format. Defaults to PersonService.default_dict_person().
127
- """
128
- if self._simulator is None:
129
- logging.warning("Simulator is not set")
130
- return
131
- if not self._has_bound_to_simulator:
132
- FROM_MEMORY_KEYS = {
133
- "attribute",
134
- "home",
135
- "work",
136
- "vehicle_attribute",
137
- "bus_attribute",
138
- "pedestrian_attribute",
139
- "bike_attribute",
140
- }
141
- simulator = self._simulator
142
- memory = self.memory
143
- person_id = await memory.get("id")
144
- # ATTENTION:模拟器分配的id从0开始
145
- if person_id >= 0:
146
- await simulator.get_person(person_id)
147
- logging.debug(f"Binding to Person `{person_id}` already in Simulator")
148
- else:
149
- dict_person = deepcopy(self._person_template)
150
- for _key in FROM_MEMORY_KEYS:
151
- try:
152
- _value = await memory.get(_key)
153
- if _value:
154
- dict_person[_key] = _value
155
- except KeyError as e:
156
- continue
157
- resp = await simulator.add_person(
158
- dict2pb(dict_person, person_pb2.Person())
159
- )
160
- person_id = resp["person_id"]
161
- await memory.update("id", person_id, protect_llm_read_only_fields=False)
162
- logging.debug(
163
- f"Binding to Person `{person_id}` just added to Simulator"
164
- )
165
- # 防止模拟器还没有到prepare阶段导致get_person出错
166
- self._has_bound_to_simulator = True
167
- self._agent_id = person_id
125
+ @property
126
+ def uuid(self):
127
+ """The Agent's UUID"""
128
+ return self._uuid
168
129
 
169
- async def _bind_to_economy(self):
170
- if self._economy_client is None:
171
- logging.warning("Economy client is not set")
172
- return
173
- if not self._has_bound_to_economy:
174
- if self._has_bound_to_simulator:
175
- try:
176
- await self._economy_client.remove_agents([self._agent_id])
177
- except:
178
- pass
179
- person_id = await self.memory.get("id")
180
- await self._economy_client.add_agents(
181
- {
182
- "id": person_id,
183
- "currency": await self.memory.get("currency"),
184
- }
185
- )
186
- self._has_bound_to_economy = True
187
- else:
188
- logging.debug(
189
- f"Binding to Economy before binding to Simulator, skip binding to Economy Simulator"
190
- )
130
+ @property
131
+ def sim_id(self):
132
+ """The Agent's Simulator ID"""
133
+ return self._agent_id
191
134
 
192
135
  @property
193
136
  def llm(self):
@@ -277,22 +220,60 @@ class Agent(ABC):
277
220
  """获取采访历史记录"""
278
221
  return self._interview_history
279
222
 
280
- async def handle_message(self, payload: str):
223
+ async def handle_agent_chat_message(self, payload: dict):
224
+ """处理收到的消息,识别发送者"""
225
+ # 从消息中解析发送者 ID 和消息内容
226
+ print(
227
+ f"Agent {self._uuid} received agent chat message: '{payload['content']}' from Agent {payload['from']}"
228
+ )
229
+
230
+ async def handle_user_chat_message(self, payload: dict):
231
+ """处理收到的消息,识别发送者"""
232
+ # 从消息中解析发送者 ID 和消息内容
233
+ print(
234
+ f"Agent {self._uuid} received user chat message: '{payload['content']}' from User"
235
+ )
236
+
237
+ async def handle_user_survey_message(self, payload: dict):
281
238
  """处理收到的消息,识别发送者"""
282
239
  # 从消息中解析发送者 ID 和消息内容
283
- message, sender_id = payload.split("|from:")
284
240
  print(
285
- f"Agent {self._agent_id} received message: '{message}' from Agent {sender_id}"
241
+ f"Agent {self._uuid} received user survey message: '{payload['content']}' from User"
286
242
  )
287
243
 
288
- async def send_message(
289
- self, to_agent_id: int, message: str, sub_topic: str = "chat"
244
+ async def _send_message(
245
+ self, to_agent_uuid: UUID, payload: dict, sub_topic: str
290
246
  ):
291
- """通过 Messager 发送消息,附带发送者的 ID"""
247
+ """通过 Messager 发送消息"""
292
248
  if self._messager is None:
293
249
  raise RuntimeError("Messager is not set")
294
- topic = f"/exps/{self._exp_id}/agents/{to_agent_id}/{sub_topic}"
295
- await self._messager.send_message(topic, message, self._agent_id)
250
+ topic = f"exps/{self._exp_id}/agents/{to_agent_uuid}/{sub_topic}"
251
+ await self._messager.send_message(topic, payload)
252
+
253
+ async def send_message_to_agent(
254
+ self, to_agent_uuid: UUID, content: dict
255
+ ):
256
+ """通过 Messager 发送消息"""
257
+ if self._messager is None:
258
+ raise RuntimeError("Messager is not set")
259
+ payload = {
260
+ "from": self._uuid,
261
+ "content": content,
262
+ "timestamp": int(time.time()),
263
+ "day": await self._simulator.get_simulator_day(),
264
+ "t": await self._simulator.get_simulator_second_from_start_of_day(),
265
+ }
266
+ await self._send_message(to_agent_uuid, payload, "agent-chat")
267
+
268
+ async def send_message_to_user(
269
+ self, content: dict
270
+ ):
271
+ pass
272
+
273
+ async def send_message_to_survey(
274
+ self, content: dict
275
+ ):
276
+ pass
296
277
 
297
278
  @abstractmethod
298
279
  async def forward(self) -> None:
@@ -410,12 +391,17 @@ class CitizenAgent(Agent):
410
391
  f"Binding to Economy before binding to Simulator, skip binding to Economy Simulator"
411
392
  )
412
393
 
413
- async def handle_gather_message(self, payload: str):
394
+ async def handle_gather_message(self, payload: dict):
414
395
  """处理收到的消息,识别发送者"""
415
396
  # 从消息中解析发送者 ID 和消息内容
416
- target, sender_id = payload.split("|from:")
397
+ target = payload["target"]
398
+ sender_id = payload["from"]
417
399
  content = await self.memory.get(f"{target}")
418
- await self.send_message(int(sender_id), content, "gather")
400
+ payload = {
401
+ "from": self._uuid,
402
+ "content": content,
403
+ }
404
+ await self._send_message(sender_id, payload, "gather")
419
405
 
420
406
 
421
407
  class InstitutionAgent(Agent):
@@ -455,28 +441,60 @@ class InstitutionAgent(Agent):
455
441
  self._agent_id = _id
456
442
  await self._memory.update("id", _id, protect_llm_read_only_fields=False)
457
443
  try:
458
- await self._economy_client.remove_agents([self._agent_id])
444
+ await self._economy_client.remove_orgs([self._agent_id])
459
445
  except:
460
446
  pass
461
447
  try:
462
448
  id = await self._memory.get("id")
463
449
  type = await self._memory.get("type")
464
- nominal_gdp = await self._memory.get("nominal_gdp")
465
- real_gdp = await self._memory.get("real_gdp")
466
- unemployment = await self._memory.get("unemployment")
467
- wages = await self._memory.get("wages")
468
- prices = await self._memory.get("prices")
469
- inventory = await self._memory.get("inventory")
470
- price = await self._memory.get("price")
471
- currency = await self._memory.get("currency")
472
- interest_rate = await self._memory.get("interest_rate")
473
- bracket_cutoffs = await self._memory.get("bracket_cutoffs")
474
- bracket_rates = await self._memory.get("bracket_rates")
450
+ try:
451
+ nominal_gdp = await self._memory.get("nominal_gdp")
452
+ except:
453
+ nominal_gdp = []
454
+ try:
455
+ real_gdp = await self._memory.get("real_gdp")
456
+ except:
457
+ real_gdp = []
458
+ try:
459
+ unemployment = await self._memory.get("unemployment")
460
+ except:
461
+ unemployment = []
462
+ try:
463
+ wages = await self._memory.get("wages")
464
+ except:
465
+ wages = []
466
+ try:
467
+ prices = await self._memory.get("prices")
468
+ except:
469
+ prices = []
470
+ try:
471
+ inventory = await self._memory.get("inventory")
472
+ except:
473
+ inventory = 0
474
+ try:
475
+ price = await self._memory.get("price")
476
+ except:
477
+ price = 0
478
+ try:
479
+ currency = await self._memory.get("currency")
480
+ except:
481
+ currency = 0.0
482
+ try:
483
+ interest_rate = await self._memory.get("interest_rate")
484
+ except:
485
+ interest_rate = 0.0
486
+ try:
487
+ bracket_cutoffs = await self._memory.get("bracket_cutoffs")
488
+ except:
489
+ bracket_cutoffs = []
490
+ try:
491
+ bracket_rates = await self._memory.get("bracket_rates")
492
+ except:
493
+ bracket_rates = []
475
494
  await self._economy_client.add_orgs(
476
495
  {
477
496
  "id": id,
478
497
  "type": type,
479
- "currency": currency,
480
498
  "nominal_gdp": nominal_gdp,
481
499
  "real_gdp": real_gdp,
482
500
  "unemployment": unemployment,
@@ -484,6 +502,7 @@ class InstitutionAgent(Agent):
484
502
  "prices": prices,
485
503
  "inventory": inventory,
486
504
  "price": price,
505
+ "currency": currency,
487
506
  "interest_rate": interest_rate,
488
507
  "bracket_cutoffs": bracket_cutoffs,
489
508
  "bracket_rates": bracket_rates,
@@ -493,15 +512,20 @@ class InstitutionAgent(Agent):
493
512
  logging.error(f"Failed to bind to Economy: {e}")
494
513
  self._has_bound_to_economy = True
495
514
 
496
- async def handle_gather_message(self, payload: str):
515
+ async def handle_gather_message(self, payload: dict):
497
516
  """处理收到的消息,识别发送者"""
498
517
  # 从消息中解析发送者 ID 和消息内容
499
- content, sender_id = payload.split("|from:")
518
+ content = payload["content"]
519
+ sender_id = payload["from"]
500
520
  print(
501
- f"Agent {self._agent_id} received gather message: '{content}' from Agent {sender_id}"
521
+ f"Agent {self._uuid} received gather message: '{content}' from Agent {sender_id}"
502
522
  )
503
523
 
504
- async def gather_messages(self, agent_ids: list[int], content: str):
524
+ async def gather_messages(self, agent_ids: list[UUID], target: str):
505
525
  """从多个智能体收集消息"""
526
+ payload = {
527
+ "from": self._uuid,
528
+ "target": target,
529
+ }
506
530
  for agent_id in agent_ids:
507
- await self.send_message(agent_id, content, "gather")
531
+ await self._send_message(agent_id, payload, "gather")
@@ -190,7 +190,25 @@ class Simulator:
190
190
  formatted_time = current_time.strftime(format)
191
191
  return formatted_time
192
192
  else:
193
+ # BUG: 返回的time是float类型
193
194
  return t_sec["t"]
195
+
196
+ async def get_simulator_day(self) -> int:
197
+ """
198
+ 获取模拟器到第几日
199
+ """
200
+ t_sec = await self._client.clock_service.Now({})
201
+ t_sec = cast(dict[str, int], t_sec)
202
+ day = t_sec["t"] // 86400
203
+ return day
204
+
205
+ async def get_simulator_second_from_start_of_day(self) -> int:
206
+ """
207
+ 获取模拟器从00:00:00到当前的秒数
208
+ """
209
+ t_sec = await self._client.clock_service.Now({})
210
+ t_sec = cast(dict[str, int], t_sec)
211
+ return t_sec["t"] % 86400
194
212
 
195
213
  async def get_person(self, person_id: int) -> dict:
196
214
  return await self._client.person_service.GetPerson(
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  from collections import defaultdict
3
+ import json
3
4
  import logging
4
5
  import math
5
6
  from aiomqtt import Client
@@ -42,7 +43,7 @@ class Messager:
42
43
  return
43
44
  await self.client.subscribe(topic)
44
45
  self.subscribers[topic] = agent
45
- logging.info(f"Subscribed to {topic} for Agent {agent._agent_id}")
46
+ logging.info(f"Subscribed to {topic} for Agent {agent._uuid}")
46
47
 
47
48
  async def receive_messages(self):
48
49
  """监听并将消息存入队列"""
@@ -56,10 +57,9 @@ class Messager:
56
57
  messages.append(await self.message_queue.get())
57
58
  return messages
58
59
 
59
- async def send_message(self, topic: str, payload: str, sender_id: int):
60
- """通过 Messager 发送消息,包含发送者 ID"""
61
- # 构造消息,payload 中加入 sender_id 以便接收者识别
62
- message = f"{payload}|from:{sender_id}"
60
+ async def send_message(self, topic: str, payload: dict):
61
+ """通过 Messager 发送消息"""
62
+ message = json.dumps(payload, default=str)
63
63
  await self.client.publish(topic, message)
64
64
  logging.info(f"Message sent to {topic}: {message}")
65
65
 
@@ -1,5 +1,7 @@
1
1
  import asyncio
2
+ import json
2
3
  import logging
4
+ import uuid
3
5
  import ray
4
6
  from uuid import UUID
5
7
  from pycityagent.agent import Agent
@@ -16,6 +18,7 @@ class AgentGroup:
16
18
  self.agents = agents
17
19
  self.config = config
18
20
  self.exp_id = exp_id
21
+ self._uuid = uuid.uuid4()
19
22
  self.messager = Messager(
20
23
  hostname=config["simulator_request"]["mqtt"]["server"],
21
24
  port=config["simulator_request"]["mqtt"]["port"],
@@ -53,26 +56,30 @@ class AgentGroup:
53
56
  async def init_agents(self):
54
57
  for agent in self.agents:
55
58
  await agent.bind_to_simulator()
56
- self.id2agent = {agent._agent_id: agent for agent in self.agents}
59
+ self.id2agent = {agent._uuid: agent for agent in self.agents}
57
60
  await self.messager.connect()
58
61
  if self.messager.is_connected():
59
62
  await self.messager.start_listening()
60
63
  for agent in self.agents:
61
64
  agent.set_messager(self.messager)
62
- topic = f"/exps/{self.exp_id}/agents/{agent._agent_id}/chat"
65
+ topic = f"exps/{self.exp_id}/agents/{agent._uuid}/agent-chat"
63
66
  await self.messager.subscribe(topic, agent)
64
- topic = f"/exps/{self.exp_id}/agents/{agent._agent_id}/gather"
67
+ topic = f"exps/{self.exp_id}/agents/{agent._uuid}/user-chat"
68
+ await self.messager.subscribe(topic, agent)
69
+ topic = f"exps/{self.exp_id}/agents/{agent._uuid}/user-survey"
70
+ await self.messager.subscribe(topic, agent)
71
+ topic = f"exps/{self.exp_id}/agents/{agent._uuid}/gather"
65
72
  await self.messager.subscribe(topic, agent)
66
73
  self.initialized = True
67
74
 
68
75
  async def gather(self, content: str):
69
76
  results = {}
70
77
  for agent in self.agents:
71
- results[agent._agent_id] = await agent.memory.get(content)
78
+ results[agent._uuid] = await agent.memory.get(content)
72
79
  return results
73
80
 
74
- async def update(self, target_agent_id: str, target_key: str, content: Any):
75
- agent = self.id2agent[target_agent_id]
81
+ async def update(self, target_agent_uuid: str, target_key: str, content: Any):
82
+ agent = self.id2agent[target_agent_uuid]
76
83
  await agent.memory.update(target_key, content)
77
84
 
78
85
  async def step(self):
@@ -90,7 +97,7 @@ class AgentGroup:
90
97
  # Step 2: 从 Messager 获取消息
91
98
  messages = await self.messager.fetch_messages()
92
99
 
93
- print(f"Received {len(messages)} messages")
100
+ logging.info(f"Group {self._uuid} received {len(messages)} messages")
94
101
 
95
102
  # Step 3: 分发消息到对应的 Agent
96
103
  for message in messages:
@@ -100,15 +107,20 @@ class AgentGroup:
100
107
  # 添加解码步骤,将bytes转换为str
101
108
  if isinstance(payload, bytes):
102
109
  payload = payload.decode("utf-8")
103
-
104
- # 提取 agent_id(主题格式为 "/exps/{exp_id}/agents/{agent_id}/chat" 或 "/exps/{exp_id}/agents/{agent_id}/gather")
105
- _, _, _, agent_id, topic_type = topic.strip("/").split("/")
106
- agent_id = int(agent_id)
107
-
108
- if agent_id in self.id2agent:
109
- agent = self.id2agent[agent_id]
110
- if topic_type == "chat":
111
- await agent.handle_message(payload)
110
+ payload = json.loads(payload)
111
+
112
+ # 提取 agent_id(主题格式为 "exps/{exp_id}/agents/{agent_uuid}/{topic_type}"
113
+ _, _, _, agent_uuid, topic_type = topic.strip("/").split("/")
114
+
115
+ if uuid.UUID(agent_uuid) in self.id2agent:
116
+ agent = self.id2agent[uuid.UUID(agent_uuid)]
117
+ # topic_type: agent-chat, user-chat, user-survey, gather
118
+ if topic_type == "agent-chat":
119
+ await agent.handle_agent_chat_message(payload)
120
+ elif topic_type == "user-chat":
121
+ await agent.handle_user_chat_message(payload)
122
+ elif topic_type == "user-survey":
123
+ await agent.handle_user_survey_message(payload)
112
124
  elif topic_type == "gather":
113
125
  await agent.handle_gather_message(payload)
114
126
 
@@ -125,13 +137,13 @@ class AgentGroup:
125
137
  try:
126
138
  # 获取开始时间
127
139
  start_time = await self.simulator.get_time()
128
- assert type(start_time)==int
140
+ start_time = int(start_time)
129
141
  # 计算结束时间(秒)
130
142
  end_time = start_time + day * 24 * 3600 # 将天数转换为秒
131
143
 
132
144
  while True:
133
145
  current_time = await self.simulator.get_time()
134
- assert type(current_time)==int
146
+ current_time = int(current_time)
135
147
  if current_time >= end_time:
136
148
  break
137
149
  await self.step()
@@ -6,10 +6,10 @@ from datetime import datetime
6
6
  import random
7
7
  from typing import Dict, List, Optional, Callable, Union,Any
8
8
  from mosstool.map._map_util.const import AOI_START_ID
9
-
9
+ import pycityproto.city.economy.v2.economy_pb2 as economyv2
10
10
  from pycityagent.memory.memory import Memory
11
11
 
12
- from ..agent import Agent
12
+ from ..agent import Agent, InstitutionAgent
13
13
  from .interview import InterviewManager
14
14
  from .survey import QuestionType, SurveyManager
15
15
  from .ui import InterviewUI
@@ -40,18 +40,30 @@ class AgentSimulation:
40
40
  self.agent_class = [agent_class]
41
41
  self.config = config
42
42
  self.agent_prefix = agent_prefix
43
- self._agents: Dict[str, Agent] = {}
44
- self._groups: Dict[str, AgentGroup] = {} # type:ignore
45
- self._interview_manager = InterviewManager()
46
- self._interview_lock = asyncio.Lock()
47
- self._start_time = datetime.now()
48
- self._agent_run_times: Dict[str, datetime] = {} # 记录每个智能体的运行开始时间
49
- self._ui: Optional[InterviewUI] = None
43
+ self._agents: Dict[uuid.UUID, Agent] = {}
44
+ self._groups: Dict[str, AgentGroup] = {}
45
+ self._agent_uuid2group: Dict[uuid.UUID, AgentGroup] = {}
46
+ self._agent_uuids: List[uuid.UUID] = []
47
+
50
48
  self._loop = asyncio.get_event_loop()
51
- self._blocked_agents: List[str] = [] # 新增:持续阻塞的智能体列表
49
+ self._interview_manager = InterviewManager()
52
50
  self._survey_manager = SurveyManager()
53
- self._agentid2group: Dict[str, AgentGroup] = {}# type:ignore
54
- self._agent_ids: List[str] = []
51
+
52
+ @property
53
+ def agents(self):
54
+ return self._agents
55
+
56
+ @property
57
+ def groups(self):
58
+ return self._groups
59
+
60
+ @property
61
+ def agent_uuids(self):
62
+ return self._agent_uuids
63
+
64
+ @property
65
+ def agent_uuid2group(self):
66
+ return self._agent_uuid2group
55
67
 
56
68
  async def init_agents(
57
69
  self,
@@ -76,15 +88,25 @@ class AgentSimulation:
76
88
  logging.warning(
77
89
  "memory_config_func is None, using default memory config function"
78
90
  )
79
- memory_config_func = [self.default_memory_config_func]
91
+ memory_config_func = []
92
+ for agent_class in self.agent_class:
93
+ if issubclass(agent_class, InstitutionAgent):
94
+ memory_config_func.append(self.default_memory_config_institution)
95
+ else:
96
+ memory_config_func.append(self.default_memory_config_citizen)
80
97
  elif not isinstance(memory_config_func, list):
81
98
  memory_config_func = [memory_config_func]
82
99
 
83
100
  if len(memory_config_func) != len(agent_count):
84
101
  logging.warning(
85
- "memory_config_func和agent_count的长度不一致,使用默认的memory_config_func"
102
+ "memory_config_func和agent_count的长度不一致,使用默认的memory_config"
86
103
  )
87
- memory_config_func = [self.default_memory_config_func] * len(agent_count)
104
+ memory_config_func = []
105
+ for agent_class in self.agent_class:
106
+ if agent_class == InstitutionAgent:
107
+ memory_config_func.append(self.default_memory_config_institution)
108
+ else:
109
+ memory_config_func.append(self.default_memory_config_citizen)
88
110
 
89
111
  class_init_index = 0
90
112
  for i in range(len(self.agent_class)):
@@ -97,7 +119,7 @@ class AgentSimulation:
97
119
  # 获取Memory配置
98
120
  extra_attributes, profile, base = memory_config_func_i()
99
121
  memory = Memory(
100
- config=extra_attributes, profile=profile.copy(), base=base.copy()
122
+ config=extra_attributes, profile=profile, base=base
101
123
  )
102
124
 
103
125
  # 创建智能体时传入Memory配置
@@ -106,7 +128,8 @@ class AgentSimulation:
106
128
  memory=memory,
107
129
  )
108
130
 
109
- self._agents[agent_name] = agent
131
+ self._agents[agent._uuid] = agent
132
+ self._agent_uuids.append(agent._uuid)
110
133
 
111
134
  # 计算需要的组数,向上取整以处理不足一组的情况
112
135
  num_group = (agent_count_i + group_size - 1) // group_size
@@ -121,9 +144,11 @@ class AgentSimulation:
121
144
 
122
145
  # 获取当前组的agents
123
146
  agents = list(self._agents.values())[start_idx:end_idx]
124
- group_name = f"{self.agent_prefix}_{i}_group_{k}"
147
+ group_name = f"AgentType_{i}_Group_{k}"
125
148
  group = AgentGroup.remote(agents, self.config, self.exp_id)
126
149
  self._groups[group_name] = group
150
+ for agent in agents:
151
+ self._agent_uuid2group[agent._uuid] = group
127
152
 
128
153
  class_init_index += agent_count_i # 更新类初始索引
129
154
 
@@ -132,12 +157,6 @@ class AgentSimulation:
132
157
  init_tasks.append(group.init_agents.remote())
133
158
  await asyncio.gather(*init_tasks)
134
159
 
135
- for group in self._groups.values():
136
- agent_ids = await group.gather.remote("id")
137
- for agent_id in agent_ids:
138
- self._agent_ids.append(agent_id)
139
- self._agentid2group[agent_id] = group
140
-
141
160
  async def gather(self, content: str):
142
161
  """收集智能体的特定信息"""
143
162
  gather_tasks = []
@@ -147,10 +166,29 @@ class AgentSimulation:
147
166
 
148
167
  async def update(self, target_agent_id: str, target_key: str, content: Any):
149
168
  """更新指定智能体的记忆"""
150
- group = self._agentid2group[target_agent_id]
169
+ group = self._agent_uuid2group[target_agent_id]
151
170
  await group.update.remote(target_agent_id, target_key, content)
152
171
 
153
- def default_memory_config_func(self):
172
+ def default_memory_config_institution(self):
173
+ """默认的Memory配置函数"""
174
+ EXTRA_ATTRIBUTES = {
175
+ "type": (int, random.choice([economyv2.ORG_TYPE_BANK, economyv2.ORG_TYPE_GOVERNMENT, economyv2.ORG_TYPE_FIRM, economyv2.ORG_TYPE_NBS, economyv2.ORG_TYPE_UNSPECIFIED])),
176
+ "nominal_gdp": (list, [], True),
177
+ "real_gdp": (list, [], True),
178
+ "unemployment": (list, [], True),
179
+ "wages": (list, [], True),
180
+ "prices": (list, [], True),
181
+ "inventory": (int, 0, True),
182
+ "price": (float, 0.0, True),
183
+ "interest_rate": (float, 0.0, True),
184
+ "bracket_cutoffs": (list, [], True),
185
+ "bracket_rates": (list, [], True),
186
+ "employees": (list, [], True),
187
+ "customers": (list, [], True),
188
+ }
189
+ return EXTRA_ATTRIBUTES, None, None
190
+
191
+ def default_memory_config_citizen(self):
154
192
  """默认的Memory配置函数"""
155
193
  EXTRA_ATTRIBUTES = {
156
194
  # 需求信息
@@ -248,201 +286,6 @@ class AgentSimulation:
248
286
 
249
287
  return EXTRA_ATTRIBUTES, PROFILE, BASE
250
288
 
251
- def get_agent_runtime(self, agent_name: str) -> str:
252
- """获取智能体运行时间"""
253
- if agent_name not in self._agent_run_times:
254
- return "-"
255
- delta = datetime.now() - self._agent_run_times[agent_name]
256
- hours = delta.seconds // 3600
257
- minutes = (delta.seconds % 3600) // 60
258
- seconds = delta.seconds % 60
259
- return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
260
-
261
- def get_total_runtime(self) -> str:
262
- """获取总运行时间"""
263
- delta = datetime.now() - self._start_time
264
- hours = delta.seconds // 3600
265
- minutes = (delta.seconds % 3600) // 60
266
- seconds = delta.seconds % 60
267
- return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
268
-
269
- def export_chat_history(self, agent_name: str | None) -> str:
270
- """导出对话历史
271
-
272
- Args:
273
- agent_name: 可选的智能体名称,如果提供则只导出该智能体的对话
274
-
275
- Returns:
276
- str: JSON格式的对话历史
277
- """
278
- history = (
279
- self._interview_manager.get_agent_history(agent_name)
280
- if agent_name
281
- else self._interview_manager.get_recent_history(limit=1000)
282
- )
283
-
284
- # 转换为易读格式
285
- formatted_history = []
286
- for record in history:
287
- formatted_history.append(
288
- {
289
- "timestamp": record.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
290
- "agent": record.agent_name,
291
- "question": record.question,
292
- "response": record.response,
293
- "blocking": record.blocking,
294
- }
295
- )
296
-
297
- return json.dumps(formatted_history, ensure_ascii=False, indent=2)
298
-
299
- def toggle_agent_block(self, agent_name: str, blocking: bool) -> str:
300
- """切换智能体的阻塞状态
301
-
302
- Args:
303
- agent_name: 能体名称
304
- blocking: True表示阻塞,False表示取消阻塞
305
-
306
- Returns:
307
- str: 状态变更消息
308
- """
309
- if agent_name not in self._agents:
310
- return f"找不到智能体 {agent_name}"
311
-
312
- if blocking and agent_name not in self._blocked_agents:
313
- self._blocked_agents.append(agent_name)
314
- self._agents[agent_name]._blocked = True
315
- return f"已阻塞智能体 {agent_name}"
316
- elif not blocking and agent_name in self._blocked_agents:
317
- self._blocked_agents.remove(agent_name)
318
- self._agents[agent_name]._blocked = False
319
- return f"已取消阻塞智能体 {agent_name}"
320
-
321
- return f"智能体 {agent_name} 状态未变"
322
-
323
- async def interview_agent(self, agent_name: str, question: str) -> str:
324
- """采访指定智能体"""
325
- agent = self._agents.get(agent_name)
326
- if not agent:
327
- return "找不到指定的智能体"
328
-
329
- try:
330
- response = await agent.generate_response(question)
331
- # 记录采访历史
332
- self._interview_manager.add_record(
333
- agent_name,
334
- question,
335
- response,
336
- blocking=(agent_name in self._blocked_agents),
337
- )
338
- return response
339
-
340
- except Exception as e:
341
- logger.error(f"采访过程出错: {str(e)}")
342
- return f"采访过程出现错误: {str(e)}"
343
-
344
- async def submit_survey(self, agent_name: str, survey_id: str) -> str:
345
- """向智能体提交问卷
346
-
347
- Args:
348
- agent_name: 智能体名称
349
- survey_id: 问卷ID
350
-
351
- Returns:
352
- str: 处理结果
353
- """
354
- agent = self._agents.get(agent_name)
355
- if not agent:
356
- return "找不到指定的智能体"
357
-
358
- survey = self._survey_manager.get_survey(survey_id)
359
- if not survey:
360
- return "找不到指定的问卷"
361
-
362
- try:
363
- # 建问卷提示
364
- prompt = f"""请以第一人称回答以下调查问卷:
365
-
366
- 问卷标题: {survey.title}
367
- 问卷说明: {survey.description}
368
-
369
- """
370
- for i, question in enumerate(survey.questions):
371
- prompt += f"\n问题{i+1}. {question.content}"
372
- if question.type in (
373
- QuestionType.SINGLE_CHOICE,
374
- QuestionType.MULTIPLE_CHOICE,
375
- ):
376
- prompt += "\n选项: " + ", ".join(question.options)
377
- elif question.type == QuestionType.RATING:
378
- prompt += (
379
- f"\n(请给出{question.min_rating}-{question.max_rating}的评分)"
380
- )
381
- elif question.type == QuestionType.LIKERT:
382
- prompt += "\n(1-强烈不同意, 2-不同意, 3-中立, 4-同意, 5-强烈同意)"
383
-
384
- # 生成回答
385
- response = await agent.generate_response(prompt)
386
-
387
- # 存储原始回答
388
- self._survey_manager.add_response(
389
- survey_id, agent_name, {"raw_response": response, "parsed": False}
390
- )
391
-
392
- return response
393
-
394
- except Exception as e:
395
- logger.error(f"问卷处理出错: {str(e)}")
396
- return f"问卷处理出现错误: {str(e)}"
397
-
398
- def create_survey(self, **survey_data: dict) -> None:
399
- """创建新问卷
400
-
401
- Args:
402
- survey_data: 问卷数据,包含 title, description, questions
403
-
404
- Returns:
405
- 更新后的问卷列表
406
- """
407
- self._survey_manager.create_survey(**survey_data) # type:ignore
408
-
409
- def get_surveys(self) -> list:
410
- """获取所有问卷"""
411
- return self._survey_manager.get_all_surveys()
412
-
413
- def get_survey_questions(self, survey_id: str) -> dict | None:
414
- """获取指定问卷的问题列表
415
-
416
- Args:
417
- survey_id: 问卷ID
418
-
419
- Returns:
420
- 问卷数据,包含 title, description, questions
421
- """
422
- for _, survey in self._survey_manager._surveys.items():
423
- survey_dict = survey.to_dict()
424
- if survey_dict["id"] == survey_id:
425
- return survey_dict
426
- return None
427
-
428
- async def init_ui(
429
- self,
430
- server_name: str = "127.0.0.1",
431
- server_port: int = 7860,
432
- ):
433
- """初始化UI"""
434
- self._interview_lock = asyncio.Lock()
435
- # 初始化GradioUI
436
- self._ui = InterviewUI(self)
437
- interface = self._ui.create_interface()
438
- interface.queue().launch(
439
- server_name=server_name,
440
- server_port=server_port,
441
- prevent_thread_lock=True,
442
- quiet=True,
443
- )
444
- logger.info(f"Gradio Frontend is running on http://{server_name}:{server_port}")
445
-
446
289
  async def step(self):
447
290
  """运行一步, 即每个智能体执行一次forward"""
448
291
  try:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pycityagent
3
- Version: 2.0.0a10
3
+ Version: 2.0.0a12
4
4
  Summary: LLM-based城市环境agent构建库
5
5
  License: MIT
6
6
  Author: Yuwei Yan
@@ -1,5 +1,5 @@
1
1
  pycityagent/__init__.py,sha256=n56bWkAUEcvjDsb7LcJpaGjlrriSKPnR0yBhwRfEYBA,212
2
- pycityagent/agent.py,sha256=HYy5lI3IZtiyVAEbz9F8gsTP-qDHEENnajVEjF1a6YA,17545
2
+ pycityagent/agent.py,sha256=h0p9G5bLlfiv4yB8bkc9l2Z4jpa7s4uYVwAT4IrxusY,17396
3
3
  pycityagent/economy/__init__.py,sha256=aonY4WHnx-6EGJ4WKrx4S-2jAkYNLtqUA04jp6q8B7w,75
4
4
  pycityagent/economy/econ_client.py,sha256=qQb_kZneEXGBRaS_y5Jdoi95I8GyjKEsDSC4s6V6R7w,10829
5
5
  pycityagent/environment/__init__.py,sha256=awHxlOud-btWbk0FCS4RmGJ13W84oVCkbGfcrhKqihA,240
@@ -21,7 +21,7 @@ pycityagent/environment/sim/person_service.py,sha256=nIvOsoBoqOTDYtsiThg07-4ZBgk
21
21
  pycityagent/environment/sim/road_service.py,sha256=phKTwTyhc_6Ht2mddEXpdENfl-lRXIVY0CHAlw1yHjI,1264
22
22
  pycityagent/environment/sim/sim_env.py,sha256=HI1LcS_FotDKQ6vBnx0e49prXSABOfA20aU9KM-ZkCY,4625
23
23
  pycityagent/environment/sim/social_service.py,sha256=6Iqvq6dz8H2jhLLdtaITc6Js9QnQw-Ylsd5AZgUj3-E,1993
24
- pycityagent/environment/simulator.py,sha256=_2gW_efEAhIhQYQGFWgynQyRMhMxYSlRG38HqYPek-0,11866
24
+ pycityagent/environment/simulator.py,sha256=fQv6D_1OnhUrKTsnnah3wnm9ec8LE1phxRhK1K93Zyg,12466
25
25
  pycityagent/environment/utils/__init__.py,sha256=1m4Q1EfGvNpUsa1bgQzzCyWhfkpElnskNImjjFD3Znc,237
26
26
  pycityagent/environment/utils/base64.py,sha256=hoREzQo3FXMN79pqQLO2jgsDEvudciomyKii7MWljAM,374
27
27
  pycityagent/environment/utils/const.py,sha256=3RMNy7_bE7-23K90j9DFW_tWEzu8s7hSTgKbV-3BFl4,5327
@@ -44,11 +44,11 @@ pycityagent/memory/self_define.py,sha256=poPiexNhOLq_iTgK8s4mK_xoL_DAAcB8kMvInj7
44
44
  pycityagent/memory/state.py,sha256=5W0c1yJ-aaPpE74B2LEcw3Ygpm77tyooHv8NylyrozE,5113
45
45
  pycityagent/memory/utils.py,sha256=wLNlNlZ-AY9VB8kbUIy0UQSYh26FOQABbhmKQkit5o8,850
46
46
  pycityagent/message/__init__.py,sha256=TCjazxqb5DVwbTu1fF0sNvaH_EPXVuj2XQ0p6W-QCLU,55
47
- pycityagent/message/messager.py,sha256=kLwh4SarqDc73hlmE7rhw3qwpeB7YMpSOGl0z7WAELE,2606
47
+ pycityagent/message/messager.py,sha256=Iv4pK83JvHAQSZyGNACryPBey2wRoiok3Hb1eIwHbww,2506
48
48
  pycityagent/simulation/__init__.py,sha256=jYaqaNpzM5M_e_ykISS_M-mIyYdzJXJWhgpfBpA6l5k,111
49
- pycityagent/simulation/agentgroup.py,sha256=2Hv3b9F3thfAfxDnjXPk0z0mXK37dLz-nhbNf_cTOMU,5421
49
+ pycityagent/simulation/agentgroup.py,sha256=DzZooUZstnhObO3X7NjB9-8zBofEu_1NevaIhspuGME,6113
50
50
  pycityagent/simulation/interview.py,sha256=S2uv8MFCB4-u_4Q202VFoPJOIleqpKK9Piju0BDSb_0,1158
51
- pycityagent/simulation/simulation.py,sha256=cKeCjRYAjJMwuy6apjGZ8pEfPLJZiKSiphS7zb5Xojs,16955
51
+ pycityagent/simulation/simulation.py,sha256=4BHavfSHPM81Y5Su-MzzBYM0sR9cHRk4YOv-VkRkzrA,11595
52
52
  pycityagent/simulation/survey/__init__.py,sha256=rxwou8U9KeFSP7rMzXtmtp2fVFZxK4Trzi-psx9LPIs,153
53
53
  pycityagent/simulation/survey/manager.py,sha256=Tqini-4-uBZDsChVVL4ezlgXatnrxAfvceAZZRP8E48,2062
54
54
  pycityagent/simulation/survey/models.py,sha256=sY4OrrG1h9iBnjBsyDage4T3mUFPBHHZQe-ORtwSjKc,1305
@@ -65,6 +65,6 @@ pycityagent/workflow/block.py,sha256=6EmiRMLdOZC1wMlmLMIjfrp9TuiI7Gw4s3nnXVMbrnw
65
65
  pycityagent/workflow/prompt.py,sha256=tY69nDO8fgYfF_dOA-iceR8pAhkYmCqoox8uRPqEuGY,2956
66
66
  pycityagent/workflow/tool.py,sha256=wD9WZ5rma6HCKugtHTwbShNE0f-Rjlwvn_1be3fCAsk,6682
67
67
  pycityagent/workflow/trigger.py,sha256=t5X_i0WtL32bipZSsq_E3UUyYYudYLxQUpvxbgClp2s,5683
68
- pycityagent-2.0.0a10.dist-info/METADATA,sha256=wh_nw6ryC0X_r6ioORJE5NqlxL_rblJ8LUYNaTdL9SE,7622
69
- pycityagent-2.0.0a10.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
70
- pycityagent-2.0.0a10.dist-info/RECORD,,
68
+ pycityagent-2.0.0a12.dist-info/METADATA,sha256=DNsYXJBF2nVXYdiNJrhDty3AwVMkqmPbnm_Ld-Vtt8k,7622
69
+ pycityagent-2.0.0a12.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
70
+ pycityagent-2.0.0a12.dist-info/RECORD,,