pycityagent 2.0.0a19__py3-none-any.whl → 2.0.0a21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pycityagent/agent.py CHANGED
@@ -1,17 +1,19 @@
1
1
  """智能体模板类及其定义"""
2
2
 
3
3
  import asyncio
4
+ import json
4
5
  import logging
5
6
  import random
6
7
  import uuid
7
8
  from abc import ABC, abstractmethod
8
9
  from copy import deepcopy
9
- from datetime import datetime
10
+ from datetime import datetime, timezone
10
11
  from enum import Enum
11
- from typing import Any, Dict, List, Optional
12
+ from typing import Any, Optional
12
13
  from uuid import UUID
13
14
 
14
15
  import fastavro
16
+ import ray
15
17
  from mosstool.util.format_converter import dict2pb
16
18
  from pycityproto.city.person.v2 import person_pb2 as person_pb2
17
19
 
@@ -55,7 +57,8 @@ class Agent(ABC):
55
57
  simulator: Optional[Simulator] = None,
56
58
  mlflow_client: Optional[MlflowClient] = None,
57
59
  memory: Optional[Memory] = None,
58
- avro_file: Optional[Dict[str, str]] = None,
60
+ avro_file: Optional[dict[str, str]] = None,
61
+ copy_writer: Optional[ray.ObjectRef] = None,
59
62
  ) -> None:
60
63
  """
61
64
  Initialize the Agent.
@@ -69,7 +72,8 @@ class Agent(ABC):
69
72
  simulator (Simulator, optional): The simulator object. Defaults to None.
70
73
  mlflow_client (MlflowClient, optional): The Mlflow object. Defaults to None.
71
74
  memory (Memory, optional): The memory of the agent. Defaults to None.
72
- avro_file (Dict[str, str], optional): The avro file of the agent. Defaults to None.
75
+ avro_file (dict[str, str], optional): The avro file of the agent. Defaults to None.
76
+ copy_writer (ray.ObjectRef): The copy_writer of the agent. Defaults to None.
73
77
  """
74
78
  self._name = name
75
79
  self._type = type
@@ -85,9 +89,11 @@ class Agent(ABC):
85
89
  self._has_bound_to_simulator = False
86
90
  self._has_bound_to_economy = False
87
91
  self._blocked = False
88
- self._interview_history: List[Dict] = [] # 存储采访历史
92
+ self._interview_history: list[dict] = [] # 存储采访历史
89
93
  self._person_template = PersonService.default_dict_person()
90
94
  self._avro_file = avro_file
95
+ self._pgsql_writer = copy_writer
96
+ self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
91
97
 
92
98
  def __getstate__(self):
93
99
  state = self.__dict__.copy()
@@ -137,12 +143,18 @@ class Agent(ABC):
137
143
  """
138
144
  self._exp_id = exp_id
139
145
 
140
- def set_avro_file(self, avro_file: Dict[str, str]):
146
+ def set_avro_file(self, avro_file: dict[str, str]):
141
147
  """
142
148
  Set the avro file of the agent.
143
149
  """
144
150
  self._avro_file = avro_file
145
151
 
152
+ def set_pgsql_writer(self, pgsql_writer: ray.ObjectRef):
153
+ """
154
+ Set the PostgreSQL copy writer of the agent.
155
+ """
156
+ self._pgsql_writer = pgsql_writer
157
+
146
158
  @property
147
159
  def uuid(self):
148
160
  """The Agent's UUID"""
@@ -198,6 +210,15 @@ class Agent(ABC):
198
210
  )
199
211
  return self._simulator
200
212
 
213
+ @property
214
+ def copy_writer(self):
215
+ """Pg Copy Writer"""
216
+ if self._pgsql_writer is None:
217
+ raise RuntimeError(
218
+ f"Copy Writer access before assignment, please `set_pgsql_writer` first!"
219
+ )
220
+ return self._pgsql_writer
221
+
201
222
  async def generate_user_survey_response(self, survey: dict) -> str:
202
223
  """生成回答 —— 可重写
203
224
  基于智能体的记忆和当前状态,生成对问卷调查的回答。
@@ -237,8 +258,8 @@ class Agent(ABC):
237
258
 
238
259
  async def _process_survey(self, survey: dict):
239
260
  survey_response = await self.generate_user_survey_response(survey)
240
- if self._avro_file is None:
241
- return
261
+ _date_time = datetime.now(timezone.utc)
262
+ # Avro
242
263
  response_to_avro = [
243
264
  {
244
265
  "id": self._uuid,
@@ -246,11 +267,41 @@ class Agent(ABC):
246
267
  "t": await self.simulator.get_simulator_second_from_start_of_day(),
247
268
  "survey_id": survey["id"],
248
269
  "result": survey_response,
249
- "created_at": int(datetime.now().timestamp() * 1000),
270
+ "created_at": int(_date_time.timestamp() * 1000),
250
271
  }
251
272
  ]
252
- with open(self._avro_file["survey"], "a+b") as f:
253
- fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
273
+ if self._avro_file is not None:
274
+ with open(self._avro_file["survey"], "a+b") as f:
275
+ fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
276
+ # Pg
277
+ if self._pgsql_writer is not None:
278
+ if self._last_asyncio_pg_task is not None:
279
+ await self._last_asyncio_pg_task
280
+ _keys = [
281
+ "id",
282
+ "day",
283
+ "t",
284
+ "survey_id",
285
+ "result",
286
+ ]
287
+ _data_tuples: list[tuple] = []
288
+ # str to json
289
+ for _dict in response_to_avro:
290
+ res = _dict["result"]
291
+ _dict["result"] = json.dumps(
292
+ {
293
+ "result": res,
294
+ }
295
+ )
296
+ _data_list = [_dict[k] for k in _keys]
297
+ # created_at
298
+ _data_list.append(_date_time)
299
+ _data_tuples.append(tuple(_data_list))
300
+ self._last_asyncio_pg_task = (
301
+ self._pgsql_writer.async_write_survey.remote( # type:ignore
302
+ _data_tuples
303
+ )
304
+ )
254
305
 
255
306
  async def generate_user_chat_response(self, question: str) -> str:
256
307
  """生成回答 —— 可重写
@@ -290,34 +341,52 @@ class Agent(ABC):
290
341
  return response # type:ignore
291
342
 
292
343
  async def _process_interview(self, payload: dict):
293
- auros = [
294
- {
295
- "id": self._uuid,
296
- "day": await self.simulator.get_simulator_day(),
297
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
298
- "type": 2,
299
- "speaker": "user",
300
- "content": payload["content"],
301
- "created_at": int(datetime.now().timestamp() * 1000),
302
- }
303
- ]
344
+ pg_list: list[tuple[dict, datetime]] = []
345
+ auros: list[dict] = []
346
+ _date_time = datetime.now(timezone.utc)
347
+ _interview_dict = {
348
+ "id": self._uuid,
349
+ "day": await self.simulator.get_simulator_day(),
350
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
351
+ "type": 2,
352
+ "speaker": "user",
353
+ "content": payload["content"],
354
+ "created_at": int(_date_time.timestamp() * 1000),
355
+ }
356
+ auros.append(_interview_dict)
357
+ pg_list.append((_interview_dict, _date_time))
304
358
  question = payload["content"]
305
359
  response = await self.generate_user_chat_response(question)
306
- auros.append(
307
- {
308
- "id": self._uuid,
309
- "day": await self.simulator.get_simulator_day(),
310
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
311
- "type": 2,
312
- "speaker": "",
313
- "content": response,
314
- "created_at": int(datetime.now().timestamp() * 1000),
315
- }
316
- )
317
- if self._avro_file is None:
318
- return
319
- with open(self._avro_file["dialog"], "a+b") as f:
320
- fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
360
+ _date_time = datetime.now(timezone.utc)
361
+ _interview_dict = {
362
+ "id": self._uuid,
363
+ "day": await self.simulator.get_simulator_day(),
364
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
365
+ "type": 2,
366
+ "speaker": "",
367
+ "content": response,
368
+ "created_at": int(_date_time.timestamp() * 1000),
369
+ }
370
+ auros.append(_interview_dict)
371
+ pg_list.append((_interview_dict, _date_time))
372
+ # Avro
373
+ if self._avro_file is not None:
374
+ with open(self._avro_file["dialog"], "a+b") as f:
375
+ fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
376
+ # Pg
377
+ if self._pgsql_writer is not None:
378
+ if self._last_asyncio_pg_task is not None:
379
+ await self._last_asyncio_pg_task
380
+ _keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
381
+ _data = [
382
+ tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
383
+ for _dict, _date_time in pg_list
384
+ ]
385
+ self._last_asyncio_pg_task = (
386
+ self._pgsql_writer.async_write_dialog.remote( # type:ignore
387
+ _data
388
+ )
389
+ )
321
390
 
322
391
  async def process_agent_chat_response(self, payload: dict) -> str:
323
392
  resp = f"Agent {self._uuid} received agent chat response: {payload}"
@@ -325,22 +394,39 @@ class Agent(ABC):
325
394
  return resp
326
395
 
327
396
  async def _process_agent_chat(self, payload: dict):
328
- auros = [
329
- {
330
- "id": self._uuid,
331
- "day": payload["day"],
332
- "t": payload["t"],
333
- "type": 1,
334
- "speaker": payload["from"],
335
- "content": payload["content"],
336
- "created_at": int(datetime.now().timestamp() * 1000),
337
- }
338
- ]
397
+ pg_list: list[tuple[dict, datetime]] = []
398
+ auros: list[dict] = []
399
+ _date_time = datetime.now(timezone.utc)
400
+ _chat_dict = {
401
+ "id": self._uuid,
402
+ "day": payload["day"],
403
+ "t": payload["t"],
404
+ "type": 1,
405
+ "speaker": payload["from"],
406
+ "content": payload["content"],
407
+ "created_at": int(_date_time.timestamp() * 1000),
408
+ }
409
+ auros.append(_chat_dict)
410
+ pg_list.append((_chat_dict, _date_time))
339
411
  asyncio.create_task(self.process_agent_chat_response(payload))
340
- if self._avro_file is None:
341
- return
342
- with open(self._avro_file["dialog"], "a+b") as f:
343
- fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
412
+ # Avro
413
+ if self._avro_file is not None:
414
+ with open(self._avro_file["dialog"], "a+b") as f:
415
+ fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
416
+ # Pg
417
+ if self._pgsql_writer is not None:
418
+ if self._last_asyncio_pg_task is not None:
419
+ await self._last_asyncio_pg_task
420
+ _keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
421
+ _data = [
422
+ tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
423
+ for _dict, _date_time in pg_list
424
+ ]
425
+ self._last_asyncio_pg_task = (
426
+ self._pgsql_writer.async_write_dialog.remote( # type:ignore
427
+ _data
428
+ )
429
+ )
344
430
 
345
431
  # Callback functions for MQTT message
346
432
  async def handle_agent_chat_message(self, payload: dict):
@@ -384,21 +470,38 @@ class Agent(ABC):
384
470
  "t": await self.simulator.get_simulator_second_from_start_of_day(),
385
471
  }
386
472
  await self._send_message(to_agent_uuid, payload, "agent-chat")
387
- auros = [
388
- {
389
- "id": self._uuid,
390
- "day": await self.simulator.get_simulator_day(),
391
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
392
- "type": 1,
393
- "speaker": self._uuid,
394
- "content": content,
395
- "created_at": int(datetime.now().timestamp() * 1000),
396
- }
397
- ]
398
- if self._avro_file is None:
399
- return
400
- with open(self._avro_file["dialog"], "a+b") as f:
401
- fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
473
+ pg_list: list[tuple[dict, datetime]] = []
474
+ auros: list[dict] = []
475
+ _date_time = datetime.now(timezone.utc)
476
+ _message_dict = {
477
+ "id": self._uuid,
478
+ "day": await self.simulator.get_simulator_day(),
479
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
480
+ "type": 1,
481
+ "speaker": self._uuid,
482
+ "content": content,
483
+ "created_at": int(datetime.now().timestamp() * 1000),
484
+ }
485
+ auros.append(_message_dict)
486
+ pg_list.append((_message_dict, _date_time))
487
+ # Avro
488
+ if self._avro_file is not None:
489
+ with open(self._avro_file["dialog"], "a+b") as f:
490
+ fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
491
+ # Pg
492
+ if self._pgsql_writer is not None:
493
+ if self._last_asyncio_pg_task is not None:
494
+ await self._last_asyncio_pg_task
495
+ _keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
496
+ _data = [
497
+ tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
498
+ for _dict, _date_time in pg_list
499
+ ]
500
+ self._last_asyncio_pg_task = (
501
+ self._pgsql_writer.async_write_dialog.remote( # type:ignore
502
+ _data
503
+ )
504
+ )
402
505
 
403
506
  # Agent logic
404
507
  @abstractmethod
@@ -557,7 +660,7 @@ class InstitutionAgent(Agent):
557
660
  avro_file=avro_file,
558
661
  )
559
662
  # 添加响应收集器
560
- self._gather_responses: Dict[str, asyncio.Future] = {}
663
+ self._gather_responses: dict[str, asyncio.Future] = {}
561
664
 
562
665
  async def bind_to_simulator(self):
563
666
  await self._bind_to_economy()
@@ -659,7 +762,7 @@ class InstitutionAgent(Agent):
659
762
  }
660
763
  )
661
764
 
662
- async def gather_messages(self, agent_uuids: list[str], target: str) -> List[dict]:
765
+ async def gather_messages(self, agent_uuids: list[str], target: str) -> list[dict]:
663
766
  """从多个智能体收集消息
664
767
 
665
768
  Args:
@@ -667,7 +770,7 @@ class InstitutionAgent(Agent):
667
770
  target: 要收集的信息类型
668
771
 
669
772
  Returns:
670
- List[dict]: 收集到的所有响应
773
+ list[dict]: 收集到的所有响应
671
774
  """
672
775
  # 为每个agent创建Future
673
776
  futures = {}
@@ -316,3 +316,40 @@ class EconomyClient:
316
316
  await self._aio_stub.GetOrgEntityIds(request)
317
317
  )
318
318
  return list(response.org_ids)
319
+
320
+ async def add_delta_value(
321
+ self,
322
+ id: int,
323
+ key: str,
324
+ value: Any,
325
+ ) -> Any:
326
+ """
327
+ Add key-value pair
328
+
329
+ Args:
330
+ - id (int): the id of `Org` or `Agent`.
331
+ - key (str): the attribute to update. Can only be `inventory`, `price`, `interest_rate` and `currency`
332
+
333
+
334
+ Returns:
335
+ - Any
336
+ """
337
+ pascal_key = _snake_to_pascal(key)
338
+ _request_type = getattr(org_service, f"Add{pascal_key}Request")
339
+ _request_func = getattr(self._aio_stub, f"Add{pascal_key}")
340
+ _available_keys = {
341
+ "inventory",
342
+ "price",
343
+ "interest_rate",
344
+ "currency",
345
+ }
346
+ if key not in _available_keys:
347
+ raise ValueError(f"Invalid key `{key}`, can only be {_available_keys}!")
348
+ return await _request_func(
349
+ _request_type(
350
+ **{
351
+ "org_id": id,
352
+ f"delta_{key}": value,
353
+ }
354
+ )
355
+ )
@@ -1,9 +1,7 @@
1
- from typing import List
2
-
3
1
  __all__ = ["wrap_feature_collection"]
4
2
 
5
3
 
6
- def wrap_feature_collection(features: List[dict], name: str):
4
+ def wrap_feature_collection(features: list[dict], name: str):
7
5
  """
8
6
  将 GeoJSON Feature 集合包装为 FeatureCollection
9
7
  Wrap GeoJSON Feature collection as FeatureCollection
@@ -1,5 +1,5 @@
1
1
  import math
2
- from typing import Dict, List, Literal, Optional, Tuple, Union
2
+ from typing import Literal, Union
3
3
 
4
4
  import numpy as np
5
5
 
@@ -14,12 +14,12 @@ def point_on_line_given_distance(start_node, end_node, distance):
14
14
  return the coordinates of the point reached after traveling s units along the line, starting from start_point.
15
15
 
16
16
  Args:
17
- start_point (tuple): Tuple of (x, y) representing the starting point on the line.
18
- end_point (tuple): Tuple of (x, y) representing the ending point on the line.
17
+ start_point (tuple): tuple of (x, y) representing the starting point on the line.
18
+ end_point (tuple): tuple of (x, y) representing the ending point on the line.
19
19
  distance (float): Distance to travel along the line, starting from start_point.
20
20
 
21
21
  Returns:
22
- tuple: Tuple of (x, y) representing the new point reached after traveling s units along the line.
22
+ tuple: tuple of (x, y) representing the new point reached after traveling s units along the line.
23
23
  """
24
24
 
25
25
  x1, y1 = start_node["x"], start_node["y"]
@@ -49,7 +49,7 @@ def point_on_line_given_distance(start_node, end_node, distance):
49
49
 
50
50
 
51
51
  def get_key_index_in_lane(
52
- nodes: List[Dict[str, float]],
52
+ nodes: list[dict[str, float]],
53
53
  distance: float,
54
54
  direction: Union[Literal["front"], Literal["back"]],
55
55
  ) -> int:
@@ -61,10 +61,10 @@ def get_key_index_in_lane(
61
61
  _index_offset, _index_factor = len(_nodes) - 1, -1
62
62
  else:
63
63
  raise ValueError(f"Invalid direction type {direction}!")
64
- _lane_points: List[Tuple[float, float, float]] = [
64
+ _lane_points: list[tuple[float, float, float]] = [
65
65
  (n["x"], n["y"], n.get("z", 0)) for n in _nodes
66
66
  ]
67
- _line_lengths: List[float] = [0.0 for _ in range(len(_nodes))]
67
+ _line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
68
68
  _s = 0.0
69
69
  for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
70
70
  _s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
@@ -82,20 +82,20 @@ def get_key_index_in_lane(
82
82
 
83
83
 
84
84
  def get_xy_in_lane(
85
- nodes: List[Dict[str, float]],
85
+ nodes: list[dict[str, float]],
86
86
  distance: float,
87
87
  direction: Union[Literal["front"], Literal["back"]],
88
- ) -> Tuple[float, float]:
88
+ ) -> tuple[float, float]:
89
89
  if direction == "front":
90
90
  _nodes = [n for n in nodes]
91
91
  elif direction == "back":
92
92
  _nodes = [n for n in nodes[::-1]]
93
93
  else:
94
94
  raise ValueError(f"Invalid direction type {direction}!")
95
- _lane_points: List[Tuple[float, float, float]] = [
95
+ _lane_points: list[tuple[float, float, float]] = [
96
96
  (n["x"], n["y"], n.get("z", 0)) for n in _nodes
97
97
  ]
98
- _line_lengths: List[float] = [0.0 for _ in range(len(_nodes))]
98
+ _line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
99
99
  _s = 0.0
100
100
  for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
101
101
  _s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
@@ -122,7 +122,7 @@ def get_xy_in_lane(
122
122
 
123
123
 
124
124
  def get_direction_by_s(
125
- nodes: List[Dict[str, float]],
125
+ nodes: list[dict[str, float]],
126
126
  distance: float,
127
127
  direction: Union[Literal["front"], Literal["back"]],
128
128
  ) -> float:
@@ -132,11 +132,11 @@ def get_direction_by_s(
132
132
  _nodes = [n for n in nodes[::-1]]
133
133
  else:
134
134
  raise ValueError(f"Invalid direction type {direction}!")
135
- _lane_points: List[Tuple[float, float, float]] = [
135
+ _lane_points: list[tuple[float, float, float]] = [
136
136
  (n["x"], n["y"], n.get("z", 0)) for n in _nodes
137
137
  ]
138
- _line_lengths: List[float] = [0.0 for _ in range(len(_nodes))]
139
- _line_directions: List[Tuple[float, float]] = []
138
+ _line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
139
+ _line_directions: list[tuple[float, float]] = []
140
140
  _s = 0.0
141
141
  for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
142
142
  _s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
@@ -1,7 +1,6 @@
1
1
  """简单的基于内存的embedding实现"""
2
2
 
3
3
  import numpy as np
4
- from typing import List, Dict, Optional
5
4
  import hashlib
6
5
  import json
7
6
 
@@ -22,34 +21,34 @@ class SimpleEmbedding:
22
21
  """
23
22
  self.vector_dim = vector_dim
24
23
  self.cache_size = cache_size
25
- self._cache: Dict[str, np.ndarray] = {}
26
- self._vocab: Dict[str, int] = {} # 词汇表
27
- self._idf: Dict[str, float] = {} # 逆文档频率
24
+ self._cache: dict[str, np.ndarray] = {}
25
+ self._vocab: dict[str, int] = {} # 词汇表
26
+ self._idf: dict[str, float] = {} # 逆文档频率
28
27
  self._doc_count = 0 # 文档总数
29
28
 
30
29
  def _text_to_hash(self, text: str) -> str:
31
30
  """将文本转换为hash值"""
32
31
  return hashlib.md5(text.encode()).hexdigest()
33
32
 
34
- def _tokenize(self, text: str) -> List[str]:
33
+ def _tokenize(self, text: str) -> list[str]:
35
34
  """简单的分词"""
36
35
  # 这里使用简单的空格分词,实际应用中可以使用更复杂的分词方法
37
36
  return text.lower().split()
38
37
 
39
- def _update_vocab(self, tokens: List[str]):
38
+ def _update_vocab(self, tokens: list[str]):
40
39
  """更新词汇表"""
41
40
  for token in set(tokens): # 使用set去重
42
41
  if token not in self._vocab:
43
42
  self._vocab[token] = len(self._vocab)
44
43
 
45
- def _update_idf(self, tokens: List[str]):
44
+ def _update_idf(self, tokens: list[str]):
46
45
  """更新IDF值"""
47
46
  self._doc_count += 1
48
47
  unique_tokens = set(tokens)
49
48
  for token in unique_tokens:
50
49
  self._idf[token] = self._idf.get(token, 0) + 1
51
50
 
52
- def _calculate_tf(self, tokens: List[str]) -> Dict[str, float]:
51
+ def _calculate_tf(self, tokens: list[str]) -> dict[str, float]:
53
52
  """计算词频(TF)"""
54
53
  tf = {}
55
54
  total_tokens = len(tokens)
@@ -60,7 +59,7 @@ class SimpleEmbedding:
60
59
  tf[token] /= total_tokens
61
60
  return tf
62
61
 
63
- def _calculate_tfidf(self, tokens: List[str]) -> np.ndarray:
62
+ def _calculate_tfidf(self, tokens: list[str]) -> np.ndarray:
64
63
  """计算TF-IDF向量"""
65
64
  vector = np.zeros(self.vector_dim)
66
65
  tf = self._calculate_tf(tokens)
pycityagent/llm/llm.py CHANGED
@@ -14,7 +14,7 @@ import requests
14
14
  from dashscope import ImageSynthesis
15
15
  from PIL import Image
16
16
  from io import BytesIO
17
- from typing import Any, Optional, Union, List, Dict
17
+ from typing import Any, Optional, Union
18
18
  from .llmconfig import *
19
19
  from .utils import *
20
20
 
@@ -117,8 +117,8 @@ Token Usage:
117
117
  presence_penalty: Optional[float] = None,
118
118
  timeout: int = 300,
119
119
  retries=3,
120
- tools: Optional[List[Dict[str, Any]]] = None,
121
- tool_choice: Optional[Dict[str, Any]] = None,
120
+ tools: Optional[list[dict[str, Any]]] = None,
121
+ tool_choice: Optional[dict[str, Any]] = None,
122
122
  ):
123
123
  """
124
124
  异步版文本请求
@@ -227,9 +227,9 @@ Token Usage:
227
227
  self.prompt_tokens_used += result_response.usage.prompt_tokens # type: ignore
228
228
  self.completion_tokens_used += result_response.usage.completion_tokens # type: ignore
229
229
  self.request_number += 1
230
- if tools and result_response.choices[0].message.tool_calls:
230
+ if tools and result_response.choices[0].message.tool_calls: # type: ignore
231
231
  return json.loads(
232
- result_response.choices[0]
232
+ result_response.choices[0] # type: ignore
233
233
  .message.tool_calls[0]
234
234
  .function.arguments
235
235
  )