pycityagent 2.0.0a73__cp310-cp310-macosx_11_0_arm64.whl → 2.0.0a75__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,12 +3,13 @@
3
3
  import asyncio
4
4
  import logging
5
5
  import os
6
- from datetime import datetime, timedelta
7
6
  import time
8
- from typing import Any, Optional, Union, cast
7
+ from datetime import datetime, timedelta
8
+ from typing import Optional, Union, cast
9
9
 
10
+ import ray
10
11
  from mosstool.type import TripMode
11
- from mosstool.util.format_converter import coll2pb
12
+ from mosstool.util.format_converter import coll2pb, dict2pb
12
13
  from pycitydata.map import Map as SimMap
13
14
  from pycityproto.city.map.v2 import map_pb2 as map_pb2
14
15
  from pycityproto.city.person.v2 import person_pb2 as person_pb2
@@ -26,6 +27,51 @@ __all__ = [
26
27
  ]
27
28
 
28
29
 
30
+ @ray.remote
31
+ class CityMap:
32
+ def __init__(self, mongo_input: tuple[str, str, str, str], map_cache_path: str):
33
+ if map_cache_path:
34
+ self.map = SimMap(
35
+ pb_path=map_cache_path,
36
+ )
37
+ else:
38
+ mongo_uri, mongo_db, mongo_coll, cache_dir = mongo_input
39
+ self.map = SimMap(
40
+ mongo_uri=mongo_uri,
41
+ mongo_db=mongo_db,
42
+ mongo_coll=mongo_coll,
43
+ cache_dir=cache_dir,
44
+ )
45
+ self.poi_cate = POI_CATG_DICT
46
+
47
+ def get_aoi(self, aoi_id: Optional[int] = None):
48
+ if aoi_id is None:
49
+ return list(self.map.aois.values())
50
+ else:
51
+ return self.map.aois[aoi_id]
52
+
53
+ def get_poi(self, poi_id: Optional[int] = None):
54
+ if poi_id is None:
55
+ return list(self.map.pois.values())
56
+ else:
57
+ return self.map.pois[poi_id]
58
+
59
+ def query_pois(self, **kwargs):
60
+ return self.map.query_pois(**kwargs)
61
+
62
+ def get_poi_cate(self):
63
+ return self.poi_cate
64
+
65
+ def get_map(self):
66
+ return self.map
67
+
68
+ def get_map_header(self):
69
+ return self.map.header
70
+
71
+ def get_projector(self):
72
+ return self.map.header["projection"]
73
+
74
+
29
75
  class Simulator:
30
76
  """
31
77
  Main class of the simulator.
@@ -35,7 +81,7 @@ class Simulator:
35
81
  - It reads parameters from a configuration dictionary, initializes map data, and starts or connects to a simulation server as needed.
36
82
  """
37
83
 
38
- def __init__(self, config: dict, secure: bool = False) -> None:
84
+ def __init__(self, config: dict, create_map: bool = False) -> None:
39
85
  self.config = config
40
86
  """
41
87
  - 模拟器配置
@@ -66,7 +112,8 @@ class Simulator:
66
112
  self._sim_env = sim_env = ControlSimEnv(
67
113
  task_name=config["simulator"].get("task", "citysim"),
68
114
  map_file=_map_pb_path,
69
- start_step=config["simulator"].get("start_step", 0),
115
+ max_day=config["simulator"].get("max_day", 1000),
116
+ start_step=config["simulator"].get("start_step", 28800),
70
117
  total_step=2147000000,
71
118
  log_dir=config["simulator"].get("log_dir", "./log"),
72
119
  min_step_time=config["simulator"].get("min_step_time", 1000),
@@ -88,16 +135,18 @@ class Simulator:
88
135
  logger.warning(
89
136
  "No simulator config found, no simulator client will be used"
90
137
  )
91
- self.map = SimMap(
92
- mongo_uri=_mongo_uri,
93
- mongo_db=_mongo_db,
94
- mongo_coll=_mongo_coll,
95
- cache_dir=_map_cache_dir,
96
- )
138
+ self._map = None
97
139
  """
98
140
  - 模拟器地图对象
99
141
  - Simulator map object
100
142
  """
143
+ if create_map:
144
+ _map_cache_path = "" # 地图pb文件路径
145
+ self._map = CityMap.remote(
146
+ (_mongo_uri, _mongo_db, _mongo_coll, _map_cache_dir),
147
+ _map_cache_path,
148
+ )
149
+ self._create_poi_id_2_aoi_id()
101
150
 
102
151
  self.time: int = 0
103
152
  """
@@ -109,19 +158,32 @@ class Simulator:
109
158
  self.map_y_gap = None
110
159
  self._bbox: tuple[float, float, float, float] = (-1, -1, -1, -1)
111
160
  self._lock = asyncio.Lock()
112
- # poi id dict
161
+ self._environment_prompt: dict[str, str] = {}
162
+ self._log_list = []
163
+
164
+ def set_map(self, map: ray.ObjectRef):
165
+ self._map = map
166
+ self._create_poi_id_2_aoi_id()
167
+
168
+ def _create_poi_id_2_aoi_id(self):
169
+ pois = ray.get(self._map.get_poi.remote()) # type:ignore
113
170
  self.poi_id_2_aoi_id: dict[int, int] = {
114
- poi["id"]: poi["aoi_id"] for _, poi in self.map.pois.items()
171
+ poi["id"]: poi["aoi_id"] for poi in pois
115
172
  }
116
- self._environment_prompt:dict[str, str] = {}
117
- self._log_list = []
173
+
174
+ @property
175
+ def map(self):
176
+ return self._map
118
177
 
119
178
  def get_log_list(self):
120
179
  return self._log_list
121
-
180
+
122
181
  def clear_log_list(self):
123
182
  self._log_list = []
124
183
 
184
+ def get_poi_cate(self):
185
+ return self.poi_cate
186
+
125
187
  @property
126
188
  def environment(self) -> dict[str, str]:
127
189
  """
@@ -129,6 +191,9 @@ class Simulator:
129
191
  """
130
192
  return self._environment_prompt
131
193
 
194
+ def get_server_addr(self):
195
+ return self.server_addr
196
+
132
197
  def set_environment(self, environment: dict[str, str]):
133
198
  """
134
199
  Set the entire dictionary of environment variables.
@@ -177,11 +242,7 @@ class Simulator:
177
242
  Refer to https://cityproto.sim.fiblab.net/#city.person.1.GetPersonByLongLatBBoxResponse.
178
243
  """
179
244
  start_time = time.time()
180
- log = {
181
- "req": "find_agents_by_area",
182
- "start_time": start_time,
183
- "consumption": 0
184
- }
245
+ log = {"req": "find_agents_by_area", "start_time": start_time, "consumption": 0}
185
246
  loop = asyncio.get_event_loop()
186
247
  resp = loop.run_until_complete(
187
248
  self._client.person_service.GetPersonByLongLatBBox(req=req)
@@ -216,18 +277,16 @@ class Simulator:
216
277
  - `List[str]`: A list of unique POI category names.
217
278
  """
218
279
  start_time = time.time()
219
- log = {
220
- "req": "get_poi_categories",
221
- "start_time": start_time,
222
- "consumption": 0
223
- }
280
+ log = {"req": "get_poi_categories", "start_time": start_time, "consumption": 0}
224
281
  categories: list[str] = []
225
282
  if center is None:
226
283
  center = (0, 0)
227
- _pois: list[dict] = self.map.query_pois( # type:ignore
228
- center=center,
229
- radius=radius,
230
- return_distance=False,
284
+ _pois: list[dict] = ray.get(
285
+ self.map.query_pois.remote( # type:ignore
286
+ center=center,
287
+ radius=radius,
288
+ return_distance=False,
289
+ )
231
290
  )
232
291
  for poi in _pois:
233
292
  catg = poi["category"]
@@ -252,11 +311,7 @@ class Simulator:
252
311
  - `Union[int, str]`: The current simulation time either as an integer representing seconds since midnight or as a formatted string.
253
312
  """
254
313
  start_time = time.time()
255
- log = {
256
- "req": "get_time",
257
- "start_time": start_time,
258
- "consumption": 0
259
- }
314
+ log = {"req": "get_time", "start_time": start_time, "consumption": 0}
260
315
  now = await self._client.clock_service.Now({})
261
316
  now = cast(dict[str, int], now)
262
317
  self.time = now["t"]
@@ -280,11 +335,7 @@ class Simulator:
280
335
  This method sends a request to the simulator's pause service to pause the simulation.
281
336
  """
282
337
  start_time = time.time()
283
- log = {
284
- "req": "pause",
285
- "start_time": start_time,
286
- "consumption": 0
287
- }
338
+ log = {"req": "pause", "start_time": start_time, "consumption": 0}
288
339
  await self._client.pause_service.pause()
289
340
  log["consumption"] = time.time() - start_time
290
341
  self._log_list.append(log)
@@ -296,11 +347,7 @@ class Simulator:
296
347
  This method sends a request to the simulator's pause service to resume the simulation.
297
348
  """
298
349
  start_time = time.time()
299
- log = {
300
- "req": "resume",
301
- "start_time": start_time,
302
- "consumption": 0
303
- }
350
+ log = {"req": "resume", "start_time": start_time, "consumption": 0}
304
351
  await self._client.pause_service.resume()
305
352
  log["consumption"] = time.time() - start_time
306
353
  self._log_list.append(log)
@@ -313,11 +360,7 @@ class Simulator:
313
360
  - `int`: The day number since the start of the simulation.
314
361
  """
315
362
  start_time = time.time()
316
- log = {
317
- "req": "get_simulator_day",
318
- "start_time": start_time,
319
- "consumption": 0
320
- }
363
+ log = {"req": "get_simulator_day", "start_time": start_time, "consumption": 0}
321
364
  now = await self._client.clock_service.Now({})
322
365
  now = cast(dict[str, int], now)
323
366
  day = now["day"]
@@ -336,7 +379,7 @@ class Simulator:
336
379
  log = {
337
380
  "req": "get_simulator_second_from_start_of_day",
338
381
  "start_time": start_time,
339
- "consumption": 0
382
+ "consumption": 0,
340
383
  }
341
384
  now = await self._client.clock_service.Now({})
342
385
  now = cast(dict[str, int], now)
@@ -355,43 +398,35 @@ class Simulator:
355
398
  - `Dict`: Information about the specified person.
356
399
  """
357
400
  start_time = time.time()
358
- log = {
359
- "req": "get_person",
360
- "start_time": start_time,
361
- "consumption": 0
362
- }
363
- person = await self._client.person_service.GetPerson(
401
+ log = {"req": "get_person", "start_time": start_time, "consumption": 0}
402
+ person: dict = await self._client.person_service.GetPerson(
364
403
  req={"person_id": person_id}
365
404
  ) # type:ignore
366
405
  log["consumption"] = time.time() - start_time
367
406
  self._log_list.append(log)
368
407
  return person
369
408
 
370
- async def add_person(self, person: Any) -> dict:
409
+ async def add_person(self, dict_person: dict) -> dict:
371
410
  """
372
411
  Add a new person to the simulation.
373
412
 
374
413
  - **Args**:
375
- - `person` (`Any`): The person object to add. If it's an instance of `person_pb2.Person`,
376
- it will be wrapped in an `AddPersonRequest`. Otherwise, `person` is expected to already be a valid request object.
414
+ - `dict_person` (`dict`): The person object to add.
377
415
 
378
416
  - **Returns**:
379
417
  - `Dict`: Response from adding the person.
380
418
  """
381
419
  start_time = time.time()
382
- log = {
383
- "req": "add_person",
384
- "start_time": start_time,
385
- "consumption": 0
386
- }
420
+ log = {"req": "add_person", "start_time": start_time, "consumption": 0}
421
+ person = dict2pb(dict_person, person_pb2.Person())
387
422
  if isinstance(person, person_pb2.Person):
388
423
  req = person_service.AddPersonRequest(person=person)
389
424
  else:
390
425
  req = person
391
- person_id = await self._client.person_service.AddPerson(req) # type:ignore
426
+ resp: dict = await self._client.person_service.AddPerson(req) # type:ignore
392
427
  log["consumption"] = time.time() - start_time
393
428
  self._log_list.append(log)
394
- return person_id
429
+ return resp
395
430
 
396
431
  async def set_aoi_schedules(
397
432
  self,
@@ -411,15 +446,11 @@ class Simulator:
411
446
  A list of AOI or POI IDs or tuples of (AOI ID, POI ID) that the person will visit.
412
447
  - `departure_times` (`Optional[List[float]]`): Departure times for each trip in the schedule.
413
448
  If not provided, current time will be used for all trips.
414
- - `modes` (`Optional[List[TripMode]]`): Travel modes for each trip.
449
+ - `modes` (`Optional[List[int]]`): Travel modes for each trip.
415
450
  Defaults to `TRIP_MODE_DRIVE_ONLY` if not specified.
416
451
  """
417
452
  start_time = time.time()
418
- log = {
419
- "req": "set_aoi_schedules",
420
- "start_time": start_time,
421
- "consumption": 0
422
- }
453
+ log = {"req": "set_aoi_schedules", "start_time": start_time, "consumption": 0}
423
454
  cur_time = float(await self.get_time())
424
455
  if not isinstance(target_positions, list):
425
456
  target_positions = [target_positions]
@@ -494,7 +525,7 @@ class Simulator:
494
525
  log = {
495
526
  "req": "reset_person_position",
496
527
  "start_time": start_time,
497
- "consumption": 0
528
+ "consumption": 0,
498
529
  }
499
530
  reset_position = {}
500
531
  if aoi_id is not None:
@@ -545,11 +576,7 @@ class Simulator:
545
576
  - `List[Dict]`: A list of dictionaries containing information about the POIs found.
546
577
  """
547
578
  start_time = time.time()
548
- log = {
549
- "req": "get_around_poi",
550
- "start_time": start_time,
551
- "consumption": 0
552
- }
579
+ log = {"req": "get_around_poi", "start_time": start_time, "consumption": 0}
553
580
  if isinstance(poi_type, str):
554
581
  poi_type = [poi_type]
555
582
  transformed_poi_type: list[str] = []
@@ -560,10 +587,12 @@ class Simulator:
560
587
  transformed_poi_type += self.poi_cate[t]
561
588
  poi_type_set = set(transformed_poi_type)
562
589
  # 获取半径内的poi
563
- _pois: list[dict] = self.map.query_pois( # type:ignore
564
- center=center,
565
- radius=radius,
566
- return_distance=False,
590
+ _pois: list[dict] = ray.get(
591
+ self.map.query_pois.remote( # type:ignore
592
+ center=center,
593
+ radius=radius,
594
+ return_distance=False,
595
+ )
567
596
  )
568
597
  # 过滤掉不满足类别前缀的poi
569
598
  pois = []
@@ -574,4 +603,4 @@ class Simulator:
574
603
  pois.append(poi)
575
604
  log["consumption"] = time.time() - start_time
576
605
  self._log_list.append(log)
577
- return pois
606
+ return pois
pycityagent/llm/llm.py CHANGED
@@ -232,111 +232,112 @@ class LLM:
232
232
  """
233
233
  start_time = time.time()
234
234
  log = {"request_time": start_time}
235
- if (
236
- self.config.text["request_type"] == "openai"
237
- or self.config.text["request_type"] == "deepseek"
238
- or self.config.text["request_type"] == "qwen"
239
- ):
240
- for attempt in range(retries):
241
- try:
242
- client = self._get_next_client()
243
- response = await client.chat.completions.create(
244
- model=self.config.text["model"],
245
- messages=dialog,
246
- temperature=temperature,
247
- max_tokens=max_tokens,
248
- top_p=top_p,
249
- frequency_penalty=frequency_penalty, # type: ignore
250
- presence_penalty=presence_penalty, # type: ignore
251
- stream=False,
252
- timeout=timeout,
253
- tools=tools,
254
- tool_choice=tool_choice,
255
- ) # type: ignore
256
- self._client_usage[self._current_client_index]["prompt_tokens"] += response.usage.prompt_tokens # type: ignore
257
- self._client_usage[self._current_client_index]["completion_tokens"] += response.usage.completion_tokens # type: ignore
258
- self._client_usage[self._current_client_index]["request_number"] += 1
259
- end_time = time.time()
260
- log["consumption"] = end_time - start_time
261
- log["input_tokens"] = response.usage.prompt_tokens
262
- log["output_tokens"] = response.usage.completion_tokens
263
- self._log_list.append(log)
264
- if tools and response.choices[0].message.tool_calls:
265
- return json.loads(
266
- response.choices[0]
267
- .message.tool_calls[0]
268
- .function.arguments
269
- )
270
- else:
271
- return response.choices[0].message.content
272
- except APIConnectionError as e:
273
- print("API connection error:", e)
274
- if attempt < retries - 1:
275
- await asyncio.sleep(2**attempt)
276
- else:
277
- raise e
278
- except OpenAIError as e:
279
- if hasattr(e, "http_status"):
280
- print(f"HTTP status code: {e.http_status}") # type: ignore
281
- else:
282
- print("An error occurred:", e)
283
- if attempt < retries - 1:
284
- await asyncio.sleep(2**attempt)
285
- else:
286
- raise e
287
- elif self.config.text["request_type"] == "zhipuai":
288
- for attempt in range(retries):
289
- try:
290
- client = self._get_next_client()
291
- response = client.chat.asyncCompletions.create( # type: ignore
292
- model=self.config.text["model"],
293
- messages=dialog,
294
- temperature=temperature,
295
- top_p=top_p,
296
- timeout=timeout,
297
- tools=tools,
298
- tool_choice=tool_choice,
299
- )
300
- task_id = response.id
301
- task_status = ""
302
- get_cnt = 0
303
- cnt_threshold = int(timeout / 0.5)
304
- while (
305
- task_status != "SUCCESS"
306
- and task_status != "FAILED"
307
- and get_cnt <= cnt_threshold
308
- ):
309
- result_response = client.chat.asyncCompletions.retrieve_completion_result(id=task_id) # type: ignore
310
- task_status = result_response.task_status
311
- await asyncio.sleep(0.5)
312
- get_cnt += 1
313
- if task_status != "SUCCESS":
314
- raise Exception(f"Task failed with status: {task_status}")
315
-
316
- self._client_usage[self._current_client_index]["prompt_tokens"] += result_response.usage.prompt_tokens # type: ignore
317
- self._client_usage[self._current_client_index]["completion_tokens"] += result_response.usage.completion_tokens # type: ignore
318
- self._client_usage[self._current_client_index]["request_number"] += 1
319
- end_time = time.time()
320
- log["used_time"] = end_time - start_time
321
- log["token_consumption"] = result_response.usage.prompt_tokens + result_response.usage.completion_tokens
322
- self._log_list.append(log)
323
- if tools and result_response.choices[0].message.tool_calls: # type: ignore
324
- return json.loads(
325
- result_response.choices[0] # type: ignore
326
- .message.tool_calls[0]
327
- .function.arguments
235
+ async with self.semaphore:
236
+ if (
237
+ self.config.text["request_type"] == "openai"
238
+ or self.config.text["request_type"] == "deepseek"
239
+ or self.config.text["request_type"] == "qwen"
240
+ ):
241
+ for attempt in range(retries):
242
+ try:
243
+ client = self._get_next_client()
244
+ response = await client.chat.completions.create(
245
+ model=self.config.text["model"],
246
+ messages=dialog,
247
+ temperature=temperature,
248
+ max_tokens=max_tokens,
249
+ top_p=top_p,
250
+ frequency_penalty=frequency_penalty, # type: ignore
251
+ presence_penalty=presence_penalty, # type: ignore
252
+ stream=False,
253
+ timeout=timeout,
254
+ tools=tools,
255
+ tool_choice=tool_choice,
256
+ ) # type: ignore
257
+ self._client_usage[self._current_client_index]["prompt_tokens"] += response.usage.prompt_tokens # type: ignore
258
+ self._client_usage[self._current_client_index]["completion_tokens"] += response.usage.completion_tokens # type: ignore
259
+ self._client_usage[self._current_client_index]["request_number"] += 1
260
+ end_time = time.time()
261
+ log["consumption"] = end_time - start_time
262
+ log["input_tokens"] = response.usage.prompt_tokens
263
+ log["output_tokens"] = response.usage.completion_tokens
264
+ self._log_list.append(log)
265
+ if tools and response.choices[0].message.tool_calls:
266
+ return json.loads(
267
+ response.choices[0]
268
+ .message.tool_calls[0]
269
+ .function.arguments
270
+ )
271
+ else:
272
+ return response.choices[0].message.content
273
+ except APIConnectionError as e:
274
+ print("API connection error:", e)
275
+ if attempt < retries - 1:
276
+ await asyncio.sleep(2**attempt)
277
+ else:
278
+ raise e
279
+ except OpenAIError as e:
280
+ if hasattr(e, "http_status"):
281
+ print(f"HTTP status code: {e.http_status}") # type: ignore
282
+ else:
283
+ print("An error occurred:", e)
284
+ if attempt < retries - 1:
285
+ await asyncio.sleep(2**attempt)
286
+ else:
287
+ raise e
288
+ elif self.config.text["request_type"] == "zhipuai":
289
+ for attempt in range(retries):
290
+ try:
291
+ client = self._get_next_client()
292
+ response = client.chat.asyncCompletions.create( # type: ignore
293
+ model=self.config.text["model"],
294
+ messages=dialog,
295
+ temperature=temperature,
296
+ top_p=top_p,
297
+ timeout=timeout,
298
+ tools=tools,
299
+ tool_choice=tool_choice,
328
300
  )
329
- else:
330
- return result_response.choices[0].message.content # type: ignore
331
- except APIConnectionError as e:
332
- print("API connection error:", e)
333
- if attempt < retries - 1:
334
- await asyncio.sleep(2**attempt)
335
- else:
336
- raise e
337
- else:
338
- print("ERROR: Wrong Config")
339
- return "wrong config"
301
+ task_id = response.id
302
+ task_status = ""
303
+ get_cnt = 0
304
+ cnt_threshold = int(timeout / 0.5)
305
+ while (
306
+ task_status != "SUCCESS"
307
+ and task_status != "FAILED"
308
+ and get_cnt <= cnt_threshold
309
+ ):
310
+ result_response = client.chat.asyncCompletions.retrieve_completion_result(id=task_id) # type: ignore
311
+ task_status = result_response.task_status
312
+ await asyncio.sleep(0.5)
313
+ get_cnt += 1
314
+ if task_status != "SUCCESS":
315
+ raise Exception(f"Task failed with status: {task_status}")
316
+
317
+ self._client_usage[self._current_client_index]["prompt_tokens"] += result_response.usage.prompt_tokens # type: ignore
318
+ self._client_usage[self._current_client_index]["completion_tokens"] += result_response.usage.completion_tokens # type: ignore
319
+ self._client_usage[self._current_client_index]["request_number"] += 1
320
+ end_time = time.time()
321
+ log["used_time"] = end_time - start_time
322
+ log["token_consumption"] = result_response.usage.prompt_tokens + result_response.usage.completion_tokens
323
+ self._log_list.append(log)
324
+ if tools and result_response.choices[0].message.tool_calls: # type: ignore
325
+ return json.loads(
326
+ result_response.choices[0] # type: ignore
327
+ .message.tool_calls[0]
328
+ .function.arguments
329
+ )
330
+ else:
331
+ return result_response.choices[0].message.content # type: ignore
332
+ except APIConnectionError as e:
333
+ print("API connection error:", e)
334
+ if attempt < retries - 1:
335
+ await asyncio.sleep(2**attempt)
336
+ else:
337
+ raise e
338
+ else:
339
+ print("ERROR: Wrong Config")
340
+ return "wrong config"
340
341
 
341
342
  async def img_understand(
342
343
  self, img_path: Union[str, list[str]], prompt: Optional[str] = None
@@ -341,7 +341,7 @@ class StreamMemory:
341
341
  """获取指定ID的记忆"""
342
342
  memories = [memory for memory in self._memories if memory.id in memory_ids]
343
343
  sorted_results = sorted(memories, key=lambda x: (x.day, x.t), reverse=True)
344
- return self.format_memory(sorted_results)
344
+ return await self.format_memory(sorted_results)
345
345
 
346
346
  async def search(
347
347
  self,
Binary file