pycityagent 2.0.0a72__cp310-cp310-macosx_11_0_arm64.whl → 2.0.0a74__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent/agent.py +3 -2
- pycityagent/agent/agent_base.py +43 -3
- pycityagent/cityagent/blocks/cognition_block.py +2 -1
- pycityagent/cityagent/blocks/economy_block.py +3 -3
- pycityagent/cityagent/blocks/mobility_block.py +17 -20
- pycityagent/cityagent/blocks/needs_block.py +2 -0
- pycityagent/cityagent/blocks/plan_block.py +2 -4
- pycityagent/cityagent/blocks/utils.py +0 -1
- pycityagent/cityagent/initial.py +0 -2
- pycityagent/cityagent/memory_config.py +40 -41
- pycityagent/cityagent/societyagent.py +9 -10
- pycityagent/environment/sim/sim_env.py +9 -2
- pycityagent/environment/simulator.py +75 -23
- pycityagent/llm/llm.py +105 -104
- pycityagent/memory/memory.py +1 -1
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +9 -16
- pycityagent/simulation/simulation.py +23 -12
- pycityagent/tools/tool.py +1 -0
- {pycityagent-2.0.0a72.dist-info → pycityagent-2.0.0a74.dist-info}/METADATA +1 -1
- {pycityagent-2.0.0a72.dist-info → pycityagent-2.0.0a74.dist-info}/RECORD +25 -25
- {pycityagent-2.0.0a72.dist-info → pycityagent-2.0.0a74.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a72.dist-info → pycityagent-2.0.0a74.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a72.dist-info → pycityagent-2.0.0a74.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a72.dist-info → pycityagent-2.0.0a74.dist-info}/top_level.txt +0 -0
@@ -5,16 +5,17 @@ import logging
|
|
5
5
|
import os
|
6
6
|
from datetime import datetime, timedelta
|
7
7
|
import time
|
8
|
-
from typing import
|
8
|
+
from typing import Optional, Union, cast
|
9
9
|
|
10
|
-
from mosstool.
|
11
|
-
from mosstool.util.format_converter import coll2pb
|
10
|
+
from mosstool.util.format_converter import coll2pb, dict2pb
|
12
11
|
from pycitydata.map import Map as SimMap
|
13
12
|
from pycityproto.city.map.v2 import map_pb2 as map_pb2
|
14
13
|
from pycityproto.city.person.v2 import person_pb2 as person_pb2
|
15
14
|
from pycityproto.city.person.v2 import person_service_pb2 as person_service
|
16
15
|
from pymongo import MongoClient
|
16
|
+
import ray
|
17
17
|
from shapely.geometry import Point
|
18
|
+
from mosstool.type import TripMode
|
18
19
|
|
19
20
|
from .sim import CityClient, ControlSimEnv
|
20
21
|
from .utils.const import *
|
@@ -25,6 +26,42 @@ __all__ = [
|
|
25
26
|
"Simulator",
|
26
27
|
]
|
27
28
|
|
29
|
+
@ray.remote
|
30
|
+
class CityMap:
|
31
|
+
def __init__(self, mongo_uri: str, mongo_db: str, mongo_coll: str, cache_dir: str):
|
32
|
+
self.map = SimMap(
|
33
|
+
mongo_uri=mongo_uri,
|
34
|
+
mongo_db=mongo_db,
|
35
|
+
mongo_coll=mongo_coll,
|
36
|
+
cache_dir=cache_dir,
|
37
|
+
)
|
38
|
+
|
39
|
+
def get_aoi(self, aoi_id: Optional[int] = None):
|
40
|
+
if aoi_id is None:
|
41
|
+
return list(self.map.aois.values())
|
42
|
+
else:
|
43
|
+
return self.map.aois[aoi_id]
|
44
|
+
|
45
|
+
def get_poi(self, poi_id: Optional[int] = None):
|
46
|
+
if poi_id is None:
|
47
|
+
return list(self.map.pois.values())
|
48
|
+
else:
|
49
|
+
return self.map.pois[poi_id]
|
50
|
+
|
51
|
+
def query_pois(self, **kwargs):
|
52
|
+
return self.map.query_pois(**kwargs)
|
53
|
+
|
54
|
+
def get_poi_cate(self):
|
55
|
+
return self.poi_cate
|
56
|
+
|
57
|
+
def get_map(self):
|
58
|
+
return self.map
|
59
|
+
|
60
|
+
def get_map_header(self):
|
61
|
+
return self.map.header
|
62
|
+
|
63
|
+
def get_projector(self):
|
64
|
+
return self.map.header["projection"]
|
28
65
|
|
29
66
|
class Simulator:
|
30
67
|
"""
|
@@ -35,7 +72,7 @@ class Simulator:
|
|
35
72
|
- It reads parameters from a configuration dictionary, initializes map data, and starts or connects to a simulation server as needed.
|
36
73
|
"""
|
37
74
|
|
38
|
-
def __init__(self, config: dict,
|
75
|
+
def __init__(self, config: dict, create_map: bool = False) -> None:
|
39
76
|
self.config = config
|
40
77
|
"""
|
41
78
|
- 模拟器配置
|
@@ -66,7 +103,8 @@ class Simulator:
|
|
66
103
|
self._sim_env = sim_env = ControlSimEnv(
|
67
104
|
task_name=config["simulator"].get("task", "citysim"),
|
68
105
|
map_file=_map_pb_path,
|
69
|
-
|
106
|
+
max_day=config["simulator"].get("max_day", 1000),
|
107
|
+
start_step=config["simulator"].get("start_step", 28800),
|
70
108
|
total_step=2147000000,
|
71
109
|
log_dir=config["simulator"].get("log_dir", "./log"),
|
72
110
|
min_step_time=config["simulator"].get("min_step_time", 1000),
|
@@ -88,16 +126,14 @@ class Simulator:
|
|
88
126
|
logger.warning(
|
89
127
|
"No simulator config found, no simulator client will be used"
|
90
128
|
)
|
91
|
-
self.
|
92
|
-
mongo_uri=_mongo_uri,
|
93
|
-
mongo_db=_mongo_db,
|
94
|
-
mongo_coll=_mongo_coll,
|
95
|
-
cache_dir=_map_cache_dir,
|
96
|
-
)
|
129
|
+
self._map = None
|
97
130
|
"""
|
98
131
|
- 模拟器地图对象
|
99
132
|
- Simulator map object
|
100
133
|
"""
|
134
|
+
if create_map:
|
135
|
+
self._map = CityMap.remote(_mongo_uri, _mongo_db, _mongo_coll, _map_cache_dir)
|
136
|
+
self._create_poi_id_2_aoi_id()
|
101
137
|
|
102
138
|
self.time: int = 0
|
103
139
|
"""
|
@@ -109,25 +145,41 @@ class Simulator:
|
|
109
145
|
self.map_y_gap = None
|
110
146
|
self._bbox: tuple[float, float, float, float] = (-1, -1, -1, -1)
|
111
147
|
self._lock = asyncio.Lock()
|
112
|
-
# poi id dict
|
113
|
-
self.poi_id_2_aoi_id: dict[int, int] = {
|
114
|
-
poi["id"]: poi["aoi_id"] for _, poi in self.map.pois.items()
|
115
|
-
}
|
116
148
|
self._environment_prompt:dict[str, str] = {}
|
117
149
|
self._log_list = []
|
118
150
|
|
151
|
+
def set_map(self, map: CityMap):
|
152
|
+
self._map = map
|
153
|
+
self._create_poi_id_2_aoi_id()
|
154
|
+
|
155
|
+
def _create_poi_id_2_aoi_id(self):
|
156
|
+
pois = ray.get(self._map.get_poi.remote())
|
157
|
+
self.poi_id_2_aoi_id: dict[int, int] = {
|
158
|
+
poi["id"]: poi["aoi_id"] for poi in pois
|
159
|
+
}
|
160
|
+
|
161
|
+
@property
|
162
|
+
def map(self):
|
163
|
+
return self._map
|
164
|
+
|
119
165
|
def get_log_list(self):
|
120
166
|
return self._log_list
|
121
167
|
|
122
168
|
def clear_log_list(self):
|
123
169
|
self._log_list = []
|
124
170
|
|
171
|
+
def get_poi_cate(self):
|
172
|
+
return self.poi_cate
|
173
|
+
|
125
174
|
@property
|
126
175
|
def environment(self) -> dict[str, str]:
|
127
176
|
"""
|
128
177
|
Get the current state of environment variables.
|
129
178
|
"""
|
130
179
|
return self._environment_prompt
|
180
|
+
|
181
|
+
def get_server_addr(self):
|
182
|
+
return self.server_addr
|
131
183
|
|
132
184
|
def set_environment(self, environment: dict[str, str]):
|
133
185
|
"""
|
@@ -224,11 +276,11 @@ class Simulator:
|
|
224
276
|
categories: list[str] = []
|
225
277
|
if center is None:
|
226
278
|
center = (0, 0)
|
227
|
-
_pois: list[dict] = self.map.query_pois( # type:ignore
|
279
|
+
_pois: list[dict] = ray.get(self.map.query_pois.remote( # type:ignore
|
228
280
|
center=center,
|
229
281
|
radius=radius,
|
230
282
|
return_distance=False,
|
231
|
-
)
|
283
|
+
))
|
232
284
|
for poi in _pois:
|
233
285
|
catg = poi["category"]
|
234
286
|
categories.append(catg.split("|")[-1])
|
@@ -367,13 +419,12 @@ class Simulator:
|
|
367
419
|
self._log_list.append(log)
|
368
420
|
return person
|
369
421
|
|
370
|
-
async def add_person(self,
|
422
|
+
async def add_person(self, dict_person: dict) -> dict:
|
371
423
|
"""
|
372
424
|
Add a new person to the simulation.
|
373
425
|
|
374
426
|
- **Args**:
|
375
|
-
- `
|
376
|
-
it will be wrapped in an `AddPersonRequest`. Otherwise, `person` is expected to already be a valid request object.
|
427
|
+
- `dict_person` (`dict`): The person object to add.
|
377
428
|
|
378
429
|
- **Returns**:
|
379
430
|
- `Dict`: Response from adding the person.
|
@@ -384,6 +435,7 @@ class Simulator:
|
|
384
435
|
"start_time": start_time,
|
385
436
|
"consumption": 0
|
386
437
|
}
|
438
|
+
person = dict2pb(dict_person, person_pb2.Person())
|
387
439
|
if isinstance(person, person_pb2.Person):
|
388
440
|
req = person_service.AddPersonRequest(person=person)
|
389
441
|
else:
|
@@ -411,7 +463,7 @@ class Simulator:
|
|
411
463
|
A list of AOI or POI IDs or tuples of (AOI ID, POI ID) that the person will visit.
|
412
464
|
- `departure_times` (`Optional[List[float]]`): Departure times for each trip in the schedule.
|
413
465
|
If not provided, current time will be used for all trips.
|
414
|
-
- `modes` (`Optional[List[
|
466
|
+
- `modes` (`Optional[List[int]]`): Travel modes for each trip.
|
415
467
|
Defaults to `TRIP_MODE_DRIVE_ONLY` if not specified.
|
416
468
|
"""
|
417
469
|
start_time = time.time()
|
@@ -560,11 +612,11 @@ class Simulator:
|
|
560
612
|
transformed_poi_type += self.poi_cate[t]
|
561
613
|
poi_type_set = set(transformed_poi_type)
|
562
614
|
# 获取半径内的poi
|
563
|
-
_pois: list[dict] = self.map.query_pois( # type:ignore
|
615
|
+
_pois: list[dict] = ray.get(self.map.query_pois.remote( # type:ignore
|
564
616
|
center=center,
|
565
617
|
radius=radius,
|
566
618
|
return_distance=False,
|
567
|
-
)
|
619
|
+
))
|
568
620
|
# 过滤掉不满足类别前缀的poi
|
569
621
|
pois = []
|
570
622
|
for poi in _pois:
|
pycityagent/llm/llm.py
CHANGED
@@ -232,111 +232,112 @@ class LLM:
|
|
232
232
|
"""
|
233
233
|
start_time = time.time()
|
234
234
|
log = {"request_time": start_time}
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
task_id = response.id
|
301
|
-
task_status = ""
|
302
|
-
get_cnt = 0
|
303
|
-
cnt_threshold = int(timeout / 0.5)
|
304
|
-
while (
|
305
|
-
task_status != "SUCCESS"
|
306
|
-
and task_status != "FAILED"
|
307
|
-
and get_cnt <= cnt_threshold
|
308
|
-
):
|
309
|
-
result_response = client.chat.asyncCompletions.retrieve_completion_result(id=task_id) # type: ignore
|
310
|
-
task_status = result_response.task_status
|
311
|
-
await asyncio.sleep(0.5)
|
312
|
-
get_cnt += 1
|
313
|
-
if task_status != "SUCCESS":
|
314
|
-
raise Exception(f"Task failed with status: {task_status}")
|
315
|
-
|
316
|
-
self._client_usage[self._current_client_index]["prompt_tokens"] += result_response.usage.prompt_tokens # type: ignore
|
317
|
-
self._client_usage[self._current_client_index]["completion_tokens"] += result_response.usage.completion_tokens # type: ignore
|
318
|
-
self._client_usage[self._current_client_index]["request_number"] += 1
|
319
|
-
end_time = time.time()
|
320
|
-
log["used_time"] = end_time - start_time
|
321
|
-
log["token_consumption"] = result_response.usage.prompt_tokens + result_response.usage.completion_tokens
|
322
|
-
self._log_list.append(log)
|
323
|
-
if tools and result_response.choices[0].message.tool_calls: # type: ignore
|
324
|
-
return json.loads(
|
325
|
-
result_response.choices[0] # type: ignore
|
326
|
-
.message.tool_calls[0]
|
327
|
-
.function.arguments
|
235
|
+
async with self.semaphore:
|
236
|
+
if (
|
237
|
+
self.config.text["request_type"] == "openai"
|
238
|
+
or self.config.text["request_type"] == "deepseek"
|
239
|
+
or self.config.text["request_type"] == "qwen"
|
240
|
+
):
|
241
|
+
for attempt in range(retries):
|
242
|
+
try:
|
243
|
+
client = self._get_next_client()
|
244
|
+
response = await client.chat.completions.create(
|
245
|
+
model=self.config.text["model"],
|
246
|
+
messages=dialog,
|
247
|
+
temperature=temperature,
|
248
|
+
max_tokens=max_tokens,
|
249
|
+
top_p=top_p,
|
250
|
+
frequency_penalty=frequency_penalty, # type: ignore
|
251
|
+
presence_penalty=presence_penalty, # type: ignore
|
252
|
+
stream=False,
|
253
|
+
timeout=timeout,
|
254
|
+
tools=tools,
|
255
|
+
tool_choice=tool_choice,
|
256
|
+
) # type: ignore
|
257
|
+
self._client_usage[self._current_client_index]["prompt_tokens"] += response.usage.prompt_tokens # type: ignore
|
258
|
+
self._client_usage[self._current_client_index]["completion_tokens"] += response.usage.completion_tokens # type: ignore
|
259
|
+
self._client_usage[self._current_client_index]["request_number"] += 1
|
260
|
+
end_time = time.time()
|
261
|
+
log["consumption"] = end_time - start_time
|
262
|
+
log["input_tokens"] = response.usage.prompt_tokens
|
263
|
+
log["output_tokens"] = response.usage.completion_tokens
|
264
|
+
self._log_list.append(log)
|
265
|
+
if tools and response.choices[0].message.tool_calls:
|
266
|
+
return json.loads(
|
267
|
+
response.choices[0]
|
268
|
+
.message.tool_calls[0]
|
269
|
+
.function.arguments
|
270
|
+
)
|
271
|
+
else:
|
272
|
+
return response.choices[0].message.content
|
273
|
+
except APIConnectionError as e:
|
274
|
+
print("API connection error:", e)
|
275
|
+
if attempt < retries - 1:
|
276
|
+
await asyncio.sleep(2**attempt)
|
277
|
+
else:
|
278
|
+
raise e
|
279
|
+
except OpenAIError as e:
|
280
|
+
if hasattr(e, "http_status"):
|
281
|
+
print(f"HTTP status code: {e.http_status}") # type: ignore
|
282
|
+
else:
|
283
|
+
print("An error occurred:", e)
|
284
|
+
if attempt < retries - 1:
|
285
|
+
await asyncio.sleep(2**attempt)
|
286
|
+
else:
|
287
|
+
raise e
|
288
|
+
elif self.config.text["request_type"] == "zhipuai":
|
289
|
+
for attempt in range(retries):
|
290
|
+
try:
|
291
|
+
client = self._get_next_client()
|
292
|
+
response = client.chat.asyncCompletions.create( # type: ignore
|
293
|
+
model=self.config.text["model"],
|
294
|
+
messages=dialog,
|
295
|
+
temperature=temperature,
|
296
|
+
top_p=top_p,
|
297
|
+
timeout=timeout,
|
298
|
+
tools=tools,
|
299
|
+
tool_choice=tool_choice,
|
328
300
|
)
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
301
|
+
task_id = response.id
|
302
|
+
task_status = ""
|
303
|
+
get_cnt = 0
|
304
|
+
cnt_threshold = int(timeout / 0.5)
|
305
|
+
while (
|
306
|
+
task_status != "SUCCESS"
|
307
|
+
and task_status != "FAILED"
|
308
|
+
and get_cnt <= cnt_threshold
|
309
|
+
):
|
310
|
+
result_response = client.chat.asyncCompletions.retrieve_completion_result(id=task_id) # type: ignore
|
311
|
+
task_status = result_response.task_status
|
312
|
+
await asyncio.sleep(0.5)
|
313
|
+
get_cnt += 1
|
314
|
+
if task_status != "SUCCESS":
|
315
|
+
raise Exception(f"Task failed with status: {task_status}")
|
316
|
+
|
317
|
+
self._client_usage[self._current_client_index]["prompt_tokens"] += result_response.usage.prompt_tokens # type: ignore
|
318
|
+
self._client_usage[self._current_client_index]["completion_tokens"] += result_response.usage.completion_tokens # type: ignore
|
319
|
+
self._client_usage[self._current_client_index]["request_number"] += 1
|
320
|
+
end_time = time.time()
|
321
|
+
log["used_time"] = end_time - start_time
|
322
|
+
log["token_consumption"] = result_response.usage.prompt_tokens + result_response.usage.completion_tokens
|
323
|
+
self._log_list.append(log)
|
324
|
+
if tools and result_response.choices[0].message.tool_calls: # type: ignore
|
325
|
+
return json.loads(
|
326
|
+
result_response.choices[0] # type: ignore
|
327
|
+
.message.tool_calls[0]
|
328
|
+
.function.arguments
|
329
|
+
)
|
330
|
+
else:
|
331
|
+
return result_response.choices[0].message.content # type: ignore
|
332
|
+
except APIConnectionError as e:
|
333
|
+
print("API connection error:", e)
|
334
|
+
if attempt < retries - 1:
|
335
|
+
await asyncio.sleep(2**attempt)
|
336
|
+
else:
|
337
|
+
raise e
|
338
|
+
else:
|
339
|
+
print("ERROR: Wrong Config")
|
340
|
+
return "wrong config"
|
340
341
|
|
341
342
|
async def img_understand(
|
342
343
|
self, img_path: Union[str, list[str]], prompt: Optional[str] = None
|
pycityagent/memory/memory.py
CHANGED
@@ -341,7 +341,7 @@ class StreamMemory:
|
|
341
341
|
"""获取指定ID的记忆"""
|
342
342
|
memories = [memory for memory in self._memories if memory.id in memory_ids]
|
343
343
|
sorted_results = sorted(memories, key=lambda x: (x.day, x.t), reverse=True)
|
344
|
-
return self.format_memory(sorted_results)
|
344
|
+
return await self.format_memory(sorted_results)
|
345
345
|
|
346
346
|
async def search(
|
347
347
|
self,
|
pycityagent/pycityagent-sim
CHANGED
Binary file
|
@@ -36,6 +36,7 @@ class AgentGroup:
|
|
36
36
|
number_of_agents: Union[int, list[int]],
|
37
37
|
memory_config_function_group: dict[type[Agent], Callable],
|
38
38
|
config: dict,
|
39
|
+
map_ref: ray.ObjectRef,
|
39
40
|
exp_name: str,
|
40
41
|
exp_id: Union[str, UUID],
|
41
42
|
enable_avro: bool,
|
@@ -47,8 +48,8 @@ class AgentGroup:
|
|
47
48
|
embedding_model: Embeddings,
|
48
49
|
logging_level: int,
|
49
50
|
agent_config_file: Optional[dict[type[Agent], str]] = None,
|
50
|
-
environment: Optional[dict[str, str]] = None,
|
51
51
|
llm_semaphore: int = 200,
|
52
|
+
environment: Optional[dict] = None,
|
52
53
|
):
|
53
54
|
"""
|
54
55
|
Represents a group of agents that can be deployed in a Ray distributed environment.
|
@@ -64,6 +65,7 @@ class AgentGroup:
|
|
64
65
|
- `number_of_agents` (Union[int, List[int]]): Number of instances to create for each agent class.
|
65
66
|
- `memory_config_function_group` (Dict[Type[Agent], Callable]): Functions to configure memory for each agent type.
|
66
67
|
- `config` (dict): Configuration settings for the agent group.
|
68
|
+
- `map_ref` (ray.ObjectRef): Reference to the map object.
|
67
69
|
- `exp_name` (str): Name of the experiment.
|
68
70
|
- `exp_id` (str | UUID): Identifier for the experiment.
|
69
71
|
- `enable_avro` (bool): Flag to enable AVRO file support.
|
@@ -143,10 +145,10 @@ class AgentGroup:
|
|
143
145
|
self.llm.set_semaphore(llm_semaphore)
|
144
146
|
|
145
147
|
# prepare Simulator
|
146
|
-
logger.info(f"-----
|
148
|
+
logger.info(f"-----Initializing Simulator in AgentGroup {self._uuid} ...")
|
147
149
|
self.simulator = Simulator(config["simulator_request"])
|
148
|
-
self.
|
149
|
-
self.simulator.
|
150
|
+
self.simulator.set_map(map_ref)
|
151
|
+
self.projector = pyproj.Proj(ray.get(self.simulator.map.get_projector.remote()))
|
150
152
|
# prepare Economy client
|
151
153
|
logger.info(f"-----Creating Economy client in AgentGroup {self._uuid} ...")
|
152
154
|
self.economy_client = EconomyClient(
|
@@ -430,16 +432,6 @@ class AgentGroup:
|
|
430
432
|
agent = self.id2agent[target_agent_uuid]
|
431
433
|
await agent.status.update(target_key, content)
|
432
434
|
|
433
|
-
async def update_environment(self, key: str, value: str):
|
434
|
-
"""
|
435
|
-
Updates the environment with a given key-value pair.
|
436
|
-
|
437
|
-
- **Args**:
|
438
|
-
- `key` (str): The key to update in the environment.
|
439
|
-
- `value` (str): The value to set for the specified key.
|
440
|
-
"""
|
441
|
-
self.simulator.update_environment(key, value)
|
442
|
-
|
443
435
|
async def message_dispatch(self):
|
444
436
|
"""
|
445
437
|
Dispatches messages received via MQTT to the appropriate agents.
|
@@ -788,12 +780,13 @@ class AgentGroup:
|
|
788
780
|
"""
|
789
781
|
try:
|
790
782
|
tasks = [agent.run() for agent in self.agents]
|
791
|
-
await asyncio.gather(*tasks)
|
783
|
+
agent_time_log =await asyncio.gather(*tasks)
|
792
784
|
simulator_log = self.simulator.get_log_list() + self.economy_client.get_log_list()
|
793
785
|
group_logs = {
|
794
786
|
"llm_log": self.llm.get_log_list(),
|
795
787
|
"mqtt_log": ray.get(self.messager.get_log_list.remote()),
|
796
|
-
"simulator_log": simulator_log
|
788
|
+
"simulator_log": simulator_log,
|
789
|
+
"agent_time_log": agent_time_log
|
797
790
|
}
|
798
791
|
self.llm.clear_log_list()
|
799
792
|
self.messager.clear_log_list.remote()
|
@@ -130,12 +130,16 @@ class AgentSimulation:
|
|
130
130
|
_simulator_config = config["simulator_request"].get("simulator", {})
|
131
131
|
if "server" in _simulator_config:
|
132
132
|
raise ValueError(f"Passing Traffic Simulation address is not supported!")
|
133
|
-
|
133
|
+
simulator = Simulator(config["simulator_request"], create_map=True)
|
134
|
+
self._simulator = simulator
|
135
|
+
self._map_ref = self._simulator.map
|
136
|
+
server_addr = self._simulator.get_server_addr()
|
137
|
+
config["simulator_request"]["simulator"]["server"] = server_addr
|
134
138
|
self._economy_client = EconomyClient(
|
135
139
|
config["simulator_request"]["simulator"]["server"]
|
136
140
|
)
|
137
141
|
if enable_institution:
|
138
|
-
self._economy_addr = economy_addr =
|
142
|
+
self._economy_addr = economy_addr = server_addr
|
139
143
|
if economy_addr is None:
|
140
144
|
raise ValueError(
|
141
145
|
f"`simulator` not provided in `simulator_request`, thus unable to activate economy!"
|
@@ -382,6 +386,7 @@ class AgentSimulation:
|
|
382
386
|
llm_log_lists = []
|
383
387
|
mqtt_log_lists = []
|
384
388
|
simulator_log_lists = []
|
389
|
+
agent_time_log_lists = []
|
385
390
|
for step in config["workflow"]:
|
386
391
|
logger.info(
|
387
392
|
f"Running step: type: {step['type']} - description: {step.get('description', 'no description')}"
|
@@ -389,17 +394,19 @@ class AgentSimulation:
|
|
389
394
|
if step["type"] not in ["run", "step", "interview", "survey", "intervene", "pause", "resume", "function"]:
|
390
395
|
raise ValueError(f"Invalid step type: {step['type']}")
|
391
396
|
if step["type"] == "run":
|
392
|
-
llm_log_list, mqtt_log_list, simulator_log_list = await simulation.run(step.get("days", 1))
|
397
|
+
llm_log_list, mqtt_log_list, simulator_log_list, agent_time_log_list = await simulation.run(step.get("days", 1))
|
393
398
|
llm_log_lists.extend(llm_log_list)
|
394
399
|
mqtt_log_lists.extend(mqtt_log_list)
|
395
400
|
simulator_log_lists.extend(simulator_log_list)
|
401
|
+
agent_time_log_lists.extend(agent_time_log_list)
|
396
402
|
elif step["type"] == "step":
|
397
403
|
times = step.get("times", 1)
|
398
404
|
for _ in range(times):
|
399
|
-
llm_log_list, mqtt_log_list, simulator_log_list = await simulation.step()
|
405
|
+
llm_log_list, mqtt_log_list, simulator_log_list, agent_time_log_list = await simulation.step()
|
400
406
|
llm_log_lists.extend(llm_log_list)
|
401
407
|
mqtt_log_lists.extend(mqtt_log_list)
|
402
408
|
simulator_log_lists.extend(simulator_log_list)
|
409
|
+
agent_time_log_lists.extend(agent_time_log_list)
|
403
410
|
elif step["type"] == "pause":
|
404
411
|
await simulation.pause_simulator()
|
405
412
|
elif step["type"] == "resume":
|
@@ -407,7 +414,7 @@ class AgentSimulation:
|
|
407
414
|
else:
|
408
415
|
await step["func"](simulation)
|
409
416
|
logger.info("Simulation finished")
|
410
|
-
return llm_log_lists, mqtt_log_lists, simulator_log_lists
|
417
|
+
return llm_log_lists, mqtt_log_lists, simulator_log_lists, agent_time_log_lists
|
411
418
|
|
412
419
|
@property
|
413
420
|
def enable_avro(
|
@@ -742,6 +749,7 @@ class AgentSimulation:
|
|
742
749
|
number_of_agents,
|
743
750
|
memory_config_function_group,
|
744
751
|
self.config,
|
752
|
+
self._map_ref,
|
745
753
|
self.exp_name,
|
746
754
|
self.exp_id,
|
747
755
|
self.enable_avro,
|
@@ -753,8 +761,8 @@ class AgentSimulation:
|
|
753
761
|
embedding_model,
|
754
762
|
self.logging_level,
|
755
763
|
config_file,
|
756
|
-
environment,
|
757
764
|
llm_semaphore,
|
765
|
+
environment,
|
758
766
|
)
|
759
767
|
creation_tasks.append((group_name, group))
|
760
768
|
|
@@ -1018,7 +1026,7 @@ class AgentSimulation:
|
|
1018
1026
|
|
1019
1027
|
# step
|
1020
1028
|
simulator_day = await self._simulator.get_simulator_day()
|
1021
|
-
simulator_time = int(await self._simulator.
|
1029
|
+
simulator_time = int(await self._simulator.get_simulator_second_from_start_of_day())
|
1022
1030
|
logger.info(
|
1023
1031
|
f"Start simulation day {simulator_day} at {simulator_time}, step {self._total_steps}"
|
1024
1032
|
)
|
@@ -1029,14 +1037,15 @@ class AgentSimulation:
|
|
1029
1037
|
llm_log_list = []
|
1030
1038
|
mqtt_log_list = []
|
1031
1039
|
simulator_log_list = []
|
1040
|
+
agent_time_log_list = []
|
1032
1041
|
for log_messages_group in log_messages_groups:
|
1033
1042
|
llm_log_list.extend(log_messages_group['llm_log'])
|
1034
1043
|
mqtt_log_list.extend(log_messages_group['mqtt_log'])
|
1035
1044
|
simulator_log_list.extend(log_messages_group['simulator_log'])
|
1036
|
-
|
1045
|
+
agent_time_log_list.extend(log_messages_group['agent_time_log'])
|
1037
1046
|
# save
|
1038
1047
|
simulator_day = await self._simulator.get_simulator_day()
|
1039
|
-
simulator_time = int(await self._simulator.
|
1048
|
+
simulator_time = int(await self._simulator.get_simulator_second_from_start_of_day())
|
1040
1049
|
save_tasks = []
|
1041
1050
|
for group in self._groups.values():
|
1042
1051
|
save_tasks.append(group.save.remote(simulator_day, simulator_time))
|
@@ -1050,7 +1059,7 @@ class AgentSimulation:
|
|
1050
1059
|
]
|
1051
1060
|
await self.extract_metric(to_excute_metric)
|
1052
1061
|
|
1053
|
-
return llm_log_list, mqtt_log_list, simulator_log_list
|
1062
|
+
return llm_log_list, mqtt_log_list, simulator_log_list, agent_time_log_list
|
1054
1063
|
except Exception as e:
|
1055
1064
|
import traceback
|
1056
1065
|
|
@@ -1081,6 +1090,7 @@ class AgentSimulation:
|
|
1081
1090
|
llm_log_lists = []
|
1082
1091
|
mqtt_log_lists = []
|
1083
1092
|
simulator_log_lists = []
|
1093
|
+
agent_time_log_lists = []
|
1084
1094
|
try:
|
1085
1095
|
self._exp_info["num_day"] += day
|
1086
1096
|
await self._update_exp_status(1) # 更新状态为运行中
|
@@ -1098,10 +1108,11 @@ class AgentSimulation:
|
|
1098
1108
|
current_time = await self._simulator.get_time()
|
1099
1109
|
if current_time >= end_time: # type:ignore
|
1100
1110
|
break
|
1101
|
-
llm_log_list, mqtt_log_list, simulator_log_list = await self.step()
|
1111
|
+
llm_log_list, mqtt_log_list, simulator_log_list, agent_time_log_list = await self.step()
|
1102
1112
|
llm_log_lists.extend(llm_log_list)
|
1103
1113
|
mqtt_log_lists.extend(mqtt_log_list)
|
1104
1114
|
simulator_log_lists.extend(simulator_log_list)
|
1115
|
+
agent_time_log_lists.extend(agent_time_log_list)
|
1105
1116
|
finally:
|
1106
1117
|
# 设置停止事件
|
1107
1118
|
stop_event.set()
|
@@ -1110,7 +1121,7 @@ class AgentSimulation:
|
|
1110
1121
|
|
1111
1122
|
# 运行成功后更新状态
|
1112
1123
|
await self._update_exp_status(2)
|
1113
|
-
return llm_log_lists, mqtt_log_lists, simulator_log_lists
|
1124
|
+
return llm_log_lists, mqtt_log_lists, simulator_log_lists, agent_time_log_lists
|
1114
1125
|
except Exception as e:
|
1115
1126
|
error_msg = f"模拟器运行错误: {str(e)}"
|
1116
1127
|
logger.error(error_msg)
|