pycityagent 2.0.0a19__py3-none-any.whl → 2.0.0a21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent.py +173 -70
- pycityagent/economy/econ_client.py +37 -0
- pycityagent/environment/utils/geojson.py +1 -3
- pycityagent/environment/utils/map_utils.py +15 -15
- pycityagent/llm/embedding.py +8 -9
- pycityagent/llm/llm.py +5 -5
- pycityagent/memory/memory.py +23 -22
- pycityagent/metrics/__init__.py +2 -1
- pycityagent/metrics/mlflow_client.py +72 -34
- pycityagent/simulation/__init__.py +2 -1
- pycityagent/simulation/agentgroup.py +131 -3
- pycityagent/simulation/simulation.py +67 -24
- pycityagent/simulation/storage/pg.py +139 -0
- pycityagent/utils/parsers/parser_base.py +1 -1
- pycityagent/utils/pg_query.py +80 -0
- pycityagent/workflow/prompt.py +6 -6
- pycityagent/workflow/tool.py +33 -25
- pycityagent/workflow/trigger.py +2 -2
- {pycityagent-2.0.0a19.dist-info → pycityagent-2.0.0a21.dist-info}/METADATA +3 -2
- {pycityagent-2.0.0a19.dist-info → pycityagent-2.0.0a21.dist-info}/RECORD +21 -19
- {pycityagent-2.0.0a19.dist-info → pycityagent-2.0.0a21.dist-info}/WHEEL +0 -0
@@ -3,6 +3,7 @@ import json
|
|
3
3
|
import logging
|
4
4
|
import os
|
5
5
|
import random
|
6
|
+
import time
|
6
7
|
import uuid
|
7
8
|
from collections.abc import Callable, Sequence
|
8
9
|
from concurrent.futures import ThreadPoolExecutor
|
@@ -11,6 +12,7 @@ from pathlib import Path
|
|
11
12
|
from typing import Any, Optional, Union
|
12
13
|
|
13
14
|
import pycityproto.city.economy.v2.economy_pb2 as economyv2
|
15
|
+
import ray
|
14
16
|
import yaml
|
15
17
|
from mosstool.map._map_util.const import AOI_START_ID
|
16
18
|
|
@@ -18,8 +20,10 @@ from ..agent import Agent, InstitutionAgent
|
|
18
20
|
from ..environment.simulator import Simulator
|
19
21
|
from ..memory.memory import Memory
|
20
22
|
from ..message.messager import Messager
|
23
|
+
from ..metrics import init_mlflow_connection
|
21
24
|
from ..survey import Survey
|
22
25
|
from .agentgroup import AgentGroup
|
26
|
+
from .storage.pg import PgWriter, create_pg_tables
|
23
27
|
|
24
28
|
logger = logging.getLogger("pycityagent")
|
25
29
|
|
@@ -60,6 +64,7 @@ class AgentSimulation:
|
|
60
64
|
self._user_survey_topics: dict[uuid.UUID, str] = {}
|
61
65
|
self._user_interview_topics: dict[uuid.UUID, str] = {}
|
62
66
|
self._loop = asyncio.get_event_loop()
|
67
|
+
# self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
63
68
|
|
64
69
|
self._messager = Messager(
|
65
70
|
hostname=config["simulator_request"]["mqtt"]["server"],
|
@@ -86,22 +91,13 @@ class AgentSimulation:
|
|
86
91
|
self._enable_pgsql = _pgsql_config.get("enabled", False)
|
87
92
|
if not self._enable_pgsql:
|
88
93
|
logger.warning("PostgreSQL is not enabled, NO POSTGRESQL DATABASE STORAGE")
|
89
|
-
self.
|
94
|
+
self._pgsql_dsn = ""
|
90
95
|
else:
|
91
|
-
self.
|
92
|
-
self._pgsql_port = _pgsql_config["port"]
|
93
|
-
self._pgsql_database = _pgsql_config["database"]
|
94
|
-
self._pgsql_user = _pgsql_config.get("user", None)
|
95
|
-
self._pgsql_password = _pgsql_config.get("password", None)
|
96
|
-
self._pgsql_args: tuple[str, str, str, str, str] = (
|
97
|
-
self._pgsql_host,
|
98
|
-
self._pgsql_port,
|
99
|
-
self._pgsql_database,
|
100
|
-
self._pgsql_user,
|
101
|
-
self._pgsql_password,
|
102
|
-
)
|
96
|
+
self._pgsql_dsn = _pgsql_config["data_source_name"]
|
103
97
|
|
104
98
|
# 添加实验信息相关的属性
|
99
|
+
self._exp_created_time = datetime.now(timezone.utc)
|
100
|
+
self._exp_updated_time = datetime.now(timezone.utc)
|
105
101
|
self._exp_info = {
|
106
102
|
"id": self.exp_id,
|
107
103
|
"name": exp_name,
|
@@ -111,7 +107,8 @@ class AgentSimulation:
|
|
111
107
|
"cur_t": 0.0,
|
112
108
|
"config": json.dumps(config),
|
113
109
|
"error": "",
|
114
|
-
"created_at":
|
110
|
+
"created_at": self._exp_created_time.isoformat(),
|
111
|
+
"updated_at": self._exp_updated_time.isoformat(),
|
115
112
|
}
|
116
113
|
|
117
114
|
# 创建异步任务保存实验信息
|
@@ -165,7 +162,8 @@ class AgentSimulation:
|
|
165
162
|
enable_avro: bool,
|
166
163
|
avro_path: Path,
|
167
164
|
enable_pgsql: bool,
|
168
|
-
|
165
|
+
pgsql_writer: ray.ObjectRef,
|
166
|
+
mlflow_run_id: str = None, # type: ignore
|
169
167
|
logging_level: int = logging.WARNING,
|
170
168
|
):
|
171
169
|
"""创建远程组"""
|
@@ -177,7 +175,8 @@ class AgentSimulation:
|
|
177
175
|
enable_avro,
|
178
176
|
avro_path,
|
179
177
|
enable_pgsql,
|
180
|
-
|
178
|
+
pgsql_writer,
|
179
|
+
mlflow_run_id,
|
181
180
|
logging_level,
|
182
181
|
)
|
183
182
|
return group_name, group, agents
|
@@ -186,6 +185,7 @@ class AgentSimulation:
|
|
186
185
|
self,
|
187
186
|
agent_count: Union[int, list[int]],
|
188
187
|
group_size: int = 1000,
|
188
|
+
pg_sql_writers: int = 32,
|
189
189
|
memory_config_func: Optional[Union[Callable, list[Callable]]] = None,
|
190
190
|
) -> None:
|
191
191
|
"""初始化智能体
|
@@ -246,8 +246,8 @@ class AgentSimulation:
|
|
246
246
|
memory=memory,
|
247
247
|
)
|
248
248
|
|
249
|
-
self._agents[agent._uuid] = agent
|
250
|
-
self._agent_uuids.append(agent._uuid)
|
249
|
+
self._agents[agent._uuid] = agent # type:ignore
|
250
|
+
self._agent_uuids.append(agent._uuid) # type:ignore
|
251
251
|
|
252
252
|
# 计算需要的组数,向上取整以处理不足一组的情况
|
253
253
|
num_group = (agent_count_i + group_size - 1) // group_size
|
@@ -267,9 +267,33 @@ class AgentSimulation:
|
|
267
267
|
|
268
268
|
class_init_index += agent_count_i
|
269
269
|
|
270
|
+
# 初始化mlflow连接
|
271
|
+
_mlflow_config = self.config.get("metric_request", {}).get("mlflow")
|
272
|
+
if _mlflow_config:
|
273
|
+
mlflow_run_id, _ = init_mlflow_connection(
|
274
|
+
config=_mlflow_config,
|
275
|
+
mlflow_run_name=f"EXP_{self.exp_name}_{1000*int(time.time())}",
|
276
|
+
experiment_name=self.exp_name,
|
277
|
+
)
|
278
|
+
else:
|
279
|
+
mlflow_run_id = None
|
280
|
+
# 建表
|
281
|
+
if self.enable_pgsql:
|
282
|
+
_num_workers = min(1, pg_sql_writers)
|
283
|
+
create_pg_tables(
|
284
|
+
exp_id=self.exp_id,
|
285
|
+
dsn=self._pgsql_dsn,
|
286
|
+
)
|
287
|
+
self._pgsql_writers = _workers = [
|
288
|
+
PgWriter.remote(self.exp_id, self._pgsql_dsn)
|
289
|
+
for _ in range(_num_workers)
|
290
|
+
]
|
291
|
+
else:
|
292
|
+
_num_workers = 1
|
293
|
+
self._pgsql_writers = _workers = [None for _ in range(_num_workers)]
|
270
294
|
# 收集所有创建组的参数
|
271
295
|
creation_tasks = []
|
272
|
-
for group_name, agents in group_creation_params:
|
296
|
+
for i, (group_name, agents) in enumerate(group_creation_params):
|
273
297
|
# 直接创建异步任务
|
274
298
|
group = AgentGroup.remote(
|
275
299
|
agents,
|
@@ -279,7 +303,8 @@ class AgentSimulation:
|
|
279
303
|
self.enable_avro,
|
280
304
|
self.avro_path,
|
281
305
|
self.enable_pgsql,
|
282
|
-
|
306
|
+
_workers[i % _num_workers], # type:ignore
|
307
|
+
mlflow_run_id, # type:ignore
|
283
308
|
self.logging_level,
|
284
309
|
)
|
285
310
|
creation_tasks.append((group_name, group, agents))
|
@@ -451,11 +476,13 @@ class AgentSimulation:
|
|
451
476
|
survey_dict = survey.to_dict()
|
452
477
|
if agent_uuids is None:
|
453
478
|
agent_uuids = self._agent_uuids
|
479
|
+
_date_time = datetime.now(timezone.utc)
|
454
480
|
payload = {
|
455
481
|
"from": "none",
|
456
482
|
"survey_id": survey_dict["id"],
|
457
|
-
"timestamp": int(
|
483
|
+
"timestamp": int(_date_time.timestamp() * 1000),
|
458
484
|
"data": survey_dict,
|
485
|
+
"_date_time": _date_time,
|
459
486
|
}
|
460
487
|
for uuid in agent_uuids:
|
461
488
|
topic = self._user_survey_topics[uuid]
|
@@ -465,10 +492,12 @@ class AgentSimulation:
|
|
465
492
|
self, content: str, agent_uuids: Union[uuid.UUID, list[uuid.UUID]]
|
466
493
|
):
|
467
494
|
"""发送面试消息"""
|
495
|
+
_date_time = datetime.now(timezone.utc)
|
468
496
|
payload = {
|
469
497
|
"from": "none",
|
470
498
|
"content": content,
|
471
|
-
"timestamp": int(
|
499
|
+
"timestamp": int(_date_time.timestamp() * 1000),
|
500
|
+
"_date_time": _date_time,
|
472
501
|
}
|
473
502
|
if not isinstance(agent_uuids, Sequence):
|
474
503
|
agent_uuids = [agent_uuids]
|
@@ -497,15 +526,29 @@ class AgentSimulation:
|
|
497
526
|
logger.error(f"Avro保存实验信息失败: {str(e)}")
|
498
527
|
try:
|
499
528
|
if self.enable_pgsql:
|
500
|
-
#
|
501
|
-
|
529
|
+
worker: ray.ObjectRef = self._pgsql_writers[0] # type:ignore
|
530
|
+
# if self._last_asyncio_pg_task is not None:
|
531
|
+
# await self._last_asyncio_pg_task
|
532
|
+
# self._last_asyncio_pg_task = (
|
533
|
+
# worker.async_update_exp_info.remote( # type:ignore
|
534
|
+
# pg_exp_info
|
535
|
+
# )
|
536
|
+
# )
|
537
|
+
pg_exp_info = {k: v for k, v in self._exp_info.items()}
|
538
|
+
pg_exp_info["created_at"] = self._exp_created_time
|
539
|
+
pg_exp_info["updated_at"] = self._exp_updated_time
|
540
|
+
await worker.async_update_exp_info.remote( # type:ignore
|
541
|
+
pg_exp_info
|
542
|
+
)
|
502
543
|
except Exception as e:
|
503
544
|
logger.error(f"PostgreSQL保存实验信息失败: {str(e)}")
|
504
545
|
|
505
546
|
async def _update_exp_status(self, status: int, error: str = "") -> None:
|
547
|
+
self._exp_updated_time = datetime.now(timezone.utc)
|
506
548
|
"""更新实验状态并保存"""
|
507
549
|
self._exp_info["status"] = status
|
508
550
|
self._exp_info["error"] = error
|
551
|
+
self._exp_info["updated_at"] = self._exp_updated_time.isoformat()
|
509
552
|
await self._save_exp_info()
|
510
553
|
|
511
554
|
async def _monitor_exp_status(self, stop_event: asyncio.Event):
|
@@ -0,0 +1,139 @@
|
|
1
|
+
import asyncio
|
2
|
+
from collections import defaultdict
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
import psycopg
|
6
|
+
import psycopg.sql
|
7
|
+
import ray
|
8
|
+
from psycopg.rows import dict_row
|
9
|
+
|
10
|
+
from ...utils.decorators import lock_decorator
|
11
|
+
from ...utils.pg_query import PGSQL_DICT
|
12
|
+
|
13
|
+
|
14
|
+
def create_pg_tables(exp_id: str, dsn: str):
|
15
|
+
for table_type, exec_strs in PGSQL_DICT.items():
|
16
|
+
table_name = f"socialcity_{exp_id.replace('-', '_')}_{table_type}"
|
17
|
+
# # debug str
|
18
|
+
# for _str in [f"DROP TABLE IF EXISTS {table_name}"] + [
|
19
|
+
# _exec_str.format(table_name=table_name) for _exec_str in exec_strs
|
20
|
+
# ]:
|
21
|
+
# print(_str)
|
22
|
+
with psycopg.connect(dsn) as conn:
|
23
|
+
with conn.cursor() as cur:
|
24
|
+
# delete table
|
25
|
+
cur.execute(f"DROP TABLE IF EXISTS {table_name}") # type:ignore
|
26
|
+
conn.commit()
|
27
|
+
# create table
|
28
|
+
for _exec_str in exec_strs:
|
29
|
+
cur.execute(_exec_str.format(table_name=table_name))
|
30
|
+
conn.commit()
|
31
|
+
|
32
|
+
|
33
|
+
@ray.remote
|
34
|
+
class PgWriter:
|
35
|
+
def __init__(self, exp_id: str, dsn: str):
|
36
|
+
self.exp_id = exp_id
|
37
|
+
self._dsn = dsn
|
38
|
+
# self._lock = asyncio.Lock()
|
39
|
+
|
40
|
+
# @lock_decorator
|
41
|
+
async def async_write_dialog(self, rows: list[tuple]):
|
42
|
+
_tuple_types = [str, int, float, int, str, str, str, None]
|
43
|
+
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_dialog"
|
44
|
+
# 将数据插入数据库
|
45
|
+
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
46
|
+
copy_sql = psycopg.sql.SQL(
|
47
|
+
"COPY {} (id, day, t, type, speaker, content, created_at) FROM STDIN"
|
48
|
+
).format(psycopg.sql.Identifier(table_name))
|
49
|
+
async with aconn.cursor() as cur:
|
50
|
+
async with cur.copy(copy_sql) as copy:
|
51
|
+
for row in rows:
|
52
|
+
_row = [
|
53
|
+
_type(r) if _type is not None else r
|
54
|
+
for (_type, r) in zip(_tuple_types, row)
|
55
|
+
]
|
56
|
+
await copy.write_row(_row)
|
57
|
+
|
58
|
+
# @lock_decorator
|
59
|
+
async def async_write_status(self, rows: list[tuple]):
|
60
|
+
_tuple_types = [str, int, float, float, float, int, str, str, None]
|
61
|
+
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_status"
|
62
|
+
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
63
|
+
copy_sql = psycopg.sql.SQL(
|
64
|
+
"COPY {} (id, day, t, lng, lat, parent_id, action, status, created_at) FROM STDIN"
|
65
|
+
).format(psycopg.sql.Identifier(table_name))
|
66
|
+
async with aconn.cursor() as cur:
|
67
|
+
async with cur.copy(copy_sql) as copy:
|
68
|
+
for row in rows:
|
69
|
+
_row = [
|
70
|
+
_type(r) if _type is not None else r
|
71
|
+
for (_type, r) in zip(_tuple_types, row)
|
72
|
+
]
|
73
|
+
await copy.write_row(_row)
|
74
|
+
|
75
|
+
# @lock_decorator
|
76
|
+
async def async_write_profile(self, rows: list[tuple]):
|
77
|
+
_tuple_types = [str, str, str]
|
78
|
+
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_profile"
|
79
|
+
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
80
|
+
copy_sql = psycopg.sql.SQL("COPY {} (id, name, profile) FROM STDIN").format(
|
81
|
+
psycopg.sql.Identifier(table_name)
|
82
|
+
)
|
83
|
+
async with aconn.cursor() as cur:
|
84
|
+
async with cur.copy(copy_sql) as copy:
|
85
|
+
for row in rows:
|
86
|
+
_row = [
|
87
|
+
_type(r) if _type is not None else r
|
88
|
+
for (_type, r) in zip(_tuple_types, row)
|
89
|
+
]
|
90
|
+
await copy.write_row(_row)
|
91
|
+
|
92
|
+
# @lock_decorator
|
93
|
+
async def async_write_survey(self, rows: list[tuple]):
|
94
|
+
_tuple_types = [str, int, float, str, str, None]
|
95
|
+
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_survey"
|
96
|
+
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
97
|
+
copy_sql = psycopg.sql.SQL(
|
98
|
+
"COPY {} (id, day, t, survey_id, result, created_at) FROM STDIN"
|
99
|
+
).format(psycopg.sql.Identifier(table_name))
|
100
|
+
async with aconn.cursor() as cur:
|
101
|
+
async with cur.copy(copy_sql) as copy:
|
102
|
+
for row in rows:
|
103
|
+
_row = [
|
104
|
+
_type(r) if _type is not None else r
|
105
|
+
for (_type, r) in zip(_tuple_types, row)
|
106
|
+
]
|
107
|
+
await copy.write_row(_row)
|
108
|
+
|
109
|
+
# @lock_decorator
|
110
|
+
async def async_update_exp_info(self, exp_info: dict[str, Any]):
|
111
|
+
# timestamp不做类型转换
|
112
|
+
TO_UPDATE_EXP_INFO_KEYS_AND_TYPES = [
|
113
|
+
("id", str),
|
114
|
+
("name", str),
|
115
|
+
("num_day", int),
|
116
|
+
("status", int),
|
117
|
+
("cur_day", int),
|
118
|
+
("cur_t", float),
|
119
|
+
("config", str),
|
120
|
+
("error", str),
|
121
|
+
("created_at", None),
|
122
|
+
("updated_at", None),
|
123
|
+
]
|
124
|
+
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_experiment"
|
125
|
+
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
126
|
+
async with aconn.cursor(row_factory=dict_row) as cur:
|
127
|
+
# UPDATE
|
128
|
+
columns = ", ".join(
|
129
|
+
f"{key} = %s" for key, _ in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
130
|
+
)
|
131
|
+
update_sql = psycopg.sql.SQL(
|
132
|
+
f"UPDATE {{}} SET {columns} WHERE id = %s" # type:ignore
|
133
|
+
).format(psycopg.sql.Identifier(table_name))
|
134
|
+
params = [
|
135
|
+
_type(exp_info[key]) if _type is not None else exp_info[key]
|
136
|
+
for key, _type in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
137
|
+
] + [self.exp_id]
|
138
|
+
await cur.execute(update_sql, params)
|
139
|
+
await aconn.commit()
|
@@ -0,0 +1,80 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
PGSQL_DICT: dict[str, list[Any]] = {
|
4
|
+
# Experiment
|
5
|
+
"experiment": [
|
6
|
+
"""
|
7
|
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
8
|
+
id UUID PRIMARY KEY,
|
9
|
+
name TEXT,
|
10
|
+
num_day INT4,
|
11
|
+
status INT4,
|
12
|
+
cur_day INT4,
|
13
|
+
cur_t FLOAT,
|
14
|
+
config TEXT,
|
15
|
+
error TEXT,
|
16
|
+
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
17
|
+
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
18
|
+
)
|
19
|
+
""",
|
20
|
+
],
|
21
|
+
# Agent Profile
|
22
|
+
"agent_profile": [
|
23
|
+
"""
|
24
|
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
25
|
+
id UUID PRIMARY KEY,
|
26
|
+
name TEXT,
|
27
|
+
profile JSONB
|
28
|
+
)
|
29
|
+
""",
|
30
|
+
],
|
31
|
+
# Agent Dialog
|
32
|
+
"agent_dialog": [
|
33
|
+
"""
|
34
|
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
35
|
+
id UUID,
|
36
|
+
day INT4,
|
37
|
+
t FLOAT,
|
38
|
+
type INT4,
|
39
|
+
speaker TEXT,
|
40
|
+
content TEXT,
|
41
|
+
created_at TIMESTAMPTZ
|
42
|
+
)
|
43
|
+
""",
|
44
|
+
"CREATE INDEX {table_name}_id_idx ON {table_name} (id)",
|
45
|
+
"CREATE INDEX {table_name}_day_t_idx ON {table_name} (day,t)",
|
46
|
+
],
|
47
|
+
# Agent Status
|
48
|
+
"agent_status": [
|
49
|
+
"""
|
50
|
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
51
|
+
id UUID,
|
52
|
+
day INT4,
|
53
|
+
t FLOAT,
|
54
|
+
lng DOUBLE PRECISION,
|
55
|
+
lat DOUBLE PRECISION,
|
56
|
+
parent_id INT4,
|
57
|
+
action TEXT,
|
58
|
+
status JSONB,
|
59
|
+
created_at TIMESTAMPTZ
|
60
|
+
)
|
61
|
+
""",
|
62
|
+
"CREATE INDEX {table_name}_id_idx ON {table_name} (id)",
|
63
|
+
"CREATE INDEX {table_name}_day_t_idx ON {table_name} (day,t)",
|
64
|
+
],
|
65
|
+
# Agent Survey
|
66
|
+
"agent_survey": [
|
67
|
+
"""
|
68
|
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
69
|
+
id UUID,
|
70
|
+
day INT4,
|
71
|
+
t FLOAT,
|
72
|
+
survey_id UUID,
|
73
|
+
result JSONB,
|
74
|
+
created_at TIMESTAMPTZ
|
75
|
+
)
|
76
|
+
""",
|
77
|
+
"CREATE INDEX {table_name}_id_idx ON {table_name} (id)",
|
78
|
+
"CREATE INDEX {table_name}_day_t_idx ON {table_name} (day,t)",
|
79
|
+
],
|
80
|
+
}
|
pycityagent/workflow/prompt.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, Union
|
2
2
|
import re
|
3
3
|
|
4
4
|
|
@@ -10,7 +10,7 @@ class FormatPrompt:
|
|
10
10
|
Attributes:
|
11
11
|
template (str): The template string containing placeholders.
|
12
12
|
system_prompt (Optional[str]): An optional system prompt to add to the dialog.
|
13
|
-
variables (
|
13
|
+
variables (list[str]): A list of variable names extracted from the template.
|
14
14
|
formatted_string (str): The formatted string derived from the template and provided variables.
|
15
15
|
"""
|
16
16
|
|
@@ -27,12 +27,12 @@ class FormatPrompt:
|
|
27
27
|
self.variables = self._extract_variables()
|
28
28
|
self.formatted_string = "" # To store the formatted string
|
29
29
|
|
30
|
-
def _extract_variables(self) ->
|
30
|
+
def _extract_variables(self) -> list[str]:
|
31
31
|
"""
|
32
32
|
Extracts variable names from the template string.
|
33
33
|
|
34
34
|
Returns:
|
35
|
-
|
35
|
+
list[str]: A list of variable names found within the template.
|
36
36
|
"""
|
37
37
|
return re.findall(r"\{(\w+)\}", self.template)
|
38
38
|
|
@@ -51,12 +51,12 @@ class FormatPrompt:
|
|
51
51
|
) # Store the formatted string
|
52
52
|
return self.formatted_string
|
53
53
|
|
54
|
-
def to_dialog(self) ->
|
54
|
+
def to_dialog(self) -> list[dict[str, str]]:
|
55
55
|
"""
|
56
56
|
Converts the formatted prompt and optional system prompt into a dialog format.
|
57
57
|
|
58
58
|
Returns:
|
59
|
-
|
59
|
+
list[dict[str, str]]: A list representing the dialog with roles and content.
|
60
60
|
"""
|
61
61
|
dialog = []
|
62
62
|
if self.system_prompt:
|
pycityagent/workflow/tool.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
import time
|
2
|
-
from
|
2
|
+
from collections import defaultdict
|
3
|
+
from collections.abc import Callable, Sequence
|
4
|
+
from typing import Any, Optional, Union
|
3
5
|
|
4
6
|
from mlflow.entities import Metric
|
5
7
|
|
@@ -76,7 +78,7 @@ class SencePOI(Tool):
|
|
76
78
|
Attributes:
|
77
79
|
radius (int): The radius within which to search for POIs.
|
78
80
|
category_prefix (str): The prefix for the categories of POIs to consider.
|
79
|
-
variables (
|
81
|
+
variables (list[str]): A list of variables relevant to the tool's operation.
|
80
82
|
|
81
83
|
Args:
|
82
84
|
radius (int, optional): The circular search radius. Defaults to 100.
|
@@ -190,33 +192,38 @@ class ResetAgentPosition(Tool):
|
|
190
192
|
class ExportMlflowMetrics(Tool):
|
191
193
|
def __init__(self, log_batch_size: int = 100) -> None:
|
192
194
|
self._log_batch_size = log_batch_size
|
193
|
-
# TODO:support other log types
|
194
|
-
self.metric_log_cache: list[Metric] =
|
195
|
+
# TODO: support other log types
|
196
|
+
self.metric_log_cache: dict[str, list[Metric]] = defaultdict(list)
|
195
197
|
|
196
198
|
async def __call__(
|
197
199
|
self,
|
198
|
-
metric: Union[Metric, dict],
|
200
|
+
metric: Union[Sequence[Union[Metric, dict]], Union[Metric, dict]],
|
199
201
|
clear_cache: bool = False,
|
200
202
|
):
|
201
203
|
agent = self.agent
|
202
204
|
batch_size = self._log_batch_size
|
203
|
-
if
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
else:
|
210
|
-
if isinstance(metric, Metric):
|
211
|
-
self.metric_log_cache.append(metric)
|
205
|
+
if not isinstance(metric, Sequence):
|
206
|
+
metric = [metric]
|
207
|
+
for _metric in metric:
|
208
|
+
if isinstance(_metric, Metric):
|
209
|
+
item = _metric
|
210
|
+
metric_key = item.key
|
212
211
|
else:
|
213
|
-
|
214
|
-
key=
|
215
|
-
value=
|
216
|
-
timestamp=
|
217
|
-
step=
|
212
|
+
item = Metric(
|
213
|
+
key=_metric["key"],
|
214
|
+
value=_metric["value"],
|
215
|
+
timestamp=_metric.get("timestamp", int(1000 * time.time())),
|
216
|
+
step=_metric["step"],
|
217
|
+
)
|
218
|
+
metric_key = _metric["key"]
|
219
|
+
self.metric_log_cache[metric_key].append(item)
|
220
|
+
for metric_key, _cache in self.metric_log_cache.items():
|
221
|
+
if len(_cache) > batch_size:
|
222
|
+
client = agent.mlflow_client
|
223
|
+
await client.log_batch(
|
224
|
+
metrics=_cache[:batch_size],
|
218
225
|
)
|
219
|
-
|
226
|
+
_cache = _cache[batch_size:]
|
220
227
|
if clear_cache:
|
221
228
|
await self._clear_cache()
|
222
229
|
|
@@ -225,8 +232,9 @@ class ExportMlflowMetrics(Tool):
|
|
225
232
|
):
|
226
233
|
agent = self.agent
|
227
234
|
client = agent.mlflow_client
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
235
|
+
for metric_key, _cache in self.metric_log_cache.items():
|
236
|
+
if len(_cache) > 0:
|
237
|
+
await client.log_batch(
|
238
|
+
metrics=_cache,
|
239
|
+
)
|
240
|
+
_cache = []
|
pycityagent/workflow/trigger.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import asyncio
|
2
|
-
from typing import
|
2
|
+
from typing import Optional
|
3
3
|
import socket
|
4
4
|
from ..memory import Memory
|
5
5
|
from ..environment import Simulator
|
@@ -11,7 +11,7 @@ class EventTrigger:
|
|
11
11
|
"""Base class for event triggers that wait for specific conditions to be met."""
|
12
12
|
|
13
13
|
# 定义该trigger需要的组件类型
|
14
|
-
required_components:
|
14
|
+
required_components: list[type] = []
|
15
15
|
|
16
16
|
def __init__(self, block=None):
|
17
17
|
self.block = block
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: pycityagent
|
3
|
-
Version: 2.0.
|
3
|
+
Version: 2.0.0a21
|
4
4
|
Summary: LLM-based城市环境agent构建库
|
5
5
|
License: MIT
|
6
6
|
Author: Yuwei Yan
|
@@ -34,8 +34,9 @@ Requires-Dist: openai (>=1.58.1,<2.0.0)
|
|
34
34
|
Requires-Dist: pandavro (>=1.8.0,<2.0.0)
|
35
35
|
Requires-Dist: poetry (>=1.2.2)
|
36
36
|
Requires-Dist: protobuf (<=4.24.0)
|
37
|
+
Requires-Dist: psycopg[binary] (>=3.2.3,<4.0.0)
|
37
38
|
Requires-Dist: pycitydata (==1.0.0)
|
38
|
-
Requires-Dist: pycityproto (>=2.1.
|
39
|
+
Requires-Dist: pycityproto (>=2.1.5,<3.0.0)
|
39
40
|
Requires-Dist: pyyaml (>=6.0.2,<7.0.0)
|
40
41
|
Requires-Dist: ray (>=2.40.0,<3.0.0)
|
41
42
|
Requires-Dist: sidecar (==0.7.0)
|