pycityagent 2.0.0a18__py3-none-any.whl → 2.0.0a20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pycityagent/agent.py CHANGED
@@ -1,30 +1,28 @@
1
1
  """智能体模板类及其定义"""
2
2
 
3
- from abc import ABC, abstractmethod
4
3
  import asyncio
5
- from uuid import UUID
6
- from copy import deepcopy
7
- from datetime import datetime
8
- from enum import Enum
9
4
  import logging
10
5
  import random
11
6
  import uuid
12
- from typing import Dict, List, Optional,Any
7
+ from abc import ABC, abstractmethod
8
+ from copy import deepcopy
9
+ from datetime import datetime
10
+ from enum import Enum
11
+ from typing import Any, Optional
12
+ from uuid import UUID
13
13
 
14
14
  import fastavro
15
-
16
- from pycityagent.environment.sim.person_service import PersonService
17
15
  from mosstool.util.format_converter import dict2pb
18
16
  from pycityproto.city.person.v2 import person_pb2 as person_pb2
19
- from pycityagent.utils import process_survey_for_llm
20
-
21
- from pycityagent.message.messager import Messager
22
- from pycityagent.utils import SURVEY_SCHEMA, DIALOG_SCHEMA
23
17
 
24
18
  from .economy import EconomyClient
25
19
  from .environment import Simulator
20
+ from .environment.sim.person_service import PersonService
26
21
  from .llm import LLM
27
22
  from .memory import Memory
23
+ from .message.messager import Messager
24
+ from .metrics import MlflowClient
25
+ from .utils import DIALOG_SCHEMA, SURVEY_SCHEMA, process_survey_for_llm
28
26
 
29
27
  logger = logging.getLogger("pycityagent")
30
28
 
@@ -55,8 +53,9 @@ class Agent(ABC):
55
53
  economy_client: Optional[EconomyClient] = None,
56
54
  messager: Optional[Messager] = None,
57
55
  simulator: Optional[Simulator] = None,
56
+ mlflow_client: Optional[MlflowClient] = None,
58
57
  memory: Optional[Memory] = None,
59
- avro_file: Optional[Dict[str, str]] = None,
58
+ avro_file: Optional[dict[str, str]] = None,
60
59
  ) -> None:
61
60
  """
62
61
  Initialize the Agent.
@@ -68,8 +67,9 @@ class Agent(ABC):
68
67
  economy_client (EconomyClient): The `EconomySim` client. Defaults to None.
69
68
  messager (Messager, optional): The messager object. Defaults to None.
70
69
  simulator (Simulator, optional): The simulator object. Defaults to None.
70
+ mlflow_client (MlflowClient, optional): The Mlflow object. Defaults to None.
71
71
  memory (Memory, optional): The memory of the agent. Defaults to None.
72
- avro_file (Dict[str, str], optional): The avro file of the agent. Defaults to None.
72
+ avro_file (dict[str, str], optional): The avro file of the agent. Defaults to None.
73
73
  """
74
74
  self._name = name
75
75
  self._type = type
@@ -78,13 +78,14 @@ class Agent(ABC):
78
78
  self._economy_client = economy_client
79
79
  self._messager = messager
80
80
  self._simulator = simulator
81
+ self._mlflow_client = mlflow_client
81
82
  self._memory = memory
82
83
  self._exp_id = -1
83
84
  self._agent_id = -1
84
85
  self._has_bound_to_simulator = False
85
86
  self._has_bound_to_economy = False
86
87
  self._blocked = False
87
- self._interview_history: List[Dict] = [] # 存储采访历史
88
+ self._interview_history: list[dict] = [] # 存储采访历史
88
89
  self._person_template = PersonService.default_dict_person()
89
90
  self._avro_file = avro_file
90
91
 
@@ -112,6 +113,12 @@ class Agent(ABC):
112
113
  """
113
114
  self._simulator = simulator
114
115
 
116
+ def set_mlflow_client(self, mlflow_client: MlflowClient):
117
+ """
118
+ Set the mlflow_client of the agent.
119
+ """
120
+ self._mlflow_client = mlflow_client
121
+
115
122
  def set_economy_client(self, economy_client: EconomyClient):
116
123
  """
117
124
  Set the economy_client of the agent.
@@ -130,7 +137,7 @@ class Agent(ABC):
130
137
  """
131
138
  self._exp_id = exp_id
132
139
 
133
- def set_avro_file(self, avro_file: Dict[str, str]):
140
+ def set_avro_file(self, avro_file: dict[str, str]):
134
141
  """
135
142
  Set the avro file of the agent.
136
143
  """
@@ -164,6 +171,15 @@ class Agent(ABC):
164
171
  )
165
172
  return self._economy_client
166
173
 
174
+ @property
175
+ def mlflow_client(self):
176
+ """The Agent's MlflowClient"""
177
+ if self._mlflow_client is None:
178
+ raise RuntimeError(
179
+ f"MlflowClient access before assignment, please `set_mlflow_client` first!"
180
+ )
181
+ return self._mlflow_client
182
+
167
183
  @property
168
184
  def memory(self):
169
185
  """The Agent's Memory"""
@@ -218,19 +234,21 @@ class Agent(ABC):
218
234
  response = await self._llm_client.atext_request(dialog) # type:ignore
219
235
 
220
236
  return response # type:ignore
221
-
237
+
222
238
  async def _process_survey(self, survey: dict):
223
239
  survey_response = await self.generate_user_survey_response(survey)
224
240
  if self._avro_file is None:
225
241
  return
226
- response_to_avro = [{
227
- "id": self._uuid,
228
- "day": await self.simulator.get_simulator_day(),
229
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
230
- "survey_id": survey["id"],
231
- "result": survey_response,
232
- "created_at": int(datetime.now().timestamp() * 1000),
233
- }]
242
+ response_to_avro = [
243
+ {
244
+ "id": self._uuid,
245
+ "day": await self.simulator.get_simulator_day(),
246
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
247
+ "survey_id": survey["id"],
248
+ "result": survey_response,
249
+ "created_at": int(datetime.now().timestamp() * 1000),
250
+ }
251
+ ]
234
252
  with open(self._avro_file["survey"], "a+b") as f:
235
253
  fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
236
254
 
@@ -270,28 +288,32 @@ class Agent(ABC):
270
288
  response = await self._llm_client.atext_request(dialog) # type:ignore
271
289
 
272
290
  return response # type:ignore
273
-
291
+
274
292
  async def _process_interview(self, payload: dict):
275
- auros = [{
276
- "id": self._uuid,
277
- "day": await self.simulator.get_simulator_day(),
278
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
279
- "type": 2,
280
- "speaker": "user",
281
- "content": payload["content"],
282
- "created_at": int(datetime.now().timestamp() * 1000),
283
- }]
293
+ auros = [
294
+ {
295
+ "id": self._uuid,
296
+ "day": await self.simulator.get_simulator_day(),
297
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
298
+ "type": 2,
299
+ "speaker": "user",
300
+ "content": payload["content"],
301
+ "created_at": int(datetime.now().timestamp() * 1000),
302
+ }
303
+ ]
284
304
  question = payload["content"]
285
305
  response = await self.generate_user_chat_response(question)
286
- auros.append({
287
- "id": self._uuid,
288
- "day": await self.simulator.get_simulator_day(),
289
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
290
- "type": 2,
291
- "speaker": "",
292
- "content": response,
293
- "created_at": int(datetime.now().timestamp() * 1000),
294
- })
306
+ auros.append(
307
+ {
308
+ "id": self._uuid,
309
+ "day": await self.simulator.get_simulator_day(),
310
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
311
+ "type": 2,
312
+ "speaker": "",
313
+ "content": response,
314
+ "created_at": int(datetime.now().timestamp() * 1000),
315
+ }
316
+ )
295
317
  if self._avro_file is None:
296
318
  return
297
319
  with open(self._avro_file["dialog"], "a+b") as f:
@@ -303,15 +325,17 @@ class Agent(ABC):
303
325
  return resp
304
326
 
305
327
  async def _process_agent_chat(self, payload: dict):
306
- auros = [{
307
- "id": self._uuid,
308
- "day": payload["day"],
309
- "t": payload["t"],
310
- "type": 1,
311
- "speaker": payload["from"],
312
- "content": payload["content"],
313
- "created_at": int(datetime.now().timestamp() * 1000),
314
- }]
328
+ auros = [
329
+ {
330
+ "id": self._uuid,
331
+ "day": payload["day"],
332
+ "t": payload["t"],
333
+ "type": 1,
334
+ "speaker": payload["from"],
335
+ "content": payload["content"],
336
+ "created_at": int(datetime.now().timestamp() * 1000),
337
+ }
338
+ ]
315
339
  asyncio.create_task(self.process_agent_chat_response(payload))
316
340
  if self._avro_file is None:
317
341
  return
@@ -341,18 +365,14 @@ class Agent(ABC):
341
365
  raise NotImplementedError
342
366
 
343
367
  # MQTT send message
344
- async def _send_message(
345
- self, to_agent_uuid: str, payload: dict, sub_topic: str
346
- ):
368
+ async def _send_message(self, to_agent_uuid: str, payload: dict, sub_topic: str):
347
369
  """通过 Messager 发送消息"""
348
370
  if self._messager is None:
349
371
  raise RuntimeError("Messager is not set")
350
372
  topic = f"exps/{self._exp_id}/agents/{to_agent_uuid}/{sub_topic}"
351
373
  await self._messager.send_message(topic, payload)
352
374
 
353
- async def send_message_to_agent(
354
- self, to_agent_uuid: str, content: str
355
- ):
375
+ async def send_message_to_agent(self, to_agent_uuid: str, content: str):
356
376
  """通过 Messager 发送消息"""
357
377
  if self._messager is None:
358
378
  raise RuntimeError("Messager is not set")
@@ -364,15 +384,17 @@ class Agent(ABC):
364
384
  "t": await self.simulator.get_simulator_second_from_start_of_day(),
365
385
  }
366
386
  await self._send_message(to_agent_uuid, payload, "agent-chat")
367
- auros = [{
368
- "id": self._uuid,
369
- "day": await self.simulator.get_simulator_day(),
370
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
371
- "type": 1,
372
- "speaker": self._uuid,
373
- "content": content,
374
- "created_at": int(datetime.now().timestamp() * 1000),
375
- }]
387
+ auros = [
388
+ {
389
+ "id": self._uuid,
390
+ "day": await self.simulator.get_simulator_day(),
391
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
392
+ "type": 1,
393
+ "speaker": self._uuid,
394
+ "content": content,
395
+ "created_at": int(datetime.now().timestamp() * 1000),
396
+ }
397
+ ]
376
398
  if self._avro_file is None:
377
399
  return
378
400
  with open(self._avro_file["dialog"], "a+b") as f:
@@ -403,20 +425,22 @@ class CitizenAgent(Agent):
403
425
  name: str,
404
426
  llm_client: Optional[LLM] = None,
405
427
  simulator: Optional[Simulator] = None,
428
+ mlflow_client: Optional[MlflowClient] = None,
406
429
  memory: Optional[Memory] = None,
407
430
  economy_client: Optional[EconomyClient] = None,
408
431
  messager: Optional[Messager] = None,
409
432
  avro_file: Optional[dict] = None,
410
433
  ) -> None:
411
434
  super().__init__(
412
- name,
413
- AgentType.Citizen,
414
- llm_client,
415
- economy_client,
416
- messager,
417
- simulator,
418
- memory,
419
- avro_file,
435
+ name=name,
436
+ type=AgentType.Citizen,
437
+ llm_client=llm_client,
438
+ economy_client=economy_client,
439
+ messager=messager,
440
+ simulator=simulator,
441
+ mlflow_client=mlflow_client,
442
+ memory=memory,
443
+ avro_file=avro_file,
420
444
  )
421
445
 
422
446
  async def bind_to_simulator(self):
@@ -464,9 +488,7 @@ class CitizenAgent(Agent):
464
488
  )
465
489
  person_id = resp["person_id"]
466
490
  await memory.update("id", person_id, protect_llm_read_only_fields=False)
467
- logger.debug(
468
- f"Binding to Person `{person_id}` just added to Simulator"
469
- )
491
+ logger.debug(f"Binding to Person `{person_id}` just added to Simulator")
470
492
  # 防止模拟器还没有到prepare阶段导致get_person出错
471
493
  self._has_bound_to_simulator = True
472
494
  self._agent_id = person_id
@@ -517,24 +539,26 @@ class InstitutionAgent(Agent):
517
539
  name: str,
518
540
  llm_client: Optional[LLM] = None,
519
541
  simulator: Optional[Simulator] = None,
542
+ mlflow_client: Optional[MlflowClient] = None,
520
543
  memory: Optional[Memory] = None,
521
544
  economy_client: Optional[EconomyClient] = None,
522
545
  messager: Optional[Messager] = None,
523
546
  avro_file: Optional[dict] = None,
524
547
  ) -> None:
525
548
  super().__init__(
526
- name,
527
- AgentType.Institution,
528
- llm_client,
529
- economy_client,
530
- messager,
531
- simulator,
532
- memory,
533
- avro_file,
549
+ name=name,
550
+ type=AgentType.Institution,
551
+ llm_client=llm_client,
552
+ economy_client=economy_client,
553
+ mlflow_client=mlflow_client,
554
+ messager=messager,
555
+ simulator=simulator,
556
+ memory=memory,
557
+ avro_file=avro_file,
534
558
  )
535
559
  # 添加响应收集器
536
- self._gather_responses: Dict[str, asyncio.Future] = {}
537
-
560
+ self._gather_responses: dict[str, asyncio.Future] = {}
561
+
538
562
  async def bind_to_simulator(self):
539
563
  await self._bind_to_economy()
540
564
 
@@ -624,31 +648,33 @@ class InstitutionAgent(Agent):
624
648
  """处理收到的消息,识别发送者"""
625
649
  content = payload["content"]
626
650
  sender_id = payload["from"]
627
-
651
+
628
652
  # 将响应存储到对应的Future中
629
653
  response_key = str(sender_id)
630
654
  if response_key in self._gather_responses:
631
- self._gather_responses[response_key].set_result({
632
- "from": sender_id,
633
- "content": content,
634
- })
655
+ self._gather_responses[response_key].set_result(
656
+ {
657
+ "from": sender_id,
658
+ "content": content,
659
+ }
660
+ )
635
661
 
636
- async def gather_messages(self, agent_uuids: list[str], target: str) -> List[dict]:
662
+ async def gather_messages(self, agent_uuids: list[str], target: str) -> list[dict]:
637
663
  """从多个智能体收集消息
638
-
664
+
639
665
  Args:
640
666
  agent_uuids: 目标智能体UUID列表
641
667
  target: 要收集的信息类型
642
-
668
+
643
669
  Returns:
644
- List[dict]: 收集到的所有响应
670
+ list[dict]: 收集到的所有响应
645
671
  """
646
672
  # 为每个agent创建Future
647
673
  futures = {}
648
674
  for agent_uuid in agent_uuids:
649
675
  futures[agent_uuid] = asyncio.Future()
650
676
  self._gather_responses[agent_uuid] = futures[agent_uuid]
651
-
677
+
652
678
  # 发送gather请求
653
679
  payload = {
654
680
  "from": self._uuid,
@@ -656,7 +682,7 @@ class InstitutionAgent(Agent):
656
682
  }
657
683
  for agent_uuid in agent_uuids:
658
684
  await self._send_message(agent_uuid, payload, "gather")
659
-
685
+
660
686
  try:
661
687
  # 等待所有响应
662
688
  responses = await asyncio.gather(*futures.values())
@@ -308,11 +308,48 @@ class EconomyClient:
308
308
  # current agent ids and org ids
309
309
  return (list(response.agent_ids), list(response.org_ids))
310
310
 
311
- async def get_org_entity_ids(self, org_type: economyv2.OrgType)->list[int]:
311
+ async def get_org_entity_ids(self, org_type: economyv2.OrgType) -> list[int]:
312
312
  request = org_service.GetOrgEntityIdsRequest(
313
313
  type=org_type,
314
314
  )
315
315
  response: org_service.GetOrgEntityIdsResponse = (
316
316
  await self._aio_stub.GetOrgEntityIds(request)
317
317
  )
318
- return list(response.org_ids)
318
+ return list(response.org_ids)
319
+
320
+ async def add_delta_value(
321
+ self,
322
+ id: int,
323
+ key: str,
324
+ value: Any,
325
+ ) -> Any:
326
+ """
327
+ Add key-value pair
328
+
329
+ Args:
330
+ - id (int): the id of `Org` or `Agent`.
331
+ - key (str): the attribute to update. Can only be `inventory`, `price`, `interest_rate` and `currency`
332
+
333
+
334
+ Returns:
335
+ - Any
336
+ """
337
+ pascal_key = _snake_to_pascal(key)
338
+ _request_type = getattr(org_service, f"Add{pascal_key}Request")
339
+ _request_func = getattr(self._aio_stub, f"Add{pascal_key}")
340
+ _available_keys = {
341
+ "inventory",
342
+ "price",
343
+ "interest_rate",
344
+ "currency",
345
+ }
346
+ if key not in _available_keys:
347
+ raise ValueError(f"Invalid key `{key}`, can only be {_available_keys}!")
348
+ return await _request_func(
349
+ _request_type(
350
+ **{
351
+ "org_id": id,
352
+ f"delta_{key}": value,
353
+ }
354
+ )
355
+ )
@@ -5,7 +5,7 @@ import logging
5
5
  import os
6
6
  from collections.abc import Sequence
7
7
  from datetime import datetime, timedelta
8
- from typing import Any, Optional, Tuple, Union, cast
8
+ from typing import Any, Optional, Union, cast
9
9
 
10
10
  from mosstool.type import TripMode
11
11
  from mosstool.util.format_converter import coll2pb
@@ -28,7 +28,7 @@ class Simulator:
28
28
  - Simulator Class
29
29
  """
30
30
 
31
- def __init__(self, config, secure: bool = False) -> None:
31
+ def __init__(self, config:dict, secure: bool = False) -> None:
32
32
  self.config = config
33
33
  """
34
34
  - 模拟器配置
@@ -1,9 +1,7 @@
1
- from typing import List
2
-
3
1
  __all__ = ["wrap_feature_collection"]
4
2
 
5
3
 
6
- def wrap_feature_collection(features: List[dict], name: str):
4
+ def wrap_feature_collection(features: list[dict], name: str):
7
5
  """
8
6
  将 GeoJSON Feature 集合包装为 FeatureCollection
9
7
  Wrap GeoJSON Feature collection as FeatureCollection
@@ -1,5 +1,5 @@
1
1
  import math
2
- from typing import Dict, List, Literal, Optional, Tuple, Union
2
+ from typing import Literal, Union
3
3
 
4
4
  import numpy as np
5
5
 
@@ -14,12 +14,12 @@ def point_on_line_given_distance(start_node, end_node, distance):
14
14
  return the coordinates of the point reached after traveling s units along the line, starting from start_point.
15
15
 
16
16
  Args:
17
- start_point (tuple): Tuple of (x, y) representing the starting point on the line.
18
- end_point (tuple): Tuple of (x, y) representing the ending point on the line.
17
+ start_point (tuple): tuple of (x, y) representing the starting point on the line.
18
+ end_point (tuple): tuple of (x, y) representing the ending point on the line.
19
19
  distance (float): Distance to travel along the line, starting from start_point.
20
20
 
21
21
  Returns:
22
- tuple: Tuple of (x, y) representing the new point reached after traveling s units along the line.
22
+ tuple: tuple of (x, y) representing the new point reached after traveling s units along the line.
23
23
  """
24
24
 
25
25
  x1, y1 = start_node["x"], start_node["y"]
@@ -49,7 +49,7 @@ def point_on_line_given_distance(start_node, end_node, distance):
49
49
 
50
50
 
51
51
  def get_key_index_in_lane(
52
- nodes: List[Dict[str, float]],
52
+ nodes: list[dict[str, float]],
53
53
  distance: float,
54
54
  direction: Union[Literal["front"], Literal["back"]],
55
55
  ) -> int:
@@ -61,10 +61,10 @@ def get_key_index_in_lane(
61
61
  _index_offset, _index_factor = len(_nodes) - 1, -1
62
62
  else:
63
63
  raise ValueError(f"Invalid direction type {direction}!")
64
- _lane_points: List[Tuple[float, float, float]] = [
64
+ _lane_points: list[tuple[float, float, float]] = [
65
65
  (n["x"], n["y"], n.get("z", 0)) for n in _nodes
66
66
  ]
67
- _line_lengths: List[float] = [0.0 for _ in range(len(_nodes))]
67
+ _line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
68
68
  _s = 0.0
69
69
  for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
70
70
  _s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
@@ -82,20 +82,20 @@ def get_key_index_in_lane(
82
82
 
83
83
 
84
84
  def get_xy_in_lane(
85
- nodes: List[Dict[str, float]],
85
+ nodes: list[dict[str, float]],
86
86
  distance: float,
87
87
  direction: Union[Literal["front"], Literal["back"]],
88
- ) -> Tuple[float, float]:
88
+ ) -> tuple[float, float]:
89
89
  if direction == "front":
90
90
  _nodes = [n for n in nodes]
91
91
  elif direction == "back":
92
92
  _nodes = [n for n in nodes[::-1]]
93
93
  else:
94
94
  raise ValueError(f"Invalid direction type {direction}!")
95
- _lane_points: List[Tuple[float, float, float]] = [
95
+ _lane_points: list[tuple[float, float, float]] = [
96
96
  (n["x"], n["y"], n.get("z", 0)) for n in _nodes
97
97
  ]
98
- _line_lengths: List[float] = [0.0 for _ in range(len(_nodes))]
98
+ _line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
99
99
  _s = 0.0
100
100
  for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
101
101
  _s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
@@ -122,7 +122,7 @@ def get_xy_in_lane(
122
122
 
123
123
 
124
124
  def get_direction_by_s(
125
- nodes: List[Dict[str, float]],
125
+ nodes: list[dict[str, float]],
126
126
  distance: float,
127
127
  direction: Union[Literal["front"], Literal["back"]],
128
128
  ) -> float:
@@ -132,11 +132,11 @@ def get_direction_by_s(
132
132
  _nodes = [n for n in nodes[::-1]]
133
133
  else:
134
134
  raise ValueError(f"Invalid direction type {direction}!")
135
- _lane_points: List[Tuple[float, float, float]] = [
135
+ _lane_points: list[tuple[float, float, float]] = [
136
136
  (n["x"], n["y"], n.get("z", 0)) for n in _nodes
137
137
  ]
138
- _line_lengths: List[float] = [0.0 for _ in range(len(_nodes))]
139
- _line_directions: List[Tuple[float, float]] = []
138
+ _line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
139
+ _line_directions: list[tuple[float, float]] = []
140
140
  _s = 0.0
141
141
  for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
142
142
  _s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
@@ -1,7 +1,6 @@
1
1
  """简单的基于内存的embedding实现"""
2
2
 
3
3
  import numpy as np
4
- from typing import List, Dict, Optional
5
4
  import hashlib
6
5
  import json
7
6
 
@@ -22,34 +21,34 @@ class SimpleEmbedding:
22
21
  """
23
22
  self.vector_dim = vector_dim
24
23
  self.cache_size = cache_size
25
- self._cache: Dict[str, np.ndarray] = {}
26
- self._vocab: Dict[str, int] = {} # 词汇表
27
- self._idf: Dict[str, float] = {} # 逆文档频率
24
+ self._cache: dict[str, np.ndarray] = {}
25
+ self._vocab: dict[str, int] = {} # 词汇表
26
+ self._idf: dict[str, float] = {} # 逆文档频率
28
27
  self._doc_count = 0 # 文档总数
29
28
 
30
29
  def _text_to_hash(self, text: str) -> str:
31
30
  """将文本转换为hash值"""
32
31
  return hashlib.md5(text.encode()).hexdigest()
33
32
 
34
- def _tokenize(self, text: str) -> List[str]:
33
+ def _tokenize(self, text: str) -> list[str]:
35
34
  """简单的分词"""
36
35
  # 这里使用简单的空格分词,实际应用中可以使用更复杂的分词方法
37
36
  return text.lower().split()
38
37
 
39
- def _update_vocab(self, tokens: List[str]):
38
+ def _update_vocab(self, tokens: list[str]):
40
39
  """更新词汇表"""
41
40
  for token in set(tokens): # 使用set去重
42
41
  if token not in self._vocab:
43
42
  self._vocab[token] = len(self._vocab)
44
43
 
45
- def _update_idf(self, tokens: List[str]):
44
+ def _update_idf(self, tokens: list[str]):
46
45
  """更新IDF值"""
47
46
  self._doc_count += 1
48
47
  unique_tokens = set(tokens)
49
48
  for token in unique_tokens:
50
49
  self._idf[token] = self._idf.get(token, 0) + 1
51
50
 
52
- def _calculate_tf(self, tokens: List[str]) -> Dict[str, float]:
51
+ def _calculate_tf(self, tokens: list[str]) -> dict[str, float]:
53
52
  """计算词频(TF)"""
54
53
  tf = {}
55
54
  total_tokens = len(tokens)
@@ -60,7 +59,7 @@ class SimpleEmbedding:
60
59
  tf[token] /= total_tokens
61
60
  return tf
62
61
 
63
- def _calculate_tfidf(self, tokens: List[str]) -> np.ndarray:
62
+ def _calculate_tfidf(self, tokens: list[str]) -> np.ndarray:
64
63
  """计算TF-IDF向量"""
65
64
  vector = np.zeros(self.vector_dim)
66
65
  tf = self._calculate_tf(tokens)
pycityagent/llm/llm.py CHANGED
@@ -14,7 +14,7 @@ import requests
14
14
  from dashscope import ImageSynthesis
15
15
  from PIL import Image
16
16
  from io import BytesIO
17
- from typing import Any, Optional, Union, List, Dict
17
+ from typing import Any, Optional, Union
18
18
  from .llmconfig import *
19
19
  from .utils import *
20
20
 
@@ -117,8 +117,8 @@ Token Usage:
117
117
  presence_penalty: Optional[float] = None,
118
118
  timeout: int = 300,
119
119
  retries=3,
120
- tools: Optional[List[Dict[str, Any]]] = None,
121
- tool_choice: Optional[Dict[str, Any]] = None,
120
+ tools: Optional[list[dict[str, Any]]] = None,
121
+ tool_choice: Optional[dict[str, Any]] = None,
122
122
  ):
123
123
  """
124
124
  异步版文本请求
@@ -227,9 +227,9 @@ Token Usage:
227
227
  self.prompt_tokens_used += result_response.usage.prompt_tokens # type: ignore
228
228
  self.completion_tokens_used += result_response.usage.completion_tokens # type: ignore
229
229
  self.request_number += 1
230
- if tools and result_response.choices[0].message.tool_calls:
230
+ if tools and result_response.choices[0].message.tool_calls: # type: ignore
231
231
  return json.loads(
232
- result_response.choices[0]
232
+ result_response.choices[0] # type: ignore
233
233
  .message.tool_calls[0]
234
234
  .function.arguments
235
235
  )