pycityagent 2.0.0a20__tar.gz → 2.0.0a21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/PKG-INFO +1 -1
  2. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/agent.py +165 -62
  3. pycityagent-2.0.0a21/pycityagent/simulation/__init__.py +8 -0
  4. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/simulation/agentgroup.py +129 -3
  5. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/simulation/simulation.py +52 -27
  6. pycityagent-2.0.0a21/pycityagent/simulation/storage/pg.py +139 -0
  7. pycityagent-2.0.0a21/pycityagent/utils/pg_query.py +80 -0
  8. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/workflow/tool.py +32 -24
  9. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pyproject.toml +1 -1
  10. pycityagent-2.0.0a20/pycityagent/simulation/__init__.py +0 -7
  11. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/README.md +0 -0
  12. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/__init__.py +0 -0
  13. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/economy/__init__.py +0 -0
  14. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/economy/econ_client.py +0 -0
  15. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/__init__.py +0 -0
  16. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/interact/__init__.py +0 -0
  17. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/interact/interact.py +0 -0
  18. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/message/__init__.py +0 -0
  19. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sence/__init__.py +0 -0
  20. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sence/static.py +0 -0
  21. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sidecar/__init__.py +0 -0
  22. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sidecar/sidecarv2.py +0 -0
  23. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/__init__.py +0 -0
  24. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/aoi_service.py +0 -0
  25. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/client.py +0 -0
  26. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/clock_service.py +0 -0
  27. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/economy_services.py +0 -0
  28. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/lane_service.py +0 -0
  29. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/light_service.py +0 -0
  30. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/person_service.py +0 -0
  31. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/road_service.py +0 -0
  32. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/sim_env.py +0 -0
  33. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/sim/social_service.py +0 -0
  34. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/simulator.py +0 -0
  35. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/__init__.py +0 -0
  36. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/base64.py +0 -0
  37. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/const.py +0 -0
  38. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/geojson.py +0 -0
  39. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/grpc.py +0 -0
  40. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/map_utils.py +0 -0
  41. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/port.py +0 -0
  42. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/environment/utils/protobuf.py +0 -0
  43. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/llm/__init__.py +0 -0
  44. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/llm/embedding.py +0 -0
  45. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/llm/llm.py +0 -0
  46. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/llm/llmconfig.py +0 -0
  47. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/llm/utils.py +0 -0
  48. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/__init__.py +0 -0
  49. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/const.py +0 -0
  50. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/memory.py +0 -0
  51. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/memory_base.py +0 -0
  52. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/profile.py +0 -0
  53. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/self_define.py +0 -0
  54. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/state.py +0 -0
  55. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/memory/utils.py +0 -0
  56. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/message/__init__.py +0 -0
  57. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/message/messager.py +0 -0
  58. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/metrics/__init__.py +0 -0
  59. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/metrics/mlflow_client.py +0 -0
  60. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/metrics/utils/const.py +0 -0
  61. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/survey/__init__.py +0 -0
  62. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/survey/manager.py +0 -0
  63. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/survey/models.py +0 -0
  64. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/__init__.py +0 -0
  65. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/avro_schema.py +0 -0
  66. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/decorators.py +0 -0
  67. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/parsers/__init__.py +0 -0
  68. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/parsers/code_block_parser.py +0 -0
  69. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/parsers/json_parser.py +0 -0
  70. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/parsers/parser_base.py +0 -0
  71. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/utils/survey_util.py +0 -0
  72. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/workflow/__init__.py +0 -0
  73. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/workflow/block.py +0 -0
  74. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/workflow/prompt.py +0 -0
  75. {pycityagent-2.0.0a20 → pycityagent-2.0.0a21}/pycityagent/workflow/trigger.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pycityagent
3
- Version: 2.0.0a20
3
+ Version: 2.0.0a21
4
4
  Summary: LLM-based城市环境agent构建库
5
5
  License: MIT
6
6
  Author: Yuwei Yan
@@ -1,17 +1,19 @@
1
1
  """智能体模板类及其定义"""
2
2
 
3
3
  import asyncio
4
+ import json
4
5
  import logging
5
6
  import random
6
7
  import uuid
7
8
  from abc import ABC, abstractmethod
8
9
  from copy import deepcopy
9
- from datetime import datetime
10
+ from datetime import datetime, timezone
10
11
  from enum import Enum
11
12
  from typing import Any, Optional
12
13
  from uuid import UUID
13
14
 
14
15
  import fastavro
16
+ import ray
15
17
  from mosstool.util.format_converter import dict2pb
16
18
  from pycityproto.city.person.v2 import person_pb2 as person_pb2
17
19
 
@@ -56,6 +58,7 @@ class Agent(ABC):
56
58
  mlflow_client: Optional[MlflowClient] = None,
57
59
  memory: Optional[Memory] = None,
58
60
  avro_file: Optional[dict[str, str]] = None,
61
+ copy_writer: Optional[ray.ObjectRef] = None,
59
62
  ) -> None:
60
63
  """
61
64
  Initialize the Agent.
@@ -70,6 +73,7 @@ class Agent(ABC):
70
73
  mlflow_client (MlflowClient, optional): The Mlflow object. Defaults to None.
71
74
  memory (Memory, optional): The memory of the agent. Defaults to None.
72
75
  avro_file (dict[str, str], optional): The avro file of the agent. Defaults to None.
76
+ copy_writer (ray.ObjectRef): The copy_writer of the agent. Defaults to None.
73
77
  """
74
78
  self._name = name
75
79
  self._type = type
@@ -88,6 +92,8 @@ class Agent(ABC):
88
92
  self._interview_history: list[dict] = [] # 存储采访历史
89
93
  self._person_template = PersonService.default_dict_person()
90
94
  self._avro_file = avro_file
95
+ self._pgsql_writer = copy_writer
96
+ self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
91
97
 
92
98
  def __getstate__(self):
93
99
  state = self.__dict__.copy()
@@ -143,6 +149,12 @@ class Agent(ABC):
143
149
  """
144
150
  self._avro_file = avro_file
145
151
 
152
+ def set_pgsql_writer(self, pgsql_writer: ray.ObjectRef):
153
+ """
154
+ Set the PostgreSQL copy writer of the agent.
155
+ """
156
+ self._pgsql_writer = pgsql_writer
157
+
146
158
  @property
147
159
  def uuid(self):
148
160
  """The Agent's UUID"""
@@ -198,6 +210,15 @@ class Agent(ABC):
198
210
  )
199
211
  return self._simulator
200
212
 
213
+ @property
214
+ def copy_writer(self):
215
+ """Pg Copy Writer"""
216
+ if self._pgsql_writer is None:
217
+ raise RuntimeError(
218
+ f"Copy Writer access before assignment, please `set_pgsql_writer` first!"
219
+ )
220
+ return self._pgsql_writer
221
+
201
222
  async def generate_user_survey_response(self, survey: dict) -> str:
202
223
  """生成回答 —— 可重写
203
224
  基于智能体的记忆和当前状态,生成对问卷调查的回答。
@@ -237,8 +258,8 @@ class Agent(ABC):
237
258
 
238
259
  async def _process_survey(self, survey: dict):
239
260
  survey_response = await self.generate_user_survey_response(survey)
240
- if self._avro_file is None:
241
- return
261
+ _date_time = datetime.now(timezone.utc)
262
+ # Avro
242
263
  response_to_avro = [
243
264
  {
244
265
  "id": self._uuid,
@@ -246,11 +267,41 @@ class Agent(ABC):
246
267
  "t": await self.simulator.get_simulator_second_from_start_of_day(),
247
268
  "survey_id": survey["id"],
248
269
  "result": survey_response,
249
- "created_at": int(datetime.now().timestamp() * 1000),
270
+ "created_at": int(_date_time.timestamp() * 1000),
250
271
  }
251
272
  ]
252
- with open(self._avro_file["survey"], "a+b") as f:
253
- fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
273
+ if self._avro_file is not None:
274
+ with open(self._avro_file["survey"], "a+b") as f:
275
+ fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
276
+ # Pg
277
+ if self._pgsql_writer is not None:
278
+ if self._last_asyncio_pg_task is not None:
279
+ await self._last_asyncio_pg_task
280
+ _keys = [
281
+ "id",
282
+ "day",
283
+ "t",
284
+ "survey_id",
285
+ "result",
286
+ ]
287
+ _data_tuples: list[tuple] = []
288
+ # str to json
289
+ for _dict in response_to_avro:
290
+ res = _dict["result"]
291
+ _dict["result"] = json.dumps(
292
+ {
293
+ "result": res,
294
+ }
295
+ )
296
+ _data_list = [_dict[k] for k in _keys]
297
+ # created_at
298
+ _data_list.append(_date_time)
299
+ _data_tuples.append(tuple(_data_list))
300
+ self._last_asyncio_pg_task = (
301
+ self._pgsql_writer.async_write_survey.remote( # type:ignore
302
+ _data_tuples
303
+ )
304
+ )
254
305
 
255
306
  async def generate_user_chat_response(self, question: str) -> str:
256
307
  """生成回答 —— 可重写
@@ -290,34 +341,52 @@ class Agent(ABC):
290
341
  return response # type:ignore
291
342
 
292
343
  async def _process_interview(self, payload: dict):
293
- auros = [
294
- {
295
- "id": self._uuid,
296
- "day": await self.simulator.get_simulator_day(),
297
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
298
- "type": 2,
299
- "speaker": "user",
300
- "content": payload["content"],
301
- "created_at": int(datetime.now().timestamp() * 1000),
302
- }
303
- ]
344
+ pg_list: list[tuple[dict, datetime]] = []
345
+ auros: list[dict] = []
346
+ _date_time = datetime.now(timezone.utc)
347
+ _interview_dict = {
348
+ "id": self._uuid,
349
+ "day": await self.simulator.get_simulator_day(),
350
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
351
+ "type": 2,
352
+ "speaker": "user",
353
+ "content": payload["content"],
354
+ "created_at": int(_date_time.timestamp() * 1000),
355
+ }
356
+ auros.append(_interview_dict)
357
+ pg_list.append((_interview_dict, _date_time))
304
358
  question = payload["content"]
305
359
  response = await self.generate_user_chat_response(question)
306
- auros.append(
307
- {
308
- "id": self._uuid,
309
- "day": await self.simulator.get_simulator_day(),
310
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
311
- "type": 2,
312
- "speaker": "",
313
- "content": response,
314
- "created_at": int(datetime.now().timestamp() * 1000),
315
- }
316
- )
317
- if self._avro_file is None:
318
- return
319
- with open(self._avro_file["dialog"], "a+b") as f:
320
- fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
360
+ _date_time = datetime.now(timezone.utc)
361
+ _interview_dict = {
362
+ "id": self._uuid,
363
+ "day": await self.simulator.get_simulator_day(),
364
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
365
+ "type": 2,
366
+ "speaker": "",
367
+ "content": response,
368
+ "created_at": int(_date_time.timestamp() * 1000),
369
+ }
370
+ auros.append(_interview_dict)
371
+ pg_list.append((_interview_dict, _date_time))
372
+ # Avro
373
+ if self._avro_file is not None:
374
+ with open(self._avro_file["dialog"], "a+b") as f:
375
+ fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
376
+ # Pg
377
+ if self._pgsql_writer is not None:
378
+ if self._last_asyncio_pg_task is not None:
379
+ await self._last_asyncio_pg_task
380
+ _keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
381
+ _data = [
382
+ tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
383
+ for _dict, _date_time in pg_list
384
+ ]
385
+ self._last_asyncio_pg_task = (
386
+ self._pgsql_writer.async_write_dialog.remote( # type:ignore
387
+ _data
388
+ )
389
+ )
321
390
 
322
391
  async def process_agent_chat_response(self, payload: dict) -> str:
323
392
  resp = f"Agent {self._uuid} received agent chat response: {payload}"
@@ -325,22 +394,39 @@ class Agent(ABC):
325
394
  return resp
326
395
 
327
396
  async def _process_agent_chat(self, payload: dict):
328
- auros = [
329
- {
330
- "id": self._uuid,
331
- "day": payload["day"],
332
- "t": payload["t"],
333
- "type": 1,
334
- "speaker": payload["from"],
335
- "content": payload["content"],
336
- "created_at": int(datetime.now().timestamp() * 1000),
337
- }
338
- ]
397
+ pg_list: list[tuple[dict, datetime]] = []
398
+ auros: list[dict] = []
399
+ _date_time = datetime.now(timezone.utc)
400
+ _chat_dict = {
401
+ "id": self._uuid,
402
+ "day": payload["day"],
403
+ "t": payload["t"],
404
+ "type": 1,
405
+ "speaker": payload["from"],
406
+ "content": payload["content"],
407
+ "created_at": int(_date_time.timestamp() * 1000),
408
+ }
409
+ auros.append(_chat_dict)
410
+ pg_list.append((_chat_dict, _date_time))
339
411
  asyncio.create_task(self.process_agent_chat_response(payload))
340
- if self._avro_file is None:
341
- return
342
- with open(self._avro_file["dialog"], "a+b") as f:
343
- fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
412
+ # Avro
413
+ if self._avro_file is not None:
414
+ with open(self._avro_file["dialog"], "a+b") as f:
415
+ fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
416
+ # Pg
417
+ if self._pgsql_writer is not None:
418
+ if self._last_asyncio_pg_task is not None:
419
+ await self._last_asyncio_pg_task
420
+ _keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
421
+ _data = [
422
+ tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
423
+ for _dict, _date_time in pg_list
424
+ ]
425
+ self._last_asyncio_pg_task = (
426
+ self._pgsql_writer.async_write_dialog.remote( # type:ignore
427
+ _data
428
+ )
429
+ )
344
430
 
345
431
  # Callback functions for MQTT message
346
432
  async def handle_agent_chat_message(self, payload: dict):
@@ -384,21 +470,38 @@ class Agent(ABC):
384
470
  "t": await self.simulator.get_simulator_second_from_start_of_day(),
385
471
  }
386
472
  await self._send_message(to_agent_uuid, payload, "agent-chat")
387
- auros = [
388
- {
389
- "id": self._uuid,
390
- "day": await self.simulator.get_simulator_day(),
391
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
392
- "type": 1,
393
- "speaker": self._uuid,
394
- "content": content,
395
- "created_at": int(datetime.now().timestamp() * 1000),
396
- }
397
- ]
398
- if self._avro_file is None:
399
- return
400
- with open(self._avro_file["dialog"], "a+b") as f:
401
- fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
473
+ pg_list: list[tuple[dict, datetime]] = []
474
+ auros: list[dict] = []
475
+ _date_time = datetime.now(timezone.utc)
476
+ _message_dict = {
477
+ "id": self._uuid,
478
+ "day": await self.simulator.get_simulator_day(),
479
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
480
+ "type": 1,
481
+ "speaker": self._uuid,
482
+ "content": content,
483
+ "created_at": int(datetime.now().timestamp() * 1000),
484
+ }
485
+ auros.append(_message_dict)
486
+ pg_list.append((_message_dict, _date_time))
487
+ # Avro
488
+ if self._avro_file is not None:
489
+ with open(self._avro_file["dialog"], "a+b") as f:
490
+ fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
491
+ # Pg
492
+ if self._pgsql_writer is not None:
493
+ if self._last_asyncio_pg_task is not None:
494
+ await self._last_asyncio_pg_task
495
+ _keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
496
+ _data = [
497
+ tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
498
+ for _dict, _date_time in pg_list
499
+ ]
500
+ self._last_asyncio_pg_task = (
501
+ self._pgsql_writer.async_write_dialog.remote( # type:ignore
502
+ _data
503
+ )
504
+ )
402
505
 
403
506
  # Agent logic
404
507
  @abstractmethod
@@ -0,0 +1,8 @@
1
+ """
2
+ 城市智能体模拟器模块
3
+ """
4
+
5
+ from .simulation import AgentSimulation
6
+ from .storage.pg import PgWriter, create_pg_tables
7
+
8
+ __all__ = ["AgentSimulation", "PgWriter", "create_pg_tables"]
@@ -3,7 +3,7 @@ import json
3
3
  import logging
4
4
  import time
5
5
  import uuid
6
- from datetime import datetime
6
+ from datetime import datetime, timezone
7
7
  from pathlib import Path
8
8
  from typing import Any
9
9
  from uuid import UUID
@@ -35,7 +35,7 @@ class AgentGroup:
35
35
  enable_avro: bool,
36
36
  avro_path: Path,
37
37
  enable_pgsql: bool,
38
- pgsql_copy_writer: ray.ObjectRef,
38
+ pgsql_writer: ray.ObjectRef,
39
39
  mlflow_run_id: str,
40
40
  logging_level: int,
41
41
  ):
@@ -45,6 +45,7 @@ class AgentGroup:
45
45
  self.config = config
46
46
  self.exp_id = exp_id
47
47
  self.enable_avro = enable_avro
48
+ self.enable_pgsql = enable_pgsql
48
49
  if enable_avro:
49
50
  self.avro_path = avro_path / f"{self._uuid}"
50
51
  self.avro_path.mkdir(parents=True, exist_ok=True)
@@ -54,6 +55,8 @@ class AgentGroup:
54
55
  "status": self.avro_path / f"status.avro",
55
56
  "survey": self.avro_path / f"survey.avro",
56
57
  }
58
+ if self.enable_pgsql:
59
+ pass
57
60
 
58
61
  self.messager = Messager(
59
62
  hostname=config["simulator_request"]["mqtt"]["server"],
@@ -61,6 +64,8 @@ class AgentGroup:
61
64
  username=config["simulator_request"]["mqtt"].get("username", None),
62
65
  password=config["simulator_request"]["mqtt"].get("password", None),
63
66
  )
67
+ self._pgsql_writer = pgsql_writer
68
+ self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
64
69
  self.initialized = False
65
70
  self.id2agent = {}
66
71
  # Step:1 prepare LLM client
@@ -105,6 +110,8 @@ class AgentGroup:
105
110
  agent.set_messager(self.messager)
106
111
  if self.enable_avro:
107
112
  agent.set_avro_file(self.avro_file) # type: ignore
113
+ if self.enable_pgsql:
114
+ agent.set_pgsql_writer(self._pgsql_writer)
108
115
 
109
116
  async def init_agents(self):
110
117
  logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
@@ -161,6 +168,20 @@ class AgentGroup:
161
168
  with open(filename, "wb") as f:
162
169
  surveys = []
163
170
  fastavro.writer(f, SURVEY_SCHEMA, surveys)
171
+
172
+ if self.enable_pgsql:
173
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
174
+ profiles: list[Any] = []
175
+ for agent in self.agents:
176
+ profile = await agent.memory._profile.export()
177
+ profile = profile[0]
178
+ profile["id"] = agent._uuid
179
+ profiles.append(
180
+ (agent._uuid, profile.get("name", ""), json.dumps(profile))
181
+ )
182
+ await self._pgsql_writer.async_write_profile.remote( # type:ignore
183
+ profiles
184
+ )
164
185
  self.initialized = True
165
186
  logger.debug(f"-----AgentGroup {self._uuid} initialized")
166
187
 
@@ -218,11 +239,13 @@ class AgentGroup:
218
239
  await asyncio.sleep(0.5)
219
240
 
220
241
  async def save_status(self):
242
+ _statuses_time_list: list[tuple[dict, datetime]] = []
221
243
  if self.enable_avro:
222
244
  logger.debug(f"-----Saving status for group {self._uuid}")
223
245
  avros = []
224
246
  if not issubclass(type(self.agents[0]), InstitutionAgent):
225
247
  for agent in self.agents:
248
+ _date_time = datetime.now(timezone.utc)
226
249
  position = await agent.memory.get("position")
227
250
  lng = position["longlat_position"]["longitude"]
228
251
  lat = position["longlat_position"]["latitude"]
@@ -248,13 +271,15 @@ class AgentGroup:
248
271
  "tired": needs["tired"],
249
272
  "safe": needs["safe"],
250
273
  "social": needs["social"],
251
- "created_at": int(datetime.now().timestamp() * 1000),
274
+ "created_at": int(_date_time.timestamp() * 1000),
252
275
  }
253
276
  avros.append(avro)
277
+ _statuses_time_list.append((avro, _date_time))
254
278
  with open(self.avro_file["status"], "a+b") as f:
255
279
  fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
256
280
  else:
257
281
  for agent in self.agents:
282
+ _date_time = datetime.now(timezone.utc)
258
283
  avro = {
259
284
  "id": agent._uuid,
260
285
  "day": await self.simulator.get_simulator_day(),
@@ -274,8 +299,109 @@ class AgentGroup:
274
299
  "customers": await agent.memory.get("customers"),
275
300
  }
276
301
  avros.append(avro)
302
+ _statuses_time_list.append((avro, _date_time))
277
303
  with open(self.avro_file["status"], "a+b") as f:
278
304
  fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, avros, codec="snappy")
305
+ if self.enable_pgsql:
306
+ # data already acquired from Avro part
307
+ if len(_statuses_time_list) > 0:
308
+ for _status_dict, _date_time in _statuses_time_list:
309
+ for key in ["lng", "lat", "parent_id"]:
310
+ if key not in _status_dict:
311
+ _status_dict[key] = -1
312
+ for key in [
313
+ "action",
314
+ ]:
315
+ if key not in _status_dict:
316
+ _status_dict[key] = ""
317
+ _status_dict["created_at"] = _date_time
318
+ else:
319
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
320
+ for agent in self.agents:
321
+ _date_time = datetime.now(timezone.utc)
322
+ position = await agent.memory.get("position")
323
+ lng = position["longlat_position"]["longitude"]
324
+ lat = position["longlat_position"]["latitude"]
325
+ if "aoi_position" in position:
326
+ parent_id = position["aoi_position"]["aoi_id"]
327
+ elif "lane_position" in position:
328
+ parent_id = position["lane_position"]["lane_id"]
329
+ else:
330
+ # BUG: 需要处理
331
+ parent_id = -1
332
+ needs = await agent.memory.get("needs")
333
+ action = await agent.memory.get("current_step")
334
+ action = action["intention"]
335
+ _status_dict = {
336
+ "id": agent._uuid,
337
+ "day": await self.simulator.get_simulator_day(),
338
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
339
+ "lng": lng,
340
+ "lat": lat,
341
+ "parent_id": parent_id,
342
+ "action": action,
343
+ "hungry": needs["hungry"],
344
+ "tired": needs["tired"],
345
+ "safe": needs["safe"],
346
+ "social": needs["social"],
347
+ "created_at": _date_time,
348
+ }
349
+ _statuses_time_list.append((_status_dict, _date_time))
350
+ else:
351
+ for agent in self.agents:
352
+ _date_time = datetime.now(timezone.utc)
353
+ _status_dict = {
354
+ "id": agent._uuid,
355
+ "day": await self.simulator.get_simulator_day(),
356
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
357
+ "lng": -1,
358
+ "lat": -1,
359
+ "parent_id": -1,
360
+ "action": "",
361
+ "type": await agent.memory.get("type"),
362
+ "nominal_gdp": await agent.memory.get("nominal_gdp"),
363
+ "real_gdp": await agent.memory.get("real_gdp"),
364
+ "unemployment": await agent.memory.get("unemployment"),
365
+ "wages": await agent.memory.get("wages"),
366
+ "prices": await agent.memory.get("prices"),
367
+ "inventory": await agent.memory.get("inventory"),
368
+ "price": await agent.memory.get("price"),
369
+ "interest_rate": await agent.memory.get("interest_rate"),
370
+ "bracket_cutoffs": await agent.memory.get(
371
+ "bracket_cutoffs"
372
+ ),
373
+ "bracket_rates": await agent.memory.get("bracket_rates"),
374
+ "employees": await agent.memory.get("employees"),
375
+ "customers": await agent.memory.get("customers"),
376
+ "created_at": _date_time,
377
+ }
378
+ _statuses_time_list.append((_status_dict, _date_time))
379
+ to_update_statues: list[tuple] = []
380
+ for _status_dict, _ in _statuses_time_list:
381
+ BASIC_KEYS = [
382
+ "id",
383
+ "day",
384
+ "t",
385
+ "lng",
386
+ "lat",
387
+ "parent_id",
388
+ "action",
389
+ "created_at",
390
+ ]
391
+ _data = [_status_dict[k] for k in BASIC_KEYS if k != "created_at"]
392
+ _other_dict = json.dumps(
393
+ {k: v for k, v in _status_dict.items() if k not in BASIC_KEYS}
394
+ )
395
+ _data.append(_other_dict)
396
+ _data.append(_status_dict["created_at"])
397
+ to_update_statues.append(tuple(_data))
398
+ if self._last_asyncio_pg_task is not None:
399
+ await self._last_asyncio_pg_task
400
+ self._last_asyncio_pg_task = (
401
+ self._pgsql_writer.async_write_status.remote( # type:ignore
402
+ to_update_statues
403
+ )
404
+ )
279
405
 
280
406
  async def step(self):
281
407
  if not self.initialized: