pycityagent 2.0.0a73__cp312-cp312-macosx_11_0_arm64.whl → 2.0.0a74__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,16 +5,17 @@ import logging
5
5
  import os
6
6
  from datetime import datetime, timedelta
7
7
  import time
8
- from typing import Any, Optional, Union, cast
8
+ from typing import Optional, Union, cast
9
9
 
10
- from mosstool.type import TripMode
11
- from mosstool.util.format_converter import coll2pb
10
+ from mosstool.util.format_converter import coll2pb, dict2pb
12
11
  from pycitydata.map import Map as SimMap
13
12
  from pycityproto.city.map.v2 import map_pb2 as map_pb2
14
13
  from pycityproto.city.person.v2 import person_pb2 as person_pb2
15
14
  from pycityproto.city.person.v2 import person_service_pb2 as person_service
16
15
  from pymongo import MongoClient
16
+ import ray
17
17
  from shapely.geometry import Point
18
+ from mosstool.type import TripMode
18
19
 
19
20
  from .sim import CityClient, ControlSimEnv
20
21
  from .utils.const import *
@@ -25,6 +26,42 @@ __all__ = [
25
26
  "Simulator",
26
27
  ]
27
28
 
29
+ @ray.remote
30
+ class CityMap:
31
+ def __init__(self, mongo_uri: str, mongo_db: str, mongo_coll: str, cache_dir: str):
32
+ self.map = SimMap(
33
+ mongo_uri=mongo_uri,
34
+ mongo_db=mongo_db,
35
+ mongo_coll=mongo_coll,
36
+ cache_dir=cache_dir,
37
+ )
38
+
39
+ def get_aoi(self, aoi_id: Optional[int] = None):
40
+ if aoi_id is None:
41
+ return list(self.map.aois.values())
42
+ else:
43
+ return self.map.aois[aoi_id]
44
+
45
+ def get_poi(self, poi_id: Optional[int] = None):
46
+ if poi_id is None:
47
+ return list(self.map.pois.values())
48
+ else:
49
+ return self.map.pois[poi_id]
50
+
51
+ def query_pois(self, **kwargs):
52
+ return self.map.query_pois(**kwargs)
53
+
54
+ def get_poi_cate(self):
55
+ return self.poi_cate
56
+
57
+ def get_map(self):
58
+ return self.map
59
+
60
+ def get_map_header(self):
61
+ return self.map.header
62
+
63
+ def get_projector(self):
64
+ return self.map.header["projection"]
28
65
 
29
66
  class Simulator:
30
67
  """
@@ -35,7 +72,7 @@ class Simulator:
35
72
  - It reads parameters from a configuration dictionary, initializes map data, and starts or connects to a simulation server as needed.
36
73
  """
37
74
 
38
- def __init__(self, config: dict, secure: bool = False) -> None:
75
+ def __init__(self, config: dict, create_map: bool = False) -> None:
39
76
  self.config = config
40
77
  """
41
78
  - 模拟器配置
@@ -66,7 +103,8 @@ class Simulator:
66
103
  self._sim_env = sim_env = ControlSimEnv(
67
104
  task_name=config["simulator"].get("task", "citysim"),
68
105
  map_file=_map_pb_path,
69
- start_step=config["simulator"].get("start_step", 0),
106
+ max_day=config["simulator"].get("max_day", 1000),
107
+ start_step=config["simulator"].get("start_step", 28800),
70
108
  total_step=2147000000,
71
109
  log_dir=config["simulator"].get("log_dir", "./log"),
72
110
  min_step_time=config["simulator"].get("min_step_time", 1000),
@@ -88,16 +126,14 @@ class Simulator:
88
126
  logger.warning(
89
127
  "No simulator config found, no simulator client will be used"
90
128
  )
91
- self.map = SimMap(
92
- mongo_uri=_mongo_uri,
93
- mongo_db=_mongo_db,
94
- mongo_coll=_mongo_coll,
95
- cache_dir=_map_cache_dir,
96
- )
129
+ self._map = None
97
130
  """
98
131
  - 模拟器地图对象
99
132
  - Simulator map object
100
133
  """
134
+ if create_map:
135
+ self._map = CityMap.remote(_mongo_uri, _mongo_db, _mongo_coll, _map_cache_dir)
136
+ self._create_poi_id_2_aoi_id()
101
137
 
102
138
  self.time: int = 0
103
139
  """
@@ -109,25 +145,41 @@ class Simulator:
109
145
  self.map_y_gap = None
110
146
  self._bbox: tuple[float, float, float, float] = (-1, -1, -1, -1)
111
147
  self._lock = asyncio.Lock()
112
- # poi id dict
113
- self.poi_id_2_aoi_id: dict[int, int] = {
114
- poi["id"]: poi["aoi_id"] for _, poi in self.map.pois.items()
115
- }
116
148
  self._environment_prompt:dict[str, str] = {}
117
149
  self._log_list = []
118
150
 
151
+ def set_map(self, map: CityMap):
152
+ self._map = map
153
+ self._create_poi_id_2_aoi_id()
154
+
155
+ def _create_poi_id_2_aoi_id(self):
156
+ pois = ray.get(self._map.get_poi.remote())
157
+ self.poi_id_2_aoi_id: dict[int, int] = {
158
+ poi["id"]: poi["aoi_id"] for poi in pois
159
+ }
160
+
161
+ @property
162
+ def map(self):
163
+ return self._map
164
+
119
165
  def get_log_list(self):
120
166
  return self._log_list
121
167
 
122
168
  def clear_log_list(self):
123
169
  self._log_list = []
124
170
 
171
+ def get_poi_cate(self):
172
+ return self.poi_cate
173
+
125
174
  @property
126
175
  def environment(self) -> dict[str, str]:
127
176
  """
128
177
  Get the current state of environment variables.
129
178
  """
130
179
  return self._environment_prompt
180
+
181
+ def get_server_addr(self):
182
+ return self.server_addr
131
183
 
132
184
  def set_environment(self, environment: dict[str, str]):
133
185
  """
@@ -224,11 +276,11 @@ class Simulator:
224
276
  categories: list[str] = []
225
277
  if center is None:
226
278
  center = (0, 0)
227
- _pois: list[dict] = self.map.query_pois( # type:ignore
279
+ _pois: list[dict] = ray.get(self.map.query_pois.remote( # type:ignore
228
280
  center=center,
229
281
  radius=radius,
230
282
  return_distance=False,
231
- )
283
+ ))
232
284
  for poi in _pois:
233
285
  catg = poi["category"]
234
286
  categories.append(catg.split("|")[-1])
@@ -367,13 +419,12 @@ class Simulator:
367
419
  self._log_list.append(log)
368
420
  return person
369
421
 
370
- async def add_person(self, person: Any) -> dict:
422
+ async def add_person(self, dict_person: dict) -> dict:
371
423
  """
372
424
  Add a new person to the simulation.
373
425
 
374
426
  - **Args**:
375
- - `person` (`Any`): The person object to add. If it's an instance of `person_pb2.Person`,
376
- it will be wrapped in an `AddPersonRequest`. Otherwise, `person` is expected to already be a valid request object.
427
+ - `dict_person` (`dict`): The person object to add.
377
428
 
378
429
  - **Returns**:
379
430
  - `Dict`: Response from adding the person.
@@ -384,6 +435,7 @@ class Simulator:
384
435
  "start_time": start_time,
385
436
  "consumption": 0
386
437
  }
438
+ person = dict2pb(dict_person, person_pb2.Person())
387
439
  if isinstance(person, person_pb2.Person):
388
440
  req = person_service.AddPersonRequest(person=person)
389
441
  else:
@@ -411,7 +463,7 @@ class Simulator:
411
463
  A list of AOI or POI IDs or tuples of (AOI ID, POI ID) that the person will visit.
412
464
  - `departure_times` (`Optional[List[float]]`): Departure times for each trip in the schedule.
413
465
  If not provided, current time will be used for all trips.
414
- - `modes` (`Optional[List[TripMode]]`): Travel modes for each trip.
466
+ - `modes` (`Optional[List[int]]`): Travel modes for each trip.
415
467
  Defaults to `TRIP_MODE_DRIVE_ONLY` if not specified.
416
468
  """
417
469
  start_time = time.time()
@@ -560,11 +612,11 @@ class Simulator:
560
612
  transformed_poi_type += self.poi_cate[t]
561
613
  poi_type_set = set(transformed_poi_type)
562
614
  # 获取半径内的poi
563
- _pois: list[dict] = self.map.query_pois( # type:ignore
615
+ _pois: list[dict] = ray.get(self.map.query_pois.remote( # type:ignore
564
616
  center=center,
565
617
  radius=radius,
566
618
  return_distance=False,
567
- )
619
+ ))
568
620
  # 过滤掉不满足类别前缀的poi
569
621
  pois = []
570
622
  for poi in _pois:
pycityagent/llm/llm.py CHANGED
@@ -232,111 +232,112 @@ class LLM:
232
232
  """
233
233
  start_time = time.time()
234
234
  log = {"request_time": start_time}
235
- if (
236
- self.config.text["request_type"] == "openai"
237
- or self.config.text["request_type"] == "deepseek"
238
- or self.config.text["request_type"] == "qwen"
239
- ):
240
- for attempt in range(retries):
241
- try:
242
- client = self._get_next_client()
243
- response = await client.chat.completions.create(
244
- model=self.config.text["model"],
245
- messages=dialog,
246
- temperature=temperature,
247
- max_tokens=max_tokens,
248
- top_p=top_p,
249
- frequency_penalty=frequency_penalty, # type: ignore
250
- presence_penalty=presence_penalty, # type: ignore
251
- stream=False,
252
- timeout=timeout,
253
- tools=tools,
254
- tool_choice=tool_choice,
255
- ) # type: ignore
256
- self._client_usage[self._current_client_index]["prompt_tokens"] += response.usage.prompt_tokens # type: ignore
257
- self._client_usage[self._current_client_index]["completion_tokens"] += response.usage.completion_tokens # type: ignore
258
- self._client_usage[self._current_client_index]["request_number"] += 1
259
- end_time = time.time()
260
- log["consumption"] = end_time - start_time
261
- log["input_tokens"] = response.usage.prompt_tokens
262
- log["output_tokens"] = response.usage.completion_tokens
263
- self._log_list.append(log)
264
- if tools and response.choices[0].message.tool_calls:
265
- return json.loads(
266
- response.choices[0]
267
- .message.tool_calls[0]
268
- .function.arguments
269
- )
270
- else:
271
- return response.choices[0].message.content
272
- except APIConnectionError as e:
273
- print("API connection error:", e)
274
- if attempt < retries - 1:
275
- await asyncio.sleep(2**attempt)
276
- else:
277
- raise e
278
- except OpenAIError as e:
279
- if hasattr(e, "http_status"):
280
- print(f"HTTP status code: {e.http_status}") # type: ignore
281
- else:
282
- print("An error occurred:", e)
283
- if attempt < retries - 1:
284
- await asyncio.sleep(2**attempt)
285
- else:
286
- raise e
287
- elif self.config.text["request_type"] == "zhipuai":
288
- for attempt in range(retries):
289
- try:
290
- client = self._get_next_client()
291
- response = client.chat.asyncCompletions.create( # type: ignore
292
- model=self.config.text["model"],
293
- messages=dialog,
294
- temperature=temperature,
295
- top_p=top_p,
296
- timeout=timeout,
297
- tools=tools,
298
- tool_choice=tool_choice,
299
- )
300
- task_id = response.id
301
- task_status = ""
302
- get_cnt = 0
303
- cnt_threshold = int(timeout / 0.5)
304
- while (
305
- task_status != "SUCCESS"
306
- and task_status != "FAILED"
307
- and get_cnt <= cnt_threshold
308
- ):
309
- result_response = client.chat.asyncCompletions.retrieve_completion_result(id=task_id) # type: ignore
310
- task_status = result_response.task_status
311
- await asyncio.sleep(0.5)
312
- get_cnt += 1
313
- if task_status != "SUCCESS":
314
- raise Exception(f"Task failed with status: {task_status}")
315
-
316
- self._client_usage[self._current_client_index]["prompt_tokens"] += result_response.usage.prompt_tokens # type: ignore
317
- self._client_usage[self._current_client_index]["completion_tokens"] += result_response.usage.completion_tokens # type: ignore
318
- self._client_usage[self._current_client_index]["request_number"] += 1
319
- end_time = time.time()
320
- log["used_time"] = end_time - start_time
321
- log["token_consumption"] = result_response.usage.prompt_tokens + result_response.usage.completion_tokens
322
- self._log_list.append(log)
323
- if tools and result_response.choices[0].message.tool_calls: # type: ignore
324
- return json.loads(
325
- result_response.choices[0] # type: ignore
326
- .message.tool_calls[0]
327
- .function.arguments
235
+ async with self.semaphore:
236
+ if (
237
+ self.config.text["request_type"] == "openai"
238
+ or self.config.text["request_type"] == "deepseek"
239
+ or self.config.text["request_type"] == "qwen"
240
+ ):
241
+ for attempt in range(retries):
242
+ try:
243
+ client = self._get_next_client()
244
+ response = await client.chat.completions.create(
245
+ model=self.config.text["model"],
246
+ messages=dialog,
247
+ temperature=temperature,
248
+ max_tokens=max_tokens,
249
+ top_p=top_p,
250
+ frequency_penalty=frequency_penalty, # type: ignore
251
+ presence_penalty=presence_penalty, # type: ignore
252
+ stream=False,
253
+ timeout=timeout,
254
+ tools=tools,
255
+ tool_choice=tool_choice,
256
+ ) # type: ignore
257
+ self._client_usage[self._current_client_index]["prompt_tokens"] += response.usage.prompt_tokens # type: ignore
258
+ self._client_usage[self._current_client_index]["completion_tokens"] += response.usage.completion_tokens # type: ignore
259
+ self._client_usage[self._current_client_index]["request_number"] += 1
260
+ end_time = time.time()
261
+ log["consumption"] = end_time - start_time
262
+ log["input_tokens"] = response.usage.prompt_tokens
263
+ log["output_tokens"] = response.usage.completion_tokens
264
+ self._log_list.append(log)
265
+ if tools and response.choices[0].message.tool_calls:
266
+ return json.loads(
267
+ response.choices[0]
268
+ .message.tool_calls[0]
269
+ .function.arguments
270
+ )
271
+ else:
272
+ return response.choices[0].message.content
273
+ except APIConnectionError as e:
274
+ print("API connection error:", e)
275
+ if attempt < retries - 1:
276
+ await asyncio.sleep(2**attempt)
277
+ else:
278
+ raise e
279
+ except OpenAIError as e:
280
+ if hasattr(e, "http_status"):
281
+ print(f"HTTP status code: {e.http_status}") # type: ignore
282
+ else:
283
+ print("An error occurred:", e)
284
+ if attempt < retries - 1:
285
+ await asyncio.sleep(2**attempt)
286
+ else:
287
+ raise e
288
+ elif self.config.text["request_type"] == "zhipuai":
289
+ for attempt in range(retries):
290
+ try:
291
+ client = self._get_next_client()
292
+ response = client.chat.asyncCompletions.create( # type: ignore
293
+ model=self.config.text["model"],
294
+ messages=dialog,
295
+ temperature=temperature,
296
+ top_p=top_p,
297
+ timeout=timeout,
298
+ tools=tools,
299
+ tool_choice=tool_choice,
328
300
  )
329
- else:
330
- return result_response.choices[0].message.content # type: ignore
331
- except APIConnectionError as e:
332
- print("API connection error:", e)
333
- if attempt < retries - 1:
334
- await asyncio.sleep(2**attempt)
335
- else:
336
- raise e
337
- else:
338
- print("ERROR: Wrong Config")
339
- return "wrong config"
301
+ task_id = response.id
302
+ task_status = ""
303
+ get_cnt = 0
304
+ cnt_threshold = int(timeout / 0.5)
305
+ while (
306
+ task_status != "SUCCESS"
307
+ and task_status != "FAILED"
308
+ and get_cnt <= cnt_threshold
309
+ ):
310
+ result_response = client.chat.asyncCompletions.retrieve_completion_result(id=task_id) # type: ignore
311
+ task_status = result_response.task_status
312
+ await asyncio.sleep(0.5)
313
+ get_cnt += 1
314
+ if task_status != "SUCCESS":
315
+ raise Exception(f"Task failed with status: {task_status}")
316
+
317
+ self._client_usage[self._current_client_index]["prompt_tokens"] += result_response.usage.prompt_tokens # type: ignore
318
+ self._client_usage[self._current_client_index]["completion_tokens"] += result_response.usage.completion_tokens # type: ignore
319
+ self._client_usage[self._current_client_index]["request_number"] += 1
320
+ end_time = time.time()
321
+ log["used_time"] = end_time - start_time
322
+ log["token_consumption"] = result_response.usage.prompt_tokens + result_response.usage.completion_tokens
323
+ self._log_list.append(log)
324
+ if tools and result_response.choices[0].message.tool_calls: # type: ignore
325
+ return json.loads(
326
+ result_response.choices[0] # type: ignore
327
+ .message.tool_calls[0]
328
+ .function.arguments
329
+ )
330
+ else:
331
+ return result_response.choices[0].message.content # type: ignore
332
+ except APIConnectionError as e:
333
+ print("API connection error:", e)
334
+ if attempt < retries - 1:
335
+ await asyncio.sleep(2**attempt)
336
+ else:
337
+ raise e
338
+ else:
339
+ print("ERROR: Wrong Config")
340
+ return "wrong config"
340
341
 
341
342
  async def img_understand(
342
343
  self, img_path: Union[str, list[str]], prompt: Optional[str] = None
@@ -341,7 +341,7 @@ class StreamMemory:
341
341
  """获取指定ID的记忆"""
342
342
  memories = [memory for memory in self._memories if memory.id in memory_ids]
343
343
  sorted_results = sorted(memories, key=lambda x: (x.day, x.t), reverse=True)
344
- return self.format_memory(sorted_results)
344
+ return await self.format_memory(sorted_results)
345
345
 
346
346
  async def search(
347
347
  self,
@@ -36,6 +36,7 @@ class AgentGroup:
36
36
  number_of_agents: Union[int, list[int]],
37
37
  memory_config_function_group: dict[type[Agent], Callable],
38
38
  config: dict,
39
+ map_ref: ray.ObjectRef,
39
40
  exp_name: str,
40
41
  exp_id: Union[str, UUID],
41
42
  enable_avro: bool,
@@ -47,8 +48,8 @@ class AgentGroup:
47
48
  embedding_model: Embeddings,
48
49
  logging_level: int,
49
50
  agent_config_file: Optional[dict[type[Agent], str]] = None,
50
- environment: Optional[dict[str, str]] = None,
51
51
  llm_semaphore: int = 200,
52
+ environment: Optional[dict] = None,
52
53
  ):
53
54
  """
54
55
  Represents a group of agents that can be deployed in a Ray distributed environment.
@@ -64,6 +65,7 @@ class AgentGroup:
64
65
  - `number_of_agents` (Union[int, List[int]]): Number of instances to create for each agent class.
65
66
  - `memory_config_function_group` (Dict[Type[Agent], Callable]): Functions to configure memory for each agent type.
66
67
  - `config` (dict): Configuration settings for the agent group.
68
+ - `map_ref` (ray.ObjectRef): Reference to the map object.
67
69
  - `exp_name` (str): Name of the experiment.
68
70
  - `exp_id` (str | UUID): Identifier for the experiment.
69
71
  - `enable_avro` (bool): Flag to enable AVRO file support.
@@ -143,10 +145,10 @@ class AgentGroup:
143
145
  self.llm.set_semaphore(llm_semaphore)
144
146
 
145
147
  # prepare Simulator
146
- logger.info(f"-----Creating Simulator in AgentGroup {self._uuid} ...")
148
+ logger.info(f"-----Initializing Simulator in AgentGroup {self._uuid} ...")
147
149
  self.simulator = Simulator(config["simulator_request"])
148
- self.projector = pyproj.Proj(self.simulator.map.header["projection"])
149
- self.simulator.set_environment(environment) # type:ignore
150
+ self.simulator.set_map(map_ref)
151
+ self.projector = pyproj.Proj(ray.get(self.simulator.map.get_projector.remote()))
150
152
  # prepare Economy client
151
153
  logger.info(f"-----Creating Economy client in AgentGroup {self._uuid} ...")
152
154
  self.economy_client = EconomyClient(
@@ -430,16 +432,6 @@ class AgentGroup:
430
432
  agent = self.id2agent[target_agent_uuid]
431
433
  await agent.status.update(target_key, content)
432
434
 
433
- async def update_environment(self, key: str, value: str):
434
- """
435
- Updates the environment with a given key-value pair.
436
-
437
- - **Args**:
438
- - `key` (str): The key to update in the environment.
439
- - `value` (str): The value to set for the specified key.
440
- """
441
- self.simulator.update_environment(key, value)
442
-
443
435
  async def message_dispatch(self):
444
436
  """
445
437
  Dispatches messages received via MQTT to the appropriate agents.
@@ -788,12 +780,13 @@ class AgentGroup:
788
780
  """
789
781
  try:
790
782
  tasks = [agent.run() for agent in self.agents]
791
- await asyncio.gather(*tasks)
783
+ agent_time_log =await asyncio.gather(*tasks)
792
784
  simulator_log = self.simulator.get_log_list() + self.economy_client.get_log_list()
793
785
  group_logs = {
794
786
  "llm_log": self.llm.get_log_list(),
795
787
  "mqtt_log": ray.get(self.messager.get_log_list.remote()),
796
- "simulator_log": simulator_log
788
+ "simulator_log": simulator_log,
789
+ "agent_time_log": agent_time_log
797
790
  }
798
791
  self.llm.clear_log_list()
799
792
  self.messager.clear_log_list.remote()
@@ -130,12 +130,16 @@ class AgentSimulation:
130
130
  _simulator_config = config["simulator_request"].get("simulator", {})
131
131
  if "server" in _simulator_config:
132
132
  raise ValueError(f"Passing Traffic Simulation address is not supported!")
133
- self._simulator = Simulator(config["simulator_request"])
133
+ simulator = Simulator(config["simulator_request"], create_map=True)
134
+ self._simulator = simulator
135
+ self._map_ref = self._simulator.map
136
+ server_addr = self._simulator.get_server_addr()
137
+ config["simulator_request"]["simulator"]["server"] = server_addr
134
138
  self._economy_client = EconomyClient(
135
139
  config["simulator_request"]["simulator"]["server"]
136
140
  )
137
141
  if enable_institution:
138
- self._economy_addr = economy_addr = self._simulator.server_addr
142
+ self._economy_addr = economy_addr = server_addr
139
143
  if economy_addr is None:
140
144
  raise ValueError(
141
145
  f"`simulator` not provided in `simulator_request`, thus unable to activate economy!"
@@ -382,6 +386,7 @@ class AgentSimulation:
382
386
  llm_log_lists = []
383
387
  mqtt_log_lists = []
384
388
  simulator_log_lists = []
389
+ agent_time_log_lists = []
385
390
  for step in config["workflow"]:
386
391
  logger.info(
387
392
  f"Running step: type: {step['type']} - description: {step.get('description', 'no description')}"
@@ -389,17 +394,19 @@ class AgentSimulation:
389
394
  if step["type"] not in ["run", "step", "interview", "survey", "intervene", "pause", "resume", "function"]:
390
395
  raise ValueError(f"Invalid step type: {step['type']}")
391
396
  if step["type"] == "run":
392
- llm_log_list, mqtt_log_list, simulator_log_list = await simulation.run(step.get("days", 1))
397
+ llm_log_list, mqtt_log_list, simulator_log_list, agent_time_log_list = await simulation.run(step.get("days", 1))
393
398
  llm_log_lists.extend(llm_log_list)
394
399
  mqtt_log_lists.extend(mqtt_log_list)
395
400
  simulator_log_lists.extend(simulator_log_list)
401
+ agent_time_log_lists.extend(agent_time_log_list)
396
402
  elif step["type"] == "step":
397
403
  times = step.get("times", 1)
398
404
  for _ in range(times):
399
- llm_log_list, mqtt_log_list, simulator_log_list = await simulation.step()
405
+ llm_log_list, mqtt_log_list, simulator_log_list, agent_time_log_list = await simulation.step()
400
406
  llm_log_lists.extend(llm_log_list)
401
407
  mqtt_log_lists.extend(mqtt_log_list)
402
408
  simulator_log_lists.extend(simulator_log_list)
409
+ agent_time_log_lists.extend(agent_time_log_list)
403
410
  elif step["type"] == "pause":
404
411
  await simulation.pause_simulator()
405
412
  elif step["type"] == "resume":
@@ -407,7 +414,7 @@ class AgentSimulation:
407
414
  else:
408
415
  await step["func"](simulation)
409
416
  logger.info("Simulation finished")
410
- return llm_log_lists, mqtt_log_lists, simulator_log_lists
417
+ return llm_log_lists, mqtt_log_lists, simulator_log_lists, agent_time_log_lists
411
418
 
412
419
  @property
413
420
  def enable_avro(
@@ -742,6 +749,7 @@ class AgentSimulation:
742
749
  number_of_agents,
743
750
  memory_config_function_group,
744
751
  self.config,
752
+ self._map_ref,
745
753
  self.exp_name,
746
754
  self.exp_id,
747
755
  self.enable_avro,
@@ -753,8 +761,8 @@ class AgentSimulation:
753
761
  embedding_model,
754
762
  self.logging_level,
755
763
  config_file,
756
- environment,
757
764
  llm_semaphore,
765
+ environment,
758
766
  )
759
767
  creation_tasks.append((group_name, group))
760
768
 
@@ -1018,7 +1026,7 @@ class AgentSimulation:
1018
1026
 
1019
1027
  # step
1020
1028
  simulator_day = await self._simulator.get_simulator_day()
1021
- simulator_time = int(await self._simulator.get_time())
1029
+ simulator_time = int(await self._simulator.get_simulator_second_from_start_of_day())
1022
1030
  logger.info(
1023
1031
  f"Start simulation day {simulator_day} at {simulator_time}, step {self._total_steps}"
1024
1032
  )
@@ -1029,14 +1037,15 @@ class AgentSimulation:
1029
1037
  llm_log_list = []
1030
1038
  mqtt_log_list = []
1031
1039
  simulator_log_list = []
1040
+ agent_time_log_list = []
1032
1041
  for log_messages_group in log_messages_groups:
1033
1042
  llm_log_list.extend(log_messages_group['llm_log'])
1034
1043
  mqtt_log_list.extend(log_messages_group['mqtt_log'])
1035
1044
  simulator_log_list.extend(log_messages_group['simulator_log'])
1036
-
1045
+ agent_time_log_list.extend(log_messages_group['agent_time_log'])
1037
1046
  # save
1038
1047
  simulator_day = await self._simulator.get_simulator_day()
1039
- simulator_time = int(await self._simulator.get_time())
1048
+ simulator_time = int(await self._simulator.get_simulator_second_from_start_of_day())
1040
1049
  save_tasks = []
1041
1050
  for group in self._groups.values():
1042
1051
  save_tasks.append(group.save.remote(simulator_day, simulator_time))
@@ -1050,7 +1059,7 @@ class AgentSimulation:
1050
1059
  ]
1051
1060
  await self.extract_metric(to_excute_metric)
1052
1061
 
1053
- return llm_log_list, mqtt_log_list, simulator_log_list
1062
+ return llm_log_list, mqtt_log_list, simulator_log_list, agent_time_log_list
1054
1063
  except Exception as e:
1055
1064
  import traceback
1056
1065
 
@@ -1081,6 +1090,7 @@ class AgentSimulation:
1081
1090
  llm_log_lists = []
1082
1091
  mqtt_log_lists = []
1083
1092
  simulator_log_lists = []
1093
+ agent_time_log_lists = []
1084
1094
  try:
1085
1095
  self._exp_info["num_day"] += day
1086
1096
  await self._update_exp_status(1) # 更新状态为运行中
@@ -1098,10 +1108,11 @@ class AgentSimulation:
1098
1108
  current_time = await self._simulator.get_time()
1099
1109
  if current_time >= end_time: # type:ignore
1100
1110
  break
1101
- llm_log_list, mqtt_log_list, simulator_log_list = await self.step()
1111
+ llm_log_list, mqtt_log_list, simulator_log_list, agent_time_log_list = await self.step()
1102
1112
  llm_log_lists.extend(llm_log_list)
1103
1113
  mqtt_log_lists.extend(mqtt_log_list)
1104
1114
  simulator_log_lists.extend(simulator_log_list)
1115
+ agent_time_log_lists.extend(agent_time_log_list)
1105
1116
  finally:
1106
1117
  # 设置停止事件
1107
1118
  stop_event.set()
@@ -1110,7 +1121,7 @@ class AgentSimulation:
1110
1121
 
1111
1122
  # 运行成功后更新状态
1112
1123
  await self._update_exp_status(2)
1113
- return llm_log_lists, mqtt_log_lists, simulator_log_lists
1124
+ return llm_log_lists, mqtt_log_lists, simulator_log_lists, agent_time_log_lists
1114
1125
  except Exception as e:
1115
1126
  error_msg = f"模拟器运行错误: {str(e)}"
1116
1127
  logger.error(error_msg)
pycityagent/tools/tool.py CHANGED
@@ -5,6 +5,7 @@ from collections.abc import Callable, Sequence
5
5
  from typing import Any, Optional, Union
6
6
 
7
7
  from mlflow.entities import Metric
8
+ import ray
8
9
 
9
10
  from ..agent import Agent
10
11
  from ..environment import AoiService, PersonService