pycityagent 2.0.0a73__cp312-cp312-macosx_11_0_arm64.whl → 2.0.0a75__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent/agent.py +3 -2
- pycityagent/agent/agent_base.py +43 -3
- pycityagent/cityagent/blocks/cognition_block.py +2 -1
- pycityagent/cityagent/blocks/economy_block.py +3 -3
- pycityagent/cityagent/blocks/mobility_block.py +17 -20
- pycityagent/cityagent/blocks/needs_block.py +2 -0
- pycityagent/cityagent/blocks/plan_block.py +2 -4
- pycityagent/cityagent/blocks/utils.py +0 -1
- pycityagent/cityagent/initial.py +0 -2
- pycityagent/cityagent/memory_config.py +40 -41
- pycityagent/cityagent/societyagent.py +9 -10
- pycityagent/environment/sim/sim_env.py +6 -2
- pycityagent/environment/simulator.py +113 -84
- pycityagent/llm/llm.py +105 -104
- pycityagent/memory/memory.py +1 -1
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +25 -22
- pycityagent/simulation/simulation.py +71 -38
- pycityagent/tools/tool.py +1 -0
- {pycityagent-2.0.0a73.dist-info → pycityagent-2.0.0a75.dist-info}/METADATA +1 -1
- {pycityagent-2.0.0a73.dist-info → pycityagent-2.0.0a75.dist-info}/RECORD +25 -25
- {pycityagent-2.0.0a73.dist-info → pycityagent-2.0.0a75.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a73.dist-info → pycityagent-2.0.0a75.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a73.dist-info → pycityagent-2.0.0a75.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a73.dist-info → pycityagent-2.0.0a75.dist-info}/top_level.txt +0 -0
@@ -3,12 +3,13 @@
|
|
3
3
|
import asyncio
|
4
4
|
import logging
|
5
5
|
import os
|
6
|
-
from datetime import datetime, timedelta
|
7
6
|
import time
|
8
|
-
from
|
7
|
+
from datetime import datetime, timedelta
|
8
|
+
from typing import Optional, Union, cast
|
9
9
|
|
10
|
+
import ray
|
10
11
|
from mosstool.type import TripMode
|
11
|
-
from mosstool.util.format_converter import coll2pb
|
12
|
+
from mosstool.util.format_converter import coll2pb, dict2pb
|
12
13
|
from pycitydata.map import Map as SimMap
|
13
14
|
from pycityproto.city.map.v2 import map_pb2 as map_pb2
|
14
15
|
from pycityproto.city.person.v2 import person_pb2 as person_pb2
|
@@ -26,6 +27,51 @@ __all__ = [
|
|
26
27
|
]
|
27
28
|
|
28
29
|
|
30
|
+
@ray.remote
|
31
|
+
class CityMap:
|
32
|
+
def __init__(self, mongo_input: tuple[str, str, str, str], map_cache_path: str):
|
33
|
+
if map_cache_path:
|
34
|
+
self.map = SimMap(
|
35
|
+
pb_path=map_cache_path,
|
36
|
+
)
|
37
|
+
else:
|
38
|
+
mongo_uri, mongo_db, mongo_coll, cache_dir = mongo_input
|
39
|
+
self.map = SimMap(
|
40
|
+
mongo_uri=mongo_uri,
|
41
|
+
mongo_db=mongo_db,
|
42
|
+
mongo_coll=mongo_coll,
|
43
|
+
cache_dir=cache_dir,
|
44
|
+
)
|
45
|
+
self.poi_cate = POI_CATG_DICT
|
46
|
+
|
47
|
+
def get_aoi(self, aoi_id: Optional[int] = None):
|
48
|
+
if aoi_id is None:
|
49
|
+
return list(self.map.aois.values())
|
50
|
+
else:
|
51
|
+
return self.map.aois[aoi_id]
|
52
|
+
|
53
|
+
def get_poi(self, poi_id: Optional[int] = None):
|
54
|
+
if poi_id is None:
|
55
|
+
return list(self.map.pois.values())
|
56
|
+
else:
|
57
|
+
return self.map.pois[poi_id]
|
58
|
+
|
59
|
+
def query_pois(self, **kwargs):
|
60
|
+
return self.map.query_pois(**kwargs)
|
61
|
+
|
62
|
+
def get_poi_cate(self):
|
63
|
+
return self.poi_cate
|
64
|
+
|
65
|
+
def get_map(self):
|
66
|
+
return self.map
|
67
|
+
|
68
|
+
def get_map_header(self):
|
69
|
+
return self.map.header
|
70
|
+
|
71
|
+
def get_projector(self):
|
72
|
+
return self.map.header["projection"]
|
73
|
+
|
74
|
+
|
29
75
|
class Simulator:
|
30
76
|
"""
|
31
77
|
Main class of the simulator.
|
@@ -35,7 +81,7 @@ class Simulator:
|
|
35
81
|
- It reads parameters from a configuration dictionary, initializes map data, and starts or connects to a simulation server as needed.
|
36
82
|
"""
|
37
83
|
|
38
|
-
def __init__(self, config: dict,
|
84
|
+
def __init__(self, config: dict, create_map: bool = False) -> None:
|
39
85
|
self.config = config
|
40
86
|
"""
|
41
87
|
- 模拟器配置
|
@@ -66,7 +112,8 @@ class Simulator:
|
|
66
112
|
self._sim_env = sim_env = ControlSimEnv(
|
67
113
|
task_name=config["simulator"].get("task", "citysim"),
|
68
114
|
map_file=_map_pb_path,
|
69
|
-
|
115
|
+
max_day=config["simulator"].get("max_day", 1000),
|
116
|
+
start_step=config["simulator"].get("start_step", 28800),
|
70
117
|
total_step=2147000000,
|
71
118
|
log_dir=config["simulator"].get("log_dir", "./log"),
|
72
119
|
min_step_time=config["simulator"].get("min_step_time", 1000),
|
@@ -88,16 +135,18 @@ class Simulator:
|
|
88
135
|
logger.warning(
|
89
136
|
"No simulator config found, no simulator client will be used"
|
90
137
|
)
|
91
|
-
self.
|
92
|
-
mongo_uri=_mongo_uri,
|
93
|
-
mongo_db=_mongo_db,
|
94
|
-
mongo_coll=_mongo_coll,
|
95
|
-
cache_dir=_map_cache_dir,
|
96
|
-
)
|
138
|
+
self._map = None
|
97
139
|
"""
|
98
140
|
- 模拟器地图对象
|
99
141
|
- Simulator map object
|
100
142
|
"""
|
143
|
+
if create_map:
|
144
|
+
_map_cache_path = "" # 地图pb文件路径
|
145
|
+
self._map = CityMap.remote(
|
146
|
+
(_mongo_uri, _mongo_db, _mongo_coll, _map_cache_dir),
|
147
|
+
_map_cache_path,
|
148
|
+
)
|
149
|
+
self._create_poi_id_2_aoi_id()
|
101
150
|
|
102
151
|
self.time: int = 0
|
103
152
|
"""
|
@@ -109,19 +158,32 @@ class Simulator:
|
|
109
158
|
self.map_y_gap = None
|
110
159
|
self._bbox: tuple[float, float, float, float] = (-1, -1, -1, -1)
|
111
160
|
self._lock = asyncio.Lock()
|
112
|
-
|
161
|
+
self._environment_prompt: dict[str, str] = {}
|
162
|
+
self._log_list = []
|
163
|
+
|
164
|
+
def set_map(self, map: ray.ObjectRef):
|
165
|
+
self._map = map
|
166
|
+
self._create_poi_id_2_aoi_id()
|
167
|
+
|
168
|
+
def _create_poi_id_2_aoi_id(self):
|
169
|
+
pois = ray.get(self._map.get_poi.remote()) # type:ignore
|
113
170
|
self.poi_id_2_aoi_id: dict[int, int] = {
|
114
|
-
poi["id"]: poi["aoi_id"] for
|
171
|
+
poi["id"]: poi["aoi_id"] for poi in pois
|
115
172
|
}
|
116
|
-
|
117
|
-
|
173
|
+
|
174
|
+
@property
|
175
|
+
def map(self):
|
176
|
+
return self._map
|
118
177
|
|
119
178
|
def get_log_list(self):
|
120
179
|
return self._log_list
|
121
|
-
|
180
|
+
|
122
181
|
def clear_log_list(self):
|
123
182
|
self._log_list = []
|
124
183
|
|
184
|
+
def get_poi_cate(self):
|
185
|
+
return self.poi_cate
|
186
|
+
|
125
187
|
@property
|
126
188
|
def environment(self) -> dict[str, str]:
|
127
189
|
"""
|
@@ -129,6 +191,9 @@ class Simulator:
|
|
129
191
|
"""
|
130
192
|
return self._environment_prompt
|
131
193
|
|
194
|
+
def get_server_addr(self):
|
195
|
+
return self.server_addr
|
196
|
+
|
132
197
|
def set_environment(self, environment: dict[str, str]):
|
133
198
|
"""
|
134
199
|
Set the entire dictionary of environment variables.
|
@@ -177,11 +242,7 @@ class Simulator:
|
|
177
242
|
Refer to https://cityproto.sim.fiblab.net/#city.person.1.GetPersonByLongLatBBoxResponse.
|
178
243
|
"""
|
179
244
|
start_time = time.time()
|
180
|
-
log = {
|
181
|
-
"req": "find_agents_by_area",
|
182
|
-
"start_time": start_time,
|
183
|
-
"consumption": 0
|
184
|
-
}
|
245
|
+
log = {"req": "find_agents_by_area", "start_time": start_time, "consumption": 0}
|
185
246
|
loop = asyncio.get_event_loop()
|
186
247
|
resp = loop.run_until_complete(
|
187
248
|
self._client.person_service.GetPersonByLongLatBBox(req=req)
|
@@ -216,18 +277,16 @@ class Simulator:
|
|
216
277
|
- `List[str]`: A list of unique POI category names.
|
217
278
|
"""
|
218
279
|
start_time = time.time()
|
219
|
-
log = {
|
220
|
-
"req": "get_poi_categories",
|
221
|
-
"start_time": start_time,
|
222
|
-
"consumption": 0
|
223
|
-
}
|
280
|
+
log = {"req": "get_poi_categories", "start_time": start_time, "consumption": 0}
|
224
281
|
categories: list[str] = []
|
225
282
|
if center is None:
|
226
283
|
center = (0, 0)
|
227
|
-
_pois: list[dict] =
|
228
|
-
|
229
|
-
|
230
|
-
|
284
|
+
_pois: list[dict] = ray.get(
|
285
|
+
self.map.query_pois.remote( # type:ignore
|
286
|
+
center=center,
|
287
|
+
radius=radius,
|
288
|
+
return_distance=False,
|
289
|
+
)
|
231
290
|
)
|
232
291
|
for poi in _pois:
|
233
292
|
catg = poi["category"]
|
@@ -252,11 +311,7 @@ class Simulator:
|
|
252
311
|
- `Union[int, str]`: The current simulation time either as an integer representing seconds since midnight or as a formatted string.
|
253
312
|
"""
|
254
313
|
start_time = time.time()
|
255
|
-
log = {
|
256
|
-
"req": "get_time",
|
257
|
-
"start_time": start_time,
|
258
|
-
"consumption": 0
|
259
|
-
}
|
314
|
+
log = {"req": "get_time", "start_time": start_time, "consumption": 0}
|
260
315
|
now = await self._client.clock_service.Now({})
|
261
316
|
now = cast(dict[str, int], now)
|
262
317
|
self.time = now["t"]
|
@@ -280,11 +335,7 @@ class Simulator:
|
|
280
335
|
This method sends a request to the simulator's pause service to pause the simulation.
|
281
336
|
"""
|
282
337
|
start_time = time.time()
|
283
|
-
log = {
|
284
|
-
"req": "pause",
|
285
|
-
"start_time": start_time,
|
286
|
-
"consumption": 0
|
287
|
-
}
|
338
|
+
log = {"req": "pause", "start_time": start_time, "consumption": 0}
|
288
339
|
await self._client.pause_service.pause()
|
289
340
|
log["consumption"] = time.time() - start_time
|
290
341
|
self._log_list.append(log)
|
@@ -296,11 +347,7 @@ class Simulator:
|
|
296
347
|
This method sends a request to the simulator's pause service to resume the simulation.
|
297
348
|
"""
|
298
349
|
start_time = time.time()
|
299
|
-
log = {
|
300
|
-
"req": "resume",
|
301
|
-
"start_time": start_time,
|
302
|
-
"consumption": 0
|
303
|
-
}
|
350
|
+
log = {"req": "resume", "start_time": start_time, "consumption": 0}
|
304
351
|
await self._client.pause_service.resume()
|
305
352
|
log["consumption"] = time.time() - start_time
|
306
353
|
self._log_list.append(log)
|
@@ -313,11 +360,7 @@ class Simulator:
|
|
313
360
|
- `int`: The day number since the start of the simulation.
|
314
361
|
"""
|
315
362
|
start_time = time.time()
|
316
|
-
log = {
|
317
|
-
"req": "get_simulator_day",
|
318
|
-
"start_time": start_time,
|
319
|
-
"consumption": 0
|
320
|
-
}
|
363
|
+
log = {"req": "get_simulator_day", "start_time": start_time, "consumption": 0}
|
321
364
|
now = await self._client.clock_service.Now({})
|
322
365
|
now = cast(dict[str, int], now)
|
323
366
|
day = now["day"]
|
@@ -336,7 +379,7 @@ class Simulator:
|
|
336
379
|
log = {
|
337
380
|
"req": "get_simulator_second_from_start_of_day",
|
338
381
|
"start_time": start_time,
|
339
|
-
"consumption": 0
|
382
|
+
"consumption": 0,
|
340
383
|
}
|
341
384
|
now = await self._client.clock_service.Now({})
|
342
385
|
now = cast(dict[str, int], now)
|
@@ -355,43 +398,35 @@ class Simulator:
|
|
355
398
|
- `Dict`: Information about the specified person.
|
356
399
|
"""
|
357
400
|
start_time = time.time()
|
358
|
-
log = {
|
359
|
-
|
360
|
-
"start_time": start_time,
|
361
|
-
"consumption": 0
|
362
|
-
}
|
363
|
-
person = await self._client.person_service.GetPerson(
|
401
|
+
log = {"req": "get_person", "start_time": start_time, "consumption": 0}
|
402
|
+
person: dict = await self._client.person_service.GetPerson(
|
364
403
|
req={"person_id": person_id}
|
365
404
|
) # type:ignore
|
366
405
|
log["consumption"] = time.time() - start_time
|
367
406
|
self._log_list.append(log)
|
368
407
|
return person
|
369
408
|
|
370
|
-
async def add_person(self,
|
409
|
+
async def add_person(self, dict_person: dict) -> dict:
|
371
410
|
"""
|
372
411
|
Add a new person to the simulation.
|
373
412
|
|
374
413
|
- **Args**:
|
375
|
-
- `
|
376
|
-
it will be wrapped in an `AddPersonRequest`. Otherwise, `person` is expected to already be a valid request object.
|
414
|
+
- `dict_person` (`dict`): The person object to add.
|
377
415
|
|
378
416
|
- **Returns**:
|
379
417
|
- `Dict`: Response from adding the person.
|
380
418
|
"""
|
381
419
|
start_time = time.time()
|
382
|
-
log = {
|
383
|
-
|
384
|
-
"start_time": start_time,
|
385
|
-
"consumption": 0
|
386
|
-
}
|
420
|
+
log = {"req": "add_person", "start_time": start_time, "consumption": 0}
|
421
|
+
person = dict2pb(dict_person, person_pb2.Person())
|
387
422
|
if isinstance(person, person_pb2.Person):
|
388
423
|
req = person_service.AddPersonRequest(person=person)
|
389
424
|
else:
|
390
425
|
req = person
|
391
|
-
|
426
|
+
resp: dict = await self._client.person_service.AddPerson(req) # type:ignore
|
392
427
|
log["consumption"] = time.time() - start_time
|
393
428
|
self._log_list.append(log)
|
394
|
-
return
|
429
|
+
return resp
|
395
430
|
|
396
431
|
async def set_aoi_schedules(
|
397
432
|
self,
|
@@ -411,15 +446,11 @@ class Simulator:
|
|
411
446
|
A list of AOI or POI IDs or tuples of (AOI ID, POI ID) that the person will visit.
|
412
447
|
- `departure_times` (`Optional[List[float]]`): Departure times for each trip in the schedule.
|
413
448
|
If not provided, current time will be used for all trips.
|
414
|
-
- `modes` (`Optional[List[
|
449
|
+
- `modes` (`Optional[List[int]]`): Travel modes for each trip.
|
415
450
|
Defaults to `TRIP_MODE_DRIVE_ONLY` if not specified.
|
416
451
|
"""
|
417
452
|
start_time = time.time()
|
418
|
-
log = {
|
419
|
-
"req": "set_aoi_schedules",
|
420
|
-
"start_time": start_time,
|
421
|
-
"consumption": 0
|
422
|
-
}
|
453
|
+
log = {"req": "set_aoi_schedules", "start_time": start_time, "consumption": 0}
|
423
454
|
cur_time = float(await self.get_time())
|
424
455
|
if not isinstance(target_positions, list):
|
425
456
|
target_positions = [target_positions]
|
@@ -494,7 +525,7 @@ class Simulator:
|
|
494
525
|
log = {
|
495
526
|
"req": "reset_person_position",
|
496
527
|
"start_time": start_time,
|
497
|
-
"consumption": 0
|
528
|
+
"consumption": 0,
|
498
529
|
}
|
499
530
|
reset_position = {}
|
500
531
|
if aoi_id is not None:
|
@@ -545,11 +576,7 @@ class Simulator:
|
|
545
576
|
- `List[Dict]`: A list of dictionaries containing information about the POIs found.
|
546
577
|
"""
|
547
578
|
start_time = time.time()
|
548
|
-
log = {
|
549
|
-
"req": "get_around_poi",
|
550
|
-
"start_time": start_time,
|
551
|
-
"consumption": 0
|
552
|
-
}
|
579
|
+
log = {"req": "get_around_poi", "start_time": start_time, "consumption": 0}
|
553
580
|
if isinstance(poi_type, str):
|
554
581
|
poi_type = [poi_type]
|
555
582
|
transformed_poi_type: list[str] = []
|
@@ -560,10 +587,12 @@ class Simulator:
|
|
560
587
|
transformed_poi_type += self.poi_cate[t]
|
561
588
|
poi_type_set = set(transformed_poi_type)
|
562
589
|
# 获取半径内的poi
|
563
|
-
_pois: list[dict] =
|
564
|
-
|
565
|
-
|
566
|
-
|
590
|
+
_pois: list[dict] = ray.get(
|
591
|
+
self.map.query_pois.remote( # type:ignore
|
592
|
+
center=center,
|
593
|
+
radius=radius,
|
594
|
+
return_distance=False,
|
595
|
+
)
|
567
596
|
)
|
568
597
|
# 过滤掉不满足类别前缀的poi
|
569
598
|
pois = []
|
@@ -574,4 +603,4 @@ class Simulator:
|
|
574
603
|
pois.append(poi)
|
575
604
|
log["consumption"] = time.time() - start_time
|
576
605
|
self._log_list.append(log)
|
577
|
-
return pois
|
606
|
+
return pois
|
pycityagent/llm/llm.py
CHANGED
@@ -232,111 +232,112 @@ class LLM:
|
|
232
232
|
"""
|
233
233
|
start_time = time.time()
|
234
234
|
log = {"request_time": start_time}
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
task_id = response.id
|
301
|
-
task_status = ""
|
302
|
-
get_cnt = 0
|
303
|
-
cnt_threshold = int(timeout / 0.5)
|
304
|
-
while (
|
305
|
-
task_status != "SUCCESS"
|
306
|
-
and task_status != "FAILED"
|
307
|
-
and get_cnt <= cnt_threshold
|
308
|
-
):
|
309
|
-
result_response = client.chat.asyncCompletions.retrieve_completion_result(id=task_id) # type: ignore
|
310
|
-
task_status = result_response.task_status
|
311
|
-
await asyncio.sleep(0.5)
|
312
|
-
get_cnt += 1
|
313
|
-
if task_status != "SUCCESS":
|
314
|
-
raise Exception(f"Task failed with status: {task_status}")
|
315
|
-
|
316
|
-
self._client_usage[self._current_client_index]["prompt_tokens"] += result_response.usage.prompt_tokens # type: ignore
|
317
|
-
self._client_usage[self._current_client_index]["completion_tokens"] += result_response.usage.completion_tokens # type: ignore
|
318
|
-
self._client_usage[self._current_client_index]["request_number"] += 1
|
319
|
-
end_time = time.time()
|
320
|
-
log["used_time"] = end_time - start_time
|
321
|
-
log["token_consumption"] = result_response.usage.prompt_tokens + result_response.usage.completion_tokens
|
322
|
-
self._log_list.append(log)
|
323
|
-
if tools and result_response.choices[0].message.tool_calls: # type: ignore
|
324
|
-
return json.loads(
|
325
|
-
result_response.choices[0] # type: ignore
|
326
|
-
.message.tool_calls[0]
|
327
|
-
.function.arguments
|
235
|
+
async with self.semaphore:
|
236
|
+
if (
|
237
|
+
self.config.text["request_type"] == "openai"
|
238
|
+
or self.config.text["request_type"] == "deepseek"
|
239
|
+
or self.config.text["request_type"] == "qwen"
|
240
|
+
):
|
241
|
+
for attempt in range(retries):
|
242
|
+
try:
|
243
|
+
client = self._get_next_client()
|
244
|
+
response = await client.chat.completions.create(
|
245
|
+
model=self.config.text["model"],
|
246
|
+
messages=dialog,
|
247
|
+
temperature=temperature,
|
248
|
+
max_tokens=max_tokens,
|
249
|
+
top_p=top_p,
|
250
|
+
frequency_penalty=frequency_penalty, # type: ignore
|
251
|
+
presence_penalty=presence_penalty, # type: ignore
|
252
|
+
stream=False,
|
253
|
+
timeout=timeout,
|
254
|
+
tools=tools,
|
255
|
+
tool_choice=tool_choice,
|
256
|
+
) # type: ignore
|
257
|
+
self._client_usage[self._current_client_index]["prompt_tokens"] += response.usage.prompt_tokens # type: ignore
|
258
|
+
self._client_usage[self._current_client_index]["completion_tokens"] += response.usage.completion_tokens # type: ignore
|
259
|
+
self._client_usage[self._current_client_index]["request_number"] += 1
|
260
|
+
end_time = time.time()
|
261
|
+
log["consumption"] = end_time - start_time
|
262
|
+
log["input_tokens"] = response.usage.prompt_tokens
|
263
|
+
log["output_tokens"] = response.usage.completion_tokens
|
264
|
+
self._log_list.append(log)
|
265
|
+
if tools and response.choices[0].message.tool_calls:
|
266
|
+
return json.loads(
|
267
|
+
response.choices[0]
|
268
|
+
.message.tool_calls[0]
|
269
|
+
.function.arguments
|
270
|
+
)
|
271
|
+
else:
|
272
|
+
return response.choices[0].message.content
|
273
|
+
except APIConnectionError as e:
|
274
|
+
print("API connection error:", e)
|
275
|
+
if attempt < retries - 1:
|
276
|
+
await asyncio.sleep(2**attempt)
|
277
|
+
else:
|
278
|
+
raise e
|
279
|
+
except OpenAIError as e:
|
280
|
+
if hasattr(e, "http_status"):
|
281
|
+
print(f"HTTP status code: {e.http_status}") # type: ignore
|
282
|
+
else:
|
283
|
+
print("An error occurred:", e)
|
284
|
+
if attempt < retries - 1:
|
285
|
+
await asyncio.sleep(2**attempt)
|
286
|
+
else:
|
287
|
+
raise e
|
288
|
+
elif self.config.text["request_type"] == "zhipuai":
|
289
|
+
for attempt in range(retries):
|
290
|
+
try:
|
291
|
+
client = self._get_next_client()
|
292
|
+
response = client.chat.asyncCompletions.create( # type: ignore
|
293
|
+
model=self.config.text["model"],
|
294
|
+
messages=dialog,
|
295
|
+
temperature=temperature,
|
296
|
+
top_p=top_p,
|
297
|
+
timeout=timeout,
|
298
|
+
tools=tools,
|
299
|
+
tool_choice=tool_choice,
|
328
300
|
)
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
301
|
+
task_id = response.id
|
302
|
+
task_status = ""
|
303
|
+
get_cnt = 0
|
304
|
+
cnt_threshold = int(timeout / 0.5)
|
305
|
+
while (
|
306
|
+
task_status != "SUCCESS"
|
307
|
+
and task_status != "FAILED"
|
308
|
+
and get_cnt <= cnt_threshold
|
309
|
+
):
|
310
|
+
result_response = client.chat.asyncCompletions.retrieve_completion_result(id=task_id) # type: ignore
|
311
|
+
task_status = result_response.task_status
|
312
|
+
await asyncio.sleep(0.5)
|
313
|
+
get_cnt += 1
|
314
|
+
if task_status != "SUCCESS":
|
315
|
+
raise Exception(f"Task failed with status: {task_status}")
|
316
|
+
|
317
|
+
self._client_usage[self._current_client_index]["prompt_tokens"] += result_response.usage.prompt_tokens # type: ignore
|
318
|
+
self._client_usage[self._current_client_index]["completion_tokens"] += result_response.usage.completion_tokens # type: ignore
|
319
|
+
self._client_usage[self._current_client_index]["request_number"] += 1
|
320
|
+
end_time = time.time()
|
321
|
+
log["used_time"] = end_time - start_time
|
322
|
+
log["token_consumption"] = result_response.usage.prompt_tokens + result_response.usage.completion_tokens
|
323
|
+
self._log_list.append(log)
|
324
|
+
if tools and result_response.choices[0].message.tool_calls: # type: ignore
|
325
|
+
return json.loads(
|
326
|
+
result_response.choices[0] # type: ignore
|
327
|
+
.message.tool_calls[0]
|
328
|
+
.function.arguments
|
329
|
+
)
|
330
|
+
else:
|
331
|
+
return result_response.choices[0].message.content # type: ignore
|
332
|
+
except APIConnectionError as e:
|
333
|
+
print("API connection error:", e)
|
334
|
+
if attempt < retries - 1:
|
335
|
+
await asyncio.sleep(2**attempt)
|
336
|
+
else:
|
337
|
+
raise e
|
338
|
+
else:
|
339
|
+
print("ERROR: Wrong Config")
|
340
|
+
return "wrong config"
|
340
341
|
|
341
342
|
async def img_understand(
|
342
343
|
self, img_path: Union[str, list[str]], prompt: Optional[str] = None
|
pycityagent/memory/memory.py
CHANGED
@@ -341,7 +341,7 @@ class StreamMemory:
|
|
341
341
|
"""获取指定ID的记忆"""
|
342
342
|
memories = [memory for memory in self._memories if memory.id in memory_ids]
|
343
343
|
sorted_results = sorted(memories, key=lambda x: (x.day, x.t), reverse=True)
|
344
|
-
return self.format_memory(sorted_results)
|
344
|
+
return await self.format_memory(sorted_results)
|
345
345
|
|
346
346
|
async def search(
|
347
347
|
self,
|
pycityagent/pycityagent-sim
CHANGED
Binary file
|