pycityagent 2.0.0a14__py3-none-any.whl → 2.0.0a15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pycityagent/__init__.py CHANGED
@@ -4,5 +4,19 @@ Pycityagent: 城市智能体构建框架
4
4
 
5
5
  from .agent import Agent, CitizenAgent, InstitutionAgent
6
6
  from .environment import Simulator
7
+ import logging
8
+
9
+ # 创建一个 pycityagent 记录器
10
+ logger = logging.getLogger("pycityagent")
11
+ logger.setLevel(logging.WARNING) # 默认级别
12
+
13
+ # 如果没有处理器,则添加一个
14
+ if not logger.hasHandlers():
15
+ handler = logging.StreamHandler()
16
+ formatter = logging.Formatter(
17
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
18
+ )
19
+ handler.setFormatter(formatter)
20
+ logger.addHandler(handler)
7
21
 
8
22
  __all__ = ["Agent", "Simulator", "CitizenAgent", "InstitutionAgent"]
pycityagent/agent.py CHANGED
@@ -2,11 +2,9 @@
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  import asyncio
5
- import json
6
5
  from uuid import UUID
7
6
  from copy import deepcopy
8
7
  from datetime import datetime
9
- import time
10
8
  from enum import Enum
11
9
  import logging
12
10
  import random
@@ -28,6 +26,8 @@ from .environment import Simulator
28
26
  from .llm import LLM
29
27
  from .memory import Memory
30
28
 
29
+ logger = logging.getLogger("pycityagent")
30
+
31
31
 
32
32
  class AgentType(Enum):
33
33
  """
@@ -73,7 +73,7 @@ class Agent(ABC):
73
73
  """
74
74
  self._name = name
75
75
  self._type = type
76
- self._uuid = uuid.uuid4()
76
+ self._uuid = str(uuid.uuid4())
77
77
  self._llm_client = llm_client
78
78
  self._economy_client = economy_client
79
79
  self._messager = messager
@@ -123,12 +123,18 @@ class Agent(ABC):
123
123
  """
124
124
  self._memory = memory
125
125
 
126
- def set_exp_id(self, exp_id: str|UUID):
126
+ def set_exp_id(self, exp_id: str):
127
127
  """
128
128
  Set the exp_id of the agent.
129
129
  """
130
130
  self._exp_id = exp_id
131
131
 
132
+ def set_avro_file(self, avro_file: Dict[str, str]):
133
+ """
134
+ Set the avro file of the agent.
135
+ """
136
+ self._avro_file = avro_file
137
+
132
138
  @property
133
139
  def uuid(self):
134
140
  """The Agent's UUID"""
@@ -214,8 +220,10 @@ class Agent(ABC):
214
220
 
215
221
  async def _process_survey(self, survey: dict):
216
222
  survey_response = await self.generate_user_survey_response(survey)
223
+ if self._avro_file is None:
224
+ return
217
225
  response_to_avro = [{
218
- "id": str(self._uuid),
226
+ "id": self._uuid,
219
227
  "day": await self._simulator.get_simulator_day(),
220
228
  "t": await self._simulator.get_simulator_second_from_start_of_day(),
221
229
  "survey_id": survey["id"],
@@ -264,7 +272,7 @@ class Agent(ABC):
264
272
 
265
273
  async def _process_interview(self, payload: dict):
266
274
  auros = [{
267
- "id": str(self._uuid),
275
+ "id": self._uuid,
268
276
  "day": await self._simulator.get_simulator_day(),
269
277
  "t": await self._simulator.get_simulator_second_from_start_of_day(),
270
278
  "type": 2,
@@ -275,7 +283,7 @@ class Agent(ABC):
275
283
  question = payload["content"]
276
284
  response = await self.generate_user_chat_response(question)
277
285
  auros.append({
278
- "id": str(self._uuid),
286
+ "id": self._uuid,
279
287
  "day": await self._simulator.get_simulator_day(),
280
288
  "t": await self._simulator.get_simulator_second_from_start_of_day(),
281
289
  "type": 2,
@@ -283,15 +291,17 @@ class Agent(ABC):
283
291
  "content": response,
284
292
  "created_at": int(datetime.now().timestamp() * 1000),
285
293
  })
294
+ if self._avro_file is None:
295
+ return
286
296
  with open(self._avro_file["dialog"], "a+b") as f:
287
297
  fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
288
298
 
289
299
  async def process_agent_chat_response(self, payload: dict) -> str:
290
- logging.info(f"Agent {self._uuid} received agent chat response: {payload}")
300
+ logger.info(f"Agent {self._uuid} received agent chat response: {payload}")
291
301
 
292
302
  async def _process_agent_chat(self, payload: dict):
293
303
  auros = [{
294
- "id": str(self._uuid),
304
+ "id": self._uuid,
295
305
  "day": payload["day"],
296
306
  "t": payload["t"],
297
307
  "type": 1,
@@ -300,6 +310,8 @@ class Agent(ABC):
300
310
  "created_at": int(datetime.now().timestamp() * 1000),
301
311
  }]
302
312
  asyncio.create_task(self.process_agent_chat_response(payload))
313
+ if self._avro_file is None:
314
+ return
303
315
  with open(self._avro_file["dialog"], "a+b") as f:
304
316
  fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
305
317
 
@@ -307,19 +319,19 @@ class Agent(ABC):
307
319
  async def handle_agent_chat_message(self, payload: dict):
308
320
  """处理收到的消息,识别发送者"""
309
321
  # 从消息中解析发送者 ID 和消息内容
310
- logging.info(f"Agent {self._uuid} received agent chat message: {payload}")
322
+ logger.info(f"Agent {self._uuid} received agent chat message: {payload}")
311
323
  asyncio.create_task(self._process_agent_chat(payload))
312
324
 
313
325
  async def handle_user_chat_message(self, payload: dict):
314
326
  """处理收到的消息,识别发送者"""
315
327
  # 从消息中解析发送者 ID 和消息内容
316
- logging.info(f"Agent {self._uuid} received user chat message: {payload}")
328
+ logger.info(f"Agent {self._uuid} received user chat message: {payload}")
317
329
  asyncio.create_task(self._process_interview(payload))
318
330
 
319
331
  async def handle_user_survey_message(self, payload: dict):
320
332
  """处理收到的消息,识别发送者"""
321
333
  # 从消息中解析发送者 ID 和消息内容
322
- logging.info(f"Agent {self._uuid} received user survey message: {payload}")
334
+ logger.info(f"Agent {self._uuid} received user survey message: {payload}")
323
335
  asyncio.create_task(self._process_survey(payload["data"]))
324
336
 
325
337
  async def handle_gather_message(self, payload: str):
@@ -327,7 +339,7 @@ class Agent(ABC):
327
339
 
328
340
  # MQTT send message
329
341
  async def _send_message(
330
- self, to_agent_uuid: UUID, payload: dict, sub_topic: str
342
+ self, to_agent_uuid: str, payload: dict, sub_topic: str
331
343
  ):
332
344
  """通过 Messager 发送消息"""
333
345
  if self._messager is None:
@@ -336,7 +348,7 @@ class Agent(ABC):
336
348
  await self._messager.send_message(topic, payload)
337
349
 
338
350
  async def send_message_to_agent(
339
- self, to_agent_uuid: UUID, content: str
351
+ self, to_agent_uuid: str, content: str
340
352
  ):
341
353
  """通过 Messager 发送消息"""
342
354
  if self._messager is None:
@@ -350,14 +362,16 @@ class Agent(ABC):
350
362
  }
351
363
  await self._send_message(to_agent_uuid, payload, "agent-chat")
352
364
  auros = [{
353
- "id": str(self._uuid),
365
+ "id": self._uuid,
354
366
  "day": await self._simulator.get_simulator_day(),
355
367
  "t": await self._simulator.get_simulator_second_from_start_of_day(),
356
368
  "type": 1,
357
- "speaker": str(self._uuid),
369
+ "speaker": self._uuid,
358
370
  "content": content,
359
371
  "created_at": int(datetime.now().timestamp() * 1000),
360
372
  }]
373
+ if self._avro_file is None:
374
+ return
361
375
  with open(self._avro_file["dialog"], "a+b") as f:
362
376
  fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
363
377
 
@@ -414,7 +428,7 @@ class CitizenAgent(Agent):
414
428
  person_template (dict, optional): The person template in dict format. Defaults to PersonService.default_dict_person().
415
429
  """
416
430
  if self._simulator is None:
417
- logging.warning("Simulator is not set")
431
+ logger.warning("Simulator is not set")
418
432
  return
419
433
  if not self._has_bound_to_simulator:
420
434
  FROM_MEMORY_KEYS = {
@@ -432,7 +446,7 @@ class CitizenAgent(Agent):
432
446
  # ATTENTION:模拟器分配的id从0开始
433
447
  if person_id >= 0:
434
448
  await simulator.get_person(person_id)
435
- logging.debug(f"Binding to Person `{person_id}` already in Simulator")
449
+ logger.debug(f"Binding to Person `{person_id}` already in Simulator")
436
450
  else:
437
451
  dict_person = deepcopy(self._person_template)
438
452
  for _key in FROM_MEMORY_KEYS:
@@ -447,7 +461,7 @@ class CitizenAgent(Agent):
447
461
  )
448
462
  person_id = resp["person_id"]
449
463
  await memory.update("id", person_id, protect_llm_read_only_fields=False)
450
- logging.debug(
464
+ logger.debug(
451
465
  f"Binding to Person `{person_id}` just added to Simulator"
452
466
  )
453
467
  # 防止模拟器还没有到prepare阶段导致get_person出错
@@ -456,7 +470,7 @@ class CitizenAgent(Agent):
456
470
 
457
471
  async def _bind_to_economy(self):
458
472
  if self._economy_client is None:
459
- logging.warning("Economy client is not set")
473
+ logger.warning("Economy client is not set")
460
474
  return
461
475
  if not self._has_bound_to_economy:
462
476
  if self._has_bound_to_simulator:
@@ -473,7 +487,7 @@ class CitizenAgent(Agent):
473
487
  )
474
488
  self._has_bound_to_economy = True
475
489
  else:
476
- logging.debug(
490
+ logger.debug(
477
491
  f"Binding to Economy before binding to Simulator, skip binding to Economy Simulator"
478
492
  )
479
493
 
@@ -523,7 +537,7 @@ class InstitutionAgent(Agent):
523
537
 
524
538
  async def _bind_to_economy(self):
525
539
  if self._economy_client is None:
526
- logging.debug("Economy client is not set")
540
+ logger.debug("Economy client is not set")
527
541
  return
528
542
  if not self._has_bound_to_economy:
529
543
  # TODO: More general id generation
@@ -599,7 +613,7 @@ class InstitutionAgent(Agent):
599
613
  }
600
614
  )
601
615
  except Exception as e:
602
- logging.error(f"Failed to bind to Economy: {e}")
616
+ logger.error(f"Failed to bind to Economy: {e}")
603
617
  self._has_bound_to_economy = True
604
618
 
605
619
  async def handle_gather_message(self, payload: dict):
@@ -615,11 +629,11 @@ class InstitutionAgent(Agent):
615
629
  "content": content,
616
630
  })
617
631
 
618
- async def gather_messages(self, agent_ids: list[UUID], target: str) -> List[dict]:
632
+ async def gather_messages(self, agent_uuids: list[str], target: str) -> List[dict]:
619
633
  """从多个智能体收集消息
620
634
 
621
635
  Args:
622
- agent_ids: 目标智能体ID列表
636
+ agent_uuids: 目标智能体UUID列表
623
637
  target: 要收集的信息类型
624
638
 
625
639
  Returns:
@@ -627,18 +641,17 @@ class InstitutionAgent(Agent):
627
641
  """
628
642
  # 为每个agent创建Future
629
643
  futures = {}
630
- for agent_id in agent_ids:
631
- response_key = str(agent_id)
632
- futures[response_key] = asyncio.Future()
633
- self._gather_responses[response_key] = futures[response_key]
644
+ for agent_uuid in agent_uuids:
645
+ futures[agent_uuid] = asyncio.Future()
646
+ self._gather_responses[agent_uuid] = futures[agent_uuid]
634
647
 
635
648
  # 发送gather请求
636
649
  payload = {
637
650
  "from": self._uuid,
638
651
  "target": target,
639
652
  }
640
- for agent_id in agent_ids:
641
- await self._send_message(agent_id, payload, "gather")
653
+ for agent_uuid in agent_uuids:
654
+ await self._send_message(agent_uuid, payload, "gather")
642
655
 
643
656
  try:
644
657
  # 等待所有响应
@@ -20,6 +20,7 @@ from shapely.strtree import STRtree
20
20
  from .sim import CityClient, ControlSimEnv
21
21
  from .utils.const import *
22
22
 
23
+ logger = logging.getLogger("pycityagent")
23
24
 
24
25
  class Simulator:
25
26
  """
@@ -72,7 +73,7 @@ class Simulator:
72
73
  else:
73
74
  self._client = CityClient(config["simulator"]["server"], secure=False)
74
75
  else:
75
- logging.warning(
76
+ logger.warning(
76
77
  "No simulator config found, no simulator client will be used"
77
78
  )
78
79
  self.map = SimMap(
@@ -285,7 +286,7 @@ class Simulator:
285
286
  reset_position["aoi_position"] = {"aoi_id": aoi_id}
286
287
  if poi_id is not None:
287
288
  reset_position["aoi_position"]["poi_id"] = poi_id
288
- logging.debug(
289
+ logger.debug(
289
290
  f"Setting person {person_id} pos to AoiPosition {reset_position}"
290
291
  )
291
292
  await self._client.person_service.ResetPersonPosition(
@@ -298,14 +299,14 @@ class Simulator:
298
299
  }
299
300
  if s is not None:
300
301
  reset_position["lane_position"]["s"] = s
301
- logging.debug(
302
+ logger.debug(
302
303
  f"Setting person {person_id} pos to LanePosition {reset_position}"
303
304
  )
304
305
  await self._client.person_service.ResetPersonPosition(
305
306
  {"person_id": person_id, "position": reset_position}
306
307
  )
307
308
  else:
308
- logging.debug(
309
+ logger.debug(
309
310
  f"Neither aoi or lane pos provided for person {person_id} position reset!!"
310
311
  )
311
312
 
@@ -13,6 +13,7 @@ from .profile import ProfileMemory
13
13
  from .self_define import DynamicMemory
14
14
  from .state import StateMemory
15
15
 
16
+ logger = logging.getLogger("pycityagent")
16
17
 
17
18
  class Memory:
18
19
  """
@@ -83,7 +84,7 @@ class Memory:
83
84
  _type.extend(_value)
84
85
  _value = deepcopy(_type)
85
86
  else:
86
- logging.warning(f"type `{_type}` is not supported!")
87
+ logger.warning(f"type `{_type}` is not supported!")
87
88
  pass
88
89
  except TypeError as e:
89
90
  pass
@@ -99,7 +100,7 @@ class Memory:
99
100
  or k in STATE_ATTRIBUTES
100
101
  or k == TIME_STAMP_KEY
101
102
  ):
102
- logging.warning(f"key `{k}` already declared in memory!")
103
+ logger.warning(f"key `{k}` already declared in memory!")
103
104
  continue
104
105
 
105
106
  _dynamic_config[k] = deepcopy(_value)
@@ -112,19 +113,19 @@ class Memory:
112
113
  if profile is not None:
113
114
  for k, v in profile.items():
114
115
  if k not in PROFILE_ATTRIBUTES:
115
- logging.warning(f"key `{k}` is not a correct `profile` field!")
116
+ logger.warning(f"key `{k}` is not a correct `profile` field!")
116
117
  continue
117
118
  _profile_config[k] = v
118
119
  if motion is not None:
119
120
  for k, v in motion.items():
120
121
  if k not in STATE_ATTRIBUTES:
121
- logging.warning(f"key `{k}` is not a correct `motion` field!")
122
+ logger.warning(f"key `{k}` is not a correct `motion` field!")
122
123
  continue
123
124
  _state_config[k] = v
124
125
  if base is not None:
125
126
  for k, v in base.items():
126
127
  if k not in STATE_ATTRIBUTES:
127
- logging.warning(f"key `{k}` is not a correct `base` field!")
128
+ logger.warning(f"key `{k}` is not a correct `base` field!")
128
129
  continue
129
130
  _state_config[k] = v
130
131
  self._state = StateMemory(
@@ -182,7 +183,7 @@ class Memory:
182
183
  """更新记忆值并在必要时更新embedding"""
183
184
  if protect_llm_read_only_fields:
184
185
  if any(key in _attrs for _attrs in [STATE_ATTRIBUTES]):
185
- logging.warning(f"Trying to write protected key `{key}`!")
186
+ logger.warning(f"Trying to write protected key `{key}`!")
186
187
  return
187
188
  for _mem in [self._state, self._profile, self._dynamic]:
188
189
  try:
@@ -208,7 +209,7 @@ class Memory:
208
209
  elif isinstance(original_value, deque):
209
210
  original_value.extend(deque(value))
210
211
  else:
211
- logging.debug(
212
+ logger.debug(
212
213
  f"Type of {type(original_value)} does not support mode `merge`, using `replace` instead!"
213
214
  )
214
215
  await _mem.update(key, value, store_snapshot)
@@ -10,6 +10,8 @@ from typing import Any, Callable, Dict, Optional, Sequence, Union
10
10
 
11
11
  from .const import *
12
12
 
13
+ logger = logging.getLogger("pycityagent")
14
+
13
15
 
14
16
  class MemoryUnit:
15
17
  def __init__(
@@ -57,7 +59,7 @@ class MemoryUnit:
57
59
  orig_v = self._content[k]
58
60
  orig_type, new_type = type(orig_v), type(v)
59
61
  if not orig_type == new_type:
60
- logging.debug(
62
+ logger.debug(
61
63
  f"Type warning: The type of the value for key '{k}' is changing from `{orig_type.__name__}` to `{new_type.__name__}`!"
62
64
  )
63
65
  self._content.update(content)
@@ -82,7 +84,7 @@ class MemoryUnit:
82
84
  await self._lock.acquire()
83
85
  values = self._content[key]
84
86
  if not isinstance(values, Sequence):
85
- logging.warning(
87
+ logger.warning(
86
88
  f"the value stored in key `{key}` is not `sequence`, return value `{values}` instead!"
87
89
  )
88
90
  return values
@@ -93,7 +95,7 @@ class MemoryUnit:
93
95
  )
94
96
  top_k = len(values) if top_k is None else top_k
95
97
  if len(_sorted_values_with_idx) < top_k:
96
- logging.debug(
98
+ logger.debug(
97
99
  f"Length of values {len(_sorted_values_with_idx)} is less than top_k {top_k}, returning all values."
98
100
  )
99
101
  self._lock.release()
@@ -149,7 +151,7 @@ class MemoryBase(ABC):
149
151
  if recent_n is None:
150
152
  return _list_units
151
153
  if len(_memories) < recent_n:
152
- logging.debug(
154
+ logger.debug(
153
155
  f"Length of memory {len(_memories)} is less than recent_n {recent_n}, returning all available memories."
154
156
  )
155
157
  return _list_units[-recent_n:]
@@ -5,6 +5,7 @@ import logging
5
5
  import math
6
6
  from aiomqtt import Client
7
7
 
8
+ logger = logging.getLogger("pycityagent")
8
9
 
9
10
  class Messager:
10
11
  def __init__(
@@ -21,15 +22,15 @@ class Messager:
21
22
  try:
22
23
  await self.client.__aenter__()
23
24
  self.connected = True
24
- logging.info("Connected to MQTT Broker")
25
+ logger.info("Connected to MQTT Broker")
25
26
  except Exception as e:
26
27
  self.connected = False
27
- logging.error(f"Failed to connect to MQTT Broker: {e}")
28
+ logger.error(f"Failed to connect to MQTT Broker: {e}")
28
29
 
29
30
  async def disconnect(self):
30
31
  await self.client.__aexit__(None, None, None)
31
32
  self.connected = False
32
- logging.info("Disconnected from MQTT Broker")
33
+ logger.info("Disconnected from MQTT Broker")
33
34
 
34
35
  def is_connected(self):
35
36
  """检查是否成功连接到 Broker"""
@@ -37,13 +38,13 @@ class Messager:
37
38
 
38
39
  async def subscribe(self, topic, agent):
39
40
  if not self.is_connected():
40
- logging.error(
41
+ logger.error(
41
42
  f"Cannot subscribe to {topic} because not connected to the Broker."
42
43
  )
43
44
  return
44
45
  await self.client.subscribe(topic)
45
46
  self.subscribers[topic] = agent
46
- logging.info(f"Subscribed to {topic} for Agent {agent._uuid}")
47
+ logger.info(f"Subscribed to {topic} for Agent {agent._uuid}")
47
48
 
48
49
  async def receive_messages(self):
49
50
  """监听并将消息存入队列"""
@@ -61,11 +62,11 @@ class Messager:
61
62
  """通过 Messager 发送消息"""
62
63
  message = json.dumps(payload, default=str)
63
64
  await self.client.publish(topic, message)
64
- logging.info(f"Message sent to {topic}: {message}")
65
+ logger.info(f"Message sent to {topic}: {message}")
65
66
 
66
67
  async def start_listening(self):
67
68
  """启动消息监听任务"""
68
69
  if self.is_connected():
69
70
  asyncio.create_task(self.receive_messages())
70
71
  else:
71
- logging.error("Cannot start listening because not connected to the Broker.")
72
+ logger.error("Cannot start listening because not connected to the Broker.")
@@ -2,27 +2,41 @@ import asyncio
2
2
  from datetime import datetime
3
3
  import json
4
4
  import logging
5
+ from pathlib import Path
5
6
  import uuid
6
7
  import fastavro
7
8
  import ray
8
9
  from uuid import UUID
9
- from pycityagent.agent import Agent, CitizenAgent
10
+ from pycityagent.agent import Agent, CitizenAgent, InstitutionAgent
10
11
  from pycityagent.economy.econ_client import EconomyClient
11
12
  from pycityagent.environment.simulator import Simulator
12
13
  from pycityagent.llm.llm import LLM
13
14
  from pycityagent.llm.llmconfig import LLMConfig
14
15
  from pycityagent.message import Messager
15
- from pycityagent.utils import STATUS_SCHEMA
16
+ from pycityagent.utils import STATUS_SCHEMA, PROFILE_SCHEMA, DIALOG_SCHEMA, SURVEY_SCHEMA, INSTITUTION_STATUS_SCHEMA
16
17
  from typing import Any
17
18
 
19
+ logger = logging.getLogger("pycityagent")
20
+
18
21
  @ray.remote
19
22
  class AgentGroup:
20
- def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, avro_file: dict):
23
+ def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, enable_avro: bool, avro_path: Path, logging_level: int = logging.WARNING):
24
+ logger.setLevel(logging_level)
25
+ self._uuid = str(uuid.uuid4())
21
26
  self.agents = agents
22
27
  self.config = config
23
28
  self.exp_id = exp_id
24
- self.avro_file = avro_file
25
- self._uuid = uuid.uuid4()
29
+ self.enable_avro = enable_avro
30
+ self.avro_path = avro_path / f"{self._uuid}"
31
+ if enable_avro:
32
+ self.avro_path.mkdir(parents=True, exist_ok=True)
33
+ self.avro_file = {
34
+ "profile": self.avro_path / f"profile.avro",
35
+ "dialog": self.avro_path / f"dialog.avro",
36
+ "status": self.avro_path / f"status.avro",
37
+ "survey": self.avro_path / f"survey.avro",
38
+ }
39
+
26
40
  self.messager = Messager(
27
41
  hostname=config["simulator_request"]["mqtt"]["server"],
28
42
  port=config["simulator_request"]["mqtt"]["port"],
@@ -33,16 +47,16 @@ class AgentGroup:
33
47
  self.id2agent = {}
34
48
  # Step:1 prepare LLM client
35
49
  llmConfig = LLMConfig(config["llm_request"])
36
- logging.info("-----Creating LLM client in remote...")
50
+ logger.info(f"-----Creating LLM client in AgentGroup {self._uuid} ...")
37
51
  self.llm = LLM(llmConfig)
38
52
 
39
53
  # Step:2 prepare Simulator
40
- logging.info("-----Creating Simulator in remote...")
54
+ logger.info(f"-----Creating Simulator in AgentGroup {self._uuid} ...")
41
55
  self.simulator = Simulator(config["simulator_request"])
42
56
 
43
57
  # Step:3 prepare Economy client
44
58
  if "economy" in config["simulator_request"]:
45
- logging.info("-----Creating Economy client in remote...")
59
+ logger.info(f"-----Creating Economy client in AgentGroup {self._uuid} ...")
46
60
  self.economy_client = EconomyClient(
47
61
  config["simulator_request"]["economy"]["server"]
48
62
  )
@@ -56,11 +70,16 @@ class AgentGroup:
56
70
  if self.economy_client is not None:
57
71
  agent.set_economy_client(self.economy_client)
58
72
  agent.set_messager(self.messager)
73
+ if self.enable_avro:
74
+ agent.set_avro_file(self.avro_file)
59
75
 
60
76
  async def init_agents(self):
77
+ logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
78
+ logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
61
79
  for agent in self.agents:
62
80
  await agent.bind_to_simulator()
63
81
  self.id2agent = {agent._uuid: agent for agent in self.agents}
82
+ logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
64
83
  await self.messager.connect()
65
84
  if self.messager.is_connected():
66
85
  await self.messager.start_listening()
@@ -74,27 +93,65 @@ class AgentGroup:
74
93
  await self.messager.subscribe(topic, agent)
75
94
  topic = f"exps/{self.exp_id}/agents/{agent._uuid}/gather"
76
95
  await self.messager.subscribe(topic, agent)
77
- self.initialized = True
78
96
  self.message_dispatch_task = asyncio.create_task(self.message_dispatch())
79
-
97
+ if self.enable_avro:
98
+ logger.debug(f"-----Creating Avro files in AgentGroup {self._uuid} ...")
99
+ # profile
100
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
101
+ filename = self.avro_file["profile"]
102
+ with open(filename, "wb") as f:
103
+ profiles = []
104
+ for agent in self.agents:
105
+ profile = await agent.memory._profile.export()
106
+ profile = profile[0]
107
+ profile['id'] = agent._uuid
108
+ profiles.append(profile)
109
+ fastavro.writer(f, PROFILE_SCHEMA, profiles)
110
+
111
+ # dialog
112
+ filename = self.avro_file["dialog"]
113
+ with open(filename, "wb") as f:
114
+ dialogs = []
115
+ fastavro.writer(f, DIALOG_SCHEMA, dialogs)
116
+
117
+ # status
118
+ filename = self.avro_file["status"]
119
+ with open(filename, "wb") as f:
120
+ statuses = []
121
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
122
+ fastavro.writer(f, STATUS_SCHEMA, statuses)
123
+ else:
124
+ fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, statuses)
125
+
126
+ # survey
127
+ filename = self.avro_file["survey"]
128
+ with open(filename, "wb") as f:
129
+ surveys = []
130
+ fastavro.writer(f, SURVEY_SCHEMA, surveys)
131
+ self.initialized = True
132
+ logger.debug(f"-----AgentGroup {self._uuid} initialized")
133
+
80
134
  async def gather(self, content: str):
135
+ logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
81
136
  results = {}
82
137
  for agent in self.agents:
83
138
  results[agent._uuid] = await agent.memory.get(content)
84
139
  return results
85
140
 
86
141
  async def update(self, target_agent_uuid: str, target_key: str, content: Any):
142
+ logger.debug(f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}")
87
143
  agent = self.id2agent[target_agent_uuid]
88
144
  await agent.memory.update(target_key, content)
89
145
 
90
146
  async def message_dispatch(self):
147
+ logger.debug(f"-----Starting message dispatch for group {self._uuid}")
91
148
  while True:
92
149
  if not self.messager.is_connected():
93
- logging.warning("Messager is not connected. Skipping message processing.")
150
+ logger.warning("Messager is not connected. Skipping message processing.")
94
151
 
95
152
  # Step 1: 获取消息
96
153
  messages = await self.messager.fetch_messages()
97
- logging.info(f"Group {self._uuid} received {len(messages)} messages")
154
+ logger.info(f"Group {self._uuid} received {len(messages)} messages")
98
155
 
99
156
  # Step 2: 分发消息到对应的 Agent
100
157
  for message in messages:
@@ -109,8 +166,8 @@ class AgentGroup:
109
166
  # 提取 agent_id(主题格式为 "exps/{exp_id}/agents/{agent_uuid}/{topic_type}")
110
167
  _, _, _, agent_uuid, topic_type = topic.strip("/").split("/")
111
168
 
112
- if uuid.UUID(agent_uuid) in self.id2agent:
113
- agent = self.id2agent[uuid.UUID(agent_uuid)]
169
+ if agent_uuid in self.id2agent:
170
+ agent = self.id2agent[agent_uuid]
114
171
  # topic_type: agent-chat, user-chat, user-survey, gather
115
172
  if topic_type == "agent-chat":
116
173
  await agent.handle_agent_chat_message(payload)
@@ -123,46 +180,73 @@ class AgentGroup:
123
180
 
124
181
  await asyncio.sleep(0.5)
125
182
 
183
+ async def save_status(self):
184
+ if self.enable_avro:
185
+ logger.debug(f"-----Saving status for group {self._uuid}")
186
+ avros = []
187
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
188
+ for agent in self.agents:
189
+ position = await agent.memory.get("position")
190
+ lng = position["longlat_position"]["longitude"]
191
+ lat = position["longlat_position"]["latitude"]
192
+ if "aoi_position" in position:
193
+ parent_id = position["aoi_position"]["aoi_id"]
194
+ elif "lane_position" in position:
195
+ parent_id = position["lane_position"]["lane_id"]
196
+ else:
197
+ # BUG: 需要处理
198
+ parent_id = -1
199
+ needs = await agent.memory.get("needs")
200
+ action = await agent.memory.get("current_step")
201
+ action = action["intention"]
202
+ avro = {
203
+ "id": agent._uuid,
204
+ "day": await self.simulator.get_simulator_day(),
205
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
206
+ "lng": lng,
207
+ "lat": lat,
208
+ "parent_id": parent_id,
209
+ "action": action,
210
+ "hungry": needs["hungry"],
211
+ "tired": needs["tired"],
212
+ "safe": needs["safe"],
213
+ "social": needs["social"],
214
+ "created_at": int(datetime.now().timestamp() * 1000),
215
+ }
216
+ avros.append(avro)
217
+ with open(self.avro_file["status"], "a+b") as f:
218
+ fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
219
+ else:
220
+ for agent in self.agents:
221
+ avro = {
222
+ "id": agent._uuid,
223
+ "day": await self.simulator.get_simulator_day(),
224
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
225
+ "type": await agent.memory.get("type"),
226
+ "nominal_gdp": await agent.memory.get("nominal_gdp"),
227
+ "real_gdp": await agent.memory.get("real_gdp"),
228
+ "unemployment": await agent.memory.get("unemployment"),
229
+ "wages": await agent.memory.get("wages"),
230
+ "prices": await agent.memory.get("prices"),
231
+ "inventory": await agent.memory.get("inventory"),
232
+ "price": await agent.memory.get("price"),
233
+ "interest_rate": await agent.memory.get("interest_rate"),
234
+ "bracket_cutoffs": await agent.memory.get("bracket_cutoffs"),
235
+ "bracket_rates": await agent.memory.get("bracket_rates"),
236
+ "employees": await agent.memory.get("employees"),
237
+ "customers": await agent.memory.get("customers"),
238
+ }
239
+ avros.append(avro)
240
+ with open(self.avro_file["status"], "a+b") as f:
241
+ fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, avros, codec="snappy")
242
+
126
243
  async def step(self):
127
244
  if not self.initialized:
128
245
  await self.init_agents()
129
246
 
130
247
  tasks = [agent.run() for agent in self.agents]
131
248
  await asyncio.gather(*tasks)
132
- avros = []
133
- for agent in self.agents:
134
- if not issubclass(type(agent), CitizenAgent):
135
- continue
136
- position = await agent.memory.get("position")
137
- lng = position["longlat_position"]["longitude"]
138
- lat = position["longlat_position"]["latitude"]
139
- if "aoi_position" in position:
140
- parent_id = position["aoi_position"]["aoi_id"]
141
- elif "lane_position" in position:
142
- parent_id = position["lane_position"]["lane_id"]
143
- else:
144
- # BUG: 需要处理
145
- parent_id = -1
146
- needs = await agent.memory.get("needs")
147
- action = await agent.memory.get("current_step")
148
- action = action["intention"]
149
- avro = {
150
- "id": str(agent._uuid), # uuid as string
151
- "day": await self.simulator.get_simulator_day(),
152
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
153
- "lng": lng,
154
- "lat": lat,
155
- "parent_id": parent_id,
156
- "action": action,
157
- "hungry": needs["hungry"],
158
- "tired": needs["tired"],
159
- "safe": needs["safe"],
160
- "social": needs["social"],
161
- "created_at": int(datetime.now().timestamp() * 1000),
162
- }
163
- avros.append(avro)
164
- with open(self.avro_file["status"], "a+b") as f:
165
- fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
249
+ await self.save_status()
166
250
 
167
251
  async def run(self, day: int = 1):
168
252
  """运行模拟器
@@ -185,5 +269,5 @@ class AgentGroup:
185
269
  await self.step()
186
270
 
187
271
  except Exception as e:
188
- logging.error(f"模拟器运行错误: {str(e)}")
272
+ logger.error(f"模拟器运行错误: {str(e)}")
189
273
  raise
@@ -4,21 +4,22 @@ import logging
4
4
  import os
5
5
  from pathlib import Path
6
6
  import uuid
7
- from datetime import datetime
7
+ from datetime import datetime, timezone
8
8
  import random
9
9
  from typing import Dict, List, Optional, Callable, Union,Any
10
- import fastavro
11
10
  from mosstool.map._map_util.const import AOI_START_ID
12
11
  import pycityproto.city.economy.v2.economy_pb2 as economyv2
12
+ from pycityagent.environment.simulator import Simulator
13
13
  from pycityagent.memory.memory import Memory
14
14
  from pycityagent.message.messager import Messager
15
15
  from pycityagent.survey import Survey
16
- from pycityagent.utils.avro_schema import PROFILE_SCHEMA, DIALOG_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA
16
+ import yaml
17
+ from concurrent.futures import ThreadPoolExecutor
17
18
 
18
19
  from ..agent import Agent, InstitutionAgent
19
20
  from .agentgroup import AgentGroup
20
21
 
21
- logger = logging.getLogger(__name__)
22
+ logger = logging.getLogger("pycityagent")
22
23
 
23
24
 
24
25
  class AgentSimulation:
@@ -29,19 +30,24 @@ class AgentSimulation:
29
30
  agent_class: Union[type[Agent], list[type[Agent]]],
30
31
  config: dict,
31
32
  agent_prefix: str = "agent_",
33
+ exp_name: str = "default_experiment",
34
+ logging_level: int = logging.WARNING
32
35
  ):
33
36
  """
34
37
  Args:
35
38
  agent_class: 智能体类
36
39
  config: 配置
37
40
  agent_prefix: 智能体名称前缀
41
+ exp_name: 实验名称
38
42
  """
39
- self.exp_id = uuid.uuid4()
43
+ self.exp_id = str(uuid.uuid4())
40
44
  if isinstance(agent_class, list):
41
45
  self.agent_class = agent_class
42
46
  else:
43
47
  self.agent_class = [agent_class]
48
+ self.logging_level = logging_level
44
49
  self.config = config
50
+ self._simulator = Simulator(config["simulator_request"])
45
51
  self.agent_prefix = agent_prefix
46
52
  self._agents: Dict[uuid.UUID, Agent] = {}
47
53
  self._groups: Dict[str, AgentGroup] = {}
@@ -61,13 +67,10 @@ class AgentSimulation:
61
67
  asyncio.create_task(self._messager.connect())
62
68
 
63
69
  self._enable_avro = config["storage"]["avro"]["enabled"]
64
- self._avro_path = Path(config["storage"]["avro"]["path"])
65
- self._avro_file = {
66
- "profile": self._avro_path / f"{self.exp_id}_profile.avro",
67
- "dialog": self._avro_path / f"{self.exp_id}_dialog.avro",
68
- "status": self._avro_path / f"{self.exp_id}_status.avro",
69
- "survey": self._avro_path / f"{self.exp_id}_survey.avro",
70
- }
70
+ if not self._enable_avro:
71
+ logger.warning("AVRO is not enabled, NO AVRO LOCAL STORAGE")
72
+ self._avro_path = Path(config["storage"]["avro"]["path"]) / f"{self.exp_id}"
73
+ self._avro_path.mkdir(parents=True, exist_ok=True)
71
74
 
72
75
  self._enable_pgsql = config["storage"]["pgsql"]["enabled"]
73
76
  self._pgsql_host = config["storage"]["pgsql"]["host"]
@@ -76,6 +79,24 @@ class AgentSimulation:
76
79
  self._pgsql_user = config["storage"]["pgsql"]["user"]
77
80
  self._pgsql_password = config["storage"]["pgsql"]["password"]
78
81
 
82
+ # 添加实验信息相关的属性
83
+ self._exp_info = {
84
+ "id": self.exp_id,
85
+ "name": exp_name,
86
+ "num_day": 0, # 将在 run 方法中更新
87
+ "status": 0,
88
+ "cur_day": 0,
89
+ "cur_t": 0.0,
90
+ "config": json.dumps(config),
91
+ "error": "",
92
+ "created_at": datetime.now(timezone.utc).isoformat()
93
+ }
94
+
95
+ # 创建异步任务保存实验信息
96
+ self._exp_info_file = self._avro_path / "experiment_info.yaml"
97
+ with open(self._exp_info_file, 'w') as f:
98
+ yaml.dump(self._exp_info, f)
99
+
79
100
  @property
80
101
  def agents(self):
81
102
  return self._agents
@@ -91,6 +112,11 @@ class AgentSimulation:
91
112
  @property
92
113
  def agent_uuid2group(self):
93
114
  return self._agent_uuid2group
115
+
116
+ def create_remote_group(self, group_name: str, agents: list[Agent], config: dict, exp_id: str, enable_avro: bool, avro_path: Path, logging_level: int = logging.WARNING):
117
+ """创建远程组"""
118
+ group = AgentGroup.remote(agents, config, exp_id, enable_avro, avro_path, logging_level)
119
+ return group_name, group, agents
94
120
 
95
121
  async def init_agents(
96
122
  self,
@@ -112,7 +138,7 @@ class AgentSimulation:
112
138
  raise ValueError("agent_class和agent_count的长度不一致")
113
139
 
114
140
  if memory_config_func is None:
115
- logging.warning(
141
+ logger.warning(
116
142
  "memory_config_func is None, using default memory config function"
117
143
  )
118
144
  memory_config_func = []
@@ -125,17 +151,21 @@ class AgentSimulation:
125
151
  memory_config_func = [memory_config_func]
126
152
 
127
153
  if len(memory_config_func) != len(agent_count):
128
- logging.warning(
154
+ logger.warning(
129
155
  "memory_config_func和agent_count的长度不一致,使用默认的memory_config"
130
156
  )
131
157
  memory_config_func = []
132
158
  for agent_class in self.agent_class:
133
- if agent_class == InstitutionAgent:
159
+ if issubclass(agent_class, InstitutionAgent):
134
160
  memory_config_func.append(self.default_memory_config_institution)
135
161
  else:
136
162
  memory_config_func.append(self.default_memory_config_citizen)
137
163
 
164
+ # 使用线程池并行创建 AgentGroup
165
+ group_creation_params = []
138
166
  class_init_index = 0
167
+
168
+ # 首先收集所有需要创建的组的参数
139
169
  for i in range(len(self.agent_class)):
140
170
  agent_class = self.agent_class[i]
141
171
  agent_count_i = agent_count[i]
@@ -153,7 +183,6 @@ class AgentSimulation:
153
183
  agent = agent_class(
154
184
  name=agent_name,
155
185
  memory=memory,
156
- avro_file=self._avro_file,
157
186
  )
158
187
 
159
188
  self._agents[agent._uuid] = agent
@@ -161,66 +190,51 @@ class AgentSimulation:
161
190
 
162
191
  # 计算需要的组数,向上取整以处理不足一组的情况
163
192
  num_group = (agent_count_i + group_size - 1) // group_size
164
-
193
+
165
194
  for k in range(num_group):
166
- # 计算当前组的起始和结束索引
167
195
  start_idx = class_init_index + k * group_size
168
196
  end_idx = min(
169
- class_init_index + start_idx + group_size,
170
- class_init_index + agent_count_i,
197
+ class_init_index + (k + 1) * group_size, # 修正了索引计算
198
+ class_init_index + agent_count_i
171
199
  )
172
-
173
- # 获取当前组的agents
200
+
174
201
  agents = list(self._agents.values())[start_idx:end_idx]
175
202
  group_name = f"AgentType_{i}_Group_{k}"
176
- group = AgentGroup.remote(agents, self.config, self.exp_id, self._avro_file)
177
- self._groups[group_name] = group
178
- for agent in agents:
179
- self._agent_uuid2group[agent._uuid] = group
180
-
181
- class_init_index += agent_count_i # 更新类初始索引
182
-
203
+
204
+ # 收集创建参数
205
+ group_creation_params.append((
206
+ group_name,
207
+ agents
208
+ ))
209
+
210
+ class_init_index += agent_count_i
211
+
212
+ # 收集所有创建组的参数
213
+ creation_tasks = []
214
+ for group_name, agents in group_creation_params:
215
+ # 直接创建异步任务
216
+ group = AgentGroup.remote(agents, self.config, self.exp_id,
217
+ self._enable_avro, self._avro_path,
218
+ self.logging_level)
219
+ creation_tasks.append((group_name, group, agents))
220
+
221
+ # 更新数据结构
222
+ for group_name, group, agents in creation_tasks:
223
+ self._groups[group_name] = group
224
+ for agent in agents:
225
+ self._agent_uuid2group[agent._uuid] = group
226
+
227
+ # 并行初始化所有组的agents
183
228
  init_tasks = []
184
229
  for group in self._groups.values():
185
230
  init_tasks.append(group.init_agents.remote())
186
231
  await asyncio.gather(*init_tasks)
232
+
233
+ # 设置用户主题
187
234
  for uuid, agent in self._agents.items():
188
235
  self._user_chat_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-chat"
189
236
  self._user_survey_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-survey"
190
237
 
191
- # save profile
192
- if self._enable_avro:
193
- self._avro_path.mkdir(parents=True, exist_ok=True)
194
- # profile
195
- filename = self._avro_file["profile"]
196
- with open(filename, "wb") as f:
197
- profiles = []
198
- for agent in self._agents.values():
199
- profile = await agent.memory._profile.export()
200
- profile = profile[0]
201
- profile['id'] = str(agent._uuid)
202
- profiles.append(profile)
203
- fastavro.writer(f, PROFILE_SCHEMA, profiles)
204
-
205
- # dialog
206
- filename = self._avro_file["dialog"]
207
- with open(filename, "wb") as f:
208
- dialogs = []
209
- fastavro.writer(f, DIALOG_SCHEMA, dialogs)
210
-
211
- # status
212
- filename = self._avro_file["status"]
213
- with open(filename, "wb") as f:
214
- statuses = []
215
- fastavro.writer(f, STATUS_SCHEMA, statuses)
216
-
217
- # survey
218
- filename = self._avro_file["survey"]
219
- with open(filename, "wb") as f:
220
- surveys = []
221
- fastavro.writer(f, SURVEY_SCHEMA, surveys)
222
-
223
-
224
238
  async def gather(self, content: str):
225
239
  """收集智能体的特定信息"""
226
240
  gather_tasks = []
@@ -228,10 +242,10 @@ class AgentSimulation:
228
242
  gather_tasks.append(group.gather.remote(content))
229
243
  return await asyncio.gather(*gather_tasks)
230
244
 
231
- async def update(self, target_agent_id: str, target_key: str, content: Any):
245
+ async def update(self, target_agent_uuid: uuid.UUID, target_key: str, content: Any):
232
246
  """更新指定智能体的记忆"""
233
- group = self._agent_uuid2group[target_agent_id]
234
- await group.update.remote(target_agent_id, target_key, content)
247
+ group = self._agent_uuid2group[target_agent_uuid]
248
+ await group.update.remote(target_agent_uuid, target_key, content)
235
249
 
236
250
  def default_memory_config_institution(self):
237
251
  """默认的Memory配置函数"""
@@ -388,23 +402,88 @@ class AgentSimulation:
388
402
  logger.error(f"运行错误: {str(e)}")
389
403
  raise
390
404
 
391
- async def run(
392
- self,
393
- day: int = 1,
394
- ):
395
- """运行模拟器
405
+ async def _save_exp_info(self) -> None:
406
+ """异步保存实验信息到YAML文件"""
407
+ try:
408
+ with open(self._exp_info_file, 'w') as f:
409
+ yaml.dump(self._exp_info, f)
410
+ except Exception as e:
411
+ logger.error(f"保存实验信息失败: {str(e)}")
412
+
413
+ async def _update_exp_status(self, status: int, error: str = "") -> None:
414
+ """更新实验状态并保存"""
415
+ self._exp_info["status"] = status
416
+ self._exp_info["error"] = error
417
+ await self._save_exp_info()
396
418
 
419
+ async def _monitor_exp_status(self, stop_event: asyncio.Event):
420
+ """监控实验状态并更新
421
+
397
422
  Args:
398
- day: 运行天数,默认为1天
423
+ stop_event: 用于通知监控任务停止的事件
399
424
  """
400
425
  try:
401
- # 获取开始时间
402
- tasks = []
403
- for group in self._groups.values():
404
- tasks.append(group.run.remote(day))
426
+ while not stop_event.is_set():
427
+ # 更新实验状态
428
+ # 假设所有group的cur_day和cur_t是同步的,取第一个即可
429
+ self._exp_info["cur_day"] = await self._simulator.get_simulator_day()
430
+ self._exp_info["cur_t"] = await self._simulator.get_simulator_second_from_start_of_day()
431
+ await self._save_exp_info()
432
+
433
+ await asyncio.sleep(1) # 避免过于频繁的更新
434
+ except asyncio.CancelError:
435
+ # 正常取消,不需要特殊处理
436
+ pass
437
+ except Exception as e:
438
+ logger.error(f"监控实验状态时发生错误: {str(e)}")
439
+ raise
405
440
 
406
- await asyncio.gather(*tasks)
441
+ async def run(
442
+ self,
443
+ day: int = 1,
444
+ ):
445
+ """运行模拟器"""
446
+ try:
447
+ self._exp_info["num_day"] += day
448
+ await self._update_exp_status(1) # 更新状态为运行中
449
+
450
+ # 创建停止事件
451
+ stop_event = asyncio.Event()
452
+ # 创建监控任务
453
+ monitor_task = asyncio.create_task(self._monitor_exp_status(stop_event))
454
+
455
+ try:
456
+ tasks = []
457
+ for group in self._groups.values():
458
+ tasks.append(group.run.remote())
459
+
460
+ # 等待所有group运行完成
461
+ await asyncio.gather(*tasks)
462
+
463
+ finally:
464
+ # 设置停止事件
465
+ stop_event.set()
466
+ # 等待监控任务结束
467
+ await monitor_task
468
+
469
+ # 运行成功后更新状态
470
+ await self._update_exp_status(2)
407
471
 
408
472
  except Exception as e:
409
- logger.error(f"模拟器运行错误: {str(e)}")
410
- raise
473
+ error_msg = f"模拟器运行错误: {str(e)}"
474
+ logger.error(error_msg)
475
+ await self._update_exp_status(3, error_msg)
476
+ raise e
477
+
478
+ async def __aenter__(self):
479
+ """异步上下文管理器入口"""
480
+ return self
481
+
482
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
483
+ """异步上下文管理器出口"""
484
+ if exc_type is not None:
485
+ # 如果发生异常,更新状态为错误
486
+ await self._update_exp_status(3, str(exc_val))
487
+ elif self._exp_info["status"] != 3:
488
+ # 如果没有发生异常且状态不是错误,则更新为完成
489
+ await self._update_exp_status(2)
@@ -1,7 +1,7 @@
1
- from .avro_schema import PROFILE_SCHEMA, DIALOG_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA
1
+ from .avro_schema import PROFILE_SCHEMA, DIALOG_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA, INSTITUTION_STATUS_SCHEMA
2
2
  from .survey_util import process_survey_for_llm
3
3
 
4
4
  __all__ = [
5
- "PROFILE_SCHEMA", "DIALOG_SCHEMA", "STATUS_SCHEMA", "SURVEY_SCHEMA",
5
+ "PROFILE_SCHEMA", "DIALOG_SCHEMA", "STATUS_SCHEMA", "SURVEY_SCHEMA", "INSTITUTION_STATUS_SCHEMA",
6
6
  "process_survey_for_llm"
7
7
  ]
@@ -66,6 +66,31 @@ STATUS_SCHEMA = {
66
66
  ],
67
67
  }
68
68
 
69
+ INSTITUTION_STATUS_SCHEMA = {
70
+ "doc": "Institution状态",
71
+ "name": "InstitutionStatus",
72
+ "namespace": "com.socialcity",
73
+ "type": "record",
74
+ "fields": [
75
+ {"name": "id", "type": "string"}, # uuid as string
76
+ {"name": "day", "type": "int"},
77
+ {"name": "t", "type": "float"},
78
+ {"name": "type", "type": "int"},
79
+ {"name": "nominal_gdp", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
80
+ {"name": "real_gdp", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
81
+ {"name": "unemployment", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
82
+ {"name": "wages", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
83
+ {"name": "prices", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
84
+ {"name": "inventory", "type": ["int", "null"]},
85
+ {"name": "price", "type": ["float", "null"]},
86
+ {"name": "interest_rate", "type": ["float", "null"]},
87
+ {"name": "bracket_cutoffs", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
88
+ {"name": "bracket_rates", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
89
+ {"name": "employees", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
90
+ {"name": "customers", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
91
+ ],
92
+ }
93
+
69
94
  SURVEY_SCHEMA = {
70
95
  "doc": "Agent问卷",
71
96
  "name": "AgentSurvey",
@@ -82,4 +107,4 @@ SURVEY_SCHEMA = {
82
107
  "type": {"type": "long", "logicalType": "timestamp-millis"},
83
108
  },
84
109
  ],
85
- }
110
+ }
@@ -1,6 +1,3 @@
1
- import asyncio
2
- import logging
3
- from copy import deepcopy
4
1
  from typing import Any, Callable, Dict, List, Optional, Union
5
2
 
6
3
  from ..agent import Agent
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pycityagent
3
- Version: 2.0.0a14
3
+ Version: 2.0.0a15
4
4
  Summary: LLM-based城市环境agent构建库
5
5
  License: MIT
6
6
  Author: Yuwei Yan
@@ -1,5 +1,5 @@
1
- pycityagent/__init__.py,sha256=n56bWkAUEcvjDsb7LcJpaGjlrriSKPnR0yBhwRfEYBA,212
2
- pycityagent/agent.py,sha256=0ZqHXImR05ETA0vt9t5GDS4AgYzza3-Zwuua2OckwTw,22788
1
+ pycityagent/__init__.py,sha256=EDxt3Su3lH1IMh9suNw7GeGL7UrXeWiZTw5KWNznDzc,637
2
+ pycityagent/agent.py,sha256=fcuKX6FtMzjNP8lVep9pG-9KHzHQwJ8IymJbmLKMfu0,23109
3
3
  pycityagent/economy/__init__.py,sha256=aonY4WHnx-6EGJ4WKrx4S-2jAkYNLtqUA04jp6q8B7w,75
4
4
  pycityagent/economy/econ_client.py,sha256=wcuNtcpkSijJwNkt2mXw3SshYy4SBy6qbvJ0VQ7Aovo,10854
5
5
  pycityagent/environment/__init__.py,sha256=awHxlOud-btWbk0FCS4RmGJ13W84oVCkbGfcrhKqihA,240
@@ -21,7 +21,7 @@ pycityagent/environment/sim/person_service.py,sha256=nIvOsoBoqOTDYtsiThg07-4ZBgk
21
21
  pycityagent/environment/sim/road_service.py,sha256=phKTwTyhc_6Ht2mddEXpdENfl-lRXIVY0CHAlw1yHjI,1264
22
22
  pycityagent/environment/sim/sim_env.py,sha256=HI1LcS_FotDKQ6vBnx0e49prXSABOfA20aU9KM-ZkCY,4625
23
23
  pycityagent/environment/sim/social_service.py,sha256=6Iqvq6dz8H2jhLLdtaITc6Js9QnQw-Ylsd5AZgUj3-E,1993
24
- pycityagent/environment/simulator.py,sha256=fQv6D_1OnhUrKTsnnah3wnm9ec8LE1phxRhK1K93Zyg,12466
24
+ pycityagent/environment/simulator.py,sha256=K7IyhiGC9BxanW28bpML4M0YREdMp1h7yMoWBlbf3RY,12504
25
25
  pycityagent/environment/utils/__init__.py,sha256=1m4Q1EfGvNpUsa1bgQzzCyWhfkpElnskNImjjFD3Znc,237
26
26
  pycityagent/environment/utils/base64.py,sha256=hoREzQo3FXMN79pqQLO2jgsDEvudciomyKii7MWljAM,374
27
27
  pycityagent/environment/utils/const.py,sha256=3RMNy7_bE7-23K90j9DFW_tWEzu8s7hSTgKbV-3BFl4,5327
@@ -37,22 +37,22 @@ pycityagent/llm/llmconfig.py,sha256=4Ylf4OFSBEFy8jrOneeX0HvPhWEaF5jGvy1HkXK08Ro,
37
37
  pycityagent/llm/utils.py,sha256=hoNPhvomb1u6lhFX0GctFipw74hVKb7bvUBDqwBzBYw,160
38
38
  pycityagent/memory/__init__.py,sha256=Hs2NhYpIG-lvpwPWwj4DydB1sxtjz7cuA4iDAzCXnjI,243
39
39
  pycityagent/memory/const.py,sha256=6zpJPJXWoH9-yf4RARYYff586agCoud9BRn7sPERB1g,932
40
- pycityagent/memory/memory.py,sha256=sDbaqr1Koqf_9joMtG9PmmVxJZ6Rq7nAZO6EO0OdVgo,18148
41
- pycityagent/memory/memory_base.py,sha256=bd2q0qNu5hCRd2u4cPxE3bBA2OaoAD1oR4-vbRdbd_s,5594
40
+ pycityagent/memory/memory.py,sha256=FjKVL_MgNBnSc0sox2tuxLqXg9_MQQr9vYdRDHMdDL4,18183
41
+ pycityagent/memory/memory_base.py,sha256=euKZRCs4dbcKxjlZzpLCTnH066DAtRjj5g1JFKD40qQ,5633
42
42
  pycityagent/memory/profile.py,sha256=s4LnxSPGSjIGZXHXkkd8mMa6uYYZrytgyQdWjcaqGf4,5182
43
43
  pycityagent/memory/self_define.py,sha256=poPiexNhOLq_iTgK8s4mK_xoL_DAAcB8kMvInj7iE5E,5179
44
44
  pycityagent/memory/state.py,sha256=5W0c1yJ-aaPpE74B2LEcw3Ygpm77tyooHv8NylyrozE,5113
45
45
  pycityagent/memory/utils.py,sha256=wLNlNlZ-AY9VB8kbUIy0UQSYh26FOQABbhmKQkit5o8,850
46
46
  pycityagent/message/__init__.py,sha256=TCjazxqb5DVwbTu1fF0sNvaH_EPXVuj2XQ0p6W-QCLU,55
47
- pycityagent/message/messager.py,sha256=Iv4pK83JvHAQSZyGNACryPBey2wRoiok3Hb1eIwHbww,2506
47
+ pycityagent/message/messager.py,sha256=W_OVlNGcreHSBf6v-DrEnfNCXExB78ySr0w26MSncfU,2541
48
48
  pycityagent/simulation/__init__.py,sha256=jYaqaNpzM5M_e_ykISS_M-mIyYdzJXJWhgpfBpA6l5k,111
49
- pycityagent/simulation/agentgroup.py,sha256=ZJIeQbFmFwK3iUiwUpa4TG-k1Eb4vOTgy5ybv-CE1Rc,7797
50
- pycityagent/simulation/simulation.py,sha256=rDQRgKmJkjfeQtvXyW9LeqBVVX0iFh6snBEsdg1bCag,15574
49
+ pycityagent/simulation/agentgroup.py,sha256=JwfssUtVrOgSnJCan4jcIcSHLjWBCwYxqOPT-AXA2sE,12514
50
+ pycityagent/simulation/simulation.py,sha256=G68P1EJ3JceA3zID2O6AGd_KdhhYy5XVZVUgkfJHypc,18897
51
51
  pycityagent/survey/__init__.py,sha256=rxwou8U9KeFSP7rMzXtmtp2fVFZxK4Trzi-psx9LPIs,153
52
52
  pycityagent/survey/manager.py,sha256=N4-Q8vve4L-PLaFikAgVB4Z8BNyFDEd6WjEk3AyMTgs,1764
53
53
  pycityagent/survey/models.py,sha256=Z2gHEazQRj0TkTz5qbh4Uy_JrU_FZGWpOLwjN0RoUrY,3547
54
- pycityagent/utils/__init__.py,sha256=_0AV01UAHp4pTj2mlJ4m8LTH0toWL9j3DouNC6sDuUQ,249
55
- pycityagent/utils/avro_schema.py,sha256=MsRG0CsYAn2UjSlSXgCF8-3076VywpALEmz1mW-UJB0,2789
54
+ pycityagent/utils/__init__.py,sha256=xXEMhVfFeOJUXjczaHv9DJqYNp57rc6FibtS7CfrVbA,305
55
+ pycityagent/utils/avro_schema.py,sha256=DHM3bOo8m0dJf8oSwyOWzVeXrH6OERmzA_a5vS4So4M,4255
56
56
  pycityagent/utils/decorators.py,sha256=Gk3r41hfk6awui40tbwpq3C7wC7jHaRmLRlcJFlLQCE,3160
57
57
  pycityagent/utils/parsers/__init__.py,sha256=AN2xgiPxszWK4rpX7zrqRsqNwfGF3WnCA5-PFTvbaKk,281
58
58
  pycityagent/utils/parsers/code_block_parser.py,sha256=Cs2Z_hm9VfNCpPPll1TwteaJF-HAQPs-3RApsOekFm4,1173
@@ -62,8 +62,8 @@ pycityagent/utils/survey_util.py,sha256=Be9nptmu2JtesFNemPgORh_2GsN7rcDYGQS9Zfvc
62
62
  pycityagent/workflow/__init__.py,sha256=EyCcjB6LyBim-5iAOPe4m2qfvghEPqu1ZdGfy4KPeZ8,551
63
63
  pycityagent/workflow/block.py,sha256=6EmiRMLdOZC1wMlmLMIjfrp9TuiI7Gw4s3nnXVMbrnw,6031
64
64
  pycityagent/workflow/prompt.py,sha256=tY69nDO8fgYfF_dOA-iceR8pAhkYmCqoox8uRPqEuGY,2956
65
- pycityagent/workflow/tool.py,sha256=wD9WZ5rma6HCKugtHTwbShNE0f-Rjlwvn_1be3fCAsk,6682
65
+ pycityagent/workflow/tool.py,sha256=zMvz3BV4QBs5TqyQ3ziJxj4pCfL2uqUI3A1FbT1gd3Q,6626
66
66
  pycityagent/workflow/trigger.py,sha256=t5X_i0WtL32bipZSsq_E3UUyYYudYLxQUpvxbgClp2s,5683
67
- pycityagent-2.0.0a14.dist-info/METADATA,sha256=O6cGdAHKuG6W0Ya_Du_wEP8_xfM4xmOetYWb6rPj-WM,7705
68
- pycityagent-2.0.0a14.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
69
- pycityagent-2.0.0a14.dist-info/RECORD,,
67
+ pycityagent-2.0.0a15.dist-info/METADATA,sha256=8ONKHaTIPOGja6FKJDE9tZ-pLZL2aeXQufB-LWkElR8,7705
68
+ pycityagent-2.0.0a15.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
69
+ pycityagent-2.0.0a15.dist-info/RECORD,,