pycityagent 2.0.0a19__py3-none-any.whl → 2.0.0a21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent.py +173 -70
- pycityagent/economy/econ_client.py +37 -0
- pycityagent/environment/utils/geojson.py +1 -3
- pycityagent/environment/utils/map_utils.py +15 -15
- pycityagent/llm/embedding.py +8 -9
- pycityagent/llm/llm.py +5 -5
- pycityagent/memory/memory.py +23 -22
- pycityagent/metrics/__init__.py +2 -1
- pycityagent/metrics/mlflow_client.py +72 -34
- pycityagent/simulation/__init__.py +2 -1
- pycityagent/simulation/agentgroup.py +131 -3
- pycityagent/simulation/simulation.py +67 -24
- pycityagent/simulation/storage/pg.py +139 -0
- pycityagent/utils/parsers/parser_base.py +1 -1
- pycityagent/utils/pg_query.py +80 -0
- pycityagent/workflow/prompt.py +6 -6
- pycityagent/workflow/tool.py +33 -25
- pycityagent/workflow/trigger.py +2 -2
- {pycityagent-2.0.0a19.dist-info → pycityagent-2.0.0a21.dist-info}/METADATA +3 -2
- {pycityagent-2.0.0a19.dist-info → pycityagent-2.0.0a21.dist-info}/RECORD +21 -19
- {pycityagent-2.0.0a19.dist-info → pycityagent-2.0.0a21.dist-info}/WHEEL +0 -0
pycityagent/agent.py
CHANGED
@@ -1,17 +1,19 @@
|
|
1
1
|
"""智能体模板类及其定义"""
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import json
|
4
5
|
import logging
|
5
6
|
import random
|
6
7
|
import uuid
|
7
8
|
from abc import ABC, abstractmethod
|
8
9
|
from copy import deepcopy
|
9
|
-
from datetime import datetime
|
10
|
+
from datetime import datetime, timezone
|
10
11
|
from enum import Enum
|
11
|
-
from typing import Any,
|
12
|
+
from typing import Any, Optional
|
12
13
|
from uuid import UUID
|
13
14
|
|
14
15
|
import fastavro
|
16
|
+
import ray
|
15
17
|
from mosstool.util.format_converter import dict2pb
|
16
18
|
from pycityproto.city.person.v2 import person_pb2 as person_pb2
|
17
19
|
|
@@ -55,7 +57,8 @@ class Agent(ABC):
|
|
55
57
|
simulator: Optional[Simulator] = None,
|
56
58
|
mlflow_client: Optional[MlflowClient] = None,
|
57
59
|
memory: Optional[Memory] = None,
|
58
|
-
avro_file: Optional[
|
60
|
+
avro_file: Optional[dict[str, str]] = None,
|
61
|
+
copy_writer: Optional[ray.ObjectRef] = None,
|
59
62
|
) -> None:
|
60
63
|
"""
|
61
64
|
Initialize the Agent.
|
@@ -69,7 +72,8 @@ class Agent(ABC):
|
|
69
72
|
simulator (Simulator, optional): The simulator object. Defaults to None.
|
70
73
|
mlflow_client (MlflowClient, optional): The Mlflow object. Defaults to None.
|
71
74
|
memory (Memory, optional): The memory of the agent. Defaults to None.
|
72
|
-
avro_file (
|
75
|
+
avro_file (dict[str, str], optional): The avro file of the agent. Defaults to None.
|
76
|
+
copy_writer (ray.ObjectRef): The copy_writer of the agent. Defaults to None.
|
73
77
|
"""
|
74
78
|
self._name = name
|
75
79
|
self._type = type
|
@@ -85,9 +89,11 @@ class Agent(ABC):
|
|
85
89
|
self._has_bound_to_simulator = False
|
86
90
|
self._has_bound_to_economy = False
|
87
91
|
self._blocked = False
|
88
|
-
self._interview_history:
|
92
|
+
self._interview_history: list[dict] = [] # 存储采访历史
|
89
93
|
self._person_template = PersonService.default_dict_person()
|
90
94
|
self._avro_file = avro_file
|
95
|
+
self._pgsql_writer = copy_writer
|
96
|
+
self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
91
97
|
|
92
98
|
def __getstate__(self):
|
93
99
|
state = self.__dict__.copy()
|
@@ -137,12 +143,18 @@ class Agent(ABC):
|
|
137
143
|
"""
|
138
144
|
self._exp_id = exp_id
|
139
145
|
|
140
|
-
def set_avro_file(self, avro_file:
|
146
|
+
def set_avro_file(self, avro_file: dict[str, str]):
|
141
147
|
"""
|
142
148
|
Set the avro file of the agent.
|
143
149
|
"""
|
144
150
|
self._avro_file = avro_file
|
145
151
|
|
152
|
+
def set_pgsql_writer(self, pgsql_writer: ray.ObjectRef):
|
153
|
+
"""
|
154
|
+
Set the PostgreSQL copy writer of the agent.
|
155
|
+
"""
|
156
|
+
self._pgsql_writer = pgsql_writer
|
157
|
+
|
146
158
|
@property
|
147
159
|
def uuid(self):
|
148
160
|
"""The Agent's UUID"""
|
@@ -198,6 +210,15 @@ class Agent(ABC):
|
|
198
210
|
)
|
199
211
|
return self._simulator
|
200
212
|
|
213
|
+
@property
|
214
|
+
def copy_writer(self):
|
215
|
+
"""Pg Copy Writer"""
|
216
|
+
if self._pgsql_writer is None:
|
217
|
+
raise RuntimeError(
|
218
|
+
f"Copy Writer access before assignment, please `set_pgsql_writer` first!"
|
219
|
+
)
|
220
|
+
return self._pgsql_writer
|
221
|
+
|
201
222
|
async def generate_user_survey_response(self, survey: dict) -> str:
|
202
223
|
"""生成回答 —— 可重写
|
203
224
|
基于智能体的记忆和当前状态,生成对问卷调查的回答。
|
@@ -237,8 +258,8 @@ class Agent(ABC):
|
|
237
258
|
|
238
259
|
async def _process_survey(self, survey: dict):
|
239
260
|
survey_response = await self.generate_user_survey_response(survey)
|
240
|
-
|
241
|
-
|
261
|
+
_date_time = datetime.now(timezone.utc)
|
262
|
+
# Avro
|
242
263
|
response_to_avro = [
|
243
264
|
{
|
244
265
|
"id": self._uuid,
|
@@ -246,11 +267,41 @@ class Agent(ABC):
|
|
246
267
|
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
247
268
|
"survey_id": survey["id"],
|
248
269
|
"result": survey_response,
|
249
|
-
"created_at": int(
|
270
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
250
271
|
}
|
251
272
|
]
|
252
|
-
|
253
|
-
|
273
|
+
if self._avro_file is not None:
|
274
|
+
with open(self._avro_file["survey"], "a+b") as f:
|
275
|
+
fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
|
276
|
+
# Pg
|
277
|
+
if self._pgsql_writer is not None:
|
278
|
+
if self._last_asyncio_pg_task is not None:
|
279
|
+
await self._last_asyncio_pg_task
|
280
|
+
_keys = [
|
281
|
+
"id",
|
282
|
+
"day",
|
283
|
+
"t",
|
284
|
+
"survey_id",
|
285
|
+
"result",
|
286
|
+
]
|
287
|
+
_data_tuples: list[tuple] = []
|
288
|
+
# str to json
|
289
|
+
for _dict in response_to_avro:
|
290
|
+
res = _dict["result"]
|
291
|
+
_dict["result"] = json.dumps(
|
292
|
+
{
|
293
|
+
"result": res,
|
294
|
+
}
|
295
|
+
)
|
296
|
+
_data_list = [_dict[k] for k in _keys]
|
297
|
+
# created_at
|
298
|
+
_data_list.append(_date_time)
|
299
|
+
_data_tuples.append(tuple(_data_list))
|
300
|
+
self._last_asyncio_pg_task = (
|
301
|
+
self._pgsql_writer.async_write_survey.remote( # type:ignore
|
302
|
+
_data_tuples
|
303
|
+
)
|
304
|
+
)
|
254
305
|
|
255
306
|
async def generate_user_chat_response(self, question: str) -> str:
|
256
307
|
"""生成回答 —— 可重写
|
@@ -290,34 +341,52 @@ class Agent(ABC):
|
|
290
341
|
return response # type:ignore
|
291
342
|
|
292
343
|
async def _process_interview(self, payload: dict):
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
344
|
+
pg_list: list[tuple[dict, datetime]] = []
|
345
|
+
auros: list[dict] = []
|
346
|
+
_date_time = datetime.now(timezone.utc)
|
347
|
+
_interview_dict = {
|
348
|
+
"id": self._uuid,
|
349
|
+
"day": await self.simulator.get_simulator_day(),
|
350
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
351
|
+
"type": 2,
|
352
|
+
"speaker": "user",
|
353
|
+
"content": payload["content"],
|
354
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
355
|
+
}
|
356
|
+
auros.append(_interview_dict)
|
357
|
+
pg_list.append((_interview_dict, _date_time))
|
304
358
|
question = payload["content"]
|
305
359
|
response = await self.generate_user_chat_response(question)
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
)
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
360
|
+
_date_time = datetime.now(timezone.utc)
|
361
|
+
_interview_dict = {
|
362
|
+
"id": self._uuid,
|
363
|
+
"day": await self.simulator.get_simulator_day(),
|
364
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
365
|
+
"type": 2,
|
366
|
+
"speaker": "",
|
367
|
+
"content": response,
|
368
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
369
|
+
}
|
370
|
+
auros.append(_interview_dict)
|
371
|
+
pg_list.append((_interview_dict, _date_time))
|
372
|
+
# Avro
|
373
|
+
if self._avro_file is not None:
|
374
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
375
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
376
|
+
# Pg
|
377
|
+
if self._pgsql_writer is not None:
|
378
|
+
if self._last_asyncio_pg_task is not None:
|
379
|
+
await self._last_asyncio_pg_task
|
380
|
+
_keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
|
381
|
+
_data = [
|
382
|
+
tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
|
383
|
+
for _dict, _date_time in pg_list
|
384
|
+
]
|
385
|
+
self._last_asyncio_pg_task = (
|
386
|
+
self._pgsql_writer.async_write_dialog.remote( # type:ignore
|
387
|
+
_data
|
388
|
+
)
|
389
|
+
)
|
321
390
|
|
322
391
|
async def process_agent_chat_response(self, payload: dict) -> str:
|
323
392
|
resp = f"Agent {self._uuid} received agent chat response: {payload}"
|
@@ -325,22 +394,39 @@ class Agent(ABC):
|
|
325
394
|
return resp
|
326
395
|
|
327
396
|
async def _process_agent_chat(self, payload: dict):
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
397
|
+
pg_list: list[tuple[dict, datetime]] = []
|
398
|
+
auros: list[dict] = []
|
399
|
+
_date_time = datetime.now(timezone.utc)
|
400
|
+
_chat_dict = {
|
401
|
+
"id": self._uuid,
|
402
|
+
"day": payload["day"],
|
403
|
+
"t": payload["t"],
|
404
|
+
"type": 1,
|
405
|
+
"speaker": payload["from"],
|
406
|
+
"content": payload["content"],
|
407
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
408
|
+
}
|
409
|
+
auros.append(_chat_dict)
|
410
|
+
pg_list.append((_chat_dict, _date_time))
|
339
411
|
asyncio.create_task(self.process_agent_chat_response(payload))
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
412
|
+
# Avro
|
413
|
+
if self._avro_file is not None:
|
414
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
415
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
416
|
+
# Pg
|
417
|
+
if self._pgsql_writer is not None:
|
418
|
+
if self._last_asyncio_pg_task is not None:
|
419
|
+
await self._last_asyncio_pg_task
|
420
|
+
_keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
|
421
|
+
_data = [
|
422
|
+
tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
|
423
|
+
for _dict, _date_time in pg_list
|
424
|
+
]
|
425
|
+
self._last_asyncio_pg_task = (
|
426
|
+
self._pgsql_writer.async_write_dialog.remote( # type:ignore
|
427
|
+
_data
|
428
|
+
)
|
429
|
+
)
|
344
430
|
|
345
431
|
# Callback functions for MQTT message
|
346
432
|
async def handle_agent_chat_message(self, payload: dict):
|
@@ -384,21 +470,38 @@ class Agent(ABC):
|
|
384
470
|
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
385
471
|
}
|
386
472
|
await self._send_message(to_agent_uuid, payload, "agent-chat")
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
473
|
+
pg_list: list[tuple[dict, datetime]] = []
|
474
|
+
auros: list[dict] = []
|
475
|
+
_date_time = datetime.now(timezone.utc)
|
476
|
+
_message_dict = {
|
477
|
+
"id": self._uuid,
|
478
|
+
"day": await self.simulator.get_simulator_day(),
|
479
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
480
|
+
"type": 1,
|
481
|
+
"speaker": self._uuid,
|
482
|
+
"content": content,
|
483
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
484
|
+
}
|
485
|
+
auros.append(_message_dict)
|
486
|
+
pg_list.append((_message_dict, _date_time))
|
487
|
+
# Avro
|
488
|
+
if self._avro_file is not None:
|
489
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
490
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
491
|
+
# Pg
|
492
|
+
if self._pgsql_writer is not None:
|
493
|
+
if self._last_asyncio_pg_task is not None:
|
494
|
+
await self._last_asyncio_pg_task
|
495
|
+
_keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
|
496
|
+
_data = [
|
497
|
+
tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
|
498
|
+
for _dict, _date_time in pg_list
|
499
|
+
]
|
500
|
+
self._last_asyncio_pg_task = (
|
501
|
+
self._pgsql_writer.async_write_dialog.remote( # type:ignore
|
502
|
+
_data
|
503
|
+
)
|
504
|
+
)
|
402
505
|
|
403
506
|
# Agent logic
|
404
507
|
@abstractmethod
|
@@ -557,7 +660,7 @@ class InstitutionAgent(Agent):
|
|
557
660
|
avro_file=avro_file,
|
558
661
|
)
|
559
662
|
# 添加响应收集器
|
560
|
-
self._gather_responses:
|
663
|
+
self._gather_responses: dict[str, asyncio.Future] = {}
|
561
664
|
|
562
665
|
async def bind_to_simulator(self):
|
563
666
|
await self._bind_to_economy()
|
@@ -659,7 +762,7 @@ class InstitutionAgent(Agent):
|
|
659
762
|
}
|
660
763
|
)
|
661
764
|
|
662
|
-
async def gather_messages(self, agent_uuids: list[str], target: str) ->
|
765
|
+
async def gather_messages(self, agent_uuids: list[str], target: str) -> list[dict]:
|
663
766
|
"""从多个智能体收集消息
|
664
767
|
|
665
768
|
Args:
|
@@ -667,7 +770,7 @@ class InstitutionAgent(Agent):
|
|
667
770
|
target: 要收集的信息类型
|
668
771
|
|
669
772
|
Returns:
|
670
|
-
|
773
|
+
list[dict]: 收集到的所有响应
|
671
774
|
"""
|
672
775
|
# 为每个agent创建Future
|
673
776
|
futures = {}
|
@@ -316,3 +316,40 @@ class EconomyClient:
|
|
316
316
|
await self._aio_stub.GetOrgEntityIds(request)
|
317
317
|
)
|
318
318
|
return list(response.org_ids)
|
319
|
+
|
320
|
+
async def add_delta_value(
|
321
|
+
self,
|
322
|
+
id: int,
|
323
|
+
key: str,
|
324
|
+
value: Any,
|
325
|
+
) -> Any:
|
326
|
+
"""
|
327
|
+
Add key-value pair
|
328
|
+
|
329
|
+
Args:
|
330
|
+
- id (int): the id of `Org` or `Agent`.
|
331
|
+
- key (str): the attribute to update. Can only be `inventory`, `price`, `interest_rate` and `currency`
|
332
|
+
|
333
|
+
|
334
|
+
Returns:
|
335
|
+
- Any
|
336
|
+
"""
|
337
|
+
pascal_key = _snake_to_pascal(key)
|
338
|
+
_request_type = getattr(org_service, f"Add{pascal_key}Request")
|
339
|
+
_request_func = getattr(self._aio_stub, f"Add{pascal_key}")
|
340
|
+
_available_keys = {
|
341
|
+
"inventory",
|
342
|
+
"price",
|
343
|
+
"interest_rate",
|
344
|
+
"currency",
|
345
|
+
}
|
346
|
+
if key not in _available_keys:
|
347
|
+
raise ValueError(f"Invalid key `{key}`, can only be {_available_keys}!")
|
348
|
+
return await _request_func(
|
349
|
+
_request_type(
|
350
|
+
**{
|
351
|
+
"org_id": id,
|
352
|
+
f"delta_{key}": value,
|
353
|
+
}
|
354
|
+
)
|
355
|
+
)
|
@@ -1,9 +1,7 @@
|
|
1
|
-
from typing import List
|
2
|
-
|
3
1
|
__all__ = ["wrap_feature_collection"]
|
4
2
|
|
5
3
|
|
6
|
-
def wrap_feature_collection(features:
|
4
|
+
def wrap_feature_collection(features: list[dict], name: str):
|
7
5
|
"""
|
8
6
|
将 GeoJSON Feature 集合包装为 FeatureCollection
|
9
7
|
Wrap GeoJSON Feature collection as FeatureCollection
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import math
|
2
|
-
from typing import
|
2
|
+
from typing import Literal, Union
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
|
@@ -14,12 +14,12 @@ def point_on_line_given_distance(start_node, end_node, distance):
|
|
14
14
|
return the coordinates of the point reached after traveling s units along the line, starting from start_point.
|
15
15
|
|
16
16
|
Args:
|
17
|
-
start_point (tuple):
|
18
|
-
end_point (tuple):
|
17
|
+
start_point (tuple): tuple of (x, y) representing the starting point on the line.
|
18
|
+
end_point (tuple): tuple of (x, y) representing the ending point on the line.
|
19
19
|
distance (float): Distance to travel along the line, starting from start_point.
|
20
20
|
|
21
21
|
Returns:
|
22
|
-
tuple:
|
22
|
+
tuple: tuple of (x, y) representing the new point reached after traveling s units along the line.
|
23
23
|
"""
|
24
24
|
|
25
25
|
x1, y1 = start_node["x"], start_node["y"]
|
@@ -49,7 +49,7 @@ def point_on_line_given_distance(start_node, end_node, distance):
|
|
49
49
|
|
50
50
|
|
51
51
|
def get_key_index_in_lane(
|
52
|
-
nodes:
|
52
|
+
nodes: list[dict[str, float]],
|
53
53
|
distance: float,
|
54
54
|
direction: Union[Literal["front"], Literal["back"]],
|
55
55
|
) -> int:
|
@@ -61,10 +61,10 @@ def get_key_index_in_lane(
|
|
61
61
|
_index_offset, _index_factor = len(_nodes) - 1, -1
|
62
62
|
else:
|
63
63
|
raise ValueError(f"Invalid direction type {direction}!")
|
64
|
-
_lane_points:
|
64
|
+
_lane_points: list[tuple[float, float, float]] = [
|
65
65
|
(n["x"], n["y"], n.get("z", 0)) for n in _nodes
|
66
66
|
]
|
67
|
-
_line_lengths:
|
67
|
+
_line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
|
68
68
|
_s = 0.0
|
69
69
|
for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
|
70
70
|
_s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
|
@@ -82,20 +82,20 @@ def get_key_index_in_lane(
|
|
82
82
|
|
83
83
|
|
84
84
|
def get_xy_in_lane(
|
85
|
-
nodes:
|
85
|
+
nodes: list[dict[str, float]],
|
86
86
|
distance: float,
|
87
87
|
direction: Union[Literal["front"], Literal["back"]],
|
88
|
-
) ->
|
88
|
+
) -> tuple[float, float]:
|
89
89
|
if direction == "front":
|
90
90
|
_nodes = [n for n in nodes]
|
91
91
|
elif direction == "back":
|
92
92
|
_nodes = [n for n in nodes[::-1]]
|
93
93
|
else:
|
94
94
|
raise ValueError(f"Invalid direction type {direction}!")
|
95
|
-
_lane_points:
|
95
|
+
_lane_points: list[tuple[float, float, float]] = [
|
96
96
|
(n["x"], n["y"], n.get("z", 0)) for n in _nodes
|
97
97
|
]
|
98
|
-
_line_lengths:
|
98
|
+
_line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
|
99
99
|
_s = 0.0
|
100
100
|
for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
|
101
101
|
_s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
|
@@ -122,7 +122,7 @@ def get_xy_in_lane(
|
|
122
122
|
|
123
123
|
|
124
124
|
def get_direction_by_s(
|
125
|
-
nodes:
|
125
|
+
nodes: list[dict[str, float]],
|
126
126
|
distance: float,
|
127
127
|
direction: Union[Literal["front"], Literal["back"]],
|
128
128
|
) -> float:
|
@@ -132,11 +132,11 @@ def get_direction_by_s(
|
|
132
132
|
_nodes = [n for n in nodes[::-1]]
|
133
133
|
else:
|
134
134
|
raise ValueError(f"Invalid direction type {direction}!")
|
135
|
-
_lane_points:
|
135
|
+
_lane_points: list[tuple[float, float, float]] = [
|
136
136
|
(n["x"], n["y"], n.get("z", 0)) for n in _nodes
|
137
137
|
]
|
138
|
-
_line_lengths:
|
139
|
-
_line_directions:
|
138
|
+
_line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
|
139
|
+
_line_directions: list[tuple[float, float]] = []
|
140
140
|
_s = 0.0
|
141
141
|
for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
|
142
142
|
_s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
|
pycityagent/llm/embedding.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
"""简单的基于内存的embedding实现"""
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
-
from typing import List, Dict, Optional
|
5
4
|
import hashlib
|
6
5
|
import json
|
7
6
|
|
@@ -22,34 +21,34 @@ class SimpleEmbedding:
|
|
22
21
|
"""
|
23
22
|
self.vector_dim = vector_dim
|
24
23
|
self.cache_size = cache_size
|
25
|
-
self._cache:
|
26
|
-
self._vocab:
|
27
|
-
self._idf:
|
24
|
+
self._cache: dict[str, np.ndarray] = {}
|
25
|
+
self._vocab: dict[str, int] = {} # 词汇表
|
26
|
+
self._idf: dict[str, float] = {} # 逆文档频率
|
28
27
|
self._doc_count = 0 # 文档总数
|
29
28
|
|
30
29
|
def _text_to_hash(self, text: str) -> str:
|
31
30
|
"""将文本转换为hash值"""
|
32
31
|
return hashlib.md5(text.encode()).hexdigest()
|
33
32
|
|
34
|
-
def _tokenize(self, text: str) ->
|
33
|
+
def _tokenize(self, text: str) -> list[str]:
|
35
34
|
"""简单的分词"""
|
36
35
|
# 这里使用简单的空格分词,实际应用中可以使用更复杂的分词方法
|
37
36
|
return text.lower().split()
|
38
37
|
|
39
|
-
def _update_vocab(self, tokens:
|
38
|
+
def _update_vocab(self, tokens: list[str]):
|
40
39
|
"""更新词汇表"""
|
41
40
|
for token in set(tokens): # 使用set去重
|
42
41
|
if token not in self._vocab:
|
43
42
|
self._vocab[token] = len(self._vocab)
|
44
43
|
|
45
|
-
def _update_idf(self, tokens:
|
44
|
+
def _update_idf(self, tokens: list[str]):
|
46
45
|
"""更新IDF值"""
|
47
46
|
self._doc_count += 1
|
48
47
|
unique_tokens = set(tokens)
|
49
48
|
for token in unique_tokens:
|
50
49
|
self._idf[token] = self._idf.get(token, 0) + 1
|
51
50
|
|
52
|
-
def _calculate_tf(self, tokens:
|
51
|
+
def _calculate_tf(self, tokens: list[str]) -> dict[str, float]:
|
53
52
|
"""计算词频(TF)"""
|
54
53
|
tf = {}
|
55
54
|
total_tokens = len(tokens)
|
@@ -60,7 +59,7 @@ class SimpleEmbedding:
|
|
60
59
|
tf[token] /= total_tokens
|
61
60
|
return tf
|
62
61
|
|
63
|
-
def _calculate_tfidf(self, tokens:
|
62
|
+
def _calculate_tfidf(self, tokens: list[str]) -> np.ndarray:
|
64
63
|
"""计算TF-IDF向量"""
|
65
64
|
vector = np.zeros(self.vector_dim)
|
66
65
|
tf = self._calculate_tf(tokens)
|
pycityagent/llm/llm.py
CHANGED
@@ -14,7 +14,7 @@ import requests
|
|
14
14
|
from dashscope import ImageSynthesis
|
15
15
|
from PIL import Image
|
16
16
|
from io import BytesIO
|
17
|
-
from typing import Any, Optional, Union
|
17
|
+
from typing import Any, Optional, Union
|
18
18
|
from .llmconfig import *
|
19
19
|
from .utils import *
|
20
20
|
|
@@ -117,8 +117,8 @@ Token Usage:
|
|
117
117
|
presence_penalty: Optional[float] = None,
|
118
118
|
timeout: int = 300,
|
119
119
|
retries=3,
|
120
|
-
tools: Optional[
|
121
|
-
tool_choice: Optional[
|
120
|
+
tools: Optional[list[dict[str, Any]]] = None,
|
121
|
+
tool_choice: Optional[dict[str, Any]] = None,
|
122
122
|
):
|
123
123
|
"""
|
124
124
|
异步版文本请求
|
@@ -227,9 +227,9 @@ Token Usage:
|
|
227
227
|
self.prompt_tokens_used += result_response.usage.prompt_tokens # type: ignore
|
228
228
|
self.completion_tokens_used += result_response.usage.completion_tokens # type: ignore
|
229
229
|
self.request_number += 1
|
230
|
-
if tools and result_response.choices[0].message.tool_calls:
|
230
|
+
if tools and result_response.choices[0].message.tool_calls: # type: ignore
|
231
231
|
return json.loads(
|
232
|
-
result_response.choices[0]
|
232
|
+
result_response.choices[0] # type: ignore
|
233
233
|
.message.tool_calls[0]
|
234
234
|
.function.arguments
|
235
235
|
)
|