pycityagent 2.0.0a20__py3-none-any.whl → 2.0.0a22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,15 @@
1
- from typing import Any, Awaitable, TypeVar, Union, Dict
2
- from google.protobuf.message import Message
1
+ from collections.abc import Awaitable
2
+ from typing import Any, TypeVar, Union
3
+
3
4
  from google.protobuf.json_format import MessageToDict
5
+ from google.protobuf.message import Message
4
6
 
5
7
  __all__ = ["parse", "async_parse"]
6
8
 
7
9
  T = TypeVar("T", bound=Message)
8
10
 
9
11
 
10
- def parse(res: T, dict_return: bool) -> Union[Dict[str, Any], T]:
12
+ def parse(res: T, dict_return: bool) -> Union[dict[str, Any], T]:
11
13
  """
12
14
  将Protobuf返回值转换为dict或者原始值
13
15
  Convert Protobuf return value to dict or original value
@@ -23,7 +25,7 @@ def parse(res: T, dict_return: bool) -> Union[Dict[str, Any], T]:
23
25
  return res
24
26
 
25
27
 
26
- async def async_parse(res: Awaitable[T], dict_return: bool) -> Union[Dict[str, Any], T]:
28
+ async def async_parse(res: Awaitable[T], dict_return: bool) -> Union[dict[str, Any], T]:
27
29
  """
28
30
  将Protobuf await返回值转换为dict或者原始值
29
31
  Convert Protobuf await return value to dict or original value
@@ -6,7 +6,8 @@ import asyncio
6
6
  import logging
7
7
  import time
8
8
  from abc import ABC, abstractmethod
9
- from typing import Any, Callable, Dict, Optional, Sequence, Union
9
+ from collections.abc import Callable, Sequence
10
+ from typing import Any, Optional, Union
10
11
 
11
12
  from .const import *
12
13
 
@@ -16,8 +17,8 @@ logger = logging.getLogger("pycityagent")
16
17
  class MemoryUnit:
17
18
  def __init__(
18
19
  self,
19
- content: Optional[Dict] = None,
20
- required_attributes: Optional[Dict] = None,
20
+ content: Optional[dict] = None,
21
+ required_attributes: Optional[dict] = None,
21
22
  activate_timestamp: bool = False,
22
23
  ) -> None:
23
24
  self._content = {}
@@ -52,7 +53,7 @@ class MemoryUnit:
52
53
  else:
53
54
  setattr(self, f"{SELF_DEFINE_PREFIX}{property_name}", property_value)
54
55
 
55
- async def update(self, content: Dict) -> None:
56
+ async def update(self, content: dict) -> None:
56
57
  await self._lock.acquire()
57
58
  for k, v in content.items():
58
59
  if k in self._content:
@@ -111,14 +112,14 @@ class MemoryUnit:
111
112
 
112
113
  async def dict_values(
113
114
  self,
114
- ) -> Dict[Any, Any]:
115
+ ) -> dict[Any, Any]:
115
116
  return self._content
116
117
 
117
118
 
118
119
  class MemoryBase(ABC):
119
120
 
120
121
  def __init__(self) -> None:
121
- self._memories: Dict[Any, Dict] = {}
122
+ self._memories: dict[Any, dict] = {}
122
123
  self._lock = asyncio.Lock()
123
124
 
124
125
  @abstractmethod
@@ -2,8 +2,9 @@
2
2
  Agent Profile
3
3
  """
4
4
 
5
+ from collections.abc import Callable, Sequence
5
6
  from copy import deepcopy
6
- from typing import Any, Callable, Dict, Optional, Sequence, Union, cast
7
+ from typing import Any, Optional, Union, cast
7
8
 
8
9
  from ..utils.decorators import lock_decorator
9
10
  from .const import *
@@ -14,7 +15,7 @@ from .utils import convert_msg_to_sequence
14
15
  class ProfileMemoryUnit(MemoryUnit):
15
16
  def __init__(
16
17
  self,
17
- content: Optional[Dict] = None,
18
+ content: Optional[dict] = None,
18
19
  activate_timestamp: bool = False,
19
20
  ) -> None:
20
21
  super().__init__(
@@ -28,7 +29,7 @@ class ProfileMemory(MemoryBase):
28
29
  def __init__(
29
30
  self,
30
31
  msg: Optional[
31
- Union[ProfileMemoryUnit, Sequence[ProfileMemoryUnit], Dict, Sequence[Dict]]
32
+ Union[ProfileMemoryUnit, Sequence[ProfileMemoryUnit], dict, Sequence[dict]]
32
33
  ] = None,
33
34
  activate_timestamp: bool = False,
34
35
  ) -> None:
@@ -74,7 +75,7 @@ class ProfileMemory(MemoryBase):
74
75
  @lock_decorator
75
76
  async def load(
76
77
  self,
77
- snapshots: Union[Dict, Sequence[Dict]],
78
+ snapshots: Union[dict, Sequence[dict]],
78
79
  reset_memory: bool = False,
79
80
  ) -> None:
80
81
  if reset_memory:
@@ -91,7 +92,7 @@ class ProfileMemory(MemoryBase):
91
92
  @lock_decorator
92
93
  async def export(
93
94
  self,
94
- ) -> Sequence[Dict]:
95
+ ) -> Sequence[dict]:
95
96
  _res = []
96
97
  for m in self._memories.keys():
97
98
  m = cast(ProfileMemoryUnit, m)
@@ -145,7 +146,7 @@ class ProfileMemory(MemoryBase):
145
146
  self._memories[unit] = {}
146
147
 
147
148
  @lock_decorator
148
- async def update_dict(self, to_update_dict: Dict, store_snapshot: bool = False):
149
+ async def update_dict(self, to_update_dict: dict, store_snapshot: bool = False):
149
150
  _latest_memories = self._fetch_recent_memory()
150
151
  _latest_memory: ProfileMemoryUnit = _latest_memories[-1]
151
152
  if not store_snapshot:
@@ -2,8 +2,9 @@
2
2
  Self Define Data
3
3
  """
4
4
 
5
+ from collections.abc import Callable, Sequence
5
6
  from copy import deepcopy
6
- from typing import Any, Callable, Dict, Optional, Sequence, Union, cast
7
+ from typing import Any, Optional, Union, cast
7
8
 
8
9
  from ..utils.decorators import lock_decorator
9
10
  from .const import *
@@ -14,8 +15,8 @@ from .utils import convert_msg_to_sequence
14
15
  class DynamicMemoryUnit(MemoryUnit):
15
16
  def __init__(
16
17
  self,
17
- content: Optional[Dict] = None,
18
- required_attributes: Optional[Dict] = None,
18
+ content: Optional[dict] = None,
19
+ required_attributes: Optional[dict] = None,
19
20
  activate_timestamp: bool = False,
20
21
  ) -> None:
21
22
  super().__init__(
@@ -29,7 +30,7 @@ class DynamicMemory(MemoryBase):
29
30
 
30
31
  def __init__(
31
32
  self,
32
- required_attributes: Dict[Any, Any],
33
+ required_attributes: dict[Any, Any],
33
34
  activate_timestamp: bool = False,
34
35
  ) -> None:
35
36
  super().__init__()
@@ -69,7 +70,7 @@ class DynamicMemory(MemoryBase):
69
70
  @lock_decorator
70
71
  async def load(
71
72
  self,
72
- snapshots: Union[Dict, Sequence[Dict]],
73
+ snapshots: Union[dict, Sequence[dict]],
73
74
  reset_memory: bool = False,
74
75
  ) -> None:
75
76
  if reset_memory:
@@ -86,7 +87,7 @@ class DynamicMemory(MemoryBase):
86
87
  @lock_decorator
87
88
  async def export(
88
89
  self,
89
- ) -> Sequence[Dict]:
90
+ ) -> Sequence[dict]:
90
91
  _res = []
91
92
  for m in self._memories.keys():
92
93
  m = cast(DynamicMemoryUnit, m)
@@ -143,7 +144,7 @@ class DynamicMemory(MemoryBase):
143
144
  self._memories[unit] = {}
144
145
 
145
146
  @lock_decorator
146
- async def update_dict(self, to_update_dict: Dict, store_snapshot: bool = False):
147
+ async def update_dict(self, to_update_dict: dict, store_snapshot: bool = False):
147
148
  _latest_memories = self._fetch_recent_memory()
148
149
  _latest_memory: DynamicMemoryUnit = _latest_memories[-1]
149
150
  if not store_snapshot:
@@ -2,8 +2,9 @@
2
2
  Agent State
3
3
  """
4
4
 
5
+ from collections.abc import Callable, Sequence
5
6
  from copy import deepcopy
6
- from typing import Any, Callable, Dict, Optional, Sequence, Union, cast
7
+ from typing import Any, Optional, Union, cast
7
8
 
8
9
  from ..utils.decorators import lock_decorator
9
10
  from .const import *
@@ -14,7 +15,7 @@ from .utils import convert_msg_to_sequence
14
15
  class StateMemoryUnit(MemoryUnit):
15
16
  def __init__(
16
17
  self,
17
- content: Optional[Dict] = None,
18
+ content: Optional[dict] = None,
18
19
  activate_timestamp: bool = False,
19
20
  ) -> None:
20
21
  super().__init__(
@@ -28,7 +29,7 @@ class StateMemory(MemoryBase):
28
29
  def __init__(
29
30
  self,
30
31
  msg: Optional[
31
- Union[MemoryUnit, Sequence[MemoryUnit], Dict, Sequence[Dict]]
32
+ Union[MemoryUnit, Sequence[MemoryUnit], dict, Sequence[dict]]
32
33
  ] = None,
33
34
  activate_timestamp: bool = False,
34
35
  ) -> None:
@@ -73,7 +74,7 @@ class StateMemory(MemoryBase):
73
74
  @lock_decorator
74
75
  async def load(
75
76
  self,
76
- snapshots: Union[Dict, Sequence[Dict]],
77
+ snapshots: Union[dict, Sequence[dict]],
77
78
  reset_memory: bool = False,
78
79
  ) -> None:
79
80
 
@@ -91,7 +92,7 @@ class StateMemory(MemoryBase):
91
92
  @lock_decorator
92
93
  async def export(
93
94
  self,
94
- ) -> Sequence[Dict]:
95
+ ) -> Sequence[dict]:
95
96
 
96
97
  _res = []
97
98
  for m in self._memories.keys():
@@ -151,7 +152,7 @@ class StateMemory(MemoryBase):
151
152
  self._memories[unit] = {}
152
153
 
153
154
  @lock_decorator
154
- async def update_dict(self, to_update_dict: Dict, store_snapshot: bool = False):
155
+ async def update_dict(self, to_update_dict: dict, store_snapshot: bool = False):
155
156
 
156
157
  _latest_memories = self._fetch_recent_memory()
157
158
  _latest_memory: StateMemoryUnit = _latest_memories[-1]
@@ -1,4 +1,5 @@
1
- from typing import Any, Sequence, Union
1
+ from collections.abc import Sequence
2
+ from typing import Any, Union
2
3
 
3
4
  from .memory_base import MemoryUnit
4
5
 
@@ -3,5 +3,6 @@
3
3
  """
4
4
 
5
5
  from .simulation import AgentSimulation
6
+ from .storage.pg import PgWriter, create_pg_tables
6
7
 
7
- __all__ = ["AgentSimulation"]
8
+ __all__ = ["AgentSimulation", "PgWriter", "create_pg_tables"]
@@ -3,7 +3,7 @@ import json
3
3
  import logging
4
4
  import time
5
5
  import uuid
6
- from datetime import datetime
6
+ from datetime import datetime, timezone
7
7
  from pathlib import Path
8
8
  from typing import Any
9
9
  from uuid import UUID
@@ -35,7 +35,7 @@ class AgentGroup:
35
35
  enable_avro: bool,
36
36
  avro_path: Path,
37
37
  enable_pgsql: bool,
38
- pgsql_copy_writer: ray.ObjectRef,
38
+ pgsql_writer: ray.ObjectRef,
39
39
  mlflow_run_id: str,
40
40
  logging_level: int,
41
41
  ):
@@ -45,6 +45,7 @@ class AgentGroup:
45
45
  self.config = config
46
46
  self.exp_id = exp_id
47
47
  self.enable_avro = enable_avro
48
+ self.enable_pgsql = enable_pgsql
48
49
  if enable_avro:
49
50
  self.avro_path = avro_path / f"{self._uuid}"
50
51
  self.avro_path.mkdir(parents=True, exist_ok=True)
@@ -54,6 +55,8 @@ class AgentGroup:
54
55
  "status": self.avro_path / f"status.avro",
55
56
  "survey": self.avro_path / f"survey.avro",
56
57
  }
58
+ if self.enable_pgsql:
59
+ pass
57
60
 
58
61
  self.messager = Messager(
59
62
  hostname=config["simulator_request"]["mqtt"]["server"],
@@ -61,6 +64,8 @@ class AgentGroup:
61
64
  username=config["simulator_request"]["mqtt"].get("username", None),
62
65
  password=config["simulator_request"]["mqtt"].get("password", None),
63
66
  )
67
+ self._pgsql_writer = pgsql_writer
68
+ self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
64
69
  self.initialized = False
65
70
  self.id2agent = {}
66
71
  # Step:1 prepare LLM client
@@ -105,6 +110,8 @@ class AgentGroup:
105
110
  agent.set_messager(self.messager)
106
111
  if self.enable_avro:
107
112
  agent.set_avro_file(self.avro_file) # type: ignore
113
+ if self.enable_pgsql:
114
+ agent.set_pgsql_writer(self._pgsql_writer)
108
115
 
109
116
  async def init_agents(self):
110
117
  logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
@@ -161,6 +168,20 @@ class AgentGroup:
161
168
  with open(filename, "wb") as f:
162
169
  surveys = []
163
170
  fastavro.writer(f, SURVEY_SCHEMA, surveys)
171
+
172
+ if self.enable_pgsql:
173
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
174
+ profiles: list[Any] = []
175
+ for agent in self.agents:
176
+ profile = await agent.memory._profile.export()
177
+ profile = profile[0]
178
+ profile["id"] = agent._uuid
179
+ profiles.append(
180
+ (agent._uuid, profile.get("name", ""), json.dumps(profile))
181
+ )
182
+ await self._pgsql_writer.async_write_profile.remote( # type:ignore
183
+ profiles
184
+ )
164
185
  self.initialized = True
165
186
  logger.debug(f"-----AgentGroup {self._uuid} initialized")
166
187
 
@@ -218,11 +239,13 @@ class AgentGroup:
218
239
  await asyncio.sleep(0.5)
219
240
 
220
241
  async def save_status(self):
242
+ _statuses_time_list: list[tuple[dict, datetime]] = []
221
243
  if self.enable_avro:
222
244
  logger.debug(f"-----Saving status for group {self._uuid}")
223
245
  avros = []
224
246
  if not issubclass(type(self.agents[0]), InstitutionAgent):
225
247
  for agent in self.agents:
248
+ _date_time = datetime.now(timezone.utc)
226
249
  position = await agent.memory.get("position")
227
250
  lng = position["longlat_position"]["longitude"]
228
251
  lat = position["longlat_position"]["latitude"]
@@ -248,13 +271,15 @@ class AgentGroup:
248
271
  "tired": needs["tired"],
249
272
  "safe": needs["safe"],
250
273
  "social": needs["social"],
251
- "created_at": int(datetime.now().timestamp() * 1000),
274
+ "created_at": int(_date_time.timestamp() * 1000),
252
275
  }
253
276
  avros.append(avro)
277
+ _statuses_time_list.append((avro, _date_time))
254
278
  with open(self.avro_file["status"], "a+b") as f:
255
279
  fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
256
280
  else:
257
281
  for agent in self.agents:
282
+ _date_time = datetime.now(timezone.utc)
258
283
  avro = {
259
284
  "id": agent._uuid,
260
285
  "day": await self.simulator.get_simulator_day(),
@@ -274,8 +299,109 @@ class AgentGroup:
274
299
  "customers": await agent.memory.get("customers"),
275
300
  }
276
301
  avros.append(avro)
302
+ _statuses_time_list.append((avro, _date_time))
277
303
  with open(self.avro_file["status"], "a+b") as f:
278
304
  fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, avros, codec="snappy")
305
+ if self.enable_pgsql:
306
+ # data already acquired from Avro part
307
+ if len(_statuses_time_list) > 0:
308
+ for _status_dict, _date_time in _statuses_time_list:
309
+ for key in ["lng", "lat", "parent_id"]:
310
+ if key not in _status_dict:
311
+ _status_dict[key] = -1
312
+ for key in [
313
+ "action",
314
+ ]:
315
+ if key not in _status_dict:
316
+ _status_dict[key] = ""
317
+ _status_dict["created_at"] = _date_time
318
+ else:
319
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
320
+ for agent in self.agents:
321
+ _date_time = datetime.now(timezone.utc)
322
+ position = await agent.memory.get("position")
323
+ lng = position["longlat_position"]["longitude"]
324
+ lat = position["longlat_position"]["latitude"]
325
+ if "aoi_position" in position:
326
+ parent_id = position["aoi_position"]["aoi_id"]
327
+ elif "lane_position" in position:
328
+ parent_id = position["lane_position"]["lane_id"]
329
+ else:
330
+ # BUG: 需要处理
331
+ parent_id = -1
332
+ needs = await agent.memory.get("needs")
333
+ action = await agent.memory.get("current_step")
334
+ action = action["intention"]
335
+ _status_dict = {
336
+ "id": agent._uuid,
337
+ "day": await self.simulator.get_simulator_day(),
338
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
339
+ "lng": lng,
340
+ "lat": lat,
341
+ "parent_id": parent_id,
342
+ "action": action,
343
+ "hungry": needs["hungry"],
344
+ "tired": needs["tired"],
345
+ "safe": needs["safe"],
346
+ "social": needs["social"],
347
+ "created_at": _date_time,
348
+ }
349
+ _statuses_time_list.append((_status_dict, _date_time))
350
+ else:
351
+ for agent in self.agents:
352
+ _date_time = datetime.now(timezone.utc)
353
+ _status_dict = {
354
+ "id": agent._uuid,
355
+ "day": await self.simulator.get_simulator_day(),
356
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
357
+ "lng": -1,
358
+ "lat": -1,
359
+ "parent_id": -1,
360
+ "action": "",
361
+ "type": await agent.memory.get("type"),
362
+ "nominal_gdp": await agent.memory.get("nominal_gdp"),
363
+ "real_gdp": await agent.memory.get("real_gdp"),
364
+ "unemployment": await agent.memory.get("unemployment"),
365
+ "wages": await agent.memory.get("wages"),
366
+ "prices": await agent.memory.get("prices"),
367
+ "inventory": await agent.memory.get("inventory"),
368
+ "price": await agent.memory.get("price"),
369
+ "interest_rate": await agent.memory.get("interest_rate"),
370
+ "bracket_cutoffs": await agent.memory.get(
371
+ "bracket_cutoffs"
372
+ ),
373
+ "bracket_rates": await agent.memory.get("bracket_rates"),
374
+ "employees": await agent.memory.get("employees"),
375
+ "customers": await agent.memory.get("customers"),
376
+ "created_at": _date_time,
377
+ }
378
+ _statuses_time_list.append((_status_dict, _date_time))
379
+ to_update_statues: list[tuple] = []
380
+ for _status_dict, _ in _statuses_time_list:
381
+ BASIC_KEYS = [
382
+ "id",
383
+ "day",
384
+ "t",
385
+ "lng",
386
+ "lat",
387
+ "parent_id",
388
+ "action",
389
+ "created_at",
390
+ ]
391
+ _data = [_status_dict[k] for k in BASIC_KEYS if k != "created_at"]
392
+ _other_dict = json.dumps(
393
+ {k: v for k, v in _status_dict.items() if k not in BASIC_KEYS}
394
+ )
395
+ _data.append(_other_dict)
396
+ _data.append(_status_dict["created_at"])
397
+ to_update_statues.append(tuple(_data))
398
+ if self._last_asyncio_pg_task is not None:
399
+ await self._last_asyncio_pg_task
400
+ self._last_asyncio_pg_task = (
401
+ self._pgsql_writer.async_write_status.remote( # type:ignore
402
+ to_update_statues
403
+ )
404
+ )
279
405
 
280
406
  async def step(self):
281
407
  if not self.initialized:
@@ -23,6 +23,7 @@ from ..message.messager import Messager
23
23
  from ..metrics import init_mlflow_connection
24
24
  from ..survey import Survey
25
25
  from .agentgroup import AgentGroup
26
+ from .storage.pg import PgWriter, create_pg_tables
26
27
 
27
28
  logger = logging.getLogger("pycityagent")
28
29
 
@@ -63,6 +64,7 @@ class AgentSimulation:
63
64
  self._user_survey_topics: dict[uuid.UUID, str] = {}
64
65
  self._user_interview_topics: dict[uuid.UUID, str] = {}
65
66
  self._loop = asyncio.get_event_loop()
67
+ # self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
66
68
 
67
69
  self._messager = Messager(
68
70
  hostname=config["simulator_request"]["mqtt"]["server"],
@@ -89,22 +91,13 @@ class AgentSimulation:
89
91
  self._enable_pgsql = _pgsql_config.get("enabled", False)
90
92
  if not self._enable_pgsql:
91
93
  logger.warning("PostgreSQL is not enabled, NO POSTGRESQL DATABASE STORAGE")
92
- self._pgsql_args = ("", "", "", "", "")
94
+ self._pgsql_dsn = ""
93
95
  else:
94
- self._pgsql_host = _pgsql_config["host"]
95
- self._pgsql_port = _pgsql_config["port"]
96
- self._pgsql_database = _pgsql_config["database"]
97
- self._pgsql_user = _pgsql_config.get("user", None)
98
- self._pgsql_password = _pgsql_config.get("password", None)
99
- self._pgsql_args: tuple[str, str, str, str, str] = (
100
- self._pgsql_host,
101
- self._pgsql_port,
102
- self._pgsql_database,
103
- self._pgsql_user,
104
- self._pgsql_password,
105
- )
96
+ self._pgsql_dsn = _pgsql_config["data_source_name"]
106
97
 
107
98
  # 添加实验信息相关的属性
99
+ self._exp_created_time = datetime.now(timezone.utc)
100
+ self._exp_updated_time = datetime.now(timezone.utc)
108
101
  self._exp_info = {
109
102
  "id": self.exp_id,
110
103
  "name": exp_name,
@@ -114,7 +107,8 @@ class AgentSimulation:
114
107
  "cur_t": 0.0,
115
108
  "config": json.dumps(config),
116
109
  "error": "",
117
- "created_at": datetime.now(timezone.utc).isoformat(),
110
+ "created_at": self._exp_created_time.isoformat(),
111
+ "updated_at": self._exp_updated_time.isoformat(),
118
112
  }
119
113
 
120
114
  # 创建异步任务保存实验信息
@@ -168,7 +162,7 @@ class AgentSimulation:
168
162
  enable_avro: bool,
169
163
  avro_path: Path,
170
164
  enable_pgsql: bool,
171
- pgsql_copy_writer: ray.ObjectRef,
165
+ pgsql_writer: ray.ObjectRef,
172
166
  mlflow_run_id: str = None, # type: ignore
173
167
  logging_level: int = logging.WARNING,
174
168
  ):
@@ -181,7 +175,7 @@ class AgentSimulation:
181
175
  enable_avro,
182
176
  avro_path,
183
177
  enable_pgsql,
184
- pgsql_copy_writer,
178
+ pgsql_writer,
185
179
  mlflow_run_id,
186
180
  logging_level,
187
181
  )
@@ -191,6 +185,7 @@ class AgentSimulation:
191
185
  self,
192
186
  agent_count: Union[int, list[int]],
193
187
  group_size: int = 1000,
188
+ pg_sql_writers: int = 32,
194
189
  memory_config_func: Optional[Union[Callable, list[Callable]]] = None,
195
190
  ) -> None:
196
191
  """初始化智能体
@@ -251,8 +246,8 @@ class AgentSimulation:
251
246
  memory=memory,
252
247
  )
253
248
 
254
- self._agents[agent._uuid] = agent
255
- self._agent_uuids.append(agent._uuid)
249
+ self._agents[agent._uuid] = agent # type:ignore
250
+ self._agent_uuids.append(agent._uuid) # type:ignore
256
251
 
257
252
  # 计算需要的组数,向上取整以处理不足一组的情况
258
253
  num_group = (agent_count_i + group_size - 1) // group_size
@@ -282,9 +277,23 @@ class AgentSimulation:
282
277
  )
283
278
  else:
284
279
  mlflow_run_id = None
280
+ # 建表
281
+ if self.enable_pgsql:
282
+ _num_workers = min(1, pg_sql_writers)
283
+ create_pg_tables(
284
+ exp_id=self.exp_id,
285
+ dsn=self._pgsql_dsn,
286
+ )
287
+ self._pgsql_writers = _workers = [
288
+ PgWriter.remote(self.exp_id, self._pgsql_dsn)
289
+ for _ in range(_num_workers)
290
+ ]
291
+ else:
292
+ _num_workers = 1
293
+ self._pgsql_writers = _workers = [None for _ in range(_num_workers)]
285
294
  # 收集所有创建组的参数
286
295
  creation_tasks = []
287
- for group_name, agents in group_creation_params:
296
+ for i, (group_name, agents) in enumerate(group_creation_params):
288
297
  # 直接创建异步任务
289
298
  group = AgentGroup.remote(
290
299
  agents,
@@ -294,10 +303,8 @@ class AgentSimulation:
294
303
  self.enable_avro,
295
304
  self.avro_path,
296
305
  self.enable_pgsql,
297
- # TODO:
298
- # self._pgsql_copy_writer, # type:ignore
299
- None,
300
- mlflow_run_id,
306
+ _workers[i % _num_workers], # type:ignore
307
+ mlflow_run_id, # type:ignore
301
308
  self.logging_level,
302
309
  )
303
310
  creation_tasks.append((group_name, group, agents))
@@ -469,11 +476,13 @@ class AgentSimulation:
469
476
  survey_dict = survey.to_dict()
470
477
  if agent_uuids is None:
471
478
  agent_uuids = self._agent_uuids
479
+ _date_time = datetime.now(timezone.utc)
472
480
  payload = {
473
481
  "from": "none",
474
482
  "survey_id": survey_dict["id"],
475
- "timestamp": int(datetime.now().timestamp() * 1000),
483
+ "timestamp": int(_date_time.timestamp() * 1000),
476
484
  "data": survey_dict,
485
+ "_date_time": _date_time,
477
486
  }
478
487
  for uuid in agent_uuids:
479
488
  topic = self._user_survey_topics[uuid]
@@ -483,10 +492,12 @@ class AgentSimulation:
483
492
  self, content: str, agent_uuids: Union[uuid.UUID, list[uuid.UUID]]
484
493
  ):
485
494
  """发送面试消息"""
495
+ _date_time = datetime.now(timezone.utc)
486
496
  payload = {
487
497
  "from": "none",
488
498
  "content": content,
489
- "timestamp": int(datetime.now().timestamp() * 1000),
499
+ "timestamp": int(_date_time.timestamp() * 1000),
500
+ "_date_time": _date_time,
490
501
  }
491
502
  if not isinstance(agent_uuids, Sequence):
492
503
  agent_uuids = [agent_uuids]
@@ -515,15 +526,29 @@ class AgentSimulation:
515
526
  logger.error(f"Avro保存实验信息失败: {str(e)}")
516
527
  try:
517
528
  if self.enable_pgsql:
518
- # TODO
519
- pass
529
+ worker: ray.ObjectRef = self._pgsql_writers[0] # type:ignore
530
+ # if self._last_asyncio_pg_task is not None:
531
+ # await self._last_asyncio_pg_task
532
+ # self._last_asyncio_pg_task = (
533
+ # worker.async_update_exp_info.remote( # type:ignore
534
+ # pg_exp_info
535
+ # )
536
+ # )
537
+ pg_exp_info = {k: v for k, v in self._exp_info.items()}
538
+ pg_exp_info["created_at"] = self._exp_created_time
539
+ pg_exp_info["updated_at"] = self._exp_updated_time
540
+ await worker.async_update_exp_info.remote( # type:ignore
541
+ pg_exp_info
542
+ )
520
543
  except Exception as e:
521
544
  logger.error(f"PostgreSQL保存实验信息失败: {str(e)}")
522
545
 
523
546
  async def _update_exp_status(self, status: int, error: str = "") -> None:
547
+ self._exp_updated_time = datetime.now(timezone.utc)
524
548
  """更新实验状态并保存"""
525
549
  self._exp_info["status"] = status
526
550
  self._exp_info["error"] = error
551
+ self._exp_info["updated_at"] = self._exp_updated_time.isoformat()
527
552
  await self._save_exp_info()
528
553
 
529
554
  async def _monitor_exp_status(self, stop_event: asyncio.Event):