pycityagent 2.0.0a18__py3-none-any.whl → 2.0.0a20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent.py +126 -100
- pycityagent/economy/econ_client.py +39 -2
- pycityagent/environment/simulator.py +2 -2
- pycityagent/environment/utils/geojson.py +1 -3
- pycityagent/environment/utils/map_utils.py +15 -15
- pycityagent/llm/embedding.py +8 -9
- pycityagent/llm/llm.py +5 -5
- pycityagent/memory/memory.py +23 -22
- pycityagent/metrics/__init__.py +6 -0
- pycityagent/metrics/mlflow_client.py +147 -0
- pycityagent/metrics/utils/const.py +0 -0
- pycityagent/simulation/agentgroup.py +58 -21
- pycityagent/simulation/simulation.py +114 -38
- pycityagent/utils/parsers/parser_base.py +1 -1
- pycityagent/workflow/__init__.py +5 -3
- pycityagent/workflow/block.py +2 -3
- pycityagent/workflow/prompt.py +6 -6
- pycityagent/workflow/tool.py +53 -4
- pycityagent/workflow/trigger.py +2 -2
- {pycityagent-2.0.0a18.dist-info → pycityagent-2.0.0a20.dist-info}/METADATA +4 -2
- {pycityagent-2.0.0a18.dist-info → pycityagent-2.0.0a20.dist-info}/RECORD +22 -19
- {pycityagent-2.0.0a18.dist-info → pycityagent-2.0.0a20.dist-info}/WHEEL +0 -0
pycityagent/agent.py
CHANGED
@@ -1,30 +1,28 @@
|
|
1
1
|
"""智能体模板类及其定义"""
|
2
2
|
|
3
|
-
from abc import ABC, abstractmethod
|
4
3
|
import asyncio
|
5
|
-
from uuid import UUID
|
6
|
-
from copy import deepcopy
|
7
|
-
from datetime import datetime
|
8
|
-
from enum import Enum
|
9
4
|
import logging
|
10
5
|
import random
|
11
6
|
import uuid
|
12
|
-
from
|
7
|
+
from abc import ABC, abstractmethod
|
8
|
+
from copy import deepcopy
|
9
|
+
from datetime import datetime
|
10
|
+
from enum import Enum
|
11
|
+
from typing import Any, Optional
|
12
|
+
from uuid import UUID
|
13
13
|
|
14
14
|
import fastavro
|
15
|
-
|
16
|
-
from pycityagent.environment.sim.person_service import PersonService
|
17
15
|
from mosstool.util.format_converter import dict2pb
|
18
16
|
from pycityproto.city.person.v2 import person_pb2 as person_pb2
|
19
|
-
from pycityagent.utils import process_survey_for_llm
|
20
|
-
|
21
|
-
from pycityagent.message.messager import Messager
|
22
|
-
from pycityagent.utils import SURVEY_SCHEMA, DIALOG_SCHEMA
|
23
17
|
|
24
18
|
from .economy import EconomyClient
|
25
19
|
from .environment import Simulator
|
20
|
+
from .environment.sim.person_service import PersonService
|
26
21
|
from .llm import LLM
|
27
22
|
from .memory import Memory
|
23
|
+
from .message.messager import Messager
|
24
|
+
from .metrics import MlflowClient
|
25
|
+
from .utils import DIALOG_SCHEMA, SURVEY_SCHEMA, process_survey_for_llm
|
28
26
|
|
29
27
|
logger = logging.getLogger("pycityagent")
|
30
28
|
|
@@ -55,8 +53,9 @@ class Agent(ABC):
|
|
55
53
|
economy_client: Optional[EconomyClient] = None,
|
56
54
|
messager: Optional[Messager] = None,
|
57
55
|
simulator: Optional[Simulator] = None,
|
56
|
+
mlflow_client: Optional[MlflowClient] = None,
|
58
57
|
memory: Optional[Memory] = None,
|
59
|
-
avro_file: Optional[
|
58
|
+
avro_file: Optional[dict[str, str]] = None,
|
60
59
|
) -> None:
|
61
60
|
"""
|
62
61
|
Initialize the Agent.
|
@@ -68,8 +67,9 @@ class Agent(ABC):
|
|
68
67
|
economy_client (EconomyClient): The `EconomySim` client. Defaults to None.
|
69
68
|
messager (Messager, optional): The messager object. Defaults to None.
|
70
69
|
simulator (Simulator, optional): The simulator object. Defaults to None.
|
70
|
+
mlflow_client (MlflowClient, optional): The Mlflow object. Defaults to None.
|
71
71
|
memory (Memory, optional): The memory of the agent. Defaults to None.
|
72
|
-
avro_file (
|
72
|
+
avro_file (dict[str, str], optional): The avro file of the agent. Defaults to None.
|
73
73
|
"""
|
74
74
|
self._name = name
|
75
75
|
self._type = type
|
@@ -78,13 +78,14 @@ class Agent(ABC):
|
|
78
78
|
self._economy_client = economy_client
|
79
79
|
self._messager = messager
|
80
80
|
self._simulator = simulator
|
81
|
+
self._mlflow_client = mlflow_client
|
81
82
|
self._memory = memory
|
82
83
|
self._exp_id = -1
|
83
84
|
self._agent_id = -1
|
84
85
|
self._has_bound_to_simulator = False
|
85
86
|
self._has_bound_to_economy = False
|
86
87
|
self._blocked = False
|
87
|
-
self._interview_history:
|
88
|
+
self._interview_history: list[dict] = [] # 存储采访历史
|
88
89
|
self._person_template = PersonService.default_dict_person()
|
89
90
|
self._avro_file = avro_file
|
90
91
|
|
@@ -112,6 +113,12 @@ class Agent(ABC):
|
|
112
113
|
"""
|
113
114
|
self._simulator = simulator
|
114
115
|
|
116
|
+
def set_mlflow_client(self, mlflow_client: MlflowClient):
|
117
|
+
"""
|
118
|
+
Set the mlflow_client of the agent.
|
119
|
+
"""
|
120
|
+
self._mlflow_client = mlflow_client
|
121
|
+
|
115
122
|
def set_economy_client(self, economy_client: EconomyClient):
|
116
123
|
"""
|
117
124
|
Set the economy_client of the agent.
|
@@ -130,7 +137,7 @@ class Agent(ABC):
|
|
130
137
|
"""
|
131
138
|
self._exp_id = exp_id
|
132
139
|
|
133
|
-
def set_avro_file(self, avro_file:
|
140
|
+
def set_avro_file(self, avro_file: dict[str, str]):
|
134
141
|
"""
|
135
142
|
Set the avro file of the agent.
|
136
143
|
"""
|
@@ -164,6 +171,15 @@ class Agent(ABC):
|
|
164
171
|
)
|
165
172
|
return self._economy_client
|
166
173
|
|
174
|
+
@property
|
175
|
+
def mlflow_client(self):
|
176
|
+
"""The Agent's MlflowClient"""
|
177
|
+
if self._mlflow_client is None:
|
178
|
+
raise RuntimeError(
|
179
|
+
f"MlflowClient access before assignment, please `set_mlflow_client` first!"
|
180
|
+
)
|
181
|
+
return self._mlflow_client
|
182
|
+
|
167
183
|
@property
|
168
184
|
def memory(self):
|
169
185
|
"""The Agent's Memory"""
|
@@ -218,19 +234,21 @@ class Agent(ABC):
|
|
218
234
|
response = await self._llm_client.atext_request(dialog) # type:ignore
|
219
235
|
|
220
236
|
return response # type:ignore
|
221
|
-
|
237
|
+
|
222
238
|
async def _process_survey(self, survey: dict):
|
223
239
|
survey_response = await self.generate_user_survey_response(survey)
|
224
240
|
if self._avro_file is None:
|
225
241
|
return
|
226
|
-
response_to_avro = [
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
242
|
+
response_to_avro = [
|
243
|
+
{
|
244
|
+
"id": self._uuid,
|
245
|
+
"day": await self.simulator.get_simulator_day(),
|
246
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
247
|
+
"survey_id": survey["id"],
|
248
|
+
"result": survey_response,
|
249
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
250
|
+
}
|
251
|
+
]
|
234
252
|
with open(self._avro_file["survey"], "a+b") as f:
|
235
253
|
fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
|
236
254
|
|
@@ -270,28 +288,32 @@ class Agent(ABC):
|
|
270
288
|
response = await self._llm_client.atext_request(dialog) # type:ignore
|
271
289
|
|
272
290
|
return response # type:ignore
|
273
|
-
|
291
|
+
|
274
292
|
async def _process_interview(self, payload: dict):
|
275
|
-
auros = [
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
293
|
+
auros = [
|
294
|
+
{
|
295
|
+
"id": self._uuid,
|
296
|
+
"day": await self.simulator.get_simulator_day(),
|
297
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
298
|
+
"type": 2,
|
299
|
+
"speaker": "user",
|
300
|
+
"content": payload["content"],
|
301
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
302
|
+
}
|
303
|
+
]
|
284
304
|
question = payload["content"]
|
285
305
|
response = await self.generate_user_chat_response(question)
|
286
|
-
auros.append(
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
306
|
+
auros.append(
|
307
|
+
{
|
308
|
+
"id": self._uuid,
|
309
|
+
"day": await self.simulator.get_simulator_day(),
|
310
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
311
|
+
"type": 2,
|
312
|
+
"speaker": "",
|
313
|
+
"content": response,
|
314
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
315
|
+
}
|
316
|
+
)
|
295
317
|
if self._avro_file is None:
|
296
318
|
return
|
297
319
|
with open(self._avro_file["dialog"], "a+b") as f:
|
@@ -303,15 +325,17 @@ class Agent(ABC):
|
|
303
325
|
return resp
|
304
326
|
|
305
327
|
async def _process_agent_chat(self, payload: dict):
|
306
|
-
auros = [
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
328
|
+
auros = [
|
329
|
+
{
|
330
|
+
"id": self._uuid,
|
331
|
+
"day": payload["day"],
|
332
|
+
"t": payload["t"],
|
333
|
+
"type": 1,
|
334
|
+
"speaker": payload["from"],
|
335
|
+
"content": payload["content"],
|
336
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
337
|
+
}
|
338
|
+
]
|
315
339
|
asyncio.create_task(self.process_agent_chat_response(payload))
|
316
340
|
if self._avro_file is None:
|
317
341
|
return
|
@@ -341,18 +365,14 @@ class Agent(ABC):
|
|
341
365
|
raise NotImplementedError
|
342
366
|
|
343
367
|
# MQTT send message
|
344
|
-
async def _send_message(
|
345
|
-
self, to_agent_uuid: str, payload: dict, sub_topic: str
|
346
|
-
):
|
368
|
+
async def _send_message(self, to_agent_uuid: str, payload: dict, sub_topic: str):
|
347
369
|
"""通过 Messager 发送消息"""
|
348
370
|
if self._messager is None:
|
349
371
|
raise RuntimeError("Messager is not set")
|
350
372
|
topic = f"exps/{self._exp_id}/agents/{to_agent_uuid}/{sub_topic}"
|
351
373
|
await self._messager.send_message(topic, payload)
|
352
374
|
|
353
|
-
async def send_message_to_agent(
|
354
|
-
self, to_agent_uuid: str, content: str
|
355
|
-
):
|
375
|
+
async def send_message_to_agent(self, to_agent_uuid: str, content: str):
|
356
376
|
"""通过 Messager 发送消息"""
|
357
377
|
if self._messager is None:
|
358
378
|
raise RuntimeError("Messager is not set")
|
@@ -364,15 +384,17 @@ class Agent(ABC):
|
|
364
384
|
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
365
385
|
}
|
366
386
|
await self._send_message(to_agent_uuid, payload, "agent-chat")
|
367
|
-
auros = [
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
387
|
+
auros = [
|
388
|
+
{
|
389
|
+
"id": self._uuid,
|
390
|
+
"day": await self.simulator.get_simulator_day(),
|
391
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
392
|
+
"type": 1,
|
393
|
+
"speaker": self._uuid,
|
394
|
+
"content": content,
|
395
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
396
|
+
}
|
397
|
+
]
|
376
398
|
if self._avro_file is None:
|
377
399
|
return
|
378
400
|
with open(self._avro_file["dialog"], "a+b") as f:
|
@@ -403,20 +425,22 @@ class CitizenAgent(Agent):
|
|
403
425
|
name: str,
|
404
426
|
llm_client: Optional[LLM] = None,
|
405
427
|
simulator: Optional[Simulator] = None,
|
428
|
+
mlflow_client: Optional[MlflowClient] = None,
|
406
429
|
memory: Optional[Memory] = None,
|
407
430
|
economy_client: Optional[EconomyClient] = None,
|
408
431
|
messager: Optional[Messager] = None,
|
409
432
|
avro_file: Optional[dict] = None,
|
410
433
|
) -> None:
|
411
434
|
super().__init__(
|
412
|
-
name,
|
413
|
-
AgentType.Citizen,
|
414
|
-
llm_client,
|
415
|
-
economy_client,
|
416
|
-
messager,
|
417
|
-
simulator,
|
418
|
-
|
419
|
-
|
435
|
+
name=name,
|
436
|
+
type=AgentType.Citizen,
|
437
|
+
llm_client=llm_client,
|
438
|
+
economy_client=economy_client,
|
439
|
+
messager=messager,
|
440
|
+
simulator=simulator,
|
441
|
+
mlflow_client=mlflow_client,
|
442
|
+
memory=memory,
|
443
|
+
avro_file=avro_file,
|
420
444
|
)
|
421
445
|
|
422
446
|
async def bind_to_simulator(self):
|
@@ -464,9 +488,7 @@ class CitizenAgent(Agent):
|
|
464
488
|
)
|
465
489
|
person_id = resp["person_id"]
|
466
490
|
await memory.update("id", person_id, protect_llm_read_only_fields=False)
|
467
|
-
logger.debug(
|
468
|
-
f"Binding to Person `{person_id}` just added to Simulator"
|
469
|
-
)
|
491
|
+
logger.debug(f"Binding to Person `{person_id}` just added to Simulator")
|
470
492
|
# 防止模拟器还没有到prepare阶段导致get_person出错
|
471
493
|
self._has_bound_to_simulator = True
|
472
494
|
self._agent_id = person_id
|
@@ -517,24 +539,26 @@ class InstitutionAgent(Agent):
|
|
517
539
|
name: str,
|
518
540
|
llm_client: Optional[LLM] = None,
|
519
541
|
simulator: Optional[Simulator] = None,
|
542
|
+
mlflow_client: Optional[MlflowClient] = None,
|
520
543
|
memory: Optional[Memory] = None,
|
521
544
|
economy_client: Optional[EconomyClient] = None,
|
522
545
|
messager: Optional[Messager] = None,
|
523
546
|
avro_file: Optional[dict] = None,
|
524
547
|
) -> None:
|
525
548
|
super().__init__(
|
526
|
-
name,
|
527
|
-
AgentType.Institution,
|
528
|
-
llm_client,
|
529
|
-
economy_client,
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
549
|
+
name=name,
|
550
|
+
type=AgentType.Institution,
|
551
|
+
llm_client=llm_client,
|
552
|
+
economy_client=economy_client,
|
553
|
+
mlflow_client=mlflow_client,
|
554
|
+
messager=messager,
|
555
|
+
simulator=simulator,
|
556
|
+
memory=memory,
|
557
|
+
avro_file=avro_file,
|
534
558
|
)
|
535
559
|
# 添加响应收集器
|
536
|
-
self._gather_responses:
|
537
|
-
|
560
|
+
self._gather_responses: dict[str, asyncio.Future] = {}
|
561
|
+
|
538
562
|
async def bind_to_simulator(self):
|
539
563
|
await self._bind_to_economy()
|
540
564
|
|
@@ -624,31 +648,33 @@ class InstitutionAgent(Agent):
|
|
624
648
|
"""处理收到的消息,识别发送者"""
|
625
649
|
content = payload["content"]
|
626
650
|
sender_id = payload["from"]
|
627
|
-
|
651
|
+
|
628
652
|
# 将响应存储到对应的Future中
|
629
653
|
response_key = str(sender_id)
|
630
654
|
if response_key in self._gather_responses:
|
631
|
-
self._gather_responses[response_key].set_result(
|
632
|
-
|
633
|
-
|
634
|
-
|
655
|
+
self._gather_responses[response_key].set_result(
|
656
|
+
{
|
657
|
+
"from": sender_id,
|
658
|
+
"content": content,
|
659
|
+
}
|
660
|
+
)
|
635
661
|
|
636
|
-
async def gather_messages(self, agent_uuids: list[str], target: str) ->
|
662
|
+
async def gather_messages(self, agent_uuids: list[str], target: str) -> list[dict]:
|
637
663
|
"""从多个智能体收集消息
|
638
|
-
|
664
|
+
|
639
665
|
Args:
|
640
666
|
agent_uuids: 目标智能体UUID列表
|
641
667
|
target: 要收集的信息类型
|
642
|
-
|
668
|
+
|
643
669
|
Returns:
|
644
|
-
|
670
|
+
list[dict]: 收集到的所有响应
|
645
671
|
"""
|
646
672
|
# 为每个agent创建Future
|
647
673
|
futures = {}
|
648
674
|
for agent_uuid in agent_uuids:
|
649
675
|
futures[agent_uuid] = asyncio.Future()
|
650
676
|
self._gather_responses[agent_uuid] = futures[agent_uuid]
|
651
|
-
|
677
|
+
|
652
678
|
# 发送gather请求
|
653
679
|
payload = {
|
654
680
|
"from": self._uuid,
|
@@ -656,7 +682,7 @@ class InstitutionAgent(Agent):
|
|
656
682
|
}
|
657
683
|
for agent_uuid in agent_uuids:
|
658
684
|
await self._send_message(agent_uuid, payload, "gather")
|
659
|
-
|
685
|
+
|
660
686
|
try:
|
661
687
|
# 等待所有响应
|
662
688
|
responses = await asyncio.gather(*futures.values())
|
@@ -308,11 +308,48 @@ class EconomyClient:
|
|
308
308
|
# current agent ids and org ids
|
309
309
|
return (list(response.agent_ids), list(response.org_ids))
|
310
310
|
|
311
|
-
async def get_org_entity_ids(self, org_type: economyv2.OrgType)->list[int]:
|
311
|
+
async def get_org_entity_ids(self, org_type: economyv2.OrgType) -> list[int]:
|
312
312
|
request = org_service.GetOrgEntityIdsRequest(
|
313
313
|
type=org_type,
|
314
314
|
)
|
315
315
|
response: org_service.GetOrgEntityIdsResponse = (
|
316
316
|
await self._aio_stub.GetOrgEntityIds(request)
|
317
317
|
)
|
318
|
-
return list(response.org_ids)
|
318
|
+
return list(response.org_ids)
|
319
|
+
|
320
|
+
async def add_delta_value(
|
321
|
+
self,
|
322
|
+
id: int,
|
323
|
+
key: str,
|
324
|
+
value: Any,
|
325
|
+
) -> Any:
|
326
|
+
"""
|
327
|
+
Add key-value pair
|
328
|
+
|
329
|
+
Args:
|
330
|
+
- id (int): the id of `Org` or `Agent`.
|
331
|
+
- key (str): the attribute to update. Can only be `inventory`, `price`, `interest_rate` and `currency`
|
332
|
+
|
333
|
+
|
334
|
+
Returns:
|
335
|
+
- Any
|
336
|
+
"""
|
337
|
+
pascal_key = _snake_to_pascal(key)
|
338
|
+
_request_type = getattr(org_service, f"Add{pascal_key}Request")
|
339
|
+
_request_func = getattr(self._aio_stub, f"Add{pascal_key}")
|
340
|
+
_available_keys = {
|
341
|
+
"inventory",
|
342
|
+
"price",
|
343
|
+
"interest_rate",
|
344
|
+
"currency",
|
345
|
+
}
|
346
|
+
if key not in _available_keys:
|
347
|
+
raise ValueError(f"Invalid key `{key}`, can only be {_available_keys}!")
|
348
|
+
return await _request_func(
|
349
|
+
_request_type(
|
350
|
+
**{
|
351
|
+
"org_id": id,
|
352
|
+
f"delta_{key}": value,
|
353
|
+
}
|
354
|
+
)
|
355
|
+
)
|
@@ -5,7 +5,7 @@ import logging
|
|
5
5
|
import os
|
6
6
|
from collections.abc import Sequence
|
7
7
|
from datetime import datetime, timedelta
|
8
|
-
from typing import Any, Optional,
|
8
|
+
from typing import Any, Optional, Union, cast
|
9
9
|
|
10
10
|
from mosstool.type import TripMode
|
11
11
|
from mosstool.util.format_converter import coll2pb
|
@@ -28,7 +28,7 @@ class Simulator:
|
|
28
28
|
- Simulator Class
|
29
29
|
"""
|
30
30
|
|
31
|
-
def __init__(self, config, secure: bool = False) -> None:
|
31
|
+
def __init__(self, config:dict, secure: bool = False) -> None:
|
32
32
|
self.config = config
|
33
33
|
"""
|
34
34
|
- 模拟器配置
|
@@ -1,9 +1,7 @@
|
|
1
|
-
from typing import List
|
2
|
-
|
3
1
|
__all__ = ["wrap_feature_collection"]
|
4
2
|
|
5
3
|
|
6
|
-
def wrap_feature_collection(features:
|
4
|
+
def wrap_feature_collection(features: list[dict], name: str):
|
7
5
|
"""
|
8
6
|
将 GeoJSON Feature 集合包装为 FeatureCollection
|
9
7
|
Wrap GeoJSON Feature collection as FeatureCollection
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import math
|
2
|
-
from typing import
|
2
|
+
from typing import Literal, Union
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
|
@@ -14,12 +14,12 @@ def point_on_line_given_distance(start_node, end_node, distance):
|
|
14
14
|
return the coordinates of the point reached after traveling s units along the line, starting from start_point.
|
15
15
|
|
16
16
|
Args:
|
17
|
-
start_point (tuple):
|
18
|
-
end_point (tuple):
|
17
|
+
start_point (tuple): tuple of (x, y) representing the starting point on the line.
|
18
|
+
end_point (tuple): tuple of (x, y) representing the ending point on the line.
|
19
19
|
distance (float): Distance to travel along the line, starting from start_point.
|
20
20
|
|
21
21
|
Returns:
|
22
|
-
tuple:
|
22
|
+
tuple: tuple of (x, y) representing the new point reached after traveling s units along the line.
|
23
23
|
"""
|
24
24
|
|
25
25
|
x1, y1 = start_node["x"], start_node["y"]
|
@@ -49,7 +49,7 @@ def point_on_line_given_distance(start_node, end_node, distance):
|
|
49
49
|
|
50
50
|
|
51
51
|
def get_key_index_in_lane(
|
52
|
-
nodes:
|
52
|
+
nodes: list[dict[str, float]],
|
53
53
|
distance: float,
|
54
54
|
direction: Union[Literal["front"], Literal["back"]],
|
55
55
|
) -> int:
|
@@ -61,10 +61,10 @@ def get_key_index_in_lane(
|
|
61
61
|
_index_offset, _index_factor = len(_nodes) - 1, -1
|
62
62
|
else:
|
63
63
|
raise ValueError(f"Invalid direction type {direction}!")
|
64
|
-
_lane_points:
|
64
|
+
_lane_points: list[tuple[float, float, float]] = [
|
65
65
|
(n["x"], n["y"], n.get("z", 0)) for n in _nodes
|
66
66
|
]
|
67
|
-
_line_lengths:
|
67
|
+
_line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
|
68
68
|
_s = 0.0
|
69
69
|
for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
|
70
70
|
_s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
|
@@ -82,20 +82,20 @@ def get_key_index_in_lane(
|
|
82
82
|
|
83
83
|
|
84
84
|
def get_xy_in_lane(
|
85
|
-
nodes:
|
85
|
+
nodes: list[dict[str, float]],
|
86
86
|
distance: float,
|
87
87
|
direction: Union[Literal["front"], Literal["back"]],
|
88
|
-
) ->
|
88
|
+
) -> tuple[float, float]:
|
89
89
|
if direction == "front":
|
90
90
|
_nodes = [n for n in nodes]
|
91
91
|
elif direction == "back":
|
92
92
|
_nodes = [n for n in nodes[::-1]]
|
93
93
|
else:
|
94
94
|
raise ValueError(f"Invalid direction type {direction}!")
|
95
|
-
_lane_points:
|
95
|
+
_lane_points: list[tuple[float, float, float]] = [
|
96
96
|
(n["x"], n["y"], n.get("z", 0)) for n in _nodes
|
97
97
|
]
|
98
|
-
_line_lengths:
|
98
|
+
_line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
|
99
99
|
_s = 0.0
|
100
100
|
for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
|
101
101
|
_s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
|
@@ -122,7 +122,7 @@ def get_xy_in_lane(
|
|
122
122
|
|
123
123
|
|
124
124
|
def get_direction_by_s(
|
125
|
-
nodes:
|
125
|
+
nodes: list[dict[str, float]],
|
126
126
|
distance: float,
|
127
127
|
direction: Union[Literal["front"], Literal["back"]],
|
128
128
|
) -> float:
|
@@ -132,11 +132,11 @@ def get_direction_by_s(
|
|
132
132
|
_nodes = [n for n in nodes[::-1]]
|
133
133
|
else:
|
134
134
|
raise ValueError(f"Invalid direction type {direction}!")
|
135
|
-
_lane_points:
|
135
|
+
_lane_points: list[tuple[float, float, float]] = [
|
136
136
|
(n["x"], n["y"], n.get("z", 0)) for n in _nodes
|
137
137
|
]
|
138
|
-
_line_lengths:
|
139
|
-
_line_directions:
|
138
|
+
_line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
|
139
|
+
_line_directions: list[tuple[float, float]] = []
|
140
140
|
_s = 0.0
|
141
141
|
for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
|
142
142
|
_s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
|
pycityagent/llm/embedding.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
"""简单的基于内存的embedding实现"""
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
-
from typing import List, Dict, Optional
|
5
4
|
import hashlib
|
6
5
|
import json
|
7
6
|
|
@@ -22,34 +21,34 @@ class SimpleEmbedding:
|
|
22
21
|
"""
|
23
22
|
self.vector_dim = vector_dim
|
24
23
|
self.cache_size = cache_size
|
25
|
-
self._cache:
|
26
|
-
self._vocab:
|
27
|
-
self._idf:
|
24
|
+
self._cache: dict[str, np.ndarray] = {}
|
25
|
+
self._vocab: dict[str, int] = {} # 词汇表
|
26
|
+
self._idf: dict[str, float] = {} # 逆文档频率
|
28
27
|
self._doc_count = 0 # 文档总数
|
29
28
|
|
30
29
|
def _text_to_hash(self, text: str) -> str:
|
31
30
|
"""将文本转换为hash值"""
|
32
31
|
return hashlib.md5(text.encode()).hexdigest()
|
33
32
|
|
34
|
-
def _tokenize(self, text: str) ->
|
33
|
+
def _tokenize(self, text: str) -> list[str]:
|
35
34
|
"""简单的分词"""
|
36
35
|
# 这里使用简单的空格分词,实际应用中可以使用更复杂的分词方法
|
37
36
|
return text.lower().split()
|
38
37
|
|
39
|
-
def _update_vocab(self, tokens:
|
38
|
+
def _update_vocab(self, tokens: list[str]):
|
40
39
|
"""更新词汇表"""
|
41
40
|
for token in set(tokens): # 使用set去重
|
42
41
|
if token not in self._vocab:
|
43
42
|
self._vocab[token] = len(self._vocab)
|
44
43
|
|
45
|
-
def _update_idf(self, tokens:
|
44
|
+
def _update_idf(self, tokens: list[str]):
|
46
45
|
"""更新IDF值"""
|
47
46
|
self._doc_count += 1
|
48
47
|
unique_tokens = set(tokens)
|
49
48
|
for token in unique_tokens:
|
50
49
|
self._idf[token] = self._idf.get(token, 0) + 1
|
51
50
|
|
52
|
-
def _calculate_tf(self, tokens:
|
51
|
+
def _calculate_tf(self, tokens: list[str]) -> dict[str, float]:
|
53
52
|
"""计算词频(TF)"""
|
54
53
|
tf = {}
|
55
54
|
total_tokens = len(tokens)
|
@@ -60,7 +59,7 @@ class SimpleEmbedding:
|
|
60
59
|
tf[token] /= total_tokens
|
61
60
|
return tf
|
62
61
|
|
63
|
-
def _calculate_tfidf(self, tokens:
|
62
|
+
def _calculate_tfidf(self, tokens: list[str]) -> np.ndarray:
|
64
63
|
"""计算TF-IDF向量"""
|
65
64
|
vector = np.zeros(self.vector_dim)
|
66
65
|
tf = self._calculate_tf(tokens)
|
pycityagent/llm/llm.py
CHANGED
@@ -14,7 +14,7 @@ import requests
|
|
14
14
|
from dashscope import ImageSynthesis
|
15
15
|
from PIL import Image
|
16
16
|
from io import BytesIO
|
17
|
-
from typing import Any, Optional, Union
|
17
|
+
from typing import Any, Optional, Union
|
18
18
|
from .llmconfig import *
|
19
19
|
from .utils import *
|
20
20
|
|
@@ -117,8 +117,8 @@ Token Usage:
|
|
117
117
|
presence_penalty: Optional[float] = None,
|
118
118
|
timeout: int = 300,
|
119
119
|
retries=3,
|
120
|
-
tools: Optional[
|
121
|
-
tool_choice: Optional[
|
120
|
+
tools: Optional[list[dict[str, Any]]] = None,
|
121
|
+
tool_choice: Optional[dict[str, Any]] = None,
|
122
122
|
):
|
123
123
|
"""
|
124
124
|
异步版文本请求
|
@@ -227,9 +227,9 @@ Token Usage:
|
|
227
227
|
self.prompt_tokens_used += result_response.usage.prompt_tokens # type: ignore
|
228
228
|
self.completion_tokens_used += result_response.usage.completion_tokens # type: ignore
|
229
229
|
self.request_number += 1
|
230
|
-
if tools and result_response.choices[0].message.tool_calls:
|
230
|
+
if tools and result_response.choices[0].message.tool_calls: # type: ignore
|
231
231
|
return json.loads(
|
232
|
-
result_response.choices[0]
|
232
|
+
result_response.choices[0] # type: ignore
|
233
233
|
.message.tool_calls[0]
|
234
234
|
.function.arguments
|
235
235
|
)
|