pycityagent 2.0.0a19__py3-none-any.whl → 2.0.0a21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,8 @@ import asyncio
2
2
  import logging
3
3
  from copy import deepcopy
4
4
  from datetime import datetime
5
- from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
5
+ from typing import Any, Literal, Optional, Union
6
+ from collections.abc import Sequence,Callable
6
7
 
7
8
  import numpy as np
8
9
  from pyparsing import deque
@@ -27,10 +28,10 @@ class Memory:
27
28
 
28
29
  def __init__(
29
30
  self,
30
- config: Optional[Dict[Any, Any]] = None,
31
- profile: Optional[Dict[Any, Any]] = None,
32
- base: Optional[Dict[Any, Any]] = None,
33
- motion: Optional[Dict[Any, Any]] = None,
31
+ config: Optional[dict[Any, Any]] = None,
32
+ profile: Optional[dict[Any, Any]] = None,
33
+ base: Optional[dict[Any, Any]] = None,
34
+ motion: Optional[dict[Any, Any]] = None,
34
35
  activate_timestamp: bool = False,
35
36
  embedding_model: Any = None,
36
37
  ) -> None:
@@ -38,7 +39,7 @@ class Memory:
38
39
  Initializes the Memory with optional configuration.
39
40
 
40
41
  Args:
41
- config (Optional[Dict[Any, Any]], optional):
42
+ config (Optional[dict[Any, Any]], optional):
42
43
  A configuration dictionary for dynamic memory. The dictionary format is:
43
44
  - Key: The name of the dynamic memory field.
44
45
  - Value: Can be one of two formats:
@@ -46,24 +47,24 @@ class Memory:
46
47
  2. A callable that returns the default value when invoked (useful for complex default values).
47
48
  Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
48
49
  Defaults to None.
49
- profile (Optional[Dict[Any, Any]], optional): profile attribute dict.
50
- base (Optional[Dict[Any, Any]], optional): base attribute dict from City Simulator.
51
- motion (Optional[Dict[Any, Any]], optional): motion attribute dict from City Simulator.
50
+ profile (Optional[dict[Any, Any]], optional): profile attribute dict.
51
+ base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
52
+ motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
52
53
  activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
53
54
  embedding_model (Any): The embedding model for memory search.
54
55
  """
55
- self.watchers: Dict[str, List[Callable]] = {}
56
+ self.watchers: dict[str, list[Callable]] = {}
56
57
  self._lock = asyncio.Lock()
57
58
  self.embedding_model = embedding_model
58
59
 
59
60
  # 初始化embedding存储
60
61
  self._embeddings = {"state": {}, "profile": {}, "dynamic": {}}
61
62
 
62
- _dynamic_config: Dict[Any, Any] = {}
63
- _state_config: Dict[Any, Any] = {}
64
- _profile_config: Dict[Any, Any] = {}
63
+ _dynamic_config: dict[Any, Any] = {}
64
+ _state_config: dict[Any, Any] = {}
65
+ _profile_config: dict[Any, Any] = {}
65
66
  # 记录哪些字段需要embedding
66
- self._embedding_fields: Dict[str, bool] = {}
67
+ self._embedding_fields: dict[str, bool] = {}
67
68
 
68
69
  if config is not None:
69
70
  for k, v in config.items():
@@ -303,7 +304,7 @@ class Memory:
303
304
 
304
305
  async def update_batch(
305
306
  self,
306
- content: Union[Dict, Sequence[Tuple[Any, Any]]],
307
+ content: Union[dict, Sequence[tuple[Any, Any]]],
307
308
  mode: Union[Literal["replace"], Literal["merge"]] = "replace",
308
309
  store_snapshot: bool = False,
309
310
  protect_llm_read_only_fields: bool = True,
@@ -312,7 +313,7 @@ class Memory:
312
313
  Updates multiple values in the memory at once.
313
314
 
314
315
  Args:
315
- content (Union[Dict, Sequence[Tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
316
+ content (Union[dict, Sequence[tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
316
317
  mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
317
318
  store_snapshot (bool): Whether to store a snapshot of the memory after the update.
318
319
  protect_llm_read_only_fields (bool): Whether to protect non-self define fields from being updated.
@@ -321,9 +322,9 @@ class Memory:
321
322
  TypeError: If the content type is neither a dictionary nor a sequence of tuples.
322
323
  """
323
324
  if isinstance(content, dict):
324
- _list_content: List[Tuple[Any, Any]] = [(k, v) for k, v in content.items()]
325
+ _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content.items()]
325
326
  elif isinstance(content, Sequence):
326
- _list_content: List[Tuple[Any, Any]] = [(k, v) for k, v in content]
327
+ _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content]
327
328
  else:
328
329
  raise TypeError(f"Invalid content type `{type(content)}`!")
329
330
  for k, v in _list_content[:1]:
@@ -353,12 +354,12 @@ class Memory:
353
354
  @lock_decorator
354
355
  async def export(
355
356
  self,
356
- ) -> Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]]:
357
+ ) -> tuple[Sequence[dict], Sequence[dict], Sequence[dict]]:
357
358
  """
358
359
  Exports the current state of all memory sections.
359
360
 
360
361
  Returns:
361
- Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
362
+ tuple[Sequence[dict], Sequence[dict], Sequence[dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
362
363
  """
363
364
  return (
364
365
  await self._profile.export(),
@@ -369,14 +370,14 @@ class Memory:
369
370
  @lock_decorator
370
371
  async def load(
371
372
  self,
372
- snapshots: Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]],
373
+ snapshots: tuple[Sequence[dict], Sequence[dict], Sequence[dict]],
373
374
  reset_memory: bool = True,
374
375
  ) -> None:
375
376
  """
376
377
  Import the snapshot memories of all sections.
377
378
 
378
379
  Args:
379
- snapshots (Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]]): The exported snapshots.
380
+ snapshots (tuple[Sequence[dict], Sequence[dict], Sequence[dict]]): The exported snapshots.
380
381
  reset_memory (bool): Whether to reset previous memory.
381
382
  """
382
383
  _profile_snapshot, _state_snapshot, _dynamic_snapshot = snapshots
@@ -1,5 +1,6 @@
1
- from .mlflow_client import MlflowClient
1
+ from .mlflow_client import MlflowClient,init_mlflow_connection
2
2
 
3
3
  __all__ = [
4
4
  "MlflowClient",
5
+ "init_mlflow_connection",
5
6
  ]
@@ -18,6 +18,55 @@ from ..utils.decorators import lock_decorator
18
18
  logger = logging.getLogger("mlflow")
19
19
 
20
20
 
21
+ def init_mlflow_connection(
22
+ config: dict,
23
+ mlflow_run_name: Optional[str] = None,
24
+ experiment_name: Optional[str] = None,
25
+ experiment_description: Optional[str] = None,
26
+ experiment_tags: Optional[dict[str, Any]] = None,
27
+ ) -> tuple[str, tuple[str, mlflow.MlflowClient, Run, str]]:
28
+
29
+ os.environ["MLFLOW_TRACKING_USERNAME"] = config.get("username", None)
30
+ os.environ["MLFLOW_TRACKING_PASSWORD"] = config.get("password", None)
31
+
32
+ run_uuid = str(uuid.uuid4())
33
+ # run name
34
+ if mlflow_run_name is None:
35
+ mlflow_run_name = f"exp_{run_uuid}"
36
+
37
+ # exp name
38
+ if experiment_name is None:
39
+ experiment_name = f"run_{run_uuid}"
40
+
41
+ # tags
42
+ if experiment_tags is None:
43
+ experiment_tags = {}
44
+ if experiment_description is not None:
45
+ experiment_tags["mlflow.note.content"] = experiment_description
46
+
47
+ uri = config["mlflow_uri"]
48
+ client = mlflow.MlflowClient(tracking_uri=uri)
49
+
50
+ # experiment
51
+ try:
52
+ experiment_id = client.create_experiment(
53
+ name=experiment_name,
54
+ tags=experiment_tags,
55
+ )
56
+ except Exception as e:
57
+ experiment = client.get_experiment_by_name(experiment_name)
58
+ if experiment is None:
59
+ raise e
60
+ experiment_id = experiment.experiment_id
61
+
62
+ # run
63
+ run = client.create_run(experiment_id=experiment_id, run_name=mlflow_run_name)
64
+
65
+ run_id = run.info.run_id
66
+
67
+ return run_id, (uri, client, run, run_uuid)
68
+
69
+
21
70
  class MlflowClient:
22
71
  """
23
72
  - Mlflow client
@@ -30,42 +79,30 @@ class MlflowClient:
30
79
  experiment_name: Optional[str] = None,
31
80
  experiment_description: Optional[str] = None,
32
81
  experiment_tags: Optional[dict[str, Any]] = None,
82
+ run_id: Optional[str] = None,
33
83
  ) -> None:
34
- os.environ["MLFLOW_TRACKING_USERNAME"] = config.get("username", None)
35
- os.environ["MLFLOW_TRACKING_PASSWORD"] = config.get("password", None)
36
- self._mlflow_uri = uri = config["mlflow_uri"]
37
- self._client = client = mlflow.MlflowClient(tracking_uri=uri)
38
- self._run_uuid = run_uuid = str(uuid.uuid4())
39
- self._lock = asyncio.Lock()
40
- # run name
41
- if mlflow_run_name is None:
42
- mlflow_run_name = f"exp_{run_uuid}"
43
-
44
- # exp name
45
- if experiment_name is None:
46
- experiment_name = f"run_{run_uuid}"
47
-
48
- # tags
49
- if experiment_tags is None:
50
- experiment_tags = {}
51
- if experiment_description is not None:
52
- experiment_tags["mlflow.note.content"] = experiment_description
53
-
54
- try:
55
- self._experiment_id = experiment_id = client.create_experiment(
56
- name=experiment_name,
57
- tags=experiment_tags,
84
+ if run_id is None:
85
+ self._run_id, (
86
+ self._mlflow_uri,
87
+ self._client,
88
+ self._run,
89
+ self._run_uuid,
90
+ ) = init_mlflow_connection(
91
+ config=config,
92
+ mlflow_run_name=mlflow_run_name,
93
+ experiment_name=experiment_name,
94
+ experiment_description=experiment_description,
95
+ experiment_tags=experiment_tags,
58
96
  )
59
- except Exception as e:
60
- experiment = client.get_experiment_by_name(experiment_name)
61
- if experiment is None:
62
- raise e
63
- self._experiment_id = experiment_id = experiment.experiment_id
64
-
65
- self._run = run = client.create_run(
66
- experiment_id=experiment_id, run_name=mlflow_run_name
67
- )
68
- self._run_id = run.info.run_id
97
+ else:
98
+ self._mlflow_uri = uri = config["mlflow_uri"]
99
+ os.environ["MLFLOW_TRACKING_USERNAME"] = config.get("username", None)
100
+ os.environ["MLFLOW_TRACKING_PASSWORD"] = config.get("password", None)
101
+ self._client = client = mlflow.MlflowClient(tracking_uri=uri)
102
+ self._run = client.get_run(run_id=run_id)
103
+ self._run_id = run_id
104
+ self._run_uuid = run_uuid = str(uuid.uuid4())
105
+ self._lock = asyncio.Lock()
69
106
 
70
107
  @property
71
108
  def client(
@@ -77,6 +114,7 @@ class MlflowClient:
77
114
  def run_id(
78
115
  self,
79
116
  ) -> str:
117
+ assert self._run_id is not None
80
118
  return self._run_id
81
119
 
82
120
  @lock_decorator
@@ -3,5 +3,6 @@
3
3
  """
4
4
 
5
5
  from .simulation import AgentSimulation
6
+ from .storage.pg import PgWriter, create_pg_tables
6
7
 
7
- __all__ = ["AgentSimulation"]
8
+ __all__ = ["AgentSimulation", "PgWriter", "create_pg_tables"]
@@ -3,7 +3,7 @@ import json
3
3
  import logging
4
4
  import time
5
5
  import uuid
6
- from datetime import datetime
6
+ from datetime import datetime, timezone
7
7
  from pathlib import Path
8
8
  from typing import Any
9
9
  from uuid import UUID
@@ -35,7 +35,8 @@ class AgentGroup:
35
35
  enable_avro: bool,
36
36
  avro_path: Path,
37
37
  enable_pgsql: bool,
38
- pgsql_args: tuple[str, str, str, str, str],
38
+ pgsql_writer: ray.ObjectRef,
39
+ mlflow_run_id: str,
39
40
  logging_level: int,
40
41
  ):
41
42
  logger.setLevel(logging_level)
@@ -44,6 +45,7 @@ class AgentGroup:
44
45
  self.config = config
45
46
  self.exp_id = exp_id
46
47
  self.enable_avro = enable_avro
48
+ self.enable_pgsql = enable_pgsql
47
49
  if enable_avro:
48
50
  self.avro_path = avro_path / f"{self._uuid}"
49
51
  self.avro_path.mkdir(parents=True, exist_ok=True)
@@ -53,6 +55,8 @@ class AgentGroup:
53
55
  "status": self.avro_path / f"status.avro",
54
56
  "survey": self.avro_path / f"survey.avro",
55
57
  }
58
+ if self.enable_pgsql:
59
+ pass
56
60
 
57
61
  self.messager = Messager(
58
62
  hostname=config["simulator_request"]["mqtt"]["server"],
@@ -60,6 +64,8 @@ class AgentGroup:
60
64
  username=config["simulator_request"]["mqtt"].get("username", None),
61
65
  password=config["simulator_request"]["mqtt"].get("password", None),
62
66
  )
67
+ self._pgsql_writer = pgsql_writer
68
+ self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
63
69
  self.initialized = False
64
70
  self.id2agent = {}
65
71
  # Step:1 prepare LLM client
@@ -88,6 +94,7 @@ class AgentGroup:
88
94
  config=_mlflow_config,
89
95
  mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
90
96
  experiment_name=exp_name,
97
+ run_id=mlflow_run_id,
91
98
  )
92
99
  else:
93
100
  self.mlflow_client = None
@@ -103,6 +110,8 @@ class AgentGroup:
103
110
  agent.set_messager(self.messager)
104
111
  if self.enable_avro:
105
112
  agent.set_avro_file(self.avro_file) # type: ignore
113
+ if self.enable_pgsql:
114
+ agent.set_pgsql_writer(self._pgsql_writer)
106
115
 
107
116
  async def init_agents(self):
108
117
  logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
@@ -159,6 +168,20 @@ class AgentGroup:
159
168
  with open(filename, "wb") as f:
160
169
  surveys = []
161
170
  fastavro.writer(f, SURVEY_SCHEMA, surveys)
171
+
172
+ if self.enable_pgsql:
173
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
174
+ profiles: list[Any] = []
175
+ for agent in self.agents:
176
+ profile = await agent.memory._profile.export()
177
+ profile = profile[0]
178
+ profile["id"] = agent._uuid
179
+ profiles.append(
180
+ (agent._uuid, profile.get("name", ""), json.dumps(profile))
181
+ )
182
+ await self._pgsql_writer.async_write_profile.remote( # type:ignore
183
+ profiles
184
+ )
162
185
  self.initialized = True
163
186
  logger.debug(f"-----AgentGroup {self._uuid} initialized")
164
187
 
@@ -216,11 +239,13 @@ class AgentGroup:
216
239
  await asyncio.sleep(0.5)
217
240
 
218
241
  async def save_status(self):
242
+ _statuses_time_list: list[tuple[dict, datetime]] = []
219
243
  if self.enable_avro:
220
244
  logger.debug(f"-----Saving status for group {self._uuid}")
221
245
  avros = []
222
246
  if not issubclass(type(self.agents[0]), InstitutionAgent):
223
247
  for agent in self.agents:
248
+ _date_time = datetime.now(timezone.utc)
224
249
  position = await agent.memory.get("position")
225
250
  lng = position["longlat_position"]["longitude"]
226
251
  lat = position["longlat_position"]["latitude"]
@@ -246,13 +271,15 @@ class AgentGroup:
246
271
  "tired": needs["tired"],
247
272
  "safe": needs["safe"],
248
273
  "social": needs["social"],
249
- "created_at": int(datetime.now().timestamp() * 1000),
274
+ "created_at": int(_date_time.timestamp() * 1000),
250
275
  }
251
276
  avros.append(avro)
277
+ _statuses_time_list.append((avro, _date_time))
252
278
  with open(self.avro_file["status"], "a+b") as f:
253
279
  fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
254
280
  else:
255
281
  for agent in self.agents:
282
+ _date_time = datetime.now(timezone.utc)
256
283
  avro = {
257
284
  "id": agent._uuid,
258
285
  "day": await self.simulator.get_simulator_day(),
@@ -272,8 +299,109 @@ class AgentGroup:
272
299
  "customers": await agent.memory.get("customers"),
273
300
  }
274
301
  avros.append(avro)
302
+ _statuses_time_list.append((avro, _date_time))
275
303
  with open(self.avro_file["status"], "a+b") as f:
276
304
  fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, avros, codec="snappy")
305
+ if self.enable_pgsql:
306
+ # data already acquired from Avro part
307
+ if len(_statuses_time_list) > 0:
308
+ for _status_dict, _date_time in _statuses_time_list:
309
+ for key in ["lng", "lat", "parent_id"]:
310
+ if key not in _status_dict:
311
+ _status_dict[key] = -1
312
+ for key in [
313
+ "action",
314
+ ]:
315
+ if key not in _status_dict:
316
+ _status_dict[key] = ""
317
+ _status_dict["created_at"] = _date_time
318
+ else:
319
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
320
+ for agent in self.agents:
321
+ _date_time = datetime.now(timezone.utc)
322
+ position = await agent.memory.get("position")
323
+ lng = position["longlat_position"]["longitude"]
324
+ lat = position["longlat_position"]["latitude"]
325
+ if "aoi_position" in position:
326
+ parent_id = position["aoi_position"]["aoi_id"]
327
+ elif "lane_position" in position:
328
+ parent_id = position["lane_position"]["lane_id"]
329
+ else:
330
+ # BUG: 需要处理
331
+ parent_id = -1
332
+ needs = await agent.memory.get("needs")
333
+ action = await agent.memory.get("current_step")
334
+ action = action["intention"]
335
+ _status_dict = {
336
+ "id": agent._uuid,
337
+ "day": await self.simulator.get_simulator_day(),
338
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
339
+ "lng": lng,
340
+ "lat": lat,
341
+ "parent_id": parent_id,
342
+ "action": action,
343
+ "hungry": needs["hungry"],
344
+ "tired": needs["tired"],
345
+ "safe": needs["safe"],
346
+ "social": needs["social"],
347
+ "created_at": _date_time,
348
+ }
349
+ _statuses_time_list.append((_status_dict, _date_time))
350
+ else:
351
+ for agent in self.agents:
352
+ _date_time = datetime.now(timezone.utc)
353
+ _status_dict = {
354
+ "id": agent._uuid,
355
+ "day": await self.simulator.get_simulator_day(),
356
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
357
+ "lng": -1,
358
+ "lat": -1,
359
+ "parent_id": -1,
360
+ "action": "",
361
+ "type": await agent.memory.get("type"),
362
+ "nominal_gdp": await agent.memory.get("nominal_gdp"),
363
+ "real_gdp": await agent.memory.get("real_gdp"),
364
+ "unemployment": await agent.memory.get("unemployment"),
365
+ "wages": await agent.memory.get("wages"),
366
+ "prices": await agent.memory.get("prices"),
367
+ "inventory": await agent.memory.get("inventory"),
368
+ "price": await agent.memory.get("price"),
369
+ "interest_rate": await agent.memory.get("interest_rate"),
370
+ "bracket_cutoffs": await agent.memory.get(
371
+ "bracket_cutoffs"
372
+ ),
373
+ "bracket_rates": await agent.memory.get("bracket_rates"),
374
+ "employees": await agent.memory.get("employees"),
375
+ "customers": await agent.memory.get("customers"),
376
+ "created_at": _date_time,
377
+ }
378
+ _statuses_time_list.append((_status_dict, _date_time))
379
+ to_update_statues: list[tuple] = []
380
+ for _status_dict, _ in _statuses_time_list:
381
+ BASIC_KEYS = [
382
+ "id",
383
+ "day",
384
+ "t",
385
+ "lng",
386
+ "lat",
387
+ "parent_id",
388
+ "action",
389
+ "created_at",
390
+ ]
391
+ _data = [_status_dict[k] for k in BASIC_KEYS if k != "created_at"]
392
+ _other_dict = json.dumps(
393
+ {k: v for k, v in _status_dict.items() if k not in BASIC_KEYS}
394
+ )
395
+ _data.append(_other_dict)
396
+ _data.append(_status_dict["created_at"])
397
+ to_update_statues.append(tuple(_data))
398
+ if self._last_asyncio_pg_task is not None:
399
+ await self._last_asyncio_pg_task
400
+ self._last_asyncio_pg_task = (
401
+ self._pgsql_writer.async_write_status.remote( # type:ignore
402
+ to_update_statues
403
+ )
404
+ )
277
405
 
278
406
  async def step(self):
279
407
  if not self.initialized: