pycityagent 2.0.0a43__cp39-cp39-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +23 -0
- pycityagent/agent.py +833 -0
- pycityagent/cli/wrapper.py +44 -0
- pycityagent/economy/__init__.py +5 -0
- pycityagent/economy/econ_client.py +355 -0
- pycityagent/environment/__init__.py +7 -0
- pycityagent/environment/interact/__init__.py +0 -0
- pycityagent/environment/interact/interact.py +198 -0
- pycityagent/environment/message/__init__.py +0 -0
- pycityagent/environment/sence/__init__.py +0 -0
- pycityagent/environment/sence/static.py +416 -0
- pycityagent/environment/sidecar/__init__.py +8 -0
- pycityagent/environment/sidecar/sidecarv2.py +109 -0
- pycityagent/environment/sim/__init__.py +29 -0
- pycityagent/environment/sim/aoi_service.py +39 -0
- pycityagent/environment/sim/client.py +126 -0
- pycityagent/environment/sim/clock_service.py +44 -0
- pycityagent/environment/sim/economy_services.py +192 -0
- pycityagent/environment/sim/lane_service.py +111 -0
- pycityagent/environment/sim/light_service.py +122 -0
- pycityagent/environment/sim/person_service.py +295 -0
- pycityagent/environment/sim/road_service.py +39 -0
- pycityagent/environment/sim/sim_env.py +145 -0
- pycityagent/environment/sim/social_service.py +59 -0
- pycityagent/environment/simulator.py +331 -0
- pycityagent/environment/utils/__init__.py +14 -0
- pycityagent/environment/utils/base64.py +16 -0
- pycityagent/environment/utils/const.py +244 -0
- pycityagent/environment/utils/geojson.py +24 -0
- pycityagent/environment/utils/grpc.py +57 -0
- pycityagent/environment/utils/map_utils.py +157 -0
- pycityagent/environment/utils/port.py +11 -0
- pycityagent/environment/utils/protobuf.py +41 -0
- pycityagent/llm/__init__.py +11 -0
- pycityagent/llm/embeddings.py +231 -0
- pycityagent/llm/llm.py +377 -0
- pycityagent/llm/llmconfig.py +13 -0
- pycityagent/llm/utils.py +6 -0
- pycityagent/memory/__init__.py +13 -0
- pycityagent/memory/const.py +43 -0
- pycityagent/memory/faiss_query.py +302 -0
- pycityagent/memory/memory.py +448 -0
- pycityagent/memory/memory_base.py +170 -0
- pycityagent/memory/profile.py +165 -0
- pycityagent/memory/self_define.py +165 -0
- pycityagent/memory/state.py +173 -0
- pycityagent/memory/utils.py +28 -0
- pycityagent/message/__init__.py +3 -0
- pycityagent/message/messager.py +88 -0
- pycityagent/metrics/__init__.py +6 -0
- pycityagent/metrics/mlflow_client.py +147 -0
- pycityagent/metrics/utils/const.py +0 -0
- pycityagent/pycityagent-sim +0 -0
- pycityagent/pycityagent-ui +0 -0
- pycityagent/simulation/__init__.py +8 -0
- pycityagent/simulation/agentgroup.py +580 -0
- pycityagent/simulation/simulation.py +634 -0
- pycityagent/simulation/storage/pg.py +184 -0
- pycityagent/survey/__init__.py +4 -0
- pycityagent/survey/manager.py +54 -0
- pycityagent/survey/models.py +120 -0
- pycityagent/utils/__init__.py +11 -0
- pycityagent/utils/avro_schema.py +109 -0
- pycityagent/utils/decorators.py +99 -0
- pycityagent/utils/parsers/__init__.py +13 -0
- pycityagent/utils/parsers/code_block_parser.py +37 -0
- pycityagent/utils/parsers/json_parser.py +86 -0
- pycityagent/utils/parsers/parser_base.py +60 -0
- pycityagent/utils/pg_query.py +92 -0
- pycityagent/utils/survey_util.py +53 -0
- pycityagent/workflow/__init__.py +26 -0
- pycityagent/workflow/block.py +211 -0
- pycityagent/workflow/prompt.py +79 -0
- pycityagent/workflow/tool.py +240 -0
- pycityagent/workflow/trigger.py +163 -0
- pycityagent-2.0.0a43.dist-info/LICENSE +21 -0
- pycityagent-2.0.0a43.dist-info/METADATA +235 -0
- pycityagent-2.0.0a43.dist-info/RECORD +81 -0
- pycityagent-2.0.0a43.dist-info/WHEEL +5 -0
- pycityagent-2.0.0a43.dist-info/entry_points.txt +3 -0
- pycityagent-2.0.0a43.dist-info/top_level.txt +3 -0
| @@ -0,0 +1,157 @@ | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from typing import Literal,  Union
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def get_angle(x, y):
         | 
| 8 | 
            +
                return math.atan2(y, x) * 180 / math.pi
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def point_on_line_given_distance(start_node, end_node, distance):
         | 
| 12 | 
            +
                """
         | 
| 13 | 
            +
                Given two points (start_point and end_point) defining a line, and a distance s to travel along the line,
         | 
| 14 | 
            +
                return the coordinates of the point reached after traveling s units along the line, starting from start_point.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                Args:
         | 
| 17 | 
            +
                    start_point (tuple): tuple of (x, y) representing the starting point on the line.
         | 
| 18 | 
            +
                    end_point (tuple): tuple of (x, y) representing the ending point on the line.
         | 
| 19 | 
            +
                    distance (float): Distance to travel along the line, starting from start_point.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                Returns:
         | 
| 22 | 
            +
                    tuple: tuple of (x, y) representing the new point reached after traveling s units along the line.
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                x1, y1 = start_node["x"], start_node["y"]
         | 
| 26 | 
            +
                x2, y2 = end_node["x"], end_node["y"]
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                # Calculate the slope m and the y-intercept b of the line
         | 
| 29 | 
            +
                if x1 == x2:
         | 
| 30 | 
            +
                    # Vertical line, distance is only along the y-axis
         | 
| 31 | 
            +
                    return (x1, y1 + distance if distance >= 0 else y1 - abs(distance))
         | 
| 32 | 
            +
                else:
         | 
| 33 | 
            +
                    m = (y2 - y1) / (x2 - x1)
         | 
| 34 | 
            +
                    b = y1 - m * x1
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    # Calculate the direction vector (dx, dy) along the line
         | 
| 37 | 
            +
                    dx = (x2 - x1) / math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
         | 
| 38 | 
            +
                    dy = (y2 - y1) / math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    # Scale the direction vector by the given distance
         | 
| 41 | 
            +
                    scaled_dx = dx * distance
         | 
| 42 | 
            +
                    scaled_dy = dy * distance
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    # Calculate the new point's coordinates
         | 
| 45 | 
            +
                    x = x1 + scaled_dx
         | 
| 46 | 
            +
                    y = y1 + scaled_dy
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    return [x, y]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def get_key_index_in_lane(
         | 
| 52 | 
            +
                nodes: list[dict[str, float]],
         | 
| 53 | 
            +
                distance: float,
         | 
| 54 | 
            +
                direction: Union[Literal["front"], Literal["back"]],
         | 
| 55 | 
            +
            ) -> int:
         | 
| 56 | 
            +
                if direction == "front":
         | 
| 57 | 
            +
                    _nodes = [n for n in nodes]
         | 
| 58 | 
            +
                    _index_offset, _index_factor = 0, 1
         | 
| 59 | 
            +
                elif direction == "back":
         | 
| 60 | 
            +
                    _nodes = [n for n in nodes[::-1]]
         | 
| 61 | 
            +
                    _index_offset, _index_factor = len(_nodes) - 1, -1
         | 
| 62 | 
            +
                else:
         | 
| 63 | 
            +
                    raise ValueError(f"Invalid direction type {direction}!")
         | 
| 64 | 
            +
                _lane_points: list[tuple[float, float, float]] = [
         | 
| 65 | 
            +
                    (n["x"], n["y"], n.get("z", 0)) for n in _nodes
         | 
| 66 | 
            +
                ]
         | 
| 67 | 
            +
                _line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
         | 
| 68 | 
            +
                _s = 0.0
         | 
| 69 | 
            +
                for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
         | 
| 70 | 
            +
                    _s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
         | 
| 71 | 
            +
                    _line_lengths[i + 1] = _s
         | 
| 72 | 
            +
                s = np.clip(distance, _line_lengths[0], _line_lengths[-1])
         | 
| 73 | 
            +
                _key_index = 0
         | 
| 74 | 
            +
                for (
         | 
| 75 | 
            +
                    prev_s,
         | 
| 76 | 
            +
                    cur_s,
         | 
| 77 | 
            +
                ) in zip(_line_lengths[:-1], _line_lengths[1:]):
         | 
| 78 | 
            +
                    if prev_s <= s < cur_s:
         | 
| 79 | 
            +
                        break
         | 
| 80 | 
            +
                    _key_index += 1
         | 
| 81 | 
            +
                return _index_offset + _index_factor * _key_index
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            def get_xy_in_lane(
         | 
| 85 | 
            +
                nodes: list[dict[str, float]],
         | 
| 86 | 
            +
                distance: float,
         | 
| 87 | 
            +
                direction: Union[Literal["front"], Literal["back"]],
         | 
| 88 | 
            +
            ) -> tuple[float, float]:
         | 
| 89 | 
            +
                if direction == "front":
         | 
| 90 | 
            +
                    _nodes = [n for n in nodes]
         | 
| 91 | 
            +
                elif direction == "back":
         | 
| 92 | 
            +
                    _nodes = [n for n in nodes[::-1]]
         | 
| 93 | 
            +
                else:
         | 
| 94 | 
            +
                    raise ValueError(f"Invalid direction type {direction}!")
         | 
| 95 | 
            +
                _lane_points: list[tuple[float, float, float]] = [
         | 
| 96 | 
            +
                    (n["x"], n["y"], n.get("z", 0)) for n in _nodes
         | 
| 97 | 
            +
                ]
         | 
| 98 | 
            +
                _line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
         | 
| 99 | 
            +
                _s = 0.0
         | 
| 100 | 
            +
                for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
         | 
| 101 | 
            +
                    _s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
         | 
| 102 | 
            +
                    _line_lengths[i + 1] = _s
         | 
| 103 | 
            +
                s = np.clip(distance, _line_lengths[0], _line_lengths[-1])
         | 
| 104 | 
            +
                for prev_s, prev_p, cur_s, cur_p in zip(
         | 
| 105 | 
            +
                    _line_lengths[:-1],
         | 
| 106 | 
            +
                    _lane_points[:-1],
         | 
| 107 | 
            +
                    _line_lengths[1:],
         | 
| 108 | 
            +
                    _lane_points[1:],
         | 
| 109 | 
            +
                ):
         | 
| 110 | 
            +
                    if prev_s <= s < cur_s:
         | 
| 111 | 
            +
                        _delta_x, _delta_y, _delta_z = [
         | 
| 112 | 
            +
                            cur_p[_idx] - prev_p[_idx] for _idx in [0, 1, 2]
         | 
| 113 | 
            +
                        ]
         | 
| 114 | 
            +
                        _blend_x, _blend_y, _blend_z = [prev_p[_idx] for _idx in [0, 1, 2]]
         | 
| 115 | 
            +
                        _ratio = (s - prev_s) / (cur_s - prev_s)
         | 
| 116 | 
            +
                        return (
         | 
| 117 | 
            +
                            _blend_x + _ratio * _delta_x,
         | 
| 118 | 
            +
                            _blend_y + _ratio * _delta_y,
         | 
| 119 | 
            +
                            _blend_z + _ratio * _delta_z,
         | 
| 120 | 
            +
                        )[:2]
         | 
| 121 | 
            +
                return _lane_points[-1][:2]
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
            def get_direction_by_s(
         | 
| 125 | 
            +
                nodes: list[dict[str, float]],
         | 
| 126 | 
            +
                distance: float,
         | 
| 127 | 
            +
                direction: Union[Literal["front"], Literal["back"]],
         | 
| 128 | 
            +
            ) -> float:
         | 
| 129 | 
            +
                if direction == "front":
         | 
| 130 | 
            +
                    _nodes = [n for n in nodes]
         | 
| 131 | 
            +
                elif direction == "back":
         | 
| 132 | 
            +
                    _nodes = [n for n in nodes[::-1]]
         | 
| 133 | 
            +
                else:
         | 
| 134 | 
            +
                    raise ValueError(f"Invalid direction type {direction}!")
         | 
| 135 | 
            +
                _lane_points: list[tuple[float, float, float]] = [
         | 
| 136 | 
            +
                    (n["x"], n["y"], n.get("z", 0)) for n in _nodes
         | 
| 137 | 
            +
                ]
         | 
| 138 | 
            +
                _line_lengths: list[float] = [0.0 for _ in range(len(_nodes))]
         | 
| 139 | 
            +
                _line_directions: list[tuple[float, float]] = []
         | 
| 140 | 
            +
                _s = 0.0
         | 
| 141 | 
            +
                for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
         | 
| 142 | 
            +
                    _s += math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1])
         | 
| 143 | 
            +
                    _line_lengths[i + 1] = _s
         | 
| 144 | 
            +
                for i, (cur_p, next_p) in enumerate(zip(_lane_points[:-1], _lane_points[1:])):
         | 
| 145 | 
            +
                    _direction = math.atan2(next_p[1] - cur_p[1], next_p[0] - cur_p[0])
         | 
| 146 | 
            +
                    _pitch = math.atan2(
         | 
| 147 | 
            +
                        next_p[2] - cur_p[2],
         | 
| 148 | 
            +
                        math.hypot(next_p[0] - cur_p[0], next_p[1] - cur_p[1]),
         | 
| 149 | 
            +
                    )
         | 
| 150 | 
            +
                    _line_directions.append((_direction / math.pi * 180, _pitch / math.pi * 180))
         | 
| 151 | 
            +
                s = np.clip(distance, _line_lengths[0], _line_lengths[-1])
         | 
| 152 | 
            +
                for prev_s, cur_s, direcs in zip(
         | 
| 153 | 
            +
                    _line_lengths[:-1], _line_lengths[1:], _line_directions
         | 
| 154 | 
            +
                ):
         | 
| 155 | 
            +
                    if prev_s <= s < cur_s:
         | 
| 156 | 
            +
                        return direcs[0]
         | 
| 157 | 
            +
                return _line_directions[-1][0]
         | 
| @@ -0,0 +1,11 @@ | |
| 1 | 
            +
            import socket
         | 
| 2 | 
            +
            from contextlib import closing
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            __all__ = ["find_free_port"]
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def find_free_port():
         | 
| 8 | 
            +
                with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
         | 
| 9 | 
            +
                    s.bind(("", 0))
         | 
| 10 | 
            +
                    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
         | 
| 11 | 
            +
                    return s.getsockname()[1]
         | 
| @@ -0,0 +1,41 @@ | |
| 1 | 
            +
            from collections.abc import Awaitable
         | 
| 2 | 
            +
            from typing import Any, TypeVar, Union
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from google.protobuf.json_format import MessageToDict
         | 
| 5 | 
            +
            from google.protobuf.message import Message
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            __all__ = ["parse", "async_parse"]
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            T = TypeVar("T", bound=Message)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def parse(res: T, dict_return: bool) -> Union[dict[str, Any], T]:
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                将Protobuf返回值转换为dict或者原始值
         | 
| 15 | 
            +
                Convert Protobuf return value to dict or original value
         | 
| 16 | 
            +
                """
         | 
| 17 | 
            +
                if dict_return:
         | 
| 18 | 
            +
                    return MessageToDict(
         | 
| 19 | 
            +
                        res,
         | 
| 20 | 
            +
                        including_default_value_fields=True,
         | 
| 21 | 
            +
                        preserving_proto_field_name=True,
         | 
| 22 | 
            +
                        use_integers_for_enums=True,
         | 
| 23 | 
            +
                    )
         | 
| 24 | 
            +
                else:
         | 
| 25 | 
            +
                    return res
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            async def async_parse(res: Awaitable[T], dict_return: bool) -> Union[dict[str, Any], T]:
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                将Protobuf await返回值转换为dict或者原始值
         | 
| 31 | 
            +
                Convert Protobuf await return value to dict or original value
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
                if dict_return:
         | 
| 34 | 
            +
                    return MessageToDict(
         | 
| 35 | 
            +
                        await res,
         | 
| 36 | 
            +
                        including_default_value_fields=True,
         | 
| 37 | 
            +
                        preserving_proto_field_name=True,
         | 
| 38 | 
            +
                        use_integers_for_enums=True,
         | 
| 39 | 
            +
                    )
         | 
| 40 | 
            +
                else:
         | 
| 41 | 
            +
                    return await res
         | 
| @@ -0,0 +1,231 @@ | |
| 1 | 
            +
            import hashlib
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            from typing import Optional, Union
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from langchain_core.embeddings import Embeddings
         | 
| 9 | 
            +
            from transformers import AutoModel, AutoTokenizer
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            __all__ = [
         | 
| 12 | 
            +
                "SentenceEmbedding",
         | 
| 13 | 
            +
                "SimpleEmbedding",
         | 
| 14 | 
            +
            ]
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class SentenceEmbedding(Embeddings):
         | 
| 18 | 
            +
                def __init__(
         | 
| 19 | 
            +
                    self,
         | 
| 20 | 
            +
                    pretrained_model_name_or_path: Union[str, os.PathLike] = "BAAI/bge-m3",
         | 
| 21 | 
            +
                    max_seq_len: int = 8192,
         | 
| 22 | 
            +
                    auto_cuda: bool = False,
         | 
| 23 | 
            +
                    local_files_only: bool = False,
         | 
| 24 | 
            +
                    cache_dir: str = "./cache",
         | 
| 25 | 
            +
                    proxies: Optional[dict] = None,
         | 
| 26 | 
            +
                ):
         | 
| 27 | 
            +
                    os.makedirs(cache_dir, exist_ok=True)
         | 
| 28 | 
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(
         | 
| 29 | 
            +
                        pretrained_model_name_or_path,
         | 
| 30 | 
            +
                        proxies=proxies,
         | 
| 31 | 
            +
                        cache_dir=cache_dir,
         | 
| 32 | 
            +
                        local_files_only=local_files_only,
         | 
| 33 | 
            +
                    )
         | 
| 34 | 
            +
                    self.model = AutoModel.from_pretrained(
         | 
| 35 | 
            +
                        pretrained_model_name_or_path,
         | 
| 36 | 
            +
                        proxies=proxies,
         | 
| 37 | 
            +
                        cache_dir=cache_dir,
         | 
| 38 | 
            +
                        local_files_only=local_files_only,
         | 
| 39 | 
            +
                    )
         | 
| 40 | 
            +
                    self._cuda = auto_cuda and torch.cuda.is_available()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    if self._cuda:
         | 
| 43 | 
            +
                        self.model = self.model.cuda()
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    self.model.eval()
         | 
| 46 | 
            +
                    self.max_seq_len = max_seq_len
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def _embed(self, texts: list[str]) -> list[list[float]]:
         | 
| 49 | 
            +
                    # Tokenize sentences
         | 
| 50 | 
            +
                    encoded_input = self.tokenizer(
         | 
| 51 | 
            +
                        texts, padding=True, truncation=True, return_tensors="pt"
         | 
| 52 | 
            +
                    )
         | 
| 53 | 
            +
                    # for s2p(short query to long passage) retrieval task, add an instruction to query (not add instruction for passages)
         | 
| 54 | 
            +
                    # encoded_input = tokenizer([instruction + q for q in queries], padding=True, truncation=True, return_tensors='pt')
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    # check length of input
         | 
| 57 | 
            +
                    # assert seq_len <= 8192
         | 
| 58 | 
            +
                    assert encoded_input["input_ids"].shape[1] <= self.max_seq_len  # type: ignore
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    if self._cuda:
         | 
| 61 | 
            +
                        encoded_input = {k: v.cuda() for k, v in encoded_input.items()}
         | 
| 62 | 
            +
                    # Compute token embeddings
         | 
| 63 | 
            +
                    with torch.no_grad():
         | 
| 64 | 
            +
                        model_output = self.model(**encoded_input)
         | 
| 65 | 
            +
                        # Perform pooling. In this case, cls pooling.
         | 
| 66 | 
            +
                        sentence_embeddings = model_output[0][:, 0]
         | 
| 67 | 
            +
                    # normalize embeddings
         | 
| 68 | 
            +
                    sentence_embeddings = torch.nn.functional.normalize(
         | 
| 69 | 
            +
                        sentence_embeddings, p=2, dim=1
         | 
| 70 | 
            +
                    )
         | 
| 71 | 
            +
                    if self._cuda:
         | 
| 72 | 
            +
                        sentence_embeddings = sentence_embeddings.cpu()
         | 
| 73 | 
            +
                    return sentence_embeddings.tolist()
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def embed_documents(self, texts: list[str]) -> list[list[float]]:
         | 
| 76 | 
            +
                    """Embed documents."""
         | 
| 77 | 
            +
                    return self._embed(texts)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def embed_query(self, text: str) -> list[float]:
         | 
| 80 | 
            +
                    """Embed query text."""
         | 
| 81 | 
            +
                    return self._embed([text])[0]
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            class SimpleEmbedding(Embeddings):
         | 
| 85 | 
            +
                """简单的基于内存的embedding实现
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                使用简单的词袋模型(Bag of Words)和TF-IDF来生成文本的向量表示。
         | 
| 88 | 
            +
                所有向量都保存在内存中,适用于小规模应用。
         | 
| 89 | 
            +
                """
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def __init__(self, vector_dim: int = 128, cache_size: int = 1000):
         | 
| 92 | 
            +
                    """初始化
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    Args:
         | 
| 95 | 
            +
                        vector_dim: 向量维度
         | 
| 96 | 
            +
                        cache_size: 缓存大小,超过此大小将清除最早的缓存
         | 
| 97 | 
            +
                    """
         | 
| 98 | 
            +
                    self.vector_dim = vector_dim
         | 
| 99 | 
            +
                    self.cache_size = cache_size
         | 
| 100 | 
            +
                    self._cache: dict[str, list[float]] = {}
         | 
| 101 | 
            +
                    self._vocab: dict[str, int] = {}  # 词汇表
         | 
| 102 | 
            +
                    self._idf: dict[str, float] = {}  # 逆文档频率
         | 
| 103 | 
            +
                    self._doc_count = 0  # 文档总数
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                def _text_to_hash(self, text: str) -> str:
         | 
| 106 | 
            +
                    """将文本转换为hash值"""
         | 
| 107 | 
            +
                    return hashlib.md5(text.encode()).hexdigest()
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def _tokenize(self, text: str) -> list[str]:
         | 
| 110 | 
            +
                    """简单的分词"""
         | 
| 111 | 
            +
                    # 这里使用简单的空格分词,实际应用中可以使用更复杂的分词方法
         | 
| 112 | 
            +
                    return text.lower().split()
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def _update_vocab(self, tokens: list[str]):
         | 
| 115 | 
            +
                    """更新词汇表"""
         | 
| 116 | 
            +
                    for token in set(tokens):  # 使用set去重
         | 
| 117 | 
            +
                        if token not in self._vocab:
         | 
| 118 | 
            +
                            self._vocab[token] = len(self._vocab)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                def _update_idf(self, tokens: list[str]):
         | 
| 121 | 
            +
                    """更新IDF值"""
         | 
| 122 | 
            +
                    self._doc_count += 1
         | 
| 123 | 
            +
                    unique_tokens = set(tokens)
         | 
| 124 | 
            +
                    for token in unique_tokens:
         | 
| 125 | 
            +
                        self._idf[token] = self._idf.get(token, 0) + 1
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                def _calculate_tf(self, tokens: list[str]) -> dict[str, float]:
         | 
| 128 | 
            +
                    """计算词频(TF)"""
         | 
| 129 | 
            +
                    tf = {}
         | 
| 130 | 
            +
                    total_tokens = len(tokens)
         | 
| 131 | 
            +
                    for token in tokens:
         | 
| 132 | 
            +
                        tf[token] = tf.get(token, 0) + 1
         | 
| 133 | 
            +
                    # 归一化
         | 
| 134 | 
            +
                    for token in tf:
         | 
| 135 | 
            +
                        tf[token] /= total_tokens
         | 
| 136 | 
            +
                    return tf
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                def _calculate_tfidf(self, tokens: list[str]) -> list[float]:
         | 
| 139 | 
            +
                    """计算TF-IDF向量"""
         | 
| 140 | 
            +
                    vector = np.zeros(self.vector_dim)
         | 
| 141 | 
            +
                    tf = self._calculate_tf(tokens)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    for token, tf_value in tf.items():
         | 
| 144 | 
            +
                        if token in self._idf:
         | 
| 145 | 
            +
                            idf = np.log(self._doc_count / self._idf[token])
         | 
| 146 | 
            +
                            idx = self._vocab[token] % self.vector_dim  # 使用取模运算来控制向量维度
         | 
| 147 | 
            +
                            vector[idx] += tf_value * idf
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    # L2归一化
         | 
| 150 | 
            +
                    norm = np.linalg.norm(vector)
         | 
| 151 | 
            +
                    if norm > 0:
         | 
| 152 | 
            +
                        vector /= norm
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    return list(vector)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                def _embed(self, text: str) -> list[float]:
         | 
| 157 | 
            +
                    """生成文本的向量表示
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    Args:
         | 
| 160 | 
            +
                        text: 输入文本
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    Returns:
         | 
| 163 | 
            +
                        np.ndarray: 文本的向量表示
         | 
| 164 | 
            +
                    """
         | 
| 165 | 
            +
                    # 检查缓存
         | 
| 166 | 
            +
                    text_hash = self._text_to_hash(text)
         | 
| 167 | 
            +
                    if text_hash in self._cache:
         | 
| 168 | 
            +
                        return self._cache[text_hash]
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    # 分词
         | 
| 171 | 
            +
                    tokens = self._tokenize(text)
         | 
| 172 | 
            +
                    if not tokens:
         | 
| 173 | 
            +
                        return list(np.zeros(self.vector_dim))
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    # 更新词汇表和IDF
         | 
| 176 | 
            +
                    self._update_vocab(tokens)
         | 
| 177 | 
            +
                    self._update_idf(tokens)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    # 计算向量
         | 
| 180 | 
            +
                    vector = self._calculate_tfidf(tokens)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    # 更新缓存
         | 
| 183 | 
            +
                    if len(self._cache) >= self.cache_size:
         | 
| 184 | 
            +
                        # 删除最早的缓存
         | 
| 185 | 
            +
                        oldest_key = next(iter(self._cache))
         | 
| 186 | 
            +
                        del self._cache[oldest_key]
         | 
| 187 | 
            +
                    self._cache[text_hash] = vector
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    return list(vector)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                def embed_documents(self, texts: list[str]) -> list[list[float]]:
         | 
| 192 | 
            +
                    """Embed documents."""
         | 
| 193 | 
            +
                    return [self._embed(text) for text in texts]
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                def embed_query(self, text: str) -> list[float]:
         | 
| 196 | 
            +
                    """Embed query text."""
         | 
| 197 | 
            +
                    return self._embed(text)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                # def save(self, file_path: str):
         | 
| 200 | 
            +
                #     """保存模型"""
         | 
| 201 | 
            +
                #     state = {
         | 
| 202 | 
            +
                #         "vector_dim": self.vector_dim,
         | 
| 203 | 
            +
                #         "cache_size": self.cache_size,
         | 
| 204 | 
            +
                #         "vocab": self._vocab,
         | 
| 205 | 
            +
                #         "idf": self._idf,
         | 
| 206 | 
            +
                #         "doc_count": self._doc_count,
         | 
| 207 | 
            +
                #     }
         | 
| 208 | 
            +
                #     with open(file_path, "w") as f:
         | 
| 209 | 
            +
                #         json.dump(state, f)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                # def load(self, file_path: str):
         | 
| 212 | 
            +
                #     """加载模型"""
         | 
| 213 | 
            +
                #     with open(file_path, "r") as f:
         | 
| 214 | 
            +
                #         state = json.load(f)
         | 
| 215 | 
            +
                #     self.vector_dim = state["vector_dim"]
         | 
| 216 | 
            +
                #     self.cache_size = state["cache_size"]
         | 
| 217 | 
            +
                #     self._vocab = state["vocab"]
         | 
| 218 | 
            +
                #     self._idf = state["idf"]
         | 
| 219 | 
            +
                #     self._doc_count = state["doc_count"]
         | 
| 220 | 
            +
                #     self._cache = {}  # 清空缓存
         | 
| 221 | 
            +
             | 
| 222 | 
            +
             | 
| 223 | 
            +
            if __name__ == "__main__":
         | 
| 224 | 
            +
                # se = SentenceEmbedding(
         | 
| 225 | 
            +
                #     pretrained_model_name_or_path="ignore/BAAI--bge-m3", cache_dir="ignore"
         | 
| 226 | 
            +
                # )
         | 
| 227 | 
            +
                se = SimpleEmbedding()
         | 
| 228 | 
            +
                print(se.embed_query("hello world"))
         | 
| 229 | 
            +
                print(se.embed_query("hello world"))
         | 
| 230 | 
            +
                print(se.embed_query("hello world"))
         | 
| 231 | 
            +
                print(se.embed_query("hello world"))
         |