pycityagent 2.0.0a20__py3-none-any.whl → 2.0.0a22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent.py +169 -63
- pycityagent/environment/sim/aoi_service.py +2 -1
- pycityagent/environment/sim/clock_service.py +2 -1
- pycityagent/environment/sim/economy_services.py +9 -8
- pycityagent/environment/sim/lane_service.py +6 -5
- pycityagent/environment/sim/light_service.py +10 -8
- pycityagent/environment/sim/person_service.py +12 -11
- pycityagent/environment/sim/road_service.py +3 -2
- pycityagent/environment/sim/social_service.py +4 -3
- pycityagent/environment/utils/protobuf.py +6 -4
- pycityagent/memory/memory_base.py +7 -6
- pycityagent/memory/profile.py +7 -6
- pycityagent/memory/self_define.py +8 -7
- pycityagent/memory/state.py +7 -6
- pycityagent/memory/utils.py +2 -1
- pycityagent/simulation/__init__.py +2 -1
- pycityagent/simulation/agentgroup.py +129 -3
- pycityagent/simulation/simulation.py +52 -27
- pycityagent/simulation/storage/pg.py +139 -0
- pycityagent/utils/parsers/json_parser.py +3 -3
- pycityagent/utils/pg_query.py +80 -0
- pycityagent/workflow/block.py +2 -1
- pycityagent/workflow/tool.py +32 -24
- {pycityagent-2.0.0a20.dist-info → pycityagent-2.0.0a22.dist-info}/METADATA +1 -1
- {pycityagent-2.0.0a20.dist-info → pycityagent-2.0.0a22.dist-info}/RECORD +26 -24
- {pycityagent-2.0.0a20.dist-info → pycityagent-2.0.0a22.dist-info}/WHEEL +0 -0
@@ -1,13 +1,15 @@
|
|
1
|
-
from
|
2
|
-
from
|
1
|
+
from collections.abc import Awaitable
|
2
|
+
from typing import Any, TypeVar, Union
|
3
|
+
|
3
4
|
from google.protobuf.json_format import MessageToDict
|
5
|
+
from google.protobuf.message import Message
|
4
6
|
|
5
7
|
__all__ = ["parse", "async_parse"]
|
6
8
|
|
7
9
|
T = TypeVar("T", bound=Message)
|
8
10
|
|
9
11
|
|
10
|
-
def parse(res: T, dict_return: bool) -> Union[
|
12
|
+
def parse(res: T, dict_return: bool) -> Union[dict[str, Any], T]:
|
11
13
|
"""
|
12
14
|
将Protobuf返回值转换为dict或者原始值
|
13
15
|
Convert Protobuf return value to dict or original value
|
@@ -23,7 +25,7 @@ def parse(res: T, dict_return: bool) -> Union[Dict[str, Any], T]:
|
|
23
25
|
return res
|
24
26
|
|
25
27
|
|
26
|
-
async def async_parse(res: Awaitable[T], dict_return: bool) -> Union[
|
28
|
+
async def async_parse(res: Awaitable[T], dict_return: bool) -> Union[dict[str, Any], T]:
|
27
29
|
"""
|
28
30
|
将Protobuf await返回值转换为dict或者原始值
|
29
31
|
Convert Protobuf await return value to dict or original value
|
@@ -6,7 +6,8 @@ import asyncio
|
|
6
6
|
import logging
|
7
7
|
import time
|
8
8
|
from abc import ABC, abstractmethod
|
9
|
-
from
|
9
|
+
from collections.abc import Callable, Sequence
|
10
|
+
from typing import Any, Optional, Union
|
10
11
|
|
11
12
|
from .const import *
|
12
13
|
|
@@ -16,8 +17,8 @@ logger = logging.getLogger("pycityagent")
|
|
16
17
|
class MemoryUnit:
|
17
18
|
def __init__(
|
18
19
|
self,
|
19
|
-
content: Optional[
|
20
|
-
required_attributes: Optional[
|
20
|
+
content: Optional[dict] = None,
|
21
|
+
required_attributes: Optional[dict] = None,
|
21
22
|
activate_timestamp: bool = False,
|
22
23
|
) -> None:
|
23
24
|
self._content = {}
|
@@ -52,7 +53,7 @@ class MemoryUnit:
|
|
52
53
|
else:
|
53
54
|
setattr(self, f"{SELF_DEFINE_PREFIX}{property_name}", property_value)
|
54
55
|
|
55
|
-
async def update(self, content:
|
56
|
+
async def update(self, content: dict) -> None:
|
56
57
|
await self._lock.acquire()
|
57
58
|
for k, v in content.items():
|
58
59
|
if k in self._content:
|
@@ -111,14 +112,14 @@ class MemoryUnit:
|
|
111
112
|
|
112
113
|
async def dict_values(
|
113
114
|
self,
|
114
|
-
) ->
|
115
|
+
) -> dict[Any, Any]:
|
115
116
|
return self._content
|
116
117
|
|
117
118
|
|
118
119
|
class MemoryBase(ABC):
|
119
120
|
|
120
121
|
def __init__(self) -> None:
|
121
|
-
self._memories:
|
122
|
+
self._memories: dict[Any, dict] = {}
|
122
123
|
self._lock = asyncio.Lock()
|
123
124
|
|
124
125
|
@abstractmethod
|
pycityagent/memory/profile.py
CHANGED
@@ -2,8 +2,9 @@
|
|
2
2
|
Agent Profile
|
3
3
|
"""
|
4
4
|
|
5
|
+
from collections.abc import Callable, Sequence
|
5
6
|
from copy import deepcopy
|
6
|
-
from typing import Any,
|
7
|
+
from typing import Any, Optional, Union, cast
|
7
8
|
|
8
9
|
from ..utils.decorators import lock_decorator
|
9
10
|
from .const import *
|
@@ -14,7 +15,7 @@ from .utils import convert_msg_to_sequence
|
|
14
15
|
class ProfileMemoryUnit(MemoryUnit):
|
15
16
|
def __init__(
|
16
17
|
self,
|
17
|
-
content: Optional[
|
18
|
+
content: Optional[dict] = None,
|
18
19
|
activate_timestamp: bool = False,
|
19
20
|
) -> None:
|
20
21
|
super().__init__(
|
@@ -28,7 +29,7 @@ class ProfileMemory(MemoryBase):
|
|
28
29
|
def __init__(
|
29
30
|
self,
|
30
31
|
msg: Optional[
|
31
|
-
Union[ProfileMemoryUnit, Sequence[ProfileMemoryUnit],
|
32
|
+
Union[ProfileMemoryUnit, Sequence[ProfileMemoryUnit], dict, Sequence[dict]]
|
32
33
|
] = None,
|
33
34
|
activate_timestamp: bool = False,
|
34
35
|
) -> None:
|
@@ -74,7 +75,7 @@ class ProfileMemory(MemoryBase):
|
|
74
75
|
@lock_decorator
|
75
76
|
async def load(
|
76
77
|
self,
|
77
|
-
snapshots: Union[
|
78
|
+
snapshots: Union[dict, Sequence[dict]],
|
78
79
|
reset_memory: bool = False,
|
79
80
|
) -> None:
|
80
81
|
if reset_memory:
|
@@ -91,7 +92,7 @@ class ProfileMemory(MemoryBase):
|
|
91
92
|
@lock_decorator
|
92
93
|
async def export(
|
93
94
|
self,
|
94
|
-
) -> Sequence[
|
95
|
+
) -> Sequence[dict]:
|
95
96
|
_res = []
|
96
97
|
for m in self._memories.keys():
|
97
98
|
m = cast(ProfileMemoryUnit, m)
|
@@ -145,7 +146,7 @@ class ProfileMemory(MemoryBase):
|
|
145
146
|
self._memories[unit] = {}
|
146
147
|
|
147
148
|
@lock_decorator
|
148
|
-
async def update_dict(self, to_update_dict:
|
149
|
+
async def update_dict(self, to_update_dict: dict, store_snapshot: bool = False):
|
149
150
|
_latest_memories = self._fetch_recent_memory()
|
150
151
|
_latest_memory: ProfileMemoryUnit = _latest_memories[-1]
|
151
152
|
if not store_snapshot:
|
@@ -2,8 +2,9 @@
|
|
2
2
|
Self Define Data
|
3
3
|
"""
|
4
4
|
|
5
|
+
from collections.abc import Callable, Sequence
|
5
6
|
from copy import deepcopy
|
6
|
-
from typing import Any,
|
7
|
+
from typing import Any, Optional, Union, cast
|
7
8
|
|
8
9
|
from ..utils.decorators import lock_decorator
|
9
10
|
from .const import *
|
@@ -14,8 +15,8 @@ from .utils import convert_msg_to_sequence
|
|
14
15
|
class DynamicMemoryUnit(MemoryUnit):
|
15
16
|
def __init__(
|
16
17
|
self,
|
17
|
-
content: Optional[
|
18
|
-
required_attributes: Optional[
|
18
|
+
content: Optional[dict] = None,
|
19
|
+
required_attributes: Optional[dict] = None,
|
19
20
|
activate_timestamp: bool = False,
|
20
21
|
) -> None:
|
21
22
|
super().__init__(
|
@@ -29,7 +30,7 @@ class DynamicMemory(MemoryBase):
|
|
29
30
|
|
30
31
|
def __init__(
|
31
32
|
self,
|
32
|
-
required_attributes:
|
33
|
+
required_attributes: dict[Any, Any],
|
33
34
|
activate_timestamp: bool = False,
|
34
35
|
) -> None:
|
35
36
|
super().__init__()
|
@@ -69,7 +70,7 @@ class DynamicMemory(MemoryBase):
|
|
69
70
|
@lock_decorator
|
70
71
|
async def load(
|
71
72
|
self,
|
72
|
-
snapshots: Union[
|
73
|
+
snapshots: Union[dict, Sequence[dict]],
|
73
74
|
reset_memory: bool = False,
|
74
75
|
) -> None:
|
75
76
|
if reset_memory:
|
@@ -86,7 +87,7 @@ class DynamicMemory(MemoryBase):
|
|
86
87
|
@lock_decorator
|
87
88
|
async def export(
|
88
89
|
self,
|
89
|
-
) -> Sequence[
|
90
|
+
) -> Sequence[dict]:
|
90
91
|
_res = []
|
91
92
|
for m in self._memories.keys():
|
92
93
|
m = cast(DynamicMemoryUnit, m)
|
@@ -143,7 +144,7 @@ class DynamicMemory(MemoryBase):
|
|
143
144
|
self._memories[unit] = {}
|
144
145
|
|
145
146
|
@lock_decorator
|
146
|
-
async def update_dict(self, to_update_dict:
|
147
|
+
async def update_dict(self, to_update_dict: dict, store_snapshot: bool = False):
|
147
148
|
_latest_memories = self._fetch_recent_memory()
|
148
149
|
_latest_memory: DynamicMemoryUnit = _latest_memories[-1]
|
149
150
|
if not store_snapshot:
|
pycityagent/memory/state.py
CHANGED
@@ -2,8 +2,9 @@
|
|
2
2
|
Agent State
|
3
3
|
"""
|
4
4
|
|
5
|
+
from collections.abc import Callable, Sequence
|
5
6
|
from copy import deepcopy
|
6
|
-
from typing import Any,
|
7
|
+
from typing import Any, Optional, Union, cast
|
7
8
|
|
8
9
|
from ..utils.decorators import lock_decorator
|
9
10
|
from .const import *
|
@@ -14,7 +15,7 @@ from .utils import convert_msg_to_sequence
|
|
14
15
|
class StateMemoryUnit(MemoryUnit):
|
15
16
|
def __init__(
|
16
17
|
self,
|
17
|
-
content: Optional[
|
18
|
+
content: Optional[dict] = None,
|
18
19
|
activate_timestamp: bool = False,
|
19
20
|
) -> None:
|
20
21
|
super().__init__(
|
@@ -28,7 +29,7 @@ class StateMemory(MemoryBase):
|
|
28
29
|
def __init__(
|
29
30
|
self,
|
30
31
|
msg: Optional[
|
31
|
-
Union[MemoryUnit, Sequence[MemoryUnit],
|
32
|
+
Union[MemoryUnit, Sequence[MemoryUnit], dict, Sequence[dict]]
|
32
33
|
] = None,
|
33
34
|
activate_timestamp: bool = False,
|
34
35
|
) -> None:
|
@@ -73,7 +74,7 @@ class StateMemory(MemoryBase):
|
|
73
74
|
@lock_decorator
|
74
75
|
async def load(
|
75
76
|
self,
|
76
|
-
snapshots: Union[
|
77
|
+
snapshots: Union[dict, Sequence[dict]],
|
77
78
|
reset_memory: bool = False,
|
78
79
|
) -> None:
|
79
80
|
|
@@ -91,7 +92,7 @@ class StateMemory(MemoryBase):
|
|
91
92
|
@lock_decorator
|
92
93
|
async def export(
|
93
94
|
self,
|
94
|
-
) -> Sequence[
|
95
|
+
) -> Sequence[dict]:
|
95
96
|
|
96
97
|
_res = []
|
97
98
|
for m in self._memories.keys():
|
@@ -151,7 +152,7 @@ class StateMemory(MemoryBase):
|
|
151
152
|
self._memories[unit] = {}
|
152
153
|
|
153
154
|
@lock_decorator
|
154
|
-
async def update_dict(self, to_update_dict:
|
155
|
+
async def update_dict(self, to_update_dict: dict, store_snapshot: bool = False):
|
155
156
|
|
156
157
|
_latest_memories = self._fetch_recent_memory()
|
157
158
|
_latest_memory: StateMemoryUnit = _latest_memories[-1]
|
pycityagent/memory/utils.py
CHANGED
@@ -3,7 +3,7 @@ import json
|
|
3
3
|
import logging
|
4
4
|
import time
|
5
5
|
import uuid
|
6
|
-
from datetime import datetime
|
6
|
+
from datetime import datetime, timezone
|
7
7
|
from pathlib import Path
|
8
8
|
from typing import Any
|
9
9
|
from uuid import UUID
|
@@ -35,7 +35,7 @@ class AgentGroup:
|
|
35
35
|
enable_avro: bool,
|
36
36
|
avro_path: Path,
|
37
37
|
enable_pgsql: bool,
|
38
|
-
|
38
|
+
pgsql_writer: ray.ObjectRef,
|
39
39
|
mlflow_run_id: str,
|
40
40
|
logging_level: int,
|
41
41
|
):
|
@@ -45,6 +45,7 @@ class AgentGroup:
|
|
45
45
|
self.config = config
|
46
46
|
self.exp_id = exp_id
|
47
47
|
self.enable_avro = enable_avro
|
48
|
+
self.enable_pgsql = enable_pgsql
|
48
49
|
if enable_avro:
|
49
50
|
self.avro_path = avro_path / f"{self._uuid}"
|
50
51
|
self.avro_path.mkdir(parents=True, exist_ok=True)
|
@@ -54,6 +55,8 @@ class AgentGroup:
|
|
54
55
|
"status": self.avro_path / f"status.avro",
|
55
56
|
"survey": self.avro_path / f"survey.avro",
|
56
57
|
}
|
58
|
+
if self.enable_pgsql:
|
59
|
+
pass
|
57
60
|
|
58
61
|
self.messager = Messager(
|
59
62
|
hostname=config["simulator_request"]["mqtt"]["server"],
|
@@ -61,6 +64,8 @@ class AgentGroup:
|
|
61
64
|
username=config["simulator_request"]["mqtt"].get("username", None),
|
62
65
|
password=config["simulator_request"]["mqtt"].get("password", None),
|
63
66
|
)
|
67
|
+
self._pgsql_writer = pgsql_writer
|
68
|
+
self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
64
69
|
self.initialized = False
|
65
70
|
self.id2agent = {}
|
66
71
|
# Step:1 prepare LLM client
|
@@ -105,6 +110,8 @@ class AgentGroup:
|
|
105
110
|
agent.set_messager(self.messager)
|
106
111
|
if self.enable_avro:
|
107
112
|
agent.set_avro_file(self.avro_file) # type: ignore
|
113
|
+
if self.enable_pgsql:
|
114
|
+
agent.set_pgsql_writer(self._pgsql_writer)
|
108
115
|
|
109
116
|
async def init_agents(self):
|
110
117
|
logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
|
@@ -161,6 +168,20 @@ class AgentGroup:
|
|
161
168
|
with open(filename, "wb") as f:
|
162
169
|
surveys = []
|
163
170
|
fastavro.writer(f, SURVEY_SCHEMA, surveys)
|
171
|
+
|
172
|
+
if self.enable_pgsql:
|
173
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
174
|
+
profiles: list[Any] = []
|
175
|
+
for agent in self.agents:
|
176
|
+
profile = await agent.memory._profile.export()
|
177
|
+
profile = profile[0]
|
178
|
+
profile["id"] = agent._uuid
|
179
|
+
profiles.append(
|
180
|
+
(agent._uuid, profile.get("name", ""), json.dumps(profile))
|
181
|
+
)
|
182
|
+
await self._pgsql_writer.async_write_profile.remote( # type:ignore
|
183
|
+
profiles
|
184
|
+
)
|
164
185
|
self.initialized = True
|
165
186
|
logger.debug(f"-----AgentGroup {self._uuid} initialized")
|
166
187
|
|
@@ -218,11 +239,13 @@ class AgentGroup:
|
|
218
239
|
await asyncio.sleep(0.5)
|
219
240
|
|
220
241
|
async def save_status(self):
|
242
|
+
_statuses_time_list: list[tuple[dict, datetime]] = []
|
221
243
|
if self.enable_avro:
|
222
244
|
logger.debug(f"-----Saving status for group {self._uuid}")
|
223
245
|
avros = []
|
224
246
|
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
225
247
|
for agent in self.agents:
|
248
|
+
_date_time = datetime.now(timezone.utc)
|
226
249
|
position = await agent.memory.get("position")
|
227
250
|
lng = position["longlat_position"]["longitude"]
|
228
251
|
lat = position["longlat_position"]["latitude"]
|
@@ -248,13 +271,15 @@ class AgentGroup:
|
|
248
271
|
"tired": needs["tired"],
|
249
272
|
"safe": needs["safe"],
|
250
273
|
"social": needs["social"],
|
251
|
-
"created_at": int(
|
274
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
252
275
|
}
|
253
276
|
avros.append(avro)
|
277
|
+
_statuses_time_list.append((avro, _date_time))
|
254
278
|
with open(self.avro_file["status"], "a+b") as f:
|
255
279
|
fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
|
256
280
|
else:
|
257
281
|
for agent in self.agents:
|
282
|
+
_date_time = datetime.now(timezone.utc)
|
258
283
|
avro = {
|
259
284
|
"id": agent._uuid,
|
260
285
|
"day": await self.simulator.get_simulator_day(),
|
@@ -274,8 +299,109 @@ class AgentGroup:
|
|
274
299
|
"customers": await agent.memory.get("customers"),
|
275
300
|
}
|
276
301
|
avros.append(avro)
|
302
|
+
_statuses_time_list.append((avro, _date_time))
|
277
303
|
with open(self.avro_file["status"], "a+b") as f:
|
278
304
|
fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, avros, codec="snappy")
|
305
|
+
if self.enable_pgsql:
|
306
|
+
# data already acquired from Avro part
|
307
|
+
if len(_statuses_time_list) > 0:
|
308
|
+
for _status_dict, _date_time in _statuses_time_list:
|
309
|
+
for key in ["lng", "lat", "parent_id"]:
|
310
|
+
if key not in _status_dict:
|
311
|
+
_status_dict[key] = -1
|
312
|
+
for key in [
|
313
|
+
"action",
|
314
|
+
]:
|
315
|
+
if key not in _status_dict:
|
316
|
+
_status_dict[key] = ""
|
317
|
+
_status_dict["created_at"] = _date_time
|
318
|
+
else:
|
319
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
320
|
+
for agent in self.agents:
|
321
|
+
_date_time = datetime.now(timezone.utc)
|
322
|
+
position = await agent.memory.get("position")
|
323
|
+
lng = position["longlat_position"]["longitude"]
|
324
|
+
lat = position["longlat_position"]["latitude"]
|
325
|
+
if "aoi_position" in position:
|
326
|
+
parent_id = position["aoi_position"]["aoi_id"]
|
327
|
+
elif "lane_position" in position:
|
328
|
+
parent_id = position["lane_position"]["lane_id"]
|
329
|
+
else:
|
330
|
+
# BUG: 需要处理
|
331
|
+
parent_id = -1
|
332
|
+
needs = await agent.memory.get("needs")
|
333
|
+
action = await agent.memory.get("current_step")
|
334
|
+
action = action["intention"]
|
335
|
+
_status_dict = {
|
336
|
+
"id": agent._uuid,
|
337
|
+
"day": await self.simulator.get_simulator_day(),
|
338
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
339
|
+
"lng": lng,
|
340
|
+
"lat": lat,
|
341
|
+
"parent_id": parent_id,
|
342
|
+
"action": action,
|
343
|
+
"hungry": needs["hungry"],
|
344
|
+
"tired": needs["tired"],
|
345
|
+
"safe": needs["safe"],
|
346
|
+
"social": needs["social"],
|
347
|
+
"created_at": _date_time,
|
348
|
+
}
|
349
|
+
_statuses_time_list.append((_status_dict, _date_time))
|
350
|
+
else:
|
351
|
+
for agent in self.agents:
|
352
|
+
_date_time = datetime.now(timezone.utc)
|
353
|
+
_status_dict = {
|
354
|
+
"id": agent._uuid,
|
355
|
+
"day": await self.simulator.get_simulator_day(),
|
356
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
357
|
+
"lng": -1,
|
358
|
+
"lat": -1,
|
359
|
+
"parent_id": -1,
|
360
|
+
"action": "",
|
361
|
+
"type": await agent.memory.get("type"),
|
362
|
+
"nominal_gdp": await agent.memory.get("nominal_gdp"),
|
363
|
+
"real_gdp": await agent.memory.get("real_gdp"),
|
364
|
+
"unemployment": await agent.memory.get("unemployment"),
|
365
|
+
"wages": await agent.memory.get("wages"),
|
366
|
+
"prices": await agent.memory.get("prices"),
|
367
|
+
"inventory": await agent.memory.get("inventory"),
|
368
|
+
"price": await agent.memory.get("price"),
|
369
|
+
"interest_rate": await agent.memory.get("interest_rate"),
|
370
|
+
"bracket_cutoffs": await agent.memory.get(
|
371
|
+
"bracket_cutoffs"
|
372
|
+
),
|
373
|
+
"bracket_rates": await agent.memory.get("bracket_rates"),
|
374
|
+
"employees": await agent.memory.get("employees"),
|
375
|
+
"customers": await agent.memory.get("customers"),
|
376
|
+
"created_at": _date_time,
|
377
|
+
}
|
378
|
+
_statuses_time_list.append((_status_dict, _date_time))
|
379
|
+
to_update_statues: list[tuple] = []
|
380
|
+
for _status_dict, _ in _statuses_time_list:
|
381
|
+
BASIC_KEYS = [
|
382
|
+
"id",
|
383
|
+
"day",
|
384
|
+
"t",
|
385
|
+
"lng",
|
386
|
+
"lat",
|
387
|
+
"parent_id",
|
388
|
+
"action",
|
389
|
+
"created_at",
|
390
|
+
]
|
391
|
+
_data = [_status_dict[k] for k in BASIC_KEYS if k != "created_at"]
|
392
|
+
_other_dict = json.dumps(
|
393
|
+
{k: v for k, v in _status_dict.items() if k not in BASIC_KEYS}
|
394
|
+
)
|
395
|
+
_data.append(_other_dict)
|
396
|
+
_data.append(_status_dict["created_at"])
|
397
|
+
to_update_statues.append(tuple(_data))
|
398
|
+
if self._last_asyncio_pg_task is not None:
|
399
|
+
await self._last_asyncio_pg_task
|
400
|
+
self._last_asyncio_pg_task = (
|
401
|
+
self._pgsql_writer.async_write_status.remote( # type:ignore
|
402
|
+
to_update_statues
|
403
|
+
)
|
404
|
+
)
|
279
405
|
|
280
406
|
async def step(self):
|
281
407
|
if not self.initialized:
|
@@ -23,6 +23,7 @@ from ..message.messager import Messager
|
|
23
23
|
from ..metrics import init_mlflow_connection
|
24
24
|
from ..survey import Survey
|
25
25
|
from .agentgroup import AgentGroup
|
26
|
+
from .storage.pg import PgWriter, create_pg_tables
|
26
27
|
|
27
28
|
logger = logging.getLogger("pycityagent")
|
28
29
|
|
@@ -63,6 +64,7 @@ class AgentSimulation:
|
|
63
64
|
self._user_survey_topics: dict[uuid.UUID, str] = {}
|
64
65
|
self._user_interview_topics: dict[uuid.UUID, str] = {}
|
65
66
|
self._loop = asyncio.get_event_loop()
|
67
|
+
# self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
66
68
|
|
67
69
|
self._messager = Messager(
|
68
70
|
hostname=config["simulator_request"]["mqtt"]["server"],
|
@@ -89,22 +91,13 @@ class AgentSimulation:
|
|
89
91
|
self._enable_pgsql = _pgsql_config.get("enabled", False)
|
90
92
|
if not self._enable_pgsql:
|
91
93
|
logger.warning("PostgreSQL is not enabled, NO POSTGRESQL DATABASE STORAGE")
|
92
|
-
self.
|
94
|
+
self._pgsql_dsn = ""
|
93
95
|
else:
|
94
|
-
self.
|
95
|
-
self._pgsql_port = _pgsql_config["port"]
|
96
|
-
self._pgsql_database = _pgsql_config["database"]
|
97
|
-
self._pgsql_user = _pgsql_config.get("user", None)
|
98
|
-
self._pgsql_password = _pgsql_config.get("password", None)
|
99
|
-
self._pgsql_args: tuple[str, str, str, str, str] = (
|
100
|
-
self._pgsql_host,
|
101
|
-
self._pgsql_port,
|
102
|
-
self._pgsql_database,
|
103
|
-
self._pgsql_user,
|
104
|
-
self._pgsql_password,
|
105
|
-
)
|
96
|
+
self._pgsql_dsn = _pgsql_config["data_source_name"]
|
106
97
|
|
107
98
|
# 添加实验信息相关的属性
|
99
|
+
self._exp_created_time = datetime.now(timezone.utc)
|
100
|
+
self._exp_updated_time = datetime.now(timezone.utc)
|
108
101
|
self._exp_info = {
|
109
102
|
"id": self.exp_id,
|
110
103
|
"name": exp_name,
|
@@ -114,7 +107,8 @@ class AgentSimulation:
|
|
114
107
|
"cur_t": 0.0,
|
115
108
|
"config": json.dumps(config),
|
116
109
|
"error": "",
|
117
|
-
"created_at":
|
110
|
+
"created_at": self._exp_created_time.isoformat(),
|
111
|
+
"updated_at": self._exp_updated_time.isoformat(),
|
118
112
|
}
|
119
113
|
|
120
114
|
# 创建异步任务保存实验信息
|
@@ -168,7 +162,7 @@ class AgentSimulation:
|
|
168
162
|
enable_avro: bool,
|
169
163
|
avro_path: Path,
|
170
164
|
enable_pgsql: bool,
|
171
|
-
|
165
|
+
pgsql_writer: ray.ObjectRef,
|
172
166
|
mlflow_run_id: str = None, # type: ignore
|
173
167
|
logging_level: int = logging.WARNING,
|
174
168
|
):
|
@@ -181,7 +175,7 @@ class AgentSimulation:
|
|
181
175
|
enable_avro,
|
182
176
|
avro_path,
|
183
177
|
enable_pgsql,
|
184
|
-
|
178
|
+
pgsql_writer,
|
185
179
|
mlflow_run_id,
|
186
180
|
logging_level,
|
187
181
|
)
|
@@ -191,6 +185,7 @@ class AgentSimulation:
|
|
191
185
|
self,
|
192
186
|
agent_count: Union[int, list[int]],
|
193
187
|
group_size: int = 1000,
|
188
|
+
pg_sql_writers: int = 32,
|
194
189
|
memory_config_func: Optional[Union[Callable, list[Callable]]] = None,
|
195
190
|
) -> None:
|
196
191
|
"""初始化智能体
|
@@ -251,8 +246,8 @@ class AgentSimulation:
|
|
251
246
|
memory=memory,
|
252
247
|
)
|
253
248
|
|
254
|
-
self._agents[agent._uuid] = agent
|
255
|
-
self._agent_uuids.append(agent._uuid)
|
249
|
+
self._agents[agent._uuid] = agent # type:ignore
|
250
|
+
self._agent_uuids.append(agent._uuid) # type:ignore
|
256
251
|
|
257
252
|
# 计算需要的组数,向上取整以处理不足一组的情况
|
258
253
|
num_group = (agent_count_i + group_size - 1) // group_size
|
@@ -282,9 +277,23 @@ class AgentSimulation:
|
|
282
277
|
)
|
283
278
|
else:
|
284
279
|
mlflow_run_id = None
|
280
|
+
# 建表
|
281
|
+
if self.enable_pgsql:
|
282
|
+
_num_workers = min(1, pg_sql_writers)
|
283
|
+
create_pg_tables(
|
284
|
+
exp_id=self.exp_id,
|
285
|
+
dsn=self._pgsql_dsn,
|
286
|
+
)
|
287
|
+
self._pgsql_writers = _workers = [
|
288
|
+
PgWriter.remote(self.exp_id, self._pgsql_dsn)
|
289
|
+
for _ in range(_num_workers)
|
290
|
+
]
|
291
|
+
else:
|
292
|
+
_num_workers = 1
|
293
|
+
self._pgsql_writers = _workers = [None for _ in range(_num_workers)]
|
285
294
|
# 收集所有创建组的参数
|
286
295
|
creation_tasks = []
|
287
|
-
for group_name, agents in group_creation_params:
|
296
|
+
for i, (group_name, agents) in enumerate(group_creation_params):
|
288
297
|
# 直接创建异步任务
|
289
298
|
group = AgentGroup.remote(
|
290
299
|
agents,
|
@@ -294,10 +303,8 @@ class AgentSimulation:
|
|
294
303
|
self.enable_avro,
|
295
304
|
self.avro_path,
|
296
305
|
self.enable_pgsql,
|
297
|
-
#
|
298
|
-
|
299
|
-
None,
|
300
|
-
mlflow_run_id,
|
306
|
+
_workers[i % _num_workers], # type:ignore
|
307
|
+
mlflow_run_id, # type:ignore
|
301
308
|
self.logging_level,
|
302
309
|
)
|
303
310
|
creation_tasks.append((group_name, group, agents))
|
@@ -469,11 +476,13 @@ class AgentSimulation:
|
|
469
476
|
survey_dict = survey.to_dict()
|
470
477
|
if agent_uuids is None:
|
471
478
|
agent_uuids = self._agent_uuids
|
479
|
+
_date_time = datetime.now(timezone.utc)
|
472
480
|
payload = {
|
473
481
|
"from": "none",
|
474
482
|
"survey_id": survey_dict["id"],
|
475
|
-
"timestamp": int(
|
483
|
+
"timestamp": int(_date_time.timestamp() * 1000),
|
476
484
|
"data": survey_dict,
|
485
|
+
"_date_time": _date_time,
|
477
486
|
}
|
478
487
|
for uuid in agent_uuids:
|
479
488
|
topic = self._user_survey_topics[uuid]
|
@@ -483,10 +492,12 @@ class AgentSimulation:
|
|
483
492
|
self, content: str, agent_uuids: Union[uuid.UUID, list[uuid.UUID]]
|
484
493
|
):
|
485
494
|
"""发送面试消息"""
|
495
|
+
_date_time = datetime.now(timezone.utc)
|
486
496
|
payload = {
|
487
497
|
"from": "none",
|
488
498
|
"content": content,
|
489
|
-
"timestamp": int(
|
499
|
+
"timestamp": int(_date_time.timestamp() * 1000),
|
500
|
+
"_date_time": _date_time,
|
490
501
|
}
|
491
502
|
if not isinstance(agent_uuids, Sequence):
|
492
503
|
agent_uuids = [agent_uuids]
|
@@ -515,15 +526,29 @@ class AgentSimulation:
|
|
515
526
|
logger.error(f"Avro保存实验信息失败: {str(e)}")
|
516
527
|
try:
|
517
528
|
if self.enable_pgsql:
|
518
|
-
#
|
519
|
-
|
529
|
+
worker: ray.ObjectRef = self._pgsql_writers[0] # type:ignore
|
530
|
+
# if self._last_asyncio_pg_task is not None:
|
531
|
+
# await self._last_asyncio_pg_task
|
532
|
+
# self._last_asyncio_pg_task = (
|
533
|
+
# worker.async_update_exp_info.remote( # type:ignore
|
534
|
+
# pg_exp_info
|
535
|
+
# )
|
536
|
+
# )
|
537
|
+
pg_exp_info = {k: v for k, v in self._exp_info.items()}
|
538
|
+
pg_exp_info["created_at"] = self._exp_created_time
|
539
|
+
pg_exp_info["updated_at"] = self._exp_updated_time
|
540
|
+
await worker.async_update_exp_info.remote( # type:ignore
|
541
|
+
pg_exp_info
|
542
|
+
)
|
520
543
|
except Exception as e:
|
521
544
|
logger.error(f"PostgreSQL保存实验信息失败: {str(e)}")
|
522
545
|
|
523
546
|
async def _update_exp_status(self, status: int, error: str = "") -> None:
|
547
|
+
self._exp_updated_time = datetime.now(timezone.utc)
|
524
548
|
"""更新实验状态并保存"""
|
525
549
|
self._exp_info["status"] = status
|
526
550
|
self._exp_info["error"] = error
|
551
|
+
self._exp_info["updated_at"] = self._exp_updated_time.isoformat()
|
527
552
|
await self._save_exp_info()
|
528
553
|
|
529
554
|
async def _monitor_exp_status(self, stop_event: asyncio.Event):
|