vectordb-bench 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/clients/__init__.py +65 -1
- vectordb_bench/backend/clients/api.py +2 -1
- vectordb_bench/backend/clients/chroma/chroma.py +2 -2
- vectordb_bench/backend/clients/clickhouse/cli.py +66 -0
- vectordb_bench/backend/clients/clickhouse/clickhouse.py +156 -0
- vectordb_bench/backend/clients/clickhouse/config.py +60 -0
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +1 -1
- vectordb_bench/backend/clients/mariadb/cli.py +122 -0
- vectordb_bench/backend/clients/mariadb/config.py +73 -0
- vectordb_bench/backend/clients/mariadb/mariadb.py +208 -0
- vectordb_bench/backend/clients/milvus/cli.py +32 -0
- vectordb_bench/backend/clients/milvus/config.py +32 -0
- vectordb_bench/backend/clients/milvus/milvus.py +1 -1
- vectordb_bench/backend/clients/pgvector/cli.py +14 -3
- vectordb_bench/backend/clients/pgvector/config.py +22 -5
- vectordb_bench/backend/clients/pgvector/pgvector.py +62 -19
- vectordb_bench/backend/clients/pinecone/pinecone.py +1 -1
- vectordb_bench/backend/clients/qdrant_cloud/config.py +1 -9
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +1 -1
- vectordb_bench/backend/clients/tidb/cli.py +98 -0
- vectordb_bench/backend/clients/tidb/config.py +46 -0
- vectordb_bench/backend/clients/tidb/tidb.py +233 -0
- vectordb_bench/backend/clients/vespa/cli.py +47 -0
- vectordb_bench/backend/clients/vespa/config.py +51 -0
- vectordb_bench/backend/clients/vespa/util.py +15 -0
- vectordb_bench/backend/clients/vespa/vespa.py +249 -0
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +1 -1
- vectordb_bench/cli/cli.py +20 -17
- vectordb_bench/cli/vectordbbench.py +8 -0
- vectordb_bench/frontend/config/dbCaseConfigs.py +147 -0
- vectordb_bench/frontend/config/styles.py +4 -0
- vectordb_bench/models.py +8 -6
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/METADATA +22 -3
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/RECORD +38 -25
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info/licenses}/LICENSE +0 -0
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,98 @@
|
|
1
|
+
from typing import Annotated, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
from pydantic import SecretStr
|
5
|
+
|
6
|
+
from vectordb_bench.backend.clients import DB
|
7
|
+
|
8
|
+
from ....cli.cli import CommonTypedDict, cli, click_parameter_decorators_from_typed_dict, run
|
9
|
+
|
10
|
+
|
11
|
+
class TiDBTypedDict(CommonTypedDict):
|
12
|
+
user_name: Annotated[
|
13
|
+
str,
|
14
|
+
click.option(
|
15
|
+
"--username",
|
16
|
+
type=str,
|
17
|
+
help="Username",
|
18
|
+
default="root",
|
19
|
+
show_default=True,
|
20
|
+
required=True,
|
21
|
+
),
|
22
|
+
]
|
23
|
+
password: Annotated[
|
24
|
+
str,
|
25
|
+
click.option(
|
26
|
+
"--password",
|
27
|
+
type=str,
|
28
|
+
default="",
|
29
|
+
show_default=True,
|
30
|
+
help="Password",
|
31
|
+
),
|
32
|
+
]
|
33
|
+
host: Annotated[
|
34
|
+
str,
|
35
|
+
click.option(
|
36
|
+
"--host",
|
37
|
+
type=str,
|
38
|
+
default="127.0.0.1",
|
39
|
+
show_default=True,
|
40
|
+
required=True,
|
41
|
+
help="Db host",
|
42
|
+
),
|
43
|
+
]
|
44
|
+
port: Annotated[
|
45
|
+
int,
|
46
|
+
click.option(
|
47
|
+
"--port",
|
48
|
+
type=int,
|
49
|
+
default=4000,
|
50
|
+
show_default=True,
|
51
|
+
required=True,
|
52
|
+
help="Db Port",
|
53
|
+
),
|
54
|
+
]
|
55
|
+
db_name: Annotated[
|
56
|
+
str,
|
57
|
+
click.option(
|
58
|
+
"--db-name",
|
59
|
+
type=str,
|
60
|
+
default="test",
|
61
|
+
show_default=True,
|
62
|
+
required=True,
|
63
|
+
help="Db name",
|
64
|
+
),
|
65
|
+
]
|
66
|
+
ssl: Annotated[
|
67
|
+
bool,
|
68
|
+
click.option(
|
69
|
+
"--ssl/--no-ssl",
|
70
|
+
default=False,
|
71
|
+
show_default=True,
|
72
|
+
is_flag=True,
|
73
|
+
help="Enable or disable SSL, for TiDB Serverless SSL must be enabled",
|
74
|
+
),
|
75
|
+
]
|
76
|
+
|
77
|
+
|
78
|
+
@cli.command()
|
79
|
+
@click_parameter_decorators_from_typed_dict(TiDBTypedDict)
|
80
|
+
def TiDB(
|
81
|
+
**parameters: Unpack[TiDBTypedDict],
|
82
|
+
):
|
83
|
+
from .config import TiDBConfig, TiDBIndexConfig
|
84
|
+
|
85
|
+
run(
|
86
|
+
db=DB.TiDB,
|
87
|
+
db_config=TiDBConfig(
|
88
|
+
db_label=parameters["db_label"],
|
89
|
+
user_name=parameters["username"],
|
90
|
+
password=SecretStr(parameters["password"]),
|
91
|
+
host=parameters["host"],
|
92
|
+
port=parameters["port"],
|
93
|
+
db_name=parameters["db_name"],
|
94
|
+
ssl=parameters["ssl"],
|
95
|
+
),
|
96
|
+
db_case_config=TiDBIndexConfig(),
|
97
|
+
**parameters,
|
98
|
+
)
|
@@ -0,0 +1,46 @@
|
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
|
+
|
3
|
+
from ..api import DBCaseConfig, DBConfig, MetricType
|
4
|
+
|
5
|
+
|
6
|
+
class TiDBConfig(DBConfig):
|
7
|
+
user_name: str = "root"
|
8
|
+
password: SecretStr
|
9
|
+
host: str = "127.0.0.1"
|
10
|
+
port: int = 4000
|
11
|
+
db_name: str = "test"
|
12
|
+
ssl: bool = False
|
13
|
+
|
14
|
+
def to_dict(self) -> dict:
|
15
|
+
pwd_str = self.password.get_secret_value()
|
16
|
+
return {
|
17
|
+
"host": self.host,
|
18
|
+
"port": self.port,
|
19
|
+
"user": self.user_name,
|
20
|
+
"password": pwd_str,
|
21
|
+
"database": self.db_name,
|
22
|
+
"ssl_verify_cert": self.ssl,
|
23
|
+
"ssl_verify_identity": self.ssl,
|
24
|
+
}
|
25
|
+
|
26
|
+
|
27
|
+
class TiDBIndexConfig(BaseModel, DBCaseConfig):
|
28
|
+
metric_type: MetricType | None = None
|
29
|
+
|
30
|
+
def get_metric_fn(self) -> str:
|
31
|
+
if self.metric_type == MetricType.L2:
|
32
|
+
return "vec_l2_distance"
|
33
|
+
if self.metric_type == MetricType.COSINE:
|
34
|
+
return "vec_cosine_distance"
|
35
|
+
msg = f"Unsupported metric type: {self.metric_type}"
|
36
|
+
raise ValueError(msg)
|
37
|
+
|
38
|
+
def index_param(self) -> dict:
|
39
|
+
return {
|
40
|
+
"metric_fn": self.get_metric_fn(),
|
41
|
+
}
|
42
|
+
|
43
|
+
def search_param(self) -> dict:
|
44
|
+
return {
|
45
|
+
"metric_fn": self.get_metric_fn(),
|
46
|
+
}
|
@@ -0,0 +1,233 @@
|
|
1
|
+
import concurrent.futures
|
2
|
+
import io
|
3
|
+
import logging
|
4
|
+
import time
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import pymysql
|
9
|
+
|
10
|
+
from ..api import VectorDB
|
11
|
+
from .config import TiDBIndexConfig
|
12
|
+
|
13
|
+
log = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class TiDB(VectorDB):
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
dim: int,
|
20
|
+
db_config: dict,
|
21
|
+
db_case_config: TiDBIndexConfig,
|
22
|
+
collection_name: str = "vector_bench_test",
|
23
|
+
drop_old: bool = False,
|
24
|
+
**kwargs,
|
25
|
+
):
|
26
|
+
self.name = "TiDB"
|
27
|
+
self.db_config = db_config
|
28
|
+
self.case_config = db_case_config
|
29
|
+
self.table_name = collection_name
|
30
|
+
self.dim = dim
|
31
|
+
self.conn = None # To be inited by init()
|
32
|
+
self.cursor = None # To be inited by init()
|
33
|
+
|
34
|
+
self.search_fn = db_case_config.search_param()["metric_fn"]
|
35
|
+
|
36
|
+
if drop_old:
|
37
|
+
self._drop_table()
|
38
|
+
self._create_table()
|
39
|
+
|
40
|
+
@contextmanager
|
41
|
+
def init(self):
|
42
|
+
with self._get_connection() as (conn, cursor):
|
43
|
+
self.conn = conn
|
44
|
+
self.cursor = cursor
|
45
|
+
try:
|
46
|
+
yield
|
47
|
+
finally:
|
48
|
+
self.conn = None
|
49
|
+
self.cursor = None
|
50
|
+
|
51
|
+
@contextmanager
|
52
|
+
def _get_connection(self):
|
53
|
+
with pymysql.connect(**self.db_config) as conn:
|
54
|
+
conn.autocommit = False
|
55
|
+
with conn.cursor() as cursor:
|
56
|
+
yield conn, cursor
|
57
|
+
|
58
|
+
def _drop_table(self):
|
59
|
+
try:
|
60
|
+
with self._get_connection() as (conn, cursor):
|
61
|
+
cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
62
|
+
conn.commit()
|
63
|
+
except Exception as e:
|
64
|
+
log.warning("Failed to drop table: %s error: %s", self.table_name, e)
|
65
|
+
raise
|
66
|
+
|
67
|
+
def _create_table(self):
|
68
|
+
try:
|
69
|
+
index_param = self.case_config.index_param()
|
70
|
+
with self._get_connection() as (conn, cursor):
|
71
|
+
cursor.execute(
|
72
|
+
f"""
|
73
|
+
CREATE TABLE {self.table_name} (
|
74
|
+
id BIGINT PRIMARY KEY,
|
75
|
+
embedding VECTOR({self.dim}) NOT NULL,
|
76
|
+
VECTOR INDEX (({index_param["metric_fn"]}(embedding)))
|
77
|
+
);
|
78
|
+
"""
|
79
|
+
)
|
80
|
+
conn.commit()
|
81
|
+
except Exception as e:
|
82
|
+
log.warning("Failed to create table: %s error: %s", self.table_name, e)
|
83
|
+
raise
|
84
|
+
|
85
|
+
def ready_to_load(self) -> bool:
|
86
|
+
pass
|
87
|
+
|
88
|
+
def optimize(self, data_size: int | None = None) -> None:
|
89
|
+
while True:
|
90
|
+
progress = self._optimize_check_tiflash_replica_progress()
|
91
|
+
if progress != 1:
|
92
|
+
log.info("Data replication not ready, progress: %d", progress)
|
93
|
+
time.sleep(2)
|
94
|
+
else:
|
95
|
+
break
|
96
|
+
|
97
|
+
log.info("Waiting TiFlash to catch up...")
|
98
|
+
self._optimize_wait_tiflash_catch_up()
|
99
|
+
|
100
|
+
log.info("Start compacting TiFlash replica...")
|
101
|
+
self._optimize_compact_tiflash()
|
102
|
+
|
103
|
+
log.info("Waiting index build to finish...")
|
104
|
+
log_reduce_seq = 0
|
105
|
+
while True:
|
106
|
+
pending_rows = self._optimize_get_tiflash_index_pending_rows()
|
107
|
+
if pending_rows > 0:
|
108
|
+
if log_reduce_seq % 15 == 0:
|
109
|
+
log.info("Index not fully built, pending rows: %d", pending_rows)
|
110
|
+
log_reduce_seq += 1
|
111
|
+
time.sleep(2)
|
112
|
+
else:
|
113
|
+
break
|
114
|
+
|
115
|
+
log.info("Index build finished successfully.")
|
116
|
+
|
117
|
+
def _optimize_check_tiflash_replica_progress(self):
|
118
|
+
try:
|
119
|
+
database = self.db_config["database"]
|
120
|
+
with self._get_connection() as (_, cursor):
|
121
|
+
cursor.execute(
|
122
|
+
f"""
|
123
|
+
SELECT PROGRESS FROM information_schema.tiflash_replica
|
124
|
+
WHERE TABLE_SCHEMA = "{database}" AND TABLE_NAME = "{self.table_name}"
|
125
|
+
""" # noqa: S608
|
126
|
+
)
|
127
|
+
result = cursor.fetchone()
|
128
|
+
return result[0]
|
129
|
+
except Exception as e:
|
130
|
+
log.warning("Failed to check TiFlash replica progress: %s", e)
|
131
|
+
raise
|
132
|
+
|
133
|
+
def _optimize_wait_tiflash_catch_up(self):
|
134
|
+
try:
|
135
|
+
with self._get_connection() as (conn, cursor):
|
136
|
+
cursor.execute('SET @@TIDB_ISOLATION_READ_ENGINES="tidb,tiflash"')
|
137
|
+
conn.commit()
|
138
|
+
cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") # noqa: S608
|
139
|
+
result = cursor.fetchone()
|
140
|
+
return result[0]
|
141
|
+
except Exception as e:
|
142
|
+
log.warning("Failed to wait TiFlash to catch up: %s", e)
|
143
|
+
raise
|
144
|
+
|
145
|
+
def _optimize_compact_tiflash(self):
|
146
|
+
try:
|
147
|
+
with self._get_connection() as (conn, cursor):
|
148
|
+
cursor.execute(f"ALTER TABLE {self.table_name} COMPACT")
|
149
|
+
conn.commit()
|
150
|
+
except Exception as e:
|
151
|
+
log.warning("Failed to compact table: %s", e)
|
152
|
+
raise
|
153
|
+
|
154
|
+
def _optimize_get_tiflash_index_pending_rows(self):
|
155
|
+
try:
|
156
|
+
database = self.db_config["database"]
|
157
|
+
with self._get_connection() as (_, cursor):
|
158
|
+
cursor.execute(
|
159
|
+
f"""
|
160
|
+
SELECT SUM(ROWS_STABLE_NOT_INDEXED)
|
161
|
+
FROM information_schema.tiflash_indexes
|
162
|
+
WHERE TIDB_DATABASE = "{database}" AND TIDB_TABLE = "{self.table_name}"
|
163
|
+
""" # noqa: S608
|
164
|
+
)
|
165
|
+
result = cursor.fetchone()
|
166
|
+
return result[0]
|
167
|
+
except Exception as e:
|
168
|
+
log.warning("Failed to read TiFlash index pending rows: %s", e)
|
169
|
+
raise
|
170
|
+
|
171
|
+
def _insert_embeddings_serial(
|
172
|
+
self,
|
173
|
+
embeddings: list[list[float]],
|
174
|
+
metadata: list[int],
|
175
|
+
offset: int,
|
176
|
+
size: int,
|
177
|
+
) -> Exception:
|
178
|
+
try:
|
179
|
+
with self._get_connection() as (conn, cursor):
|
180
|
+
buf = io.StringIO()
|
181
|
+
buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ") # noqa: S608
|
182
|
+
for i in range(offset, offset + size):
|
183
|
+
if i > offset:
|
184
|
+
buf.write(",")
|
185
|
+
buf.write(f'({metadata[i]}, "{embeddings[i]!s}")')
|
186
|
+
cursor.execute(buf.getvalue())
|
187
|
+
conn.commit()
|
188
|
+
except Exception as e:
|
189
|
+
log.warning("Failed to insert data into table: %s", e)
|
190
|
+
raise
|
191
|
+
|
192
|
+
def insert_embeddings(
|
193
|
+
self,
|
194
|
+
embeddings: list[list[float]],
|
195
|
+
metadata: list[int],
|
196
|
+
**kwargs: Any,
|
197
|
+
) -> tuple[int, Exception]:
|
198
|
+
workers = 10
|
199
|
+
# Avoid exceeding MAX_ALLOWED_PACKET (default=64MB)
|
200
|
+
max_batch_size = 64 * 1024 * 1024 // 24 // self.dim
|
201
|
+
batch_size = len(embeddings) // workers
|
202
|
+
batch_size = min(batch_size, max_batch_size)
|
203
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
204
|
+
futures = []
|
205
|
+
for i in range(0, len(embeddings), batch_size):
|
206
|
+
offset = i
|
207
|
+
size = min(batch_size, len(embeddings) - i)
|
208
|
+
future = executor.submit(self._insert_embeddings_serial, embeddings, metadata, offset, size)
|
209
|
+
futures.append(future)
|
210
|
+
done, pending = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION)
|
211
|
+
executor.shutdown(wait=False)
|
212
|
+
for future in done:
|
213
|
+
future.result()
|
214
|
+
for future in pending:
|
215
|
+
future.cancel()
|
216
|
+
return len(metadata), None
|
217
|
+
|
218
|
+
def search_embedding(
|
219
|
+
self,
|
220
|
+
query: list[float],
|
221
|
+
k: int = 100,
|
222
|
+
filters: dict | None = None,
|
223
|
+
timeout: int | None = None,
|
224
|
+
**kwargs: Any,
|
225
|
+
) -> list[int]:
|
226
|
+
self.cursor.execute(
|
227
|
+
f"""
|
228
|
+
SELECT id FROM {self.table_name}
|
229
|
+
ORDER BY {self.search_fn}(embedding, "{query!s}") LIMIT {k};
|
230
|
+
""" # noqa: S608
|
231
|
+
)
|
232
|
+
result = self.cursor.fetchall()
|
233
|
+
return [int(i[0]) for i in result]
|
@@ -0,0 +1,47 @@
|
|
1
|
+
from typing import Annotated, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
from pydantic import SecretStr
|
5
|
+
|
6
|
+
from vectordb_bench.backend.clients import DB
|
7
|
+
from vectordb_bench.cli.cli import (
|
8
|
+
CommonTypedDict,
|
9
|
+
HNSWFlavor1,
|
10
|
+
cli,
|
11
|
+
click_parameter_decorators_from_typed_dict,
|
12
|
+
run,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
class VespaTypedDict(CommonTypedDict, HNSWFlavor1):
|
17
|
+
uri: Annotated[
|
18
|
+
str,
|
19
|
+
click.option("--uri", "-u", type=str, help="uri connection string", default="http://127.0.0.1"),
|
20
|
+
]
|
21
|
+
port: Annotated[
|
22
|
+
int,
|
23
|
+
click.option("--port", "-p", type=int, help="connection port", default=8080),
|
24
|
+
]
|
25
|
+
quantization: Annotated[
|
26
|
+
str, click.option("--quantization", type=click.Choice(["none", "binary"], case_sensitive=False), default="none")
|
27
|
+
]
|
28
|
+
|
29
|
+
|
30
|
+
@cli.command()
|
31
|
+
@click_parameter_decorators_from_typed_dict(VespaTypedDict)
|
32
|
+
def Vespa(**params: Unpack[VespaTypedDict]):
|
33
|
+
from .config import VespaConfig, VespaHNSWConfig
|
34
|
+
|
35
|
+
case_params = {
|
36
|
+
"quantization_type": params["quantization"],
|
37
|
+
"M": params["m"],
|
38
|
+
"efConstruction": params["ef_construction"],
|
39
|
+
"ef": params["ef_search"],
|
40
|
+
}
|
41
|
+
|
42
|
+
run(
|
43
|
+
db=DB.Vespa,
|
44
|
+
db_config=VespaConfig(url=SecretStr(params["uri"]), port=params["port"]),
|
45
|
+
db_case_config=VespaHNSWConfig(**{k: v for k, v in case_params.items() if v}),
|
46
|
+
**params,
|
47
|
+
)
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from typing import Literal, TypeAlias
|
2
|
+
|
3
|
+
from pydantic import BaseModel, SecretStr
|
4
|
+
|
5
|
+
from ..api import DBCaseConfig, DBConfig, MetricType
|
6
|
+
|
7
|
+
VespaMetric: TypeAlias = Literal["euclidean", "angular", "dotproduct", "prenormalized-angular", "hamming", "geodegrees"]
|
8
|
+
|
9
|
+
VespaQuantizationType: TypeAlias = Literal["none", "binary"]
|
10
|
+
|
11
|
+
|
12
|
+
class VespaConfig(DBConfig):
|
13
|
+
url: SecretStr = "http://127.0.0.1"
|
14
|
+
port: int = 8080
|
15
|
+
|
16
|
+
def to_dict(self):
|
17
|
+
return {
|
18
|
+
"url": self.url.get_secret_value(),
|
19
|
+
"port": self.port,
|
20
|
+
}
|
21
|
+
|
22
|
+
|
23
|
+
class VespaHNSWConfig(BaseModel, DBCaseConfig):
|
24
|
+
metric_type: MetricType = MetricType.COSINE
|
25
|
+
quantization_type: VespaQuantizationType = "none"
|
26
|
+
M: int = 16
|
27
|
+
efConstruction: int = 200
|
28
|
+
ef: int = 100
|
29
|
+
|
30
|
+
def index_param(self) -> dict:
|
31
|
+
return {
|
32
|
+
"distance_metric": self.parse_metric(self.metric_type),
|
33
|
+
"max_links_per_node": self.M,
|
34
|
+
"neighbors_to_explore_at_insert": self.efConstruction,
|
35
|
+
}
|
36
|
+
|
37
|
+
def search_param(self) -> dict:
|
38
|
+
return {}
|
39
|
+
|
40
|
+
def parse_metric(self, metric_type: MetricType) -> VespaMetric:
|
41
|
+
match metric_type:
|
42
|
+
case MetricType.COSINE:
|
43
|
+
return "angular"
|
44
|
+
case MetricType.L2:
|
45
|
+
return "euclidean"
|
46
|
+
case MetricType.DP | MetricType.IP:
|
47
|
+
return "dotproduct"
|
48
|
+
case MetricType.HAMMING:
|
49
|
+
return "hamming"
|
50
|
+
case _:
|
51
|
+
raise NotImplementedError
|
@@ -0,0 +1,15 @@
|
|
1
|
+
"""Utility functions for supporting binary quantization
|
2
|
+
|
3
|
+
From https://docs.vespa.ai/en/binarizing-vectors.html#appendix-conversion-to-int8
|
4
|
+
"""
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
|
9
|
+
def binarize_tensor(tensor: list[float]) -> list[int]:
|
10
|
+
"""
|
11
|
+
Binarize a floating-point list by thresholding at zero
|
12
|
+
and packing the bits into bytes.
|
13
|
+
"""
|
14
|
+
tensor = np.array(tensor)
|
15
|
+
return np.packbits(np.where(tensor > 0, 1, 0), axis=0).astype(np.int8).tolist()
|