pycityagent 2.0.0a20__tar.gz → 2.0.0a21__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/PKG-INFO +1 -1
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/agent.py +165 -62
- pycityagent-2.0.0a21/pycityagent/simulation/__init__.py +8 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/simulation/agentgroup.py +129 -3
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/simulation/simulation.py +52 -27
- pycityagent-2.0.0a21/pycityagent/simulation/storage/pg.py +139 -0
- pycityagent-2.0.0a21/pycityagent/utils/pg_query.py +80 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/workflow/tool.py +32 -24
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pyproject.toml +1 -1
- pycityagent-2.0.0a20/pycityagent/simulation/__init__.py +0 -7
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/README.md +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/economy/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/economy/econ_client.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/interact/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/interact/interact.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/message/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sence/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sence/static.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sidecar/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sidecar/sidecarv2.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/aoi_service.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/client.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/clock_service.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/economy_services.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/lane_service.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/light_service.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/person_service.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/road_service.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/sim_env.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/social_service.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/simulator.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/base64.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/const.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/geojson.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/grpc.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/map_utils.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/port.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/protobuf.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/llm/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/llm/embedding.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/llm/llm.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/llm/llmconfig.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/llm/utils.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/const.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/memory.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/memory_base.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/profile.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/self_define.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/state.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/utils.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/message/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/message/messager.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/metrics/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/metrics/mlflow_client.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/metrics/utils/const.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/survey/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/survey/manager.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/survey/models.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/avro_schema.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/decorators.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/parsers/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/parsers/code_block_parser.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/parsers/json_parser.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/parsers/parser_base.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/survey_util.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/workflow/__init__.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/workflow/block.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/workflow/prompt.py +0 -0
- {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/workflow/trigger.py +0 -0
@@ -1,17 +1,19 @@
|
|
1
1
|
"""智能体模板类及其定义"""
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import json
|
4
5
|
import logging
|
5
6
|
import random
|
6
7
|
import uuid
|
7
8
|
from abc import ABC, abstractmethod
|
8
9
|
from copy import deepcopy
|
9
|
-
from datetime import datetime
|
10
|
+
from datetime import datetime, timezone
|
10
11
|
from enum import Enum
|
11
12
|
from typing import Any, Optional
|
12
13
|
from uuid import UUID
|
13
14
|
|
14
15
|
import fastavro
|
16
|
+
import ray
|
15
17
|
from mosstool.util.format_converter import dict2pb
|
16
18
|
from pycityproto.city.person.v2 import person_pb2 as person_pb2
|
17
19
|
|
@@ -56,6 +58,7 @@ class Agent(ABC):
|
|
56
58
|
mlflow_client: Optional[MlflowClient] = None,
|
57
59
|
memory: Optional[Memory] = None,
|
58
60
|
avro_file: Optional[dict[str, str]] = None,
|
61
|
+
copy_writer: Optional[ray.ObjectRef] = None,
|
59
62
|
) -> None:
|
60
63
|
"""
|
61
64
|
Initialize the Agent.
|
@@ -70,6 +73,7 @@ class Agent(ABC):
|
|
70
73
|
mlflow_client (MlflowClient, optional): The Mlflow object. Defaults to None.
|
71
74
|
memory (Memory, optional): The memory of the agent. Defaults to None.
|
72
75
|
avro_file (dict[str, str], optional): The avro file of the agent. Defaults to None.
|
76
|
+
copy_writer (ray.ObjectRef): The copy_writer of the agent. Defaults to None.
|
73
77
|
"""
|
74
78
|
self._name = name
|
75
79
|
self._type = type
|
@@ -88,6 +92,8 @@ class Agent(ABC):
|
|
88
92
|
self._interview_history: list[dict] = [] # 存储采访历史
|
89
93
|
self._person_template = PersonService.default_dict_person()
|
90
94
|
self._avro_file = avro_file
|
95
|
+
self._pgsql_writer = copy_writer
|
96
|
+
self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
91
97
|
|
92
98
|
def __getstate__(self):
|
93
99
|
state = self.__dict__.copy()
|
@@ -143,6 +149,12 @@ class Agent(ABC):
|
|
143
149
|
"""
|
144
150
|
self._avro_file = avro_file
|
145
151
|
|
152
|
+
def set_pgsql_writer(self, pgsql_writer: ray.ObjectRef):
|
153
|
+
"""
|
154
|
+
Set the PostgreSQL copy writer of the agent.
|
155
|
+
"""
|
156
|
+
self._pgsql_writer = pgsql_writer
|
157
|
+
|
146
158
|
@property
|
147
159
|
def uuid(self):
|
148
160
|
"""The Agent's UUID"""
|
@@ -198,6 +210,15 @@ class Agent(ABC):
|
|
198
210
|
)
|
199
211
|
return self._simulator
|
200
212
|
|
213
|
+
@property
|
214
|
+
def copy_writer(self):
|
215
|
+
"""Pg Copy Writer"""
|
216
|
+
if self._pgsql_writer is None:
|
217
|
+
raise RuntimeError(
|
218
|
+
f"Copy Writer access before assignment, please `set_pgsql_writer` first!"
|
219
|
+
)
|
220
|
+
return self._pgsql_writer
|
221
|
+
|
201
222
|
async def generate_user_survey_response(self, survey: dict) -> str:
|
202
223
|
"""生成回答 —— 可重写
|
203
224
|
基于智能体的记忆和当前状态,生成对问卷调查的回答。
|
@@ -237,8 +258,8 @@ class Agent(ABC):
|
|
237
258
|
|
238
259
|
async def _process_survey(self, survey: dict):
|
239
260
|
survey_response = await self.generate_user_survey_response(survey)
|
240
|
-
|
241
|
-
|
261
|
+
_date_time = datetime.now(timezone.utc)
|
262
|
+
# Avro
|
242
263
|
response_to_avro = [
|
243
264
|
{
|
244
265
|
"id": self._uuid,
|
@@ -246,11 +267,41 @@ class Agent(ABC):
|
|
246
267
|
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
247
268
|
"survey_id": survey["id"],
|
248
269
|
"result": survey_response,
|
249
|
-
"created_at": int(
|
270
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
250
271
|
}
|
251
272
|
]
|
252
|
-
|
253
|
-
|
273
|
+
if self._avro_file is not None:
|
274
|
+
with open(self._avro_file["survey"], "a+b") as f:
|
275
|
+
fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
|
276
|
+
# Pg
|
277
|
+
if self._pgsql_writer is not None:
|
278
|
+
if self._last_asyncio_pg_task is not None:
|
279
|
+
await self._last_asyncio_pg_task
|
280
|
+
_keys = [
|
281
|
+
"id",
|
282
|
+
"day",
|
283
|
+
"t",
|
284
|
+
"survey_id",
|
285
|
+
"result",
|
286
|
+
]
|
287
|
+
_data_tuples: list[tuple] = []
|
288
|
+
# str to json
|
289
|
+
for _dict in response_to_avro:
|
290
|
+
res = _dict["result"]
|
291
|
+
_dict["result"] = json.dumps(
|
292
|
+
{
|
293
|
+
"result": res,
|
294
|
+
}
|
295
|
+
)
|
296
|
+
_data_list = [_dict[k] for k in _keys]
|
297
|
+
# created_at
|
298
|
+
_data_list.append(_date_time)
|
299
|
+
_data_tuples.append(tuple(_data_list))
|
300
|
+
self._last_asyncio_pg_task = (
|
301
|
+
self._pgsql_writer.async_write_survey.remote( # type:ignore
|
302
|
+
_data_tuples
|
303
|
+
)
|
304
|
+
)
|
254
305
|
|
255
306
|
async def generate_user_chat_response(self, question: str) -> str:
|
256
307
|
"""生成回答 —— 可重写
|
@@ -290,34 +341,52 @@ class Agent(ABC):
|
|
290
341
|
return response # type:ignore
|
291
342
|
|
292
343
|
async def _process_interview(self, payload: dict):
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
344
|
+
pg_list: list[tuple[dict, datetime]] = []
|
345
|
+
auros: list[dict] = []
|
346
|
+
_date_time = datetime.now(timezone.utc)
|
347
|
+
_interview_dict = {
|
348
|
+
"id": self._uuid,
|
349
|
+
"day": await self.simulator.get_simulator_day(),
|
350
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
351
|
+
"type": 2,
|
352
|
+
"speaker": "user",
|
353
|
+
"content": payload["content"],
|
354
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
355
|
+
}
|
356
|
+
auros.append(_interview_dict)
|
357
|
+
pg_list.append((_interview_dict, _date_time))
|
304
358
|
question = payload["content"]
|
305
359
|
response = await self.generate_user_chat_response(question)
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
)
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
360
|
+
_date_time = datetime.now(timezone.utc)
|
361
|
+
_interview_dict = {
|
362
|
+
"id": self._uuid,
|
363
|
+
"day": await self.simulator.get_simulator_day(),
|
364
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
365
|
+
"type": 2,
|
366
|
+
"speaker": "",
|
367
|
+
"content": response,
|
368
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
369
|
+
}
|
370
|
+
auros.append(_interview_dict)
|
371
|
+
pg_list.append((_interview_dict, _date_time))
|
372
|
+
# Avro
|
373
|
+
if self._avro_file is not None:
|
374
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
375
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
376
|
+
# Pg
|
377
|
+
if self._pgsql_writer is not None:
|
378
|
+
if self._last_asyncio_pg_task is not None:
|
379
|
+
await self._last_asyncio_pg_task
|
380
|
+
_keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
|
381
|
+
_data = [
|
382
|
+
tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
|
383
|
+
for _dict, _date_time in pg_list
|
384
|
+
]
|
385
|
+
self._last_asyncio_pg_task = (
|
386
|
+
self._pgsql_writer.async_write_dialog.remote( # type:ignore
|
387
|
+
_data
|
388
|
+
)
|
389
|
+
)
|
321
390
|
|
322
391
|
async def process_agent_chat_response(self, payload: dict) -> str:
|
323
392
|
resp = f"Agent {self._uuid} received agent chat response: {payload}"
|
@@ -325,22 +394,39 @@ class Agent(ABC):
|
|
325
394
|
return resp
|
326
395
|
|
327
396
|
async def _process_agent_chat(self, payload: dict):
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
397
|
+
pg_list: list[tuple[dict, datetime]] = []
|
398
|
+
auros: list[dict] = []
|
399
|
+
_date_time = datetime.now(timezone.utc)
|
400
|
+
_chat_dict = {
|
401
|
+
"id": self._uuid,
|
402
|
+
"day": payload["day"],
|
403
|
+
"t": payload["t"],
|
404
|
+
"type": 1,
|
405
|
+
"speaker": payload["from"],
|
406
|
+
"content": payload["content"],
|
407
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
408
|
+
}
|
409
|
+
auros.append(_chat_dict)
|
410
|
+
pg_list.append((_chat_dict, _date_time))
|
339
411
|
asyncio.create_task(self.process_agent_chat_response(payload))
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
412
|
+
# Avro
|
413
|
+
if self._avro_file is not None:
|
414
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
415
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
416
|
+
# Pg
|
417
|
+
if self._pgsql_writer is not None:
|
418
|
+
if self._last_asyncio_pg_task is not None:
|
419
|
+
await self._last_asyncio_pg_task
|
420
|
+
_keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
|
421
|
+
_data = [
|
422
|
+
tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
|
423
|
+
for _dict, _date_time in pg_list
|
424
|
+
]
|
425
|
+
self._last_asyncio_pg_task = (
|
426
|
+
self._pgsql_writer.async_write_dialog.remote( # type:ignore
|
427
|
+
_data
|
428
|
+
)
|
429
|
+
)
|
344
430
|
|
345
431
|
# Callback functions for MQTT message
|
346
432
|
async def handle_agent_chat_message(self, payload: dict):
|
@@ -384,21 +470,38 @@ class Agent(ABC):
|
|
384
470
|
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
385
471
|
}
|
386
472
|
await self._send_message(to_agent_uuid, payload, "agent-chat")
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
473
|
+
pg_list: list[tuple[dict, datetime]] = []
|
474
|
+
auros: list[dict] = []
|
475
|
+
_date_time = datetime.now(timezone.utc)
|
476
|
+
_message_dict = {
|
477
|
+
"id": self._uuid,
|
478
|
+
"day": await self.simulator.get_simulator_day(),
|
479
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
480
|
+
"type": 1,
|
481
|
+
"speaker": self._uuid,
|
482
|
+
"content": content,
|
483
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
484
|
+
}
|
485
|
+
auros.append(_message_dict)
|
486
|
+
pg_list.append((_message_dict, _date_time))
|
487
|
+
# Avro
|
488
|
+
if self._avro_file is not None:
|
489
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
490
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
491
|
+
# Pg
|
492
|
+
if self._pgsql_writer is not None:
|
493
|
+
if self._last_asyncio_pg_task is not None:
|
494
|
+
await self._last_asyncio_pg_task
|
495
|
+
_keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
|
496
|
+
_data = [
|
497
|
+
tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
|
498
|
+
for _dict, _date_time in pg_list
|
499
|
+
]
|
500
|
+
self._last_asyncio_pg_task = (
|
501
|
+
self._pgsql_writer.async_write_dialog.remote( # type:ignore
|
502
|
+
_data
|
503
|
+
)
|
504
|
+
)
|
402
505
|
|
403
506
|
# Agent logic
|
404
507
|
@abstractmethod
|
@@ -3,7 +3,7 @@ import json
|
|
3
3
|
import logging
|
4
4
|
import time
|
5
5
|
import uuid
|
6
|
-
from datetime import datetime
|
6
|
+
from datetime import datetime, timezone
|
7
7
|
from pathlib import Path
|
8
8
|
from typing import Any
|
9
9
|
from uuid import UUID
|
@@ -35,7 +35,7 @@ class AgentGroup:
|
|
35
35
|
enable_avro: bool,
|
36
36
|
avro_path: Path,
|
37
37
|
enable_pgsql: bool,
|
38
|
-
|
38
|
+
pgsql_writer: ray.ObjectRef,
|
39
39
|
mlflow_run_id: str,
|
40
40
|
logging_level: int,
|
41
41
|
):
|
@@ -45,6 +45,7 @@ class AgentGroup:
|
|
45
45
|
self.config = config
|
46
46
|
self.exp_id = exp_id
|
47
47
|
self.enable_avro = enable_avro
|
48
|
+
self.enable_pgsql = enable_pgsql
|
48
49
|
if enable_avro:
|
49
50
|
self.avro_path = avro_path / f"{self._uuid}"
|
50
51
|
self.avro_path.mkdir(parents=True, exist_ok=True)
|
@@ -54,6 +55,8 @@ class AgentGroup:
|
|
54
55
|
"status": self.avro_path / f"status.avro",
|
55
56
|
"survey": self.avro_path / f"survey.avro",
|
56
57
|
}
|
58
|
+
if self.enable_pgsql:
|
59
|
+
pass
|
57
60
|
|
58
61
|
self.messager = Messager(
|
59
62
|
hostname=config["simulator_request"]["mqtt"]["server"],
|
@@ -61,6 +64,8 @@ class AgentGroup:
|
|
61
64
|
username=config["simulator_request"]["mqtt"].get("username", None),
|
62
65
|
password=config["simulator_request"]["mqtt"].get("password", None),
|
63
66
|
)
|
67
|
+
self._pgsql_writer = pgsql_writer
|
68
|
+
self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
64
69
|
self.initialized = False
|
65
70
|
self.id2agent = {}
|
66
71
|
# Step:1 prepare LLM client
|
@@ -105,6 +110,8 @@ class AgentGroup:
|
|
105
110
|
agent.set_messager(self.messager)
|
106
111
|
if self.enable_avro:
|
107
112
|
agent.set_avro_file(self.avro_file) # type: ignore
|
113
|
+
if self.enable_pgsql:
|
114
|
+
agent.set_pgsql_writer(self._pgsql_writer)
|
108
115
|
|
109
116
|
async def init_agents(self):
|
110
117
|
logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
|
@@ -161,6 +168,20 @@ class AgentGroup:
|
|
161
168
|
with open(filename, "wb") as f:
|
162
169
|
surveys = []
|
163
170
|
fastavro.writer(f, SURVEY_SCHEMA, surveys)
|
171
|
+
|
172
|
+
if self.enable_pgsql:
|
173
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
174
|
+
profiles: list[Any] = []
|
175
|
+
for agent in self.agents:
|
176
|
+
profile = await agent.memory._profile.export()
|
177
|
+
profile = profile[0]
|
178
|
+
profile["id"] = agent._uuid
|
179
|
+
profiles.append(
|
180
|
+
(agent._uuid, profile.get("name", ""), json.dumps(profile))
|
181
|
+
)
|
182
|
+
await self._pgsql_writer.async_write_profile.remote( # type:ignore
|
183
|
+
profiles
|
184
|
+
)
|
164
185
|
self.initialized = True
|
165
186
|
logger.debug(f"-----AgentGroup {self._uuid} initialized")
|
166
187
|
|
@@ -218,11 +239,13 @@ class AgentGroup:
|
|
218
239
|
await asyncio.sleep(0.5)
|
219
240
|
|
220
241
|
async def save_status(self):
|
242
|
+
_statuses_time_list: list[tuple[dict, datetime]] = []
|
221
243
|
if self.enable_avro:
|
222
244
|
logger.debug(f"-----Saving status for group {self._uuid}")
|
223
245
|
avros = []
|
224
246
|
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
225
247
|
for agent in self.agents:
|
248
|
+
_date_time = datetime.now(timezone.utc)
|
226
249
|
position = await agent.memory.get("position")
|
227
250
|
lng = position["longlat_position"]["longitude"]
|
228
251
|
lat = position["longlat_position"]["latitude"]
|
@@ -248,13 +271,15 @@ class AgentGroup:
|
|
248
271
|
"tired": needs["tired"],
|
249
272
|
"safe": needs["safe"],
|
250
273
|
"social": needs["social"],
|
251
|
-
"created_at": int(
|
274
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
252
275
|
}
|
253
276
|
avros.append(avro)
|
277
|
+
_statuses_time_list.append((avro, _date_time))
|
254
278
|
with open(self.avro_file["status"], "a+b") as f:
|
255
279
|
fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
|
256
280
|
else:
|
257
281
|
for agent in self.agents:
|
282
|
+
_date_time = datetime.now(timezone.utc)
|
258
283
|
avro = {
|
259
284
|
"id": agent._uuid,
|
260
285
|
"day": await self.simulator.get_simulator_day(),
|
@@ -274,8 +299,109 @@ class AgentGroup:
|
|
274
299
|
"customers": await agent.memory.get("customers"),
|
275
300
|
}
|
276
301
|
avros.append(avro)
|
302
|
+
_statuses_time_list.append((avro, _date_time))
|
277
303
|
with open(self.avro_file["status"], "a+b") as f:
|
278
304
|
fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, avros, codec="snappy")
|
305
|
+
if self.enable_pgsql:
|
306
|
+
# data already acquired from Avro part
|
307
|
+
if len(_statuses_time_list) > 0:
|
308
|
+
for _status_dict, _date_time in _statuses_time_list:
|
309
|
+
for key in ["lng", "lat", "parent_id"]:
|
310
|
+
if key not in _status_dict:
|
311
|
+
_status_dict[key] = -1
|
312
|
+
for key in [
|
313
|
+
"action",
|
314
|
+
]:
|
315
|
+
if key not in _status_dict:
|
316
|
+
_status_dict[key] = ""
|
317
|
+
_status_dict["created_at"] = _date_time
|
318
|
+
else:
|
319
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
320
|
+
for agent in self.agents:
|
321
|
+
_date_time = datetime.now(timezone.utc)
|
322
|
+
position = await agent.memory.get("position")
|
323
|
+
lng = position["longlat_position"]["longitude"]
|
324
|
+
lat = position["longlat_position"]["latitude"]
|
325
|
+
if "aoi_position" in position:
|
326
|
+
parent_id = position["aoi_position"]["aoi_id"]
|
327
|
+
elif "lane_position" in position:
|
328
|
+
parent_id = position["lane_position"]["lane_id"]
|
329
|
+
else:
|
330
|
+
# BUG: 需要处理
|
331
|
+
parent_id = -1
|
332
|
+
needs = await agent.memory.get("needs")
|
333
|
+
action = await agent.memory.get("current_step")
|
334
|
+
action = action["intention"]
|
335
|
+
_status_dict = {
|
336
|
+
"id": agent._uuid,
|
337
|
+
"day": await self.simulator.get_simulator_day(),
|
338
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
339
|
+
"lng": lng,
|
340
|
+
"lat": lat,
|
341
|
+
"parent_id": parent_id,
|
342
|
+
"action": action,
|
343
|
+
"hungry": needs["hungry"],
|
344
|
+
"tired": needs["tired"],
|
345
|
+
"safe": needs["safe"],
|
346
|
+
"social": needs["social"],
|
347
|
+
"created_at": _date_time,
|
348
|
+
}
|
349
|
+
_statuses_time_list.append((_status_dict, _date_time))
|
350
|
+
else:
|
351
|
+
for agent in self.agents:
|
352
|
+
_date_time = datetime.now(timezone.utc)
|
353
|
+
_status_dict = {
|
354
|
+
"id": agent._uuid,
|
355
|
+
"day": await self.simulator.get_simulator_day(),
|
356
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
357
|
+
"lng": -1,
|
358
|
+
"lat": -1,
|
359
|
+
"parent_id": -1,
|
360
|
+
"action": "",
|
361
|
+
"type": await agent.memory.get("type"),
|
362
|
+
"nominal_gdp": await agent.memory.get("nominal_gdp"),
|
363
|
+
"real_gdp": await agent.memory.get("real_gdp"),
|
364
|
+
"unemployment": await agent.memory.get("unemployment"),
|
365
|
+
"wages": await agent.memory.get("wages"),
|
366
|
+
"prices": await agent.memory.get("prices"),
|
367
|
+
"inventory": await agent.memory.get("inventory"),
|
368
|
+
"price": await agent.memory.get("price"),
|
369
|
+
"interest_rate": await agent.memory.get("interest_rate"),
|
370
|
+
"bracket_cutoffs": await agent.memory.get(
|
371
|
+
"bracket_cutoffs"
|
372
|
+
),
|
373
|
+
"bracket_rates": await agent.memory.get("bracket_rates"),
|
374
|
+
"employees": await agent.memory.get("employees"),
|
375
|
+
"customers": await agent.memory.get("customers"),
|
376
|
+
"created_at": _date_time,
|
377
|
+
}
|
378
|
+
_statuses_time_list.append((_status_dict, _date_time))
|
379
|
+
to_update_statues: list[tuple] = []
|
380
|
+
for _status_dict, _ in _statuses_time_list:
|
381
|
+
BASIC_KEYS = [
|
382
|
+
"id",
|
383
|
+
"day",
|
384
|
+
"t",
|
385
|
+
"lng",
|
386
|
+
"lat",
|
387
|
+
"parent_id",
|
388
|
+
"action",
|
389
|
+
"created_at",
|
390
|
+
]
|
391
|
+
_data = [_status_dict[k] for k in BASIC_KEYS if k != "created_at"]
|
392
|
+
_other_dict = json.dumps(
|
393
|
+
{k: v for k, v in _status_dict.items() if k not in BASIC_KEYS}
|
394
|
+
)
|
395
|
+
_data.append(_other_dict)
|
396
|
+
_data.append(_status_dict["created_at"])
|
397
|
+
to_update_statues.append(tuple(_data))
|
398
|
+
if self._last_asyncio_pg_task is not None:
|
399
|
+
await self._last_asyncio_pg_task
|
400
|
+
self._last_asyncio_pg_task = (
|
401
|
+
self._pgsql_writer.async_write_status.remote( # type:ignore
|
402
|
+
to_update_statues
|
403
|
+
)
|
404
|
+
)
|
279
405
|
|
280
406
|
async def step(self):
|
281
407
|
if not self.initialized:
|