pycityagent 2.0.0a20__py3-none-any.whl → 2.0.0a21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent.py +165 -62
- pycityagent/simulation/__init__.py +2 -1
- pycityagent/simulation/agentgroup.py +129 -3
- pycityagent/simulation/simulation.py +52 -27
- pycityagent/simulation/storage/pg.py +139 -0
- pycityagent/utils/pg_query.py +80 -0
- pycityagent/workflow/tool.py +32 -24
- {pycityagent-2.0.0a20.dist-info → pycityagent-2.0.0a21.dist-info}/METADATA +1 -1
- {pycityagent-2.0.0a20.dist-info → pycityagent-2.0.0a21.dist-info}/RECORD +10 -8
- {pycityagent-2.0.0a20.dist-info → pycityagent-2.0.0a21.dist-info}/WHEEL +0 -0
pycityagent/agent.py
CHANGED
@@ -1,17 +1,19 @@
|
|
1
1
|
"""智能体模板类及其定义"""
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import json
|
4
5
|
import logging
|
5
6
|
import random
|
6
7
|
import uuid
|
7
8
|
from abc import ABC, abstractmethod
|
8
9
|
from copy import deepcopy
|
9
|
-
from datetime import datetime
|
10
|
+
from datetime import datetime, timezone
|
10
11
|
from enum import Enum
|
11
12
|
from typing import Any, Optional
|
12
13
|
from uuid import UUID
|
13
14
|
|
14
15
|
import fastavro
|
16
|
+
import ray
|
15
17
|
from mosstool.util.format_converter import dict2pb
|
16
18
|
from pycityproto.city.person.v2 import person_pb2 as person_pb2
|
17
19
|
|
@@ -56,6 +58,7 @@ class Agent(ABC):
|
|
56
58
|
mlflow_client: Optional[MlflowClient] = None,
|
57
59
|
memory: Optional[Memory] = None,
|
58
60
|
avro_file: Optional[dict[str, str]] = None,
|
61
|
+
copy_writer: Optional[ray.ObjectRef] = None,
|
59
62
|
) -> None:
|
60
63
|
"""
|
61
64
|
Initialize the Agent.
|
@@ -70,6 +73,7 @@ class Agent(ABC):
|
|
70
73
|
mlflow_client (MlflowClient, optional): The Mlflow object. Defaults to None.
|
71
74
|
memory (Memory, optional): The memory of the agent. Defaults to None.
|
72
75
|
avro_file (dict[str, str], optional): The avro file of the agent. Defaults to None.
|
76
|
+
copy_writer (ray.ObjectRef): The copy_writer of the agent. Defaults to None.
|
73
77
|
"""
|
74
78
|
self._name = name
|
75
79
|
self._type = type
|
@@ -88,6 +92,8 @@ class Agent(ABC):
|
|
88
92
|
self._interview_history: list[dict] = [] # 存储采访历史
|
89
93
|
self._person_template = PersonService.default_dict_person()
|
90
94
|
self._avro_file = avro_file
|
95
|
+
self._pgsql_writer = copy_writer
|
96
|
+
self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
91
97
|
|
92
98
|
def __getstate__(self):
|
93
99
|
state = self.__dict__.copy()
|
@@ -143,6 +149,12 @@ class Agent(ABC):
|
|
143
149
|
"""
|
144
150
|
self._avro_file = avro_file
|
145
151
|
|
152
|
+
def set_pgsql_writer(self, pgsql_writer: ray.ObjectRef):
|
153
|
+
"""
|
154
|
+
Set the PostgreSQL copy writer of the agent.
|
155
|
+
"""
|
156
|
+
self._pgsql_writer = pgsql_writer
|
157
|
+
|
146
158
|
@property
|
147
159
|
def uuid(self):
|
148
160
|
"""The Agent's UUID"""
|
@@ -198,6 +210,15 @@ class Agent(ABC):
|
|
198
210
|
)
|
199
211
|
return self._simulator
|
200
212
|
|
213
|
+
@property
|
214
|
+
def copy_writer(self):
|
215
|
+
"""Pg Copy Writer"""
|
216
|
+
if self._pgsql_writer is None:
|
217
|
+
raise RuntimeError(
|
218
|
+
f"Copy Writer access before assignment, please `set_pgsql_writer` first!"
|
219
|
+
)
|
220
|
+
return self._pgsql_writer
|
221
|
+
|
201
222
|
async def generate_user_survey_response(self, survey: dict) -> str:
|
202
223
|
"""生成回答 —— 可重写
|
203
224
|
基于智能体的记忆和当前状态,生成对问卷调查的回答。
|
@@ -237,8 +258,8 @@ class Agent(ABC):
|
|
237
258
|
|
238
259
|
async def _process_survey(self, survey: dict):
|
239
260
|
survey_response = await self.generate_user_survey_response(survey)
|
240
|
-
|
241
|
-
|
261
|
+
_date_time = datetime.now(timezone.utc)
|
262
|
+
# Avro
|
242
263
|
response_to_avro = [
|
243
264
|
{
|
244
265
|
"id": self._uuid,
|
@@ -246,11 +267,41 @@ class Agent(ABC):
|
|
246
267
|
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
247
268
|
"survey_id": survey["id"],
|
248
269
|
"result": survey_response,
|
249
|
-
"created_at": int(
|
270
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
250
271
|
}
|
251
272
|
]
|
252
|
-
|
253
|
-
|
273
|
+
if self._avro_file is not None:
|
274
|
+
with open(self._avro_file["survey"], "a+b") as f:
|
275
|
+
fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
|
276
|
+
# Pg
|
277
|
+
if self._pgsql_writer is not None:
|
278
|
+
if self._last_asyncio_pg_task is not None:
|
279
|
+
await self._last_asyncio_pg_task
|
280
|
+
_keys = [
|
281
|
+
"id",
|
282
|
+
"day",
|
283
|
+
"t",
|
284
|
+
"survey_id",
|
285
|
+
"result",
|
286
|
+
]
|
287
|
+
_data_tuples: list[tuple] = []
|
288
|
+
# str to json
|
289
|
+
for _dict in response_to_avro:
|
290
|
+
res = _dict["result"]
|
291
|
+
_dict["result"] = json.dumps(
|
292
|
+
{
|
293
|
+
"result": res,
|
294
|
+
}
|
295
|
+
)
|
296
|
+
_data_list = [_dict[k] for k in _keys]
|
297
|
+
# created_at
|
298
|
+
_data_list.append(_date_time)
|
299
|
+
_data_tuples.append(tuple(_data_list))
|
300
|
+
self._last_asyncio_pg_task = (
|
301
|
+
self._pgsql_writer.async_write_survey.remote( # type:ignore
|
302
|
+
_data_tuples
|
303
|
+
)
|
304
|
+
)
|
254
305
|
|
255
306
|
async def generate_user_chat_response(self, question: str) -> str:
|
256
307
|
"""生成回答 —— 可重写
|
@@ -290,34 +341,52 @@ class Agent(ABC):
|
|
290
341
|
return response # type:ignore
|
291
342
|
|
292
343
|
async def _process_interview(self, payload: dict):
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
344
|
+
pg_list: list[tuple[dict, datetime]] = []
|
345
|
+
auros: list[dict] = []
|
346
|
+
_date_time = datetime.now(timezone.utc)
|
347
|
+
_interview_dict = {
|
348
|
+
"id": self._uuid,
|
349
|
+
"day": await self.simulator.get_simulator_day(),
|
350
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
351
|
+
"type": 2,
|
352
|
+
"speaker": "user",
|
353
|
+
"content": payload["content"],
|
354
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
355
|
+
}
|
356
|
+
auros.append(_interview_dict)
|
357
|
+
pg_list.append((_interview_dict, _date_time))
|
304
358
|
question = payload["content"]
|
305
359
|
response = await self.generate_user_chat_response(question)
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
)
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
360
|
+
_date_time = datetime.now(timezone.utc)
|
361
|
+
_interview_dict = {
|
362
|
+
"id": self._uuid,
|
363
|
+
"day": await self.simulator.get_simulator_day(),
|
364
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
365
|
+
"type": 2,
|
366
|
+
"speaker": "",
|
367
|
+
"content": response,
|
368
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
369
|
+
}
|
370
|
+
auros.append(_interview_dict)
|
371
|
+
pg_list.append((_interview_dict, _date_time))
|
372
|
+
# Avro
|
373
|
+
if self._avro_file is not None:
|
374
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
375
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
376
|
+
# Pg
|
377
|
+
if self._pgsql_writer is not None:
|
378
|
+
if self._last_asyncio_pg_task is not None:
|
379
|
+
await self._last_asyncio_pg_task
|
380
|
+
_keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
|
381
|
+
_data = [
|
382
|
+
tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
|
383
|
+
for _dict, _date_time in pg_list
|
384
|
+
]
|
385
|
+
self._last_asyncio_pg_task = (
|
386
|
+
self._pgsql_writer.async_write_dialog.remote( # type:ignore
|
387
|
+
_data
|
388
|
+
)
|
389
|
+
)
|
321
390
|
|
322
391
|
async def process_agent_chat_response(self, payload: dict) -> str:
|
323
392
|
resp = f"Agent {self._uuid} received agent chat response: {payload}"
|
@@ -325,22 +394,39 @@ class Agent(ABC):
|
|
325
394
|
return resp
|
326
395
|
|
327
396
|
async def _process_agent_chat(self, payload: dict):
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
397
|
+
pg_list: list[tuple[dict, datetime]] = []
|
398
|
+
auros: list[dict] = []
|
399
|
+
_date_time = datetime.now(timezone.utc)
|
400
|
+
_chat_dict = {
|
401
|
+
"id": self._uuid,
|
402
|
+
"day": payload["day"],
|
403
|
+
"t": payload["t"],
|
404
|
+
"type": 1,
|
405
|
+
"speaker": payload["from"],
|
406
|
+
"content": payload["content"],
|
407
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
408
|
+
}
|
409
|
+
auros.append(_chat_dict)
|
410
|
+
pg_list.append((_chat_dict, _date_time))
|
339
411
|
asyncio.create_task(self.process_agent_chat_response(payload))
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
412
|
+
# Avro
|
413
|
+
if self._avro_file is not None:
|
414
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
415
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
416
|
+
# Pg
|
417
|
+
if self._pgsql_writer is not None:
|
418
|
+
if self._last_asyncio_pg_task is not None:
|
419
|
+
await self._last_asyncio_pg_task
|
420
|
+
_keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
|
421
|
+
_data = [
|
422
|
+
tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
|
423
|
+
for _dict, _date_time in pg_list
|
424
|
+
]
|
425
|
+
self._last_asyncio_pg_task = (
|
426
|
+
self._pgsql_writer.async_write_dialog.remote( # type:ignore
|
427
|
+
_data
|
428
|
+
)
|
429
|
+
)
|
344
430
|
|
345
431
|
# Callback functions for MQTT message
|
346
432
|
async def handle_agent_chat_message(self, payload: dict):
|
@@ -384,21 +470,38 @@ class Agent(ABC):
|
|
384
470
|
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
385
471
|
}
|
386
472
|
await self._send_message(to_agent_uuid, payload, "agent-chat")
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
473
|
+
pg_list: list[tuple[dict, datetime]] = []
|
474
|
+
auros: list[dict] = []
|
475
|
+
_date_time = datetime.now(timezone.utc)
|
476
|
+
_message_dict = {
|
477
|
+
"id": self._uuid,
|
478
|
+
"day": await self.simulator.get_simulator_day(),
|
479
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
480
|
+
"type": 1,
|
481
|
+
"speaker": self._uuid,
|
482
|
+
"content": content,
|
483
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
484
|
+
}
|
485
|
+
auros.append(_message_dict)
|
486
|
+
pg_list.append((_message_dict, _date_time))
|
487
|
+
# Avro
|
488
|
+
if self._avro_file is not None:
|
489
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
490
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
491
|
+
# Pg
|
492
|
+
if self._pgsql_writer is not None:
|
493
|
+
if self._last_asyncio_pg_task is not None:
|
494
|
+
await self._last_asyncio_pg_task
|
495
|
+
_keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
|
496
|
+
_data = [
|
497
|
+
tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
|
498
|
+
for _dict, _date_time in pg_list
|
499
|
+
]
|
500
|
+
self._last_asyncio_pg_task = (
|
501
|
+
self._pgsql_writer.async_write_dialog.remote( # type:ignore
|
502
|
+
_data
|
503
|
+
)
|
504
|
+
)
|
402
505
|
|
403
506
|
# Agent logic
|
404
507
|
@abstractmethod
|
@@ -3,7 +3,7 @@ import json
|
|
3
3
|
import logging
|
4
4
|
import time
|
5
5
|
import uuid
|
6
|
-
from datetime import datetime
|
6
|
+
from datetime import datetime, timezone
|
7
7
|
from pathlib import Path
|
8
8
|
from typing import Any
|
9
9
|
from uuid import UUID
|
@@ -35,7 +35,7 @@ class AgentGroup:
|
|
35
35
|
enable_avro: bool,
|
36
36
|
avro_path: Path,
|
37
37
|
enable_pgsql: bool,
|
38
|
-
|
38
|
+
pgsql_writer: ray.ObjectRef,
|
39
39
|
mlflow_run_id: str,
|
40
40
|
logging_level: int,
|
41
41
|
):
|
@@ -45,6 +45,7 @@ class AgentGroup:
|
|
45
45
|
self.config = config
|
46
46
|
self.exp_id = exp_id
|
47
47
|
self.enable_avro = enable_avro
|
48
|
+
self.enable_pgsql = enable_pgsql
|
48
49
|
if enable_avro:
|
49
50
|
self.avro_path = avro_path / f"{self._uuid}"
|
50
51
|
self.avro_path.mkdir(parents=True, exist_ok=True)
|
@@ -54,6 +55,8 @@ class AgentGroup:
|
|
54
55
|
"status": self.avro_path / f"status.avro",
|
55
56
|
"survey": self.avro_path / f"survey.avro",
|
56
57
|
}
|
58
|
+
if self.enable_pgsql:
|
59
|
+
pass
|
57
60
|
|
58
61
|
self.messager = Messager(
|
59
62
|
hostname=config["simulator_request"]["mqtt"]["server"],
|
@@ -61,6 +64,8 @@ class AgentGroup:
|
|
61
64
|
username=config["simulator_request"]["mqtt"].get("username", None),
|
62
65
|
password=config["simulator_request"]["mqtt"].get("password", None),
|
63
66
|
)
|
67
|
+
self._pgsql_writer = pgsql_writer
|
68
|
+
self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
64
69
|
self.initialized = False
|
65
70
|
self.id2agent = {}
|
66
71
|
# Step:1 prepare LLM client
|
@@ -105,6 +110,8 @@ class AgentGroup:
|
|
105
110
|
agent.set_messager(self.messager)
|
106
111
|
if self.enable_avro:
|
107
112
|
agent.set_avro_file(self.avro_file) # type: ignore
|
113
|
+
if self.enable_pgsql:
|
114
|
+
agent.set_pgsql_writer(self._pgsql_writer)
|
108
115
|
|
109
116
|
async def init_agents(self):
|
110
117
|
logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
|
@@ -161,6 +168,20 @@ class AgentGroup:
|
|
161
168
|
with open(filename, "wb") as f:
|
162
169
|
surveys = []
|
163
170
|
fastavro.writer(f, SURVEY_SCHEMA, surveys)
|
171
|
+
|
172
|
+
if self.enable_pgsql:
|
173
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
174
|
+
profiles: list[Any] = []
|
175
|
+
for agent in self.agents:
|
176
|
+
profile = await agent.memory._profile.export()
|
177
|
+
profile = profile[0]
|
178
|
+
profile["id"] = agent._uuid
|
179
|
+
profiles.append(
|
180
|
+
(agent._uuid, profile.get("name", ""), json.dumps(profile))
|
181
|
+
)
|
182
|
+
await self._pgsql_writer.async_write_profile.remote( # type:ignore
|
183
|
+
profiles
|
184
|
+
)
|
164
185
|
self.initialized = True
|
165
186
|
logger.debug(f"-----AgentGroup {self._uuid} initialized")
|
166
187
|
|
@@ -218,11 +239,13 @@ class AgentGroup:
|
|
218
239
|
await asyncio.sleep(0.5)
|
219
240
|
|
220
241
|
async def save_status(self):
|
242
|
+
_statuses_time_list: list[tuple[dict, datetime]] = []
|
221
243
|
if self.enable_avro:
|
222
244
|
logger.debug(f"-----Saving status for group {self._uuid}")
|
223
245
|
avros = []
|
224
246
|
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
225
247
|
for agent in self.agents:
|
248
|
+
_date_time = datetime.now(timezone.utc)
|
226
249
|
position = await agent.memory.get("position")
|
227
250
|
lng = position["longlat_position"]["longitude"]
|
228
251
|
lat = position["longlat_position"]["latitude"]
|
@@ -248,13 +271,15 @@ class AgentGroup:
|
|
248
271
|
"tired": needs["tired"],
|
249
272
|
"safe": needs["safe"],
|
250
273
|
"social": needs["social"],
|
251
|
-
"created_at": int(
|
274
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
252
275
|
}
|
253
276
|
avros.append(avro)
|
277
|
+
_statuses_time_list.append((avro, _date_time))
|
254
278
|
with open(self.avro_file["status"], "a+b") as f:
|
255
279
|
fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
|
256
280
|
else:
|
257
281
|
for agent in self.agents:
|
282
|
+
_date_time = datetime.now(timezone.utc)
|
258
283
|
avro = {
|
259
284
|
"id": agent._uuid,
|
260
285
|
"day": await self.simulator.get_simulator_day(),
|
@@ -274,8 +299,109 @@ class AgentGroup:
|
|
274
299
|
"customers": await agent.memory.get("customers"),
|
275
300
|
}
|
276
301
|
avros.append(avro)
|
302
|
+
_statuses_time_list.append((avro, _date_time))
|
277
303
|
with open(self.avro_file["status"], "a+b") as f:
|
278
304
|
fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, avros, codec="snappy")
|
305
|
+
if self.enable_pgsql:
|
306
|
+
# data already acquired from Avro part
|
307
|
+
if len(_statuses_time_list) > 0:
|
308
|
+
for _status_dict, _date_time in _statuses_time_list:
|
309
|
+
for key in ["lng", "lat", "parent_id"]:
|
310
|
+
if key not in _status_dict:
|
311
|
+
_status_dict[key] = -1
|
312
|
+
for key in [
|
313
|
+
"action",
|
314
|
+
]:
|
315
|
+
if key not in _status_dict:
|
316
|
+
_status_dict[key] = ""
|
317
|
+
_status_dict["created_at"] = _date_time
|
318
|
+
else:
|
319
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
320
|
+
for agent in self.agents:
|
321
|
+
_date_time = datetime.now(timezone.utc)
|
322
|
+
position = await agent.memory.get("position")
|
323
|
+
lng = position["longlat_position"]["longitude"]
|
324
|
+
lat = position["longlat_position"]["latitude"]
|
325
|
+
if "aoi_position" in position:
|
326
|
+
parent_id = position["aoi_position"]["aoi_id"]
|
327
|
+
elif "lane_position" in position:
|
328
|
+
parent_id = position["lane_position"]["lane_id"]
|
329
|
+
else:
|
330
|
+
# BUG: 需要处理
|
331
|
+
parent_id = -1
|
332
|
+
needs = await agent.memory.get("needs")
|
333
|
+
action = await agent.memory.get("current_step")
|
334
|
+
action = action["intention"]
|
335
|
+
_status_dict = {
|
336
|
+
"id": agent._uuid,
|
337
|
+
"day": await self.simulator.get_simulator_day(),
|
338
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
339
|
+
"lng": lng,
|
340
|
+
"lat": lat,
|
341
|
+
"parent_id": parent_id,
|
342
|
+
"action": action,
|
343
|
+
"hungry": needs["hungry"],
|
344
|
+
"tired": needs["tired"],
|
345
|
+
"safe": needs["safe"],
|
346
|
+
"social": needs["social"],
|
347
|
+
"created_at": _date_time,
|
348
|
+
}
|
349
|
+
_statuses_time_list.append((_status_dict, _date_time))
|
350
|
+
else:
|
351
|
+
for agent in self.agents:
|
352
|
+
_date_time = datetime.now(timezone.utc)
|
353
|
+
_status_dict = {
|
354
|
+
"id": agent._uuid,
|
355
|
+
"day": await self.simulator.get_simulator_day(),
|
356
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
357
|
+
"lng": -1,
|
358
|
+
"lat": -1,
|
359
|
+
"parent_id": -1,
|
360
|
+
"action": "",
|
361
|
+
"type": await agent.memory.get("type"),
|
362
|
+
"nominal_gdp": await agent.memory.get("nominal_gdp"),
|
363
|
+
"real_gdp": await agent.memory.get("real_gdp"),
|
364
|
+
"unemployment": await agent.memory.get("unemployment"),
|
365
|
+
"wages": await agent.memory.get("wages"),
|
366
|
+
"prices": await agent.memory.get("prices"),
|
367
|
+
"inventory": await agent.memory.get("inventory"),
|
368
|
+
"price": await agent.memory.get("price"),
|
369
|
+
"interest_rate": await agent.memory.get("interest_rate"),
|
370
|
+
"bracket_cutoffs": await agent.memory.get(
|
371
|
+
"bracket_cutoffs"
|
372
|
+
),
|
373
|
+
"bracket_rates": await agent.memory.get("bracket_rates"),
|
374
|
+
"employees": await agent.memory.get("employees"),
|
375
|
+
"customers": await agent.memory.get("customers"),
|
376
|
+
"created_at": _date_time,
|
377
|
+
}
|
378
|
+
_statuses_time_list.append((_status_dict, _date_time))
|
379
|
+
to_update_statues: list[tuple] = []
|
380
|
+
for _status_dict, _ in _statuses_time_list:
|
381
|
+
BASIC_KEYS = [
|
382
|
+
"id",
|
383
|
+
"day",
|
384
|
+
"t",
|
385
|
+
"lng",
|
386
|
+
"lat",
|
387
|
+
"parent_id",
|
388
|
+
"action",
|
389
|
+
"created_at",
|
390
|
+
]
|
391
|
+
_data = [_status_dict[k] for k in BASIC_KEYS if k != "created_at"]
|
392
|
+
_other_dict = json.dumps(
|
393
|
+
{k: v for k, v in _status_dict.items() if k not in BASIC_KEYS}
|
394
|
+
)
|
395
|
+
_data.append(_other_dict)
|
396
|
+
_data.append(_status_dict["created_at"])
|
397
|
+
to_update_statues.append(tuple(_data))
|
398
|
+
if self._last_asyncio_pg_task is not None:
|
399
|
+
await self._last_asyncio_pg_task
|
400
|
+
self._last_asyncio_pg_task = (
|
401
|
+
self._pgsql_writer.async_write_status.remote( # type:ignore
|
402
|
+
to_update_statues
|
403
|
+
)
|
404
|
+
)
|
279
405
|
|
280
406
|
async def step(self):
|
281
407
|
if not self.initialized:
|
@@ -23,6 +23,7 @@ from ..message.messager import Messager
|
|
23
23
|
from ..metrics import init_mlflow_connection
|
24
24
|
from ..survey import Survey
|
25
25
|
from .agentgroup import AgentGroup
|
26
|
+
from .storage.pg import PgWriter, create_pg_tables
|
26
27
|
|
27
28
|
logger = logging.getLogger("pycityagent")
|
28
29
|
|
@@ -63,6 +64,7 @@ class AgentSimulation:
|
|
63
64
|
self._user_survey_topics: dict[uuid.UUID, str] = {}
|
64
65
|
self._user_interview_topics: dict[uuid.UUID, str] = {}
|
65
66
|
self._loop = asyncio.get_event_loop()
|
67
|
+
# self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
66
68
|
|
67
69
|
self._messager = Messager(
|
68
70
|
hostname=config["simulator_request"]["mqtt"]["server"],
|
@@ -89,22 +91,13 @@ class AgentSimulation:
|
|
89
91
|
self._enable_pgsql = _pgsql_config.get("enabled", False)
|
90
92
|
if not self._enable_pgsql:
|
91
93
|
logger.warning("PostgreSQL is not enabled, NO POSTGRESQL DATABASE STORAGE")
|
92
|
-
self.
|
94
|
+
self._pgsql_dsn = ""
|
93
95
|
else:
|
94
|
-
self.
|
95
|
-
self._pgsql_port = _pgsql_config["port"]
|
96
|
-
self._pgsql_database = _pgsql_config["database"]
|
97
|
-
self._pgsql_user = _pgsql_config.get("user", None)
|
98
|
-
self._pgsql_password = _pgsql_config.get("password", None)
|
99
|
-
self._pgsql_args: tuple[str, str, str, str, str] = (
|
100
|
-
self._pgsql_host,
|
101
|
-
self._pgsql_port,
|
102
|
-
self._pgsql_database,
|
103
|
-
self._pgsql_user,
|
104
|
-
self._pgsql_password,
|
105
|
-
)
|
96
|
+
self._pgsql_dsn = _pgsql_config["data_source_name"]
|
106
97
|
|
107
98
|
# 添加实验信息相关的属性
|
99
|
+
self._exp_created_time = datetime.now(timezone.utc)
|
100
|
+
self._exp_updated_time = datetime.now(timezone.utc)
|
108
101
|
self._exp_info = {
|
109
102
|
"id": self.exp_id,
|
110
103
|
"name": exp_name,
|
@@ -114,7 +107,8 @@ class AgentSimulation:
|
|
114
107
|
"cur_t": 0.0,
|
115
108
|
"config": json.dumps(config),
|
116
109
|
"error": "",
|
117
|
-
"created_at":
|
110
|
+
"created_at": self._exp_created_time.isoformat(),
|
111
|
+
"updated_at": self._exp_updated_time.isoformat(),
|
118
112
|
}
|
119
113
|
|
120
114
|
# 创建异步任务保存实验信息
|
@@ -168,7 +162,7 @@ class AgentSimulation:
|
|
168
162
|
enable_avro: bool,
|
169
163
|
avro_path: Path,
|
170
164
|
enable_pgsql: bool,
|
171
|
-
|
165
|
+
pgsql_writer: ray.ObjectRef,
|
172
166
|
mlflow_run_id: str = None, # type: ignore
|
173
167
|
logging_level: int = logging.WARNING,
|
174
168
|
):
|
@@ -181,7 +175,7 @@ class AgentSimulation:
|
|
181
175
|
enable_avro,
|
182
176
|
avro_path,
|
183
177
|
enable_pgsql,
|
184
|
-
|
178
|
+
pgsql_writer,
|
185
179
|
mlflow_run_id,
|
186
180
|
logging_level,
|
187
181
|
)
|
@@ -191,6 +185,7 @@ class AgentSimulation:
|
|
191
185
|
self,
|
192
186
|
agent_count: Union[int, list[int]],
|
193
187
|
group_size: int = 1000,
|
188
|
+
pg_sql_writers: int = 32,
|
194
189
|
memory_config_func: Optional[Union[Callable, list[Callable]]] = None,
|
195
190
|
) -> None:
|
196
191
|
"""初始化智能体
|
@@ -251,8 +246,8 @@ class AgentSimulation:
|
|
251
246
|
memory=memory,
|
252
247
|
)
|
253
248
|
|
254
|
-
self._agents[agent._uuid] = agent
|
255
|
-
self._agent_uuids.append(agent._uuid)
|
249
|
+
self._agents[agent._uuid] = agent # type:ignore
|
250
|
+
self._agent_uuids.append(agent._uuid) # type:ignore
|
256
251
|
|
257
252
|
# 计算需要的组数,向上取整以处理不足一组的情况
|
258
253
|
num_group = (agent_count_i + group_size - 1) // group_size
|
@@ -282,9 +277,23 @@ class AgentSimulation:
|
|
282
277
|
)
|
283
278
|
else:
|
284
279
|
mlflow_run_id = None
|
280
|
+
# 建表
|
281
|
+
if self.enable_pgsql:
|
282
|
+
_num_workers = min(1, pg_sql_writers)
|
283
|
+
create_pg_tables(
|
284
|
+
exp_id=self.exp_id,
|
285
|
+
dsn=self._pgsql_dsn,
|
286
|
+
)
|
287
|
+
self._pgsql_writers = _workers = [
|
288
|
+
PgWriter.remote(self.exp_id, self._pgsql_dsn)
|
289
|
+
for _ in range(_num_workers)
|
290
|
+
]
|
291
|
+
else:
|
292
|
+
_num_workers = 1
|
293
|
+
self._pgsql_writers = _workers = [None for _ in range(_num_workers)]
|
285
294
|
# 收集所有创建组的参数
|
286
295
|
creation_tasks = []
|
287
|
-
for group_name, agents in group_creation_params:
|
296
|
+
for i, (group_name, agents) in enumerate(group_creation_params):
|
288
297
|
# 直接创建异步任务
|
289
298
|
group = AgentGroup.remote(
|
290
299
|
agents,
|
@@ -294,10 +303,8 @@ class AgentSimulation:
|
|
294
303
|
self.enable_avro,
|
295
304
|
self.avro_path,
|
296
305
|
self.enable_pgsql,
|
297
|
-
#
|
298
|
-
|
299
|
-
None,
|
300
|
-
mlflow_run_id,
|
306
|
+
_workers[i % _num_workers], # type:ignore
|
307
|
+
mlflow_run_id, # type:ignore
|
301
308
|
self.logging_level,
|
302
309
|
)
|
303
310
|
creation_tasks.append((group_name, group, agents))
|
@@ -469,11 +476,13 @@ class AgentSimulation:
|
|
469
476
|
survey_dict = survey.to_dict()
|
470
477
|
if agent_uuids is None:
|
471
478
|
agent_uuids = self._agent_uuids
|
479
|
+
_date_time = datetime.now(timezone.utc)
|
472
480
|
payload = {
|
473
481
|
"from": "none",
|
474
482
|
"survey_id": survey_dict["id"],
|
475
|
-
"timestamp": int(
|
483
|
+
"timestamp": int(_date_time.timestamp() * 1000),
|
476
484
|
"data": survey_dict,
|
485
|
+
"_date_time": _date_time,
|
477
486
|
}
|
478
487
|
for uuid in agent_uuids:
|
479
488
|
topic = self._user_survey_topics[uuid]
|
@@ -483,10 +492,12 @@ class AgentSimulation:
|
|
483
492
|
self, content: str, agent_uuids: Union[uuid.UUID, list[uuid.UUID]]
|
484
493
|
):
|
485
494
|
"""发送面试消息"""
|
495
|
+
_date_time = datetime.now(timezone.utc)
|
486
496
|
payload = {
|
487
497
|
"from": "none",
|
488
498
|
"content": content,
|
489
|
-
"timestamp": int(
|
499
|
+
"timestamp": int(_date_time.timestamp() * 1000),
|
500
|
+
"_date_time": _date_time,
|
490
501
|
}
|
491
502
|
if not isinstance(agent_uuids, Sequence):
|
492
503
|
agent_uuids = [agent_uuids]
|
@@ -515,15 +526,29 @@ class AgentSimulation:
|
|
515
526
|
logger.error(f"Avro保存实验信息失败: {str(e)}")
|
516
527
|
try:
|
517
528
|
if self.enable_pgsql:
|
518
|
-
#
|
519
|
-
|
529
|
+
worker: ray.ObjectRef = self._pgsql_writers[0] # type:ignore
|
530
|
+
# if self._last_asyncio_pg_task is not None:
|
531
|
+
# await self._last_asyncio_pg_task
|
532
|
+
# self._last_asyncio_pg_task = (
|
533
|
+
# worker.async_update_exp_info.remote( # type:ignore
|
534
|
+
# pg_exp_info
|
535
|
+
# )
|
536
|
+
# )
|
537
|
+
pg_exp_info = {k: v for k, v in self._exp_info.items()}
|
538
|
+
pg_exp_info["created_at"] = self._exp_created_time
|
539
|
+
pg_exp_info["updated_at"] = self._exp_updated_time
|
540
|
+
await worker.async_update_exp_info.remote( # type:ignore
|
541
|
+
pg_exp_info
|
542
|
+
)
|
520
543
|
except Exception as e:
|
521
544
|
logger.error(f"PostgreSQL保存实验信息失败: {str(e)}")
|
522
545
|
|
523
546
|
async def _update_exp_status(self, status: int, error: str = "") -> None:
|
547
|
+
self._exp_updated_time = datetime.now(timezone.utc)
|
524
548
|
"""更新实验状态并保存"""
|
525
549
|
self._exp_info["status"] = status
|
526
550
|
self._exp_info["error"] = error
|
551
|
+
self._exp_info["updated_at"] = self._exp_updated_time.isoformat()
|
527
552
|
await self._save_exp_info()
|
528
553
|
|
529
554
|
async def _monitor_exp_status(self, stop_event: asyncio.Event):
|
@@ -0,0 +1,139 @@
|
|
1
|
+
import asyncio
|
2
|
+
from collections import defaultdict
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
import psycopg
|
6
|
+
import psycopg.sql
|
7
|
+
import ray
|
8
|
+
from psycopg.rows import dict_row
|
9
|
+
|
10
|
+
from ...utils.decorators import lock_decorator
|
11
|
+
from ...utils.pg_query import PGSQL_DICT
|
12
|
+
|
13
|
+
|
14
|
+
def create_pg_tables(exp_id: str, dsn: str):
|
15
|
+
for table_type, exec_strs in PGSQL_DICT.items():
|
16
|
+
table_name = f"socialcity_{exp_id.replace('-', '_')}_{table_type}"
|
17
|
+
# # debug str
|
18
|
+
# for _str in [f"DROP TABLE IF EXISTS {table_name}"] + [
|
19
|
+
# _exec_str.format(table_name=table_name) for _exec_str in exec_strs
|
20
|
+
# ]:
|
21
|
+
# print(_str)
|
22
|
+
with psycopg.connect(dsn) as conn:
|
23
|
+
with conn.cursor() as cur:
|
24
|
+
# delete table
|
25
|
+
cur.execute(f"DROP TABLE IF EXISTS {table_name}") # type:ignore
|
26
|
+
conn.commit()
|
27
|
+
# create table
|
28
|
+
for _exec_str in exec_strs:
|
29
|
+
cur.execute(_exec_str.format(table_name=table_name))
|
30
|
+
conn.commit()
|
31
|
+
|
32
|
+
|
33
|
+
@ray.remote
|
34
|
+
class PgWriter:
|
35
|
+
def __init__(self, exp_id: str, dsn: str):
|
36
|
+
self.exp_id = exp_id
|
37
|
+
self._dsn = dsn
|
38
|
+
# self._lock = asyncio.Lock()
|
39
|
+
|
40
|
+
# @lock_decorator
|
41
|
+
async def async_write_dialog(self, rows: list[tuple]):
|
42
|
+
_tuple_types = [str, int, float, int, str, str, str, None]
|
43
|
+
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_dialog"
|
44
|
+
# 将数据插入数据库
|
45
|
+
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
46
|
+
copy_sql = psycopg.sql.SQL(
|
47
|
+
"COPY {} (id, day, t, type, speaker, content, created_at) FROM STDIN"
|
48
|
+
).format(psycopg.sql.Identifier(table_name))
|
49
|
+
async with aconn.cursor() as cur:
|
50
|
+
async with cur.copy(copy_sql) as copy:
|
51
|
+
for row in rows:
|
52
|
+
_row = [
|
53
|
+
_type(r) if _type is not None else r
|
54
|
+
for (_type, r) in zip(_tuple_types, row)
|
55
|
+
]
|
56
|
+
await copy.write_row(_row)
|
57
|
+
|
58
|
+
# @lock_decorator
|
59
|
+
async def async_write_status(self, rows: list[tuple]):
|
60
|
+
_tuple_types = [str, int, float, float, float, int, str, str, None]
|
61
|
+
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_status"
|
62
|
+
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
63
|
+
copy_sql = psycopg.sql.SQL(
|
64
|
+
"COPY {} (id, day, t, lng, lat, parent_id, action, status, created_at) FROM STDIN"
|
65
|
+
).format(psycopg.sql.Identifier(table_name))
|
66
|
+
async with aconn.cursor() as cur:
|
67
|
+
async with cur.copy(copy_sql) as copy:
|
68
|
+
for row in rows:
|
69
|
+
_row = [
|
70
|
+
_type(r) if _type is not None else r
|
71
|
+
for (_type, r) in zip(_tuple_types, row)
|
72
|
+
]
|
73
|
+
await copy.write_row(_row)
|
74
|
+
|
75
|
+
# @lock_decorator
|
76
|
+
async def async_write_profile(self, rows: list[tuple]):
|
77
|
+
_tuple_types = [str, str, str]
|
78
|
+
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_profile"
|
79
|
+
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
80
|
+
copy_sql = psycopg.sql.SQL("COPY {} (id, name, profile) FROM STDIN").format(
|
81
|
+
psycopg.sql.Identifier(table_name)
|
82
|
+
)
|
83
|
+
async with aconn.cursor() as cur:
|
84
|
+
async with cur.copy(copy_sql) as copy:
|
85
|
+
for row in rows:
|
86
|
+
_row = [
|
87
|
+
_type(r) if _type is not None else r
|
88
|
+
for (_type, r) in zip(_tuple_types, row)
|
89
|
+
]
|
90
|
+
await copy.write_row(_row)
|
91
|
+
|
92
|
+
# @lock_decorator
|
93
|
+
async def async_write_survey(self, rows: list[tuple]):
|
94
|
+
_tuple_types = [str, int, float, str, str, None]
|
95
|
+
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_survey"
|
96
|
+
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
97
|
+
copy_sql = psycopg.sql.SQL(
|
98
|
+
"COPY {} (id, day, t, survey_id, result, created_at) FROM STDIN"
|
99
|
+
).format(psycopg.sql.Identifier(table_name))
|
100
|
+
async with aconn.cursor() as cur:
|
101
|
+
async with cur.copy(copy_sql) as copy:
|
102
|
+
for row in rows:
|
103
|
+
_row = [
|
104
|
+
_type(r) if _type is not None else r
|
105
|
+
for (_type, r) in zip(_tuple_types, row)
|
106
|
+
]
|
107
|
+
await copy.write_row(_row)
|
108
|
+
|
109
|
+
# @lock_decorator
|
110
|
+
async def async_update_exp_info(self, exp_info: dict[str, Any]):
|
111
|
+
# timestamp不做类型转换
|
112
|
+
TO_UPDATE_EXP_INFO_KEYS_AND_TYPES = [
|
113
|
+
("id", str),
|
114
|
+
("name", str),
|
115
|
+
("num_day", int),
|
116
|
+
("status", int),
|
117
|
+
("cur_day", int),
|
118
|
+
("cur_t", float),
|
119
|
+
("config", str),
|
120
|
+
("error", str),
|
121
|
+
("created_at", None),
|
122
|
+
("updated_at", None),
|
123
|
+
]
|
124
|
+
table_name = f"socialcity_{self.exp_id.replace('-', '_')}_experiment"
|
125
|
+
async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
|
126
|
+
async with aconn.cursor(row_factory=dict_row) as cur:
|
127
|
+
# UPDATE
|
128
|
+
columns = ", ".join(
|
129
|
+
f"{key} = %s" for key, _ in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
130
|
+
)
|
131
|
+
update_sql = psycopg.sql.SQL(
|
132
|
+
f"UPDATE {{}} SET {columns} WHERE id = %s" # type:ignore
|
133
|
+
).format(psycopg.sql.Identifier(table_name))
|
134
|
+
params = [
|
135
|
+
_type(exp_info[key]) if _type is not None else exp_info[key]
|
136
|
+
for key, _type in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
137
|
+
] + [self.exp_id]
|
138
|
+
await cur.execute(update_sql, params)
|
139
|
+
await aconn.commit()
|
@@ -0,0 +1,80 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
PGSQL_DICT: dict[str, list[Any]] = {
|
4
|
+
# Experiment
|
5
|
+
"experiment": [
|
6
|
+
"""
|
7
|
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
8
|
+
id UUID PRIMARY KEY,
|
9
|
+
name TEXT,
|
10
|
+
num_day INT4,
|
11
|
+
status INT4,
|
12
|
+
cur_day INT4,
|
13
|
+
cur_t FLOAT,
|
14
|
+
config TEXT,
|
15
|
+
error TEXT,
|
16
|
+
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
17
|
+
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
18
|
+
)
|
19
|
+
""",
|
20
|
+
],
|
21
|
+
# Agent Profile
|
22
|
+
"agent_profile": [
|
23
|
+
"""
|
24
|
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
25
|
+
id UUID PRIMARY KEY,
|
26
|
+
name TEXT,
|
27
|
+
profile JSONB
|
28
|
+
)
|
29
|
+
""",
|
30
|
+
],
|
31
|
+
# Agent Dialog
|
32
|
+
"agent_dialog": [
|
33
|
+
"""
|
34
|
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
35
|
+
id UUID,
|
36
|
+
day INT4,
|
37
|
+
t FLOAT,
|
38
|
+
type INT4,
|
39
|
+
speaker TEXT,
|
40
|
+
content TEXT,
|
41
|
+
created_at TIMESTAMPTZ
|
42
|
+
)
|
43
|
+
""",
|
44
|
+
"CREATE INDEX {table_name}_id_idx ON {table_name} (id)",
|
45
|
+
"CREATE INDEX {table_name}_day_t_idx ON {table_name} (day,t)",
|
46
|
+
],
|
47
|
+
# Agent Status
|
48
|
+
"agent_status": [
|
49
|
+
"""
|
50
|
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
51
|
+
id UUID,
|
52
|
+
day INT4,
|
53
|
+
t FLOAT,
|
54
|
+
lng DOUBLE PRECISION,
|
55
|
+
lat DOUBLE PRECISION,
|
56
|
+
parent_id INT4,
|
57
|
+
action TEXT,
|
58
|
+
status JSONB,
|
59
|
+
created_at TIMESTAMPTZ
|
60
|
+
)
|
61
|
+
""",
|
62
|
+
"CREATE INDEX {table_name}_id_idx ON {table_name} (id)",
|
63
|
+
"CREATE INDEX {table_name}_day_t_idx ON {table_name} (day,t)",
|
64
|
+
],
|
65
|
+
# Agent Survey
|
66
|
+
"agent_survey": [
|
67
|
+
"""
|
68
|
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
69
|
+
id UUID,
|
70
|
+
day INT4,
|
71
|
+
t FLOAT,
|
72
|
+
survey_id UUID,
|
73
|
+
result JSONB,
|
74
|
+
created_at TIMESTAMPTZ
|
75
|
+
)
|
76
|
+
""",
|
77
|
+
"CREATE INDEX {table_name}_id_idx ON {table_name} (id)",
|
78
|
+
"CREATE INDEX {table_name}_day_t_idx ON {table_name} (day,t)",
|
79
|
+
],
|
80
|
+
}
|
pycityagent/workflow/tool.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
import time
|
2
|
+
from collections import defaultdict
|
3
|
+
from collections.abc import Callable, Sequence
|
2
4
|
from typing import Any, Optional, Union
|
3
|
-
|
5
|
+
|
4
6
|
from mlflow.entities import Metric
|
5
7
|
|
6
8
|
from ..agent import Agent
|
@@ -190,33 +192,38 @@ class ResetAgentPosition(Tool):
|
|
190
192
|
class ExportMlflowMetrics(Tool):
|
191
193
|
def __init__(self, log_batch_size: int = 100) -> None:
|
192
194
|
self._log_batch_size = log_batch_size
|
193
|
-
# TODO:support other log types
|
194
|
-
self.metric_log_cache: list[Metric] =
|
195
|
+
# TODO: support other log types
|
196
|
+
self.metric_log_cache: dict[str, list[Metric]] = defaultdict(list)
|
195
197
|
|
196
198
|
async def __call__(
|
197
199
|
self,
|
198
|
-
metric: Union[Metric, dict],
|
200
|
+
metric: Union[Sequence[Union[Metric, dict]], Union[Metric, dict]],
|
199
201
|
clear_cache: bool = False,
|
200
202
|
):
|
201
203
|
agent = self.agent
|
202
204
|
batch_size = self._log_batch_size
|
203
|
-
if
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
else:
|
210
|
-
if isinstance(metric, Metric):
|
211
|
-
self.metric_log_cache.append(metric)
|
205
|
+
if not isinstance(metric, Sequence):
|
206
|
+
metric = [metric]
|
207
|
+
for _metric in metric:
|
208
|
+
if isinstance(_metric, Metric):
|
209
|
+
item = _metric
|
210
|
+
metric_key = item.key
|
212
211
|
else:
|
213
|
-
|
214
|
-
key=
|
215
|
-
value=
|
216
|
-
timestamp=
|
217
|
-
step=
|
212
|
+
item = Metric(
|
213
|
+
key=_metric["key"],
|
214
|
+
value=_metric["value"],
|
215
|
+
timestamp=_metric.get("timestamp", int(1000 * time.time())),
|
216
|
+
step=_metric["step"],
|
218
217
|
)
|
219
|
-
|
218
|
+
metric_key = _metric["key"]
|
219
|
+
self.metric_log_cache[metric_key].append(item)
|
220
|
+
for metric_key, _cache in self.metric_log_cache.items():
|
221
|
+
if len(_cache) > batch_size:
|
222
|
+
client = agent.mlflow_client
|
223
|
+
await client.log_batch(
|
224
|
+
metrics=_cache[:batch_size],
|
225
|
+
)
|
226
|
+
_cache = _cache[batch_size:]
|
220
227
|
if clear_cache:
|
221
228
|
await self._clear_cache()
|
222
229
|
|
@@ -225,8 +232,9 @@ class ExportMlflowMetrics(Tool):
|
|
225
232
|
):
|
226
233
|
agent = self.agent
|
227
234
|
client = agent.mlflow_client
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
235
|
+
for metric_key, _cache in self.metric_log_cache.items():
|
236
|
+
if len(_cache) > 0:
|
237
|
+
await client.log_batch(
|
238
|
+
metrics=_cache,
|
239
|
+
)
|
240
|
+
_cache = []
|
@@ -1,5 +1,5 @@
|
|
1
1
|
pycityagent/__init__.py,sha256=EDxt3Su3lH1IMh9suNw7GeGL7UrXeWiZTw5KWNznDzc,637
|
2
|
-
pycityagent/agent.py,sha256=
|
2
|
+
pycityagent/agent.py,sha256=FMplKAgcz2Exkl8EiE2RwQ0Hd5U08krRZ3CFFLoF_4g,28450
|
3
3
|
pycityagent/economy/__init__.py,sha256=aonY4WHnx-6EGJ4WKrx4S-2jAkYNLtqUA04jp6q8B7w,75
|
4
4
|
pycityagent/economy/econ_client.py,sha256=GuHK9ZBnhqW3Z7F8ViDJn_iN73yOBbbwFyJv1wLEBDk,12211
|
5
5
|
pycityagent/environment/__init__.py,sha256=awHxlOud-btWbk0FCS4RmGJ13W84oVCkbGfcrhKqihA,240
|
@@ -48,9 +48,10 @@ pycityagent/message/messager.py,sha256=W_OVlNGcreHSBf6v-DrEnfNCXExB78ySr0w26MSnc
|
|
48
48
|
pycityagent/metrics/__init__.py,sha256=X08PaBbGVAd7_PRGLREXWxaqm7nS82WBQpD1zvQzcqc,128
|
49
49
|
pycityagent/metrics/mlflow_client.py,sha256=g_tHxWkWTDijtbGL74-HmiYzWVKb1y8-w12QrY9jL30,4449
|
50
50
|
pycityagent/metrics/utils/const.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
51
|
-
pycityagent/simulation/__init__.py,sha256=
|
52
|
-
pycityagent/simulation/agentgroup.py,sha256=
|
53
|
-
pycityagent/simulation/simulation.py,sha256=
|
51
|
+
pycityagent/simulation/__init__.py,sha256=P5czbcg2d8S0nbbnsQXFIhwzO4CennAhZM8OmKvAeYw,194
|
52
|
+
pycityagent/simulation/agentgroup.py,sha256=5p68wNoEaog4nDym3xsCTporBWmxNiQ1crN3mbOHFsE,19788
|
53
|
+
pycityagent/simulation/simulation.py,sha256=7Go_RkpkC_DuBWW21JPqlV2yXY754RqSkqzM0vTdteU,23008
|
54
|
+
pycityagent/simulation/storage/pg.py,sha256=Ws04mUgRcbbvWi_eQm3PXYa6w7AQUbDPWhSU7HFtsD8,6026
|
54
55
|
pycityagent/survey/__init__.py,sha256=rxwou8U9KeFSP7rMzXtmtp2fVFZxK4Trzi-psx9LPIs,153
|
55
56
|
pycityagent/survey/manager.py,sha256=S5IkwTdelsdtZETChRcfCEczzwSrry_Fly9MY4s3rbk,1681
|
56
57
|
pycityagent/survey/models.py,sha256=YE50UUt5qJ0O_lIUsSY6XFCGUTkJVNu_L1gAhaCJ2fs,3546
|
@@ -61,12 +62,13 @@ pycityagent/utils/parsers/__init__.py,sha256=AN2xgiPxszWK4rpX7zrqRsqNwfGF3WnCA5-
|
|
61
62
|
pycityagent/utils/parsers/code_block_parser.py,sha256=Cs2Z_hm9VfNCpPPll1TwteaJF-HAQPs-3RApsOekFm4,1173
|
62
63
|
pycityagent/utils/parsers/json_parser.py,sha256=FZ3XN1g8z4Dr2TFraUOoah1oQcze4fPd2m01hHoX0Mo,2917
|
63
64
|
pycityagent/utils/parsers/parser_base.py,sha256=KBKO4zLZPNdGjPAGqIus8LseZ8W3Tlt2y0QxqeCd25Q,1713
|
65
|
+
pycityagent/utils/pg_query.py,sha256=h5158xcrxjUTR0nKwAaG1neFfTHPbN5guLmaXpC8yvs,1918
|
64
66
|
pycityagent/utils/survey_util.py,sha256=Be9nptmu2JtesFNemPgORh_2GsN7rcDYGQS9Zfvc5OI,2169
|
65
67
|
pycityagent/workflow/__init__.py,sha256=QNkUV-9mACMrR8c0cSKna2gC1mMZdxXbxWzjE-Uods0,621
|
66
68
|
pycityagent/workflow/block.py,sha256=WkE2On97DCZS_9n8aIgT8wxv9Oaff4Fdf2tLqbKfMtE,6010
|
67
69
|
pycityagent/workflow/prompt.py,sha256=6jI0Rq54JLv3-IXqZLYug62vse10wTI83xvf4ZX42nk,2929
|
68
|
-
pycityagent/workflow/tool.py,sha256=
|
70
|
+
pycityagent/workflow/tool.py,sha256=xADxhNgVsjNiMxlhdwn3xGUstFOkLEG8P67ez8VmwSI,8555
|
69
71
|
pycityagent/workflow/trigger.py,sha256=Df-MOBEDWBbM-v0dFLQLXteLsipymT4n8vqexmK2GiQ,5643
|
70
|
-
pycityagent-2.0.
|
71
|
-
pycityagent-2.0.
|
72
|
-
pycityagent-2.0.
|
72
|
+
pycityagent-2.0.0a21.dist-info/METADATA,sha256=sowWsIPV6PFjNPeQI30Pn0J1Fqz5KfZ7sMydvfaOAX0,7848
|
73
|
+
pycityagent-2.0.0a21.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
74
|
+
pycityagent-2.0.0a21.dist-info/RECORD,,
|
File without changes
|