pycityagent 2.0.0a19__py3-none-any.whl → 2.0.0a21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,7 @@ import json
3
3
  import logging
4
4
  import os
5
5
  import random
6
+ import time
6
7
  import uuid
7
8
  from collections.abc import Callable, Sequence
8
9
  from concurrent.futures import ThreadPoolExecutor
@@ -11,6 +12,7 @@ from pathlib import Path
11
12
  from typing import Any, Optional, Union
12
13
 
13
14
  import pycityproto.city.economy.v2.economy_pb2 as economyv2
15
+ import ray
14
16
  import yaml
15
17
  from mosstool.map._map_util.const import AOI_START_ID
16
18
 
@@ -18,8 +20,10 @@ from ..agent import Agent, InstitutionAgent
18
20
  from ..environment.simulator import Simulator
19
21
  from ..memory.memory import Memory
20
22
  from ..message.messager import Messager
23
+ from ..metrics import init_mlflow_connection
21
24
  from ..survey import Survey
22
25
  from .agentgroup import AgentGroup
26
+ from .storage.pg import PgWriter, create_pg_tables
23
27
 
24
28
  logger = logging.getLogger("pycityagent")
25
29
 
@@ -60,6 +64,7 @@ class AgentSimulation:
60
64
  self._user_survey_topics: dict[uuid.UUID, str] = {}
61
65
  self._user_interview_topics: dict[uuid.UUID, str] = {}
62
66
  self._loop = asyncio.get_event_loop()
67
+ # self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
63
68
 
64
69
  self._messager = Messager(
65
70
  hostname=config["simulator_request"]["mqtt"]["server"],
@@ -86,22 +91,13 @@ class AgentSimulation:
86
91
  self._enable_pgsql = _pgsql_config.get("enabled", False)
87
92
  if not self._enable_pgsql:
88
93
  logger.warning("PostgreSQL is not enabled, NO POSTGRESQL DATABASE STORAGE")
89
- self._pgsql_args = ("", "", "", "", "")
94
+ self._pgsql_dsn = ""
90
95
  else:
91
- self._pgsql_host = _pgsql_config["host"]
92
- self._pgsql_port = _pgsql_config["port"]
93
- self._pgsql_database = _pgsql_config["database"]
94
- self._pgsql_user = _pgsql_config.get("user", None)
95
- self._pgsql_password = _pgsql_config.get("password", None)
96
- self._pgsql_args: tuple[str, str, str, str, str] = (
97
- self._pgsql_host,
98
- self._pgsql_port,
99
- self._pgsql_database,
100
- self._pgsql_user,
101
- self._pgsql_password,
102
- )
96
+ self._pgsql_dsn = _pgsql_config["data_source_name"]
103
97
 
104
98
  # 添加实验信息相关的属性
99
+ self._exp_created_time = datetime.now(timezone.utc)
100
+ self._exp_updated_time = datetime.now(timezone.utc)
105
101
  self._exp_info = {
106
102
  "id": self.exp_id,
107
103
  "name": exp_name,
@@ -111,7 +107,8 @@ class AgentSimulation:
111
107
  "cur_t": 0.0,
112
108
  "config": json.dumps(config),
113
109
  "error": "",
114
- "created_at": datetime.now(timezone.utc).isoformat(),
110
+ "created_at": self._exp_created_time.isoformat(),
111
+ "updated_at": self._exp_updated_time.isoformat(),
115
112
  }
116
113
 
117
114
  # 创建异步任务保存实验信息
@@ -165,7 +162,8 @@ class AgentSimulation:
165
162
  enable_avro: bool,
166
163
  avro_path: Path,
167
164
  enable_pgsql: bool,
168
- pgsql_args: tuple[str, str, str, str, str],
165
+ pgsql_writer: ray.ObjectRef,
166
+ mlflow_run_id: str = None, # type: ignore
169
167
  logging_level: int = logging.WARNING,
170
168
  ):
171
169
  """创建远程组"""
@@ -177,7 +175,8 @@ class AgentSimulation:
177
175
  enable_avro,
178
176
  avro_path,
179
177
  enable_pgsql,
180
- pgsql_args,
178
+ pgsql_writer,
179
+ mlflow_run_id,
181
180
  logging_level,
182
181
  )
183
182
  return group_name, group, agents
@@ -186,6 +185,7 @@ class AgentSimulation:
186
185
  self,
187
186
  agent_count: Union[int, list[int]],
188
187
  group_size: int = 1000,
188
+ pg_sql_writers: int = 32,
189
189
  memory_config_func: Optional[Union[Callable, list[Callable]]] = None,
190
190
  ) -> None:
191
191
  """初始化智能体
@@ -246,8 +246,8 @@ class AgentSimulation:
246
246
  memory=memory,
247
247
  )
248
248
 
249
- self._agents[agent._uuid] = agent
250
- self._agent_uuids.append(agent._uuid)
249
+ self._agents[agent._uuid] = agent # type:ignore
250
+ self._agent_uuids.append(agent._uuid) # type:ignore
251
251
 
252
252
  # 计算需要的组数,向上取整以处理不足一组的情况
253
253
  num_group = (agent_count_i + group_size - 1) // group_size
@@ -267,9 +267,33 @@ class AgentSimulation:
267
267
 
268
268
  class_init_index += agent_count_i
269
269
 
270
+ # 初始化mlflow连接
271
+ _mlflow_config = self.config.get("metric_request", {}).get("mlflow")
272
+ if _mlflow_config:
273
+ mlflow_run_id, _ = init_mlflow_connection(
274
+ config=_mlflow_config,
275
+ mlflow_run_name=f"EXP_{self.exp_name}_{1000*int(time.time())}",
276
+ experiment_name=self.exp_name,
277
+ )
278
+ else:
279
+ mlflow_run_id = None
280
+ # 建表
281
+ if self.enable_pgsql:
282
+ _num_workers = min(1, pg_sql_writers)
283
+ create_pg_tables(
284
+ exp_id=self.exp_id,
285
+ dsn=self._pgsql_dsn,
286
+ )
287
+ self._pgsql_writers = _workers = [
288
+ PgWriter.remote(self.exp_id, self._pgsql_dsn)
289
+ for _ in range(_num_workers)
290
+ ]
291
+ else:
292
+ _num_workers = 1
293
+ self._pgsql_writers = _workers = [None for _ in range(_num_workers)]
270
294
  # 收集所有创建组的参数
271
295
  creation_tasks = []
272
- for group_name, agents in group_creation_params:
296
+ for i, (group_name, agents) in enumerate(group_creation_params):
273
297
  # 直接创建异步任务
274
298
  group = AgentGroup.remote(
275
299
  agents,
@@ -279,7 +303,8 @@ class AgentSimulation:
279
303
  self.enable_avro,
280
304
  self.avro_path,
281
305
  self.enable_pgsql,
282
- self._pgsql_args,
306
+ _workers[i % _num_workers], # type:ignore
307
+ mlflow_run_id, # type:ignore
283
308
  self.logging_level,
284
309
  )
285
310
  creation_tasks.append((group_name, group, agents))
@@ -451,11 +476,13 @@ class AgentSimulation:
451
476
  survey_dict = survey.to_dict()
452
477
  if agent_uuids is None:
453
478
  agent_uuids = self._agent_uuids
479
+ _date_time = datetime.now(timezone.utc)
454
480
  payload = {
455
481
  "from": "none",
456
482
  "survey_id": survey_dict["id"],
457
- "timestamp": int(datetime.now().timestamp() * 1000),
483
+ "timestamp": int(_date_time.timestamp() * 1000),
458
484
  "data": survey_dict,
485
+ "_date_time": _date_time,
459
486
  }
460
487
  for uuid in agent_uuids:
461
488
  topic = self._user_survey_topics[uuid]
@@ -465,10 +492,12 @@ class AgentSimulation:
465
492
  self, content: str, agent_uuids: Union[uuid.UUID, list[uuid.UUID]]
466
493
  ):
467
494
  """发送面试消息"""
495
+ _date_time = datetime.now(timezone.utc)
468
496
  payload = {
469
497
  "from": "none",
470
498
  "content": content,
471
- "timestamp": int(datetime.now().timestamp() * 1000),
499
+ "timestamp": int(_date_time.timestamp() * 1000),
500
+ "_date_time": _date_time,
472
501
  }
473
502
  if not isinstance(agent_uuids, Sequence):
474
503
  agent_uuids = [agent_uuids]
@@ -497,15 +526,29 @@ class AgentSimulation:
497
526
  logger.error(f"Avro保存实验信息失败: {str(e)}")
498
527
  try:
499
528
  if self.enable_pgsql:
500
- # TODO
501
- pass
529
+ worker: ray.ObjectRef = self._pgsql_writers[0] # type:ignore
530
+ # if self._last_asyncio_pg_task is not None:
531
+ # await self._last_asyncio_pg_task
532
+ # self._last_asyncio_pg_task = (
533
+ # worker.async_update_exp_info.remote( # type:ignore
534
+ # pg_exp_info
535
+ # )
536
+ # )
537
+ pg_exp_info = {k: v for k, v in self._exp_info.items()}
538
+ pg_exp_info["created_at"] = self._exp_created_time
539
+ pg_exp_info["updated_at"] = self._exp_updated_time
540
+ await worker.async_update_exp_info.remote( # type:ignore
541
+ pg_exp_info
542
+ )
502
543
  except Exception as e:
503
544
  logger.error(f"PostgreSQL保存实验信息失败: {str(e)}")
504
545
 
505
546
  async def _update_exp_status(self, status: int, error: str = "") -> None:
547
+ self._exp_updated_time = datetime.now(timezone.utc)
506
548
  """更新实验状态并保存"""
507
549
  self._exp_info["status"] = status
508
550
  self._exp_info["error"] = error
551
+ self._exp_info["updated_at"] = self._exp_updated_time.isoformat()
509
552
  await self._save_exp_info()
510
553
 
511
554
  async def _monitor_exp_status(self, stop_event: asyncio.Event):
@@ -0,0 +1,139 @@
1
+ import asyncio
2
+ from collections import defaultdict
3
+ from typing import Any
4
+
5
+ import psycopg
6
+ import psycopg.sql
7
+ import ray
8
+ from psycopg.rows import dict_row
9
+
10
+ from ...utils.decorators import lock_decorator
11
+ from ...utils.pg_query import PGSQL_DICT
12
+
13
+
14
+ def create_pg_tables(exp_id: str, dsn: str):
15
+ for table_type, exec_strs in PGSQL_DICT.items():
16
+ table_name = f"socialcity_{exp_id.replace('-', '_')}_{table_type}"
17
+ # # debug str
18
+ # for _str in [f"DROP TABLE IF EXISTS {table_name}"] + [
19
+ # _exec_str.format(table_name=table_name) for _exec_str in exec_strs
20
+ # ]:
21
+ # print(_str)
22
+ with psycopg.connect(dsn) as conn:
23
+ with conn.cursor() as cur:
24
+ # delete table
25
+ cur.execute(f"DROP TABLE IF EXISTS {table_name}") # type:ignore
26
+ conn.commit()
27
+ # create table
28
+ for _exec_str in exec_strs:
29
+ cur.execute(_exec_str.format(table_name=table_name))
30
+ conn.commit()
31
+
32
+
33
+ @ray.remote
34
+ class PgWriter:
35
+ def __init__(self, exp_id: str, dsn: str):
36
+ self.exp_id = exp_id
37
+ self._dsn = dsn
38
+ # self._lock = asyncio.Lock()
39
+
40
+ # @lock_decorator
41
+ async def async_write_dialog(self, rows: list[tuple]):
42
+ _tuple_types = [str, int, float, int, str, str, str, None]
43
+ table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_dialog"
44
+ # 将数据插入数据库
45
+ async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
46
+ copy_sql = psycopg.sql.SQL(
47
+ "COPY {} (id, day, t, type, speaker, content, created_at) FROM STDIN"
48
+ ).format(psycopg.sql.Identifier(table_name))
49
+ async with aconn.cursor() as cur:
50
+ async with cur.copy(copy_sql) as copy:
51
+ for row in rows:
52
+ _row = [
53
+ _type(r) if _type is not None else r
54
+ for (_type, r) in zip(_tuple_types, row)
55
+ ]
56
+ await copy.write_row(_row)
57
+
58
+ # @lock_decorator
59
+ async def async_write_status(self, rows: list[tuple]):
60
+ _tuple_types = [str, int, float, float, float, int, str, str, None]
61
+ table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_status"
62
+ async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
63
+ copy_sql = psycopg.sql.SQL(
64
+ "COPY {} (id, day, t, lng, lat, parent_id, action, status, created_at) FROM STDIN"
65
+ ).format(psycopg.sql.Identifier(table_name))
66
+ async with aconn.cursor() as cur:
67
+ async with cur.copy(copy_sql) as copy:
68
+ for row in rows:
69
+ _row = [
70
+ _type(r) if _type is not None else r
71
+ for (_type, r) in zip(_tuple_types, row)
72
+ ]
73
+ await copy.write_row(_row)
74
+
75
+ # @lock_decorator
76
+ async def async_write_profile(self, rows: list[tuple]):
77
+ _tuple_types = [str, str, str]
78
+ table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_profile"
79
+ async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
80
+ copy_sql = psycopg.sql.SQL("COPY {} (id, name, profile) FROM STDIN").format(
81
+ psycopg.sql.Identifier(table_name)
82
+ )
83
+ async with aconn.cursor() as cur:
84
+ async with cur.copy(copy_sql) as copy:
85
+ for row in rows:
86
+ _row = [
87
+ _type(r) if _type is not None else r
88
+ for (_type, r) in zip(_tuple_types, row)
89
+ ]
90
+ await copy.write_row(_row)
91
+
92
+ # @lock_decorator
93
+ async def async_write_survey(self, rows: list[tuple]):
94
+ _tuple_types = [str, int, float, str, str, None]
95
+ table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_survey"
96
+ async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
97
+ copy_sql = psycopg.sql.SQL(
98
+ "COPY {} (id, day, t, survey_id, result, created_at) FROM STDIN"
99
+ ).format(psycopg.sql.Identifier(table_name))
100
+ async with aconn.cursor() as cur:
101
+ async with cur.copy(copy_sql) as copy:
102
+ for row in rows:
103
+ _row = [
104
+ _type(r) if _type is not None else r
105
+ for (_type, r) in zip(_tuple_types, row)
106
+ ]
107
+ await copy.write_row(_row)
108
+
109
+ # @lock_decorator
110
+ async def async_update_exp_info(self, exp_info: dict[str, Any]):
111
+ # timestamp不做类型转换
112
+ TO_UPDATE_EXP_INFO_KEYS_AND_TYPES = [
113
+ ("id", str),
114
+ ("name", str),
115
+ ("num_day", int),
116
+ ("status", int),
117
+ ("cur_day", int),
118
+ ("cur_t", float),
119
+ ("config", str),
120
+ ("error", str),
121
+ ("created_at", None),
122
+ ("updated_at", None),
123
+ ]
124
+ table_name = f"socialcity_{self.exp_id.replace('-', '_')}_experiment"
125
+ async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
126
+ async with aconn.cursor(row_factory=dict_row) as cur:
127
+ # UPDATE
128
+ columns = ", ".join(
129
+ f"{key} = %s" for key, _ in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
130
+ )
131
+ update_sql = psycopg.sql.SQL(
132
+ f"UPDATE {{}} SET {columns} WHERE id = %s" # type:ignore
133
+ ).format(psycopg.sql.Identifier(table_name))
134
+ params = [
135
+ _type(exp_info[key]) if _type is not None else exp_info[key]
136
+ for key, _type in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
137
+ ] + [self.exp_id]
138
+ await cur.execute(update_sql, params)
139
+ await aconn.commit()
@@ -4,7 +4,7 @@ Base class of parser
4
4
 
5
5
  import re
6
6
  from abc import ABC, abstractmethod
7
- from typing import Any, Dict, List, Optional, Tuple, Union
7
+ from typing import Any, Union
8
8
 
9
9
 
10
10
  class ParserBase(ABC):
@@ -0,0 +1,80 @@
1
+ from typing import Any
2
+
3
+ PGSQL_DICT: dict[str, list[Any]] = {
4
+ # Experiment
5
+ "experiment": [
6
+ """
7
+ CREATE TABLE IF NOT EXISTS {table_name} (
8
+ id UUID PRIMARY KEY,
9
+ name TEXT,
10
+ num_day INT4,
11
+ status INT4,
12
+ cur_day INT4,
13
+ cur_t FLOAT,
14
+ config TEXT,
15
+ error TEXT,
16
+ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
17
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
18
+ )
19
+ """,
20
+ ],
21
+ # Agent Profile
22
+ "agent_profile": [
23
+ """
24
+ CREATE TABLE IF NOT EXISTS {table_name} (
25
+ id UUID PRIMARY KEY,
26
+ name TEXT,
27
+ profile JSONB
28
+ )
29
+ """,
30
+ ],
31
+ # Agent Dialog
32
+ "agent_dialog": [
33
+ """
34
+ CREATE TABLE IF NOT EXISTS {table_name} (
35
+ id UUID,
36
+ day INT4,
37
+ t FLOAT,
38
+ type INT4,
39
+ speaker TEXT,
40
+ content TEXT,
41
+ created_at TIMESTAMPTZ
42
+ )
43
+ """,
44
+ "CREATE INDEX {table_name}_id_idx ON {table_name} (id)",
45
+ "CREATE INDEX {table_name}_day_t_idx ON {table_name} (day,t)",
46
+ ],
47
+ # Agent Status
48
+ "agent_status": [
49
+ """
50
+ CREATE TABLE IF NOT EXISTS {table_name} (
51
+ id UUID,
52
+ day INT4,
53
+ t FLOAT,
54
+ lng DOUBLE PRECISION,
55
+ lat DOUBLE PRECISION,
56
+ parent_id INT4,
57
+ action TEXT,
58
+ status JSONB,
59
+ created_at TIMESTAMPTZ
60
+ )
61
+ """,
62
+ "CREATE INDEX {table_name}_id_idx ON {table_name} (id)",
63
+ "CREATE INDEX {table_name}_day_t_idx ON {table_name} (day,t)",
64
+ ],
65
+ # Agent Survey
66
+ "agent_survey": [
67
+ """
68
+ CREATE TABLE IF NOT EXISTS {table_name} (
69
+ id UUID,
70
+ day INT4,
71
+ t FLOAT,
72
+ survey_id UUID,
73
+ result JSONB,
74
+ created_at TIMESTAMPTZ
75
+ )
76
+ """,
77
+ "CREATE INDEX {table_name}_id_idx ON {table_name} (id)",
78
+ "CREATE INDEX {table_name}_day_t_idx ON {table_name} (day,t)",
79
+ ],
80
+ }
@@ -1,4 +1,4 @@
1
- from typing import Any, Callable, Dict, List, Optional, Union
1
+ from typing import Optional, Union
2
2
  import re
3
3
 
4
4
 
@@ -10,7 +10,7 @@ class FormatPrompt:
10
10
  Attributes:
11
11
  template (str): The template string containing placeholders.
12
12
  system_prompt (Optional[str]): An optional system prompt to add to the dialog.
13
- variables (List[str]): A list of variable names extracted from the template.
13
+ variables (list[str]): A list of variable names extracted from the template.
14
14
  formatted_string (str): The formatted string derived from the template and provided variables.
15
15
  """
16
16
 
@@ -27,12 +27,12 @@ class FormatPrompt:
27
27
  self.variables = self._extract_variables()
28
28
  self.formatted_string = "" # To store the formatted string
29
29
 
30
- def _extract_variables(self) -> List[str]:
30
+ def _extract_variables(self) -> list[str]:
31
31
  """
32
32
  Extracts variable names from the template string.
33
33
 
34
34
  Returns:
35
- List[str]: A list of variable names found within the template.
35
+ list[str]: A list of variable names found within the template.
36
36
  """
37
37
  return re.findall(r"\{(\w+)\}", self.template)
38
38
 
@@ -51,12 +51,12 @@ class FormatPrompt:
51
51
  ) # Store the formatted string
52
52
  return self.formatted_string
53
53
 
54
- def to_dialog(self) -> List[Dict[str, str]]:
54
+ def to_dialog(self) -> list[dict[str, str]]:
55
55
  """
56
56
  Converts the formatted prompt and optional system prompt into a dialog format.
57
57
 
58
58
  Returns:
59
- List[Dict[str, str]]: A list representing the dialog with roles and content.
59
+ list[dict[str, str]]: A list representing the dialog with roles and content.
60
60
  """
61
61
  dialog = []
62
62
  if self.system_prompt:
@@ -1,5 +1,7 @@
1
1
  import time
2
- from typing import Any, Callable, Dict, List, Optional, Union
2
+ from collections import defaultdict
3
+ from collections.abc import Callable, Sequence
4
+ from typing import Any, Optional, Union
3
5
 
4
6
  from mlflow.entities import Metric
5
7
 
@@ -76,7 +78,7 @@ class SencePOI(Tool):
76
78
  Attributes:
77
79
  radius (int): The radius within which to search for POIs.
78
80
  category_prefix (str): The prefix for the categories of POIs to consider.
79
- variables (List[str]): A list of variables relevant to the tool's operation.
81
+ variables (list[str]): A list of variables relevant to the tool's operation.
80
82
 
81
83
  Args:
82
84
  radius (int, optional): The circular search radius. Defaults to 100.
@@ -190,33 +192,38 @@ class ResetAgentPosition(Tool):
190
192
  class ExportMlflowMetrics(Tool):
191
193
  def __init__(self, log_batch_size: int = 100) -> None:
192
194
  self._log_batch_size = log_batch_size
193
- # TODO:support other log types
194
- self.metric_log_cache: list[Metric] = []
195
+ # TODO: support other log types
196
+ self.metric_log_cache: dict[str, list[Metric]] = defaultdict(list)
195
197
 
196
198
  async def __call__(
197
199
  self,
198
- metric: Union[Metric, dict],
200
+ metric: Union[Sequence[Union[Metric, dict]], Union[Metric, dict]],
199
201
  clear_cache: bool = False,
200
202
  ):
201
203
  agent = self.agent
202
204
  batch_size = self._log_batch_size
203
- if len(self.metric_log_cache) > batch_size:
204
- client = agent.mlflow_client
205
- await client.log_batch(
206
- metrics=self.metric_log_cache[:batch_size],
207
- )
208
- self.metric_log_cache = self.metric_log_cache[batch_size:]
209
- else:
210
- if isinstance(metric, Metric):
211
- self.metric_log_cache.append(metric)
205
+ if not isinstance(metric, Sequence):
206
+ metric = [metric]
207
+ for _metric in metric:
208
+ if isinstance(_metric, Metric):
209
+ item = _metric
210
+ metric_key = item.key
212
211
  else:
213
- _metric = Metric(
214
- key=metric["key"],
215
- value=metric["value"],
216
- timestamp=metric.get("timestamp", int(1000 * time.time())),
217
- step=metric["step"],
212
+ item = Metric(
213
+ key=_metric["key"],
214
+ value=_metric["value"],
215
+ timestamp=_metric.get("timestamp", int(1000 * time.time())),
216
+ step=_metric["step"],
217
+ )
218
+ metric_key = _metric["key"]
219
+ self.metric_log_cache[metric_key].append(item)
220
+ for metric_key, _cache in self.metric_log_cache.items():
221
+ if len(_cache) > batch_size:
222
+ client = agent.mlflow_client
223
+ await client.log_batch(
224
+ metrics=_cache[:batch_size],
218
225
  )
219
- self.metric_log_cache.append(_metric)
226
+ _cache = _cache[batch_size:]
220
227
  if clear_cache:
221
228
  await self._clear_cache()
222
229
 
@@ -225,8 +232,9 @@ class ExportMlflowMetrics(Tool):
225
232
  ):
226
233
  agent = self.agent
227
234
  client = agent.mlflow_client
228
- if len(self.metric_log_cache) > 0:
229
- await client.log_batch(
230
- metrics=self.metric_log_cache,
231
- )
232
- self.metric_log_cache = []
235
+ for metric_key, _cache in self.metric_log_cache.items():
236
+ if len(_cache) > 0:
237
+ await client.log_batch(
238
+ metrics=_cache,
239
+ )
240
+ _cache = []
@@ -1,5 +1,5 @@
1
1
  import asyncio
2
- from typing import Any, Callable, Dict, List, Optional, Union, Type
2
+ from typing import Optional
3
3
  import socket
4
4
  from ..memory import Memory
5
5
  from ..environment import Simulator
@@ -11,7 +11,7 @@ class EventTrigger:
11
11
  """Base class for event triggers that wait for specific conditions to be met."""
12
12
 
13
13
  # 定义该trigger需要的组件类型
14
- required_components: List[Type] = []
14
+ required_components: list[type] = []
15
15
 
16
16
  def __init__(self, block=None):
17
17
  self.block = block
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pycityagent
3
- Version: 2.0.0a19
3
+ Version: 2.0.0a21
4
4
  Summary: LLM-based城市环境agent构建库
5
5
  License: MIT
6
6
  Author: Yuwei Yan
@@ -34,8 +34,9 @@ Requires-Dist: openai (>=1.58.1,<2.0.0)
34
34
  Requires-Dist: pandavro (>=1.8.0,<2.0.0)
35
35
  Requires-Dist: poetry (>=1.2.2)
36
36
  Requires-Dist: protobuf (<=4.24.0)
37
+ Requires-Dist: psycopg[binary] (>=3.2.3,<4.0.0)
37
38
  Requires-Dist: pycitydata (==1.0.0)
38
- Requires-Dist: pycityproto (>=2.1.4,<3.0.0)
39
+ Requires-Dist: pycityproto (>=2.1.5,<3.0.0)
39
40
  Requires-Dist: pyyaml (>=6.0.2,<7.0.0)
40
41
  Requires-Dist: ray (>=2.40.0,<3.0.0)
41
42
  Requires-Dist: sidecar (==0.7.0)