pyobvector 0.2.16__py3-none-any.whl → 0.2.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyobvector/__init__.py +3 -0
- pyobvector/client/collection_schema.py +6 -6
- pyobvector/client/exceptions.py +4 -4
- pyobvector/client/fts_index_param.py +2 -3
- pyobvector/client/hybrid_search.py +81 -0
- pyobvector/client/index_param.py +21 -8
- pyobvector/client/milvus_like_client.py +124 -88
- pyobvector/client/ob_client.py +459 -0
- pyobvector/client/ob_vec_client.py +153 -493
- pyobvector/client/schema_type.py +4 -2
- pyobvector/schema/__init__.py +3 -0
- pyobvector/schema/dialect.py +3 -0
- pyobvector/schema/reflection.py +1 -1
- pyobvector/schema/sparse_vector.py +35 -0
- pyobvector/schema/vector_index.py +1 -1
- pyobvector/util/__init__.py +3 -1
- pyobvector/util/ob_version.py +1 -1
- pyobvector/util/sparse_vector.py +48 -0
- pyobvector/util/vector.py +10 -4
- {pyobvector-0.2.16.dist-info → pyobvector-0.2.18.dist-info}/METADATA +69 -7
- pyobvector-0.2.18.dist-info/RECORD +40 -0
- {pyobvector-0.2.16.dist-info → pyobvector-0.2.18.dist-info}/WHEEL +1 -1
- pyobvector-0.2.16.dist-info/RECORD +0 -36
- {pyobvector-0.2.16.dist-info → pyobvector-0.2.18.dist-info/licenses}/LICENSE +0 -0
|
@@ -1,52 +1,35 @@
|
|
|
1
1
|
"""OceanBase Vector Store Client."""
|
|
2
|
-
|
|
3
2
|
import logging
|
|
4
|
-
from typing import List, Optional,
|
|
3
|
+
from typing import List, Optional, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
5
6
|
from sqlalchemy import (
|
|
6
|
-
create_engine,
|
|
7
|
-
MetaData,
|
|
8
7
|
Table,
|
|
9
8
|
Column,
|
|
10
9
|
Index,
|
|
11
10
|
select,
|
|
12
|
-
delete,
|
|
13
|
-
update,
|
|
14
|
-
insert,
|
|
15
11
|
text,
|
|
16
|
-
inspect,
|
|
17
|
-
and_,
|
|
18
12
|
)
|
|
19
|
-
from sqlalchemy.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
import numpy as np
|
|
23
|
-
from urllib.parse import quote
|
|
24
|
-
from .index_param import IndexParams, IndexParam
|
|
13
|
+
from sqlalchemy.schema import CreateTable
|
|
14
|
+
|
|
15
|
+
from .exceptions import ClusterVersionException, ErrorCode, ExceptionsMessage
|
|
25
16
|
from .fts_index_param import FtsIndexParam
|
|
17
|
+
from .index_param import IndexParams, IndexParam
|
|
18
|
+
from .ob_client import ObClient
|
|
19
|
+
from .partitions import ObPartition
|
|
26
20
|
from ..schema import (
|
|
27
21
|
ObTable,
|
|
28
22
|
VectorIndex,
|
|
29
|
-
l2_distance,
|
|
30
|
-
cosine_distance,
|
|
31
|
-
inner_product,
|
|
32
|
-
negative_inner_product,
|
|
33
|
-
ST_GeomFromText,
|
|
34
|
-
st_distance,
|
|
35
|
-
st_dwithin,
|
|
36
|
-
st_astext,
|
|
37
|
-
ReplaceStmt,
|
|
38
23
|
FtsIndex,
|
|
39
24
|
)
|
|
40
25
|
from ..util import ObVersion
|
|
41
|
-
from .partitions import *
|
|
42
|
-
from .exceptions import *
|
|
43
26
|
|
|
44
27
|
logger = logging.getLogger(__name__)
|
|
45
28
|
logger.setLevel(logging.DEBUG)
|
|
46
29
|
|
|
47
30
|
|
|
48
|
-
class ObVecClient:
|
|
49
|
-
"""The OceanBase Client"""
|
|
31
|
+
class ObVecClient(ObClient):
|
|
32
|
+
"""The OceanBase Vector Client"""
|
|
50
33
|
|
|
51
34
|
def __init__(
|
|
52
35
|
self,
|
|
@@ -56,119 +39,23 @@ class ObVecClient:
|
|
|
56
39
|
db_name: str = "test",
|
|
57
40
|
**kwargs,
|
|
58
41
|
):
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
# ischema_names["VECTOR"] = VECTOR
|
|
62
|
-
setattr(func_mod, "l2_distance", l2_distance)
|
|
63
|
-
setattr(func_mod, "cosine_distance", cosine_distance)
|
|
64
|
-
setattr(func_mod, "inner_product", inner_product)
|
|
65
|
-
setattr(func_mod, "negative_inner_product", negative_inner_product)
|
|
66
|
-
setattr(func_mod, "ST_GeomFromText", ST_GeomFromText)
|
|
67
|
-
setattr(func_mod, "st_distance", st_distance)
|
|
68
|
-
setattr(func_mod, "st_dwithin", st_dwithin)
|
|
69
|
-
setattr(func_mod, "st_astext", st_astext)
|
|
70
|
-
|
|
71
|
-
user = quote(user, safe="")
|
|
72
|
-
password = quote(password, safe="")
|
|
73
|
-
|
|
74
|
-
connection_str = (
|
|
75
|
-
f"mysql+oceanbase://{user}:{password}@{uri}/{db_name}?charset=utf8mb4"
|
|
76
|
-
)
|
|
77
|
-
self.engine = create_engine(connection_str, **kwargs)
|
|
78
|
-
self.metadata_obj = MetaData()
|
|
79
|
-
self.metadata_obj.reflect(bind=self.engine)
|
|
80
|
-
|
|
81
|
-
with self.engine.connect() as conn:
|
|
82
|
-
with conn.begin():
|
|
83
|
-
res = conn.execute(text("SELECT OB_VERSION() FROM DUAL"))
|
|
84
|
-
version = [r[0] for r in res][0]
|
|
85
|
-
ob_version = ObVersion.from_db_version_string(version)
|
|
86
|
-
if ob_version < ObVersion.from_db_version_nums(4, 3, 3, 0):
|
|
87
|
-
raise ClusterVersionException(
|
|
88
|
-
code=ErrorCode.NOT_SUPPORTED,
|
|
89
|
-
message=ExceptionsMessage.ClusterVersionIsLow,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
def refresh_metadata(self, tables: Optional[list[str]] = None):
|
|
93
|
-
"""Reload metadata from the database.
|
|
94
|
-
|
|
95
|
-
Args:
|
|
96
|
-
tables (Optional[list[str]]): names of the tables to refresh. If None, refresh all tables.
|
|
97
|
-
"""
|
|
98
|
-
if tables is not None:
|
|
99
|
-
for table_name in tables:
|
|
100
|
-
if table_name in self.metadata_obj.tables:
|
|
101
|
-
self.metadata_obj.remove(Table(table_name, self.metadata_obj))
|
|
102
|
-
self.metadata_obj.reflect(bind=self.engine, only=tables, extend_existing=True)
|
|
103
|
-
else:
|
|
104
|
-
self.metadata_obj.clear()
|
|
105
|
-
self.metadata_obj.reflect(bind=self.engine, extend_existing=True)
|
|
106
|
-
|
|
107
|
-
def _insert_partition_hint_for_query_sql(self, sql: str, partition_hint: str):
|
|
108
|
-
from_index = sql.find("FROM")
|
|
109
|
-
assert from_index != -1
|
|
110
|
-
first_space_after_from = sql.find(" ", from_index + len("FROM") + 1)
|
|
111
|
-
if first_space_after_from == -1:
|
|
112
|
-
return sql + " " + partition_hint
|
|
113
|
-
return (
|
|
114
|
-
sql[:first_space_after_from]
|
|
115
|
-
+ " "
|
|
116
|
-
+ partition_hint
|
|
117
|
-
+ sql[first_space_after_from:]
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
def check_table_exists(self, table_name: str):
|
|
121
|
-
"""check if table exists.
|
|
42
|
+
super().__init__(uri, user, password, db_name, **kwargs)
|
|
122
43
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
44
|
+
if self.ob_version < ObVersion.from_db_version_nums(4, 3, 3, 0):
|
|
45
|
+
raise ClusterVersionException(
|
|
46
|
+
code=ErrorCode.NOT_SUPPORTED,
|
|
47
|
+
message=ExceptionsMessage.ClusterVersionIsLow % ("Vector Store", "4.3.3.0"),
|
|
48
|
+
)
|
|
128
49
|
|
|
129
|
-
def
|
|
130
|
-
self,
|
|
131
|
-
table_name: str,
|
|
132
|
-
columns: List[Column],
|
|
133
|
-
indexes: Optional[List[Index]] = None,
|
|
134
|
-
partitions: Optional[ObPartition] = None,
|
|
50
|
+
def _get_sparse_vector_index_params(
|
|
51
|
+
self, vidxs: Optional[IndexParams]
|
|
135
52
|
):
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
partitions (Optional[ObPartition]) : optional partition strategy
|
|
143
|
-
"""
|
|
144
|
-
with self.engine.connect() as conn:
|
|
145
|
-
with conn.begin():
|
|
146
|
-
if indexes is not None:
|
|
147
|
-
table = ObTable(
|
|
148
|
-
table_name,
|
|
149
|
-
self.metadata_obj,
|
|
150
|
-
*columns,
|
|
151
|
-
*indexes,
|
|
152
|
-
extend_existing=True,
|
|
153
|
-
)
|
|
154
|
-
else:
|
|
155
|
-
table = ObTable(
|
|
156
|
-
table_name,
|
|
157
|
-
self.metadata_obj,
|
|
158
|
-
*columns,
|
|
159
|
-
extend_existing=True,
|
|
160
|
-
)
|
|
161
|
-
table.create(self.engine, checkfirst=True)
|
|
162
|
-
# do partition
|
|
163
|
-
if partitions is not None:
|
|
164
|
-
conn.execute(
|
|
165
|
-
text(f"ALTER TABLE `{table_name}` {partitions.do_compile()}")
|
|
166
|
-
)
|
|
167
|
-
|
|
168
|
-
@classmethod
|
|
169
|
-
def prepare_index_params(cls):
|
|
170
|
-
"""Create `IndexParams` to hold index configuration."""
|
|
171
|
-
return IndexParams()
|
|
53
|
+
if vidxs is None:
|
|
54
|
+
return None
|
|
55
|
+
return [
|
|
56
|
+
vidx for vidx in vidxs
|
|
57
|
+
if vidx.is_index_type_sparse_vector()
|
|
58
|
+
]
|
|
172
59
|
|
|
173
60
|
def create_table_with_index_params(
|
|
174
61
|
self,
|
|
@@ -182,12 +69,14 @@ class ObVecClient:
|
|
|
182
69
|
"""Create table with optional index_params.
|
|
183
70
|
|
|
184
71
|
Args:
|
|
185
|
-
table_name (string)
|
|
186
|
-
columns (List[Column])
|
|
187
|
-
indexes (Optional[List[Index]])
|
|
188
|
-
|
|
189
|
-
|
|
72
|
+
table_name (string): table name
|
|
73
|
+
columns (List[Column]): column schema
|
|
74
|
+
indexes (Optional[List[Index]]): optional common index schema
|
|
75
|
+
vidxs (Optional[IndexParams]): optional vector index schema
|
|
76
|
+
fts_idxs (Optional[List[FtsIndexParam]]): optional full-text search index schema
|
|
77
|
+
partitions (Optional[ObPartition]): optional partition strategy
|
|
190
78
|
"""
|
|
79
|
+
sparse_vidxs = self._get_sparse_vector_index_params(vidxs)
|
|
191
80
|
with self.engine.connect() as conn:
|
|
192
81
|
with conn.begin():
|
|
193
82
|
# create table with common index
|
|
@@ -206,7 +95,15 @@ class ObVecClient:
|
|
|
206
95
|
*columns,
|
|
207
96
|
extend_existing=True,
|
|
208
97
|
)
|
|
209
|
-
|
|
98
|
+
if sparse_vidxs is not None and len(sparse_vidxs) > 0:
|
|
99
|
+
create_table_sql = str(CreateTable(table).compile(self.engine))
|
|
100
|
+
new_sql = create_table_sql[:create_table_sql.rfind(')')]
|
|
101
|
+
for sparse_vidx in sparse_vidxs:
|
|
102
|
+
new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (distance=inner_product)"
|
|
103
|
+
new_sql += "\n)"
|
|
104
|
+
conn.execute(text(new_sql))
|
|
105
|
+
else:
|
|
106
|
+
table.create(self.engine, checkfirst=True)
|
|
210
107
|
# do partition
|
|
211
108
|
if partitions is not None:
|
|
212
109
|
conn.execute(
|
|
@@ -215,6 +112,8 @@ class ObVecClient:
|
|
|
215
112
|
# create vector indexes
|
|
216
113
|
if vidxs is not None:
|
|
217
114
|
for vidx in vidxs:
|
|
115
|
+
if vidx.is_index_type_sparse_vector():
|
|
116
|
+
continue
|
|
218
117
|
vidx = VectorIndex(
|
|
219
118
|
vidx.index_name,
|
|
220
119
|
table.c[vidx.field_name],
|
|
@@ -244,12 +143,12 @@ class ObVecClient:
|
|
|
244
143
|
"""Create common index or vector index.
|
|
245
144
|
|
|
246
145
|
Args:
|
|
247
|
-
table_name (string)
|
|
248
|
-
is_vec_index (bool)
|
|
249
|
-
index_name (string)
|
|
250
|
-
column_names (List[string])
|
|
251
|
-
vidx_params (Optional[str])
|
|
252
|
-
|
|
146
|
+
table_name (string): table name
|
|
147
|
+
is_vec_index (bool): common index or vector index
|
|
148
|
+
index_name (string): index name
|
|
149
|
+
column_names (List[string]): create index on which columns
|
|
150
|
+
vidx_params (Optional[str]): vector index params, for example 'distance=l2, type=hnsw, lib=vsag'
|
|
151
|
+
**kw: additional keyword arguments
|
|
253
152
|
"""
|
|
254
153
|
table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
|
|
255
154
|
columns = [table.c[column_name] for column_name in column_names]
|
|
@@ -270,8 +169,8 @@ class ObVecClient:
|
|
|
270
169
|
"""Create vector index with vector index parameter.
|
|
271
170
|
|
|
272
171
|
Args:
|
|
273
|
-
table_name (string)
|
|
274
|
-
vidx_param (IndexParam)
|
|
172
|
+
table_name (string): table name
|
|
173
|
+
vidx_param (IndexParam): vector index parameter
|
|
275
174
|
"""
|
|
276
175
|
table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
|
|
277
176
|
with self.engine.connect() as conn:
|
|
@@ -291,8 +190,8 @@ class ObVecClient:
|
|
|
291
190
|
"""Create fts index with fts index parameter.
|
|
292
191
|
|
|
293
192
|
Args:
|
|
294
|
-
table_name (string)
|
|
295
|
-
fts_idx_param (FtsIndexParam)
|
|
193
|
+
table_name (string): table name
|
|
194
|
+
fts_idx_param (FtsIndexParam): fts index parameter
|
|
296
195
|
"""
|
|
297
196
|
table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
|
|
298
197
|
with self.engine.connect() as conn:
|
|
@@ -305,26 +204,6 @@ class ObVecClient:
|
|
|
305
204
|
)
|
|
306
205
|
fts_idx.create(self.engine, checkfirst=True)
|
|
307
206
|
|
|
308
|
-
def drop_table_if_exist(self, table_name: str):
|
|
309
|
-
"""Drop table if exists."""
|
|
310
|
-
try:
|
|
311
|
-
table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
|
|
312
|
-
except NoSuchTableError:
|
|
313
|
-
return
|
|
314
|
-
with self.engine.connect() as conn:
|
|
315
|
-
with conn.begin():
|
|
316
|
-
table.drop(self.engine, checkfirst=True)
|
|
317
|
-
self.metadata_obj.remove(table)
|
|
318
|
-
|
|
319
|
-
def drop_index(self, table_name: str, index_name: str):
|
|
320
|
-
"""drop index on specified table.
|
|
321
|
-
|
|
322
|
-
If the index not exists, SQL ERROR 1091 will raise.
|
|
323
|
-
"""
|
|
324
|
-
with self.engine.connect() as conn:
|
|
325
|
-
with conn.begin():
|
|
326
|
-
conn.execute(text(f"DROP INDEX `{index_name}` ON `{table_name}`"))
|
|
327
|
-
|
|
328
207
|
def refresh_index(
|
|
329
208
|
self,
|
|
330
209
|
table_name: str,
|
|
@@ -334,10 +213,9 @@ class ObVecClient:
|
|
|
334
213
|
"""Refresh vector index for performance.
|
|
335
214
|
|
|
336
215
|
Args:
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
If delta_buffer_table row count is greater than `trigger_threshold`,
|
|
216
|
+
table_name (string): table name
|
|
217
|
+
index_name (string): vector index name
|
|
218
|
+
trigger_threshold (int): If delta_buffer_table row count is greater than `trigger_threshold`,
|
|
341
219
|
refreshing is actually triggered.
|
|
342
220
|
"""
|
|
343
221
|
with self.engine.connect() as conn:
|
|
@@ -358,9 +236,9 @@ class ObVecClient:
|
|
|
358
236
|
"""Rebuild vector index for performance.
|
|
359
237
|
|
|
360
238
|
Args:
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
239
|
+
table_name (string): table name
|
|
240
|
+
index_name (string): vector index name
|
|
241
|
+
trigger_threshold (float): threshold value for rebuilding index
|
|
364
242
|
"""
|
|
365
243
|
with self.engine.connect() as conn:
|
|
366
244
|
with conn.begin():
|
|
@@ -371,221 +249,6 @@ class ObVecClient:
|
|
|
371
249
|
)
|
|
372
250
|
)
|
|
373
251
|
|
|
374
|
-
def insert(
|
|
375
|
-
self,
|
|
376
|
-
table_name: str,
|
|
377
|
-
data: Union[Dict, List[Dict]],
|
|
378
|
-
partition_name: Optional[str] = "",
|
|
379
|
-
):
|
|
380
|
-
"""Insert data into table.
|
|
381
|
-
|
|
382
|
-
Args:
|
|
383
|
-
table_name (string) : table name
|
|
384
|
-
data (Union[Dict, List[Dict]]) : data that will be inserted
|
|
385
|
-
partition_names (Optional[str]) : limit the query to certain partition
|
|
386
|
-
"""
|
|
387
|
-
if isinstance(data, Dict):
|
|
388
|
-
data = [data]
|
|
389
|
-
|
|
390
|
-
if len(data) == 0:
|
|
391
|
-
return
|
|
392
|
-
|
|
393
|
-
table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
|
|
394
|
-
|
|
395
|
-
with self.engine.connect() as conn:
|
|
396
|
-
with conn.begin():
|
|
397
|
-
if partition_name is None or partition_name == "":
|
|
398
|
-
conn.execute(insert(table).values(data))
|
|
399
|
-
else:
|
|
400
|
-
conn.execute(
|
|
401
|
-
insert(table)
|
|
402
|
-
.with_hint(f"PARTITION({partition_name})")
|
|
403
|
-
.values(data)
|
|
404
|
-
)
|
|
405
|
-
|
|
406
|
-
def upsert(
|
|
407
|
-
self,
|
|
408
|
-
table_name: str,
|
|
409
|
-
data: Union[Dict, List[Dict]],
|
|
410
|
-
partition_name: Optional[str] = "",
|
|
411
|
-
):
|
|
412
|
-
"""Update data in table. If primary key is duplicated, replace it.
|
|
413
|
-
|
|
414
|
-
Args:
|
|
415
|
-
table_name (string) : table name
|
|
416
|
-
data (Union[Dict, List[Dict]]) : data that will be upserted
|
|
417
|
-
partition_names (Optional[str]) : limit the query to certain partition
|
|
418
|
-
"""
|
|
419
|
-
if isinstance(data, Dict):
|
|
420
|
-
data = [data]
|
|
421
|
-
|
|
422
|
-
if len(data) == 0:
|
|
423
|
-
return
|
|
424
|
-
|
|
425
|
-
table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
|
|
426
|
-
|
|
427
|
-
with self.engine.connect() as conn:
|
|
428
|
-
with conn.begin():
|
|
429
|
-
upsert_stmt = (
|
|
430
|
-
ReplaceStmt(table).with_hint(f"PARTITION({partition_name})")
|
|
431
|
-
if partition_name is not None and partition_name != ""
|
|
432
|
-
else ReplaceStmt(table)
|
|
433
|
-
)
|
|
434
|
-
upsert_stmt = upsert_stmt.values(data)
|
|
435
|
-
conn.execute(upsert_stmt)
|
|
436
|
-
|
|
437
|
-
def update(
|
|
438
|
-
self,
|
|
439
|
-
table_name: str,
|
|
440
|
-
values_clause,
|
|
441
|
-
where_clause=None,
|
|
442
|
-
partition_name: Optional[str] = "",
|
|
443
|
-
):
|
|
444
|
-
"""Update data in table.
|
|
445
|
-
|
|
446
|
-
Args:
|
|
447
|
-
table_name (string) : table name
|
|
448
|
-
values_clause: update values clause
|
|
449
|
-
where_clause: update with filter
|
|
450
|
-
partition_name (Optional[str]) : limit the query to certain partition
|
|
451
|
-
|
|
452
|
-
Example:
|
|
453
|
-
.. code-block:: python
|
|
454
|
-
|
|
455
|
-
data = [
|
|
456
|
-
{"id": 112, "embedding": [1, 2, 3], "meta": {'doc':'hhh1'}},
|
|
457
|
-
{"id": 190, "embedding": [0.13, 0.123, 1.213], "meta": {'doc':'hhh2'}},
|
|
458
|
-
]
|
|
459
|
-
client.insert(collection_name=test_collection_name, data=data)
|
|
460
|
-
client.update(
|
|
461
|
-
table_name=test_collection_name,
|
|
462
|
-
values_clause=[{'meta':{'doc':'HHH'}}],
|
|
463
|
-
where_clause=[text("id=112")]
|
|
464
|
-
)
|
|
465
|
-
"""
|
|
466
|
-
table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
|
|
467
|
-
|
|
468
|
-
with self.engine.connect() as conn:
|
|
469
|
-
with conn.begin():
|
|
470
|
-
update_stmt = (
|
|
471
|
-
update(table).with_hint(f"PARTITION({partition_name})")
|
|
472
|
-
if partition_name is not None and partition_name != ""
|
|
473
|
-
else update(table)
|
|
474
|
-
)
|
|
475
|
-
if where_clause is not None:
|
|
476
|
-
update_stmt = update_stmt.where(*where_clause).values(
|
|
477
|
-
*values_clause
|
|
478
|
-
)
|
|
479
|
-
else:
|
|
480
|
-
update_stmt = update_stmt.values(*values_clause)
|
|
481
|
-
conn.execute(update_stmt)
|
|
482
|
-
|
|
483
|
-
def delete(
|
|
484
|
-
self,
|
|
485
|
-
table_name: str,
|
|
486
|
-
ids: Optional[Union[list, str, int]] = None,
|
|
487
|
-
where_clause=None,
|
|
488
|
-
partition_name: Optional[str] = "",
|
|
489
|
-
):
|
|
490
|
-
"""Delete data in table.
|
|
491
|
-
|
|
492
|
-
Args:
|
|
493
|
-
table_name (string) : table name
|
|
494
|
-
where_clause : delete with filter
|
|
495
|
-
partition_names (Optional[str]) : limit the query to certain partition
|
|
496
|
-
"""
|
|
497
|
-
table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
|
|
498
|
-
where_in_clause = None
|
|
499
|
-
if ids is not None:
|
|
500
|
-
primary_keys = table.primary_key
|
|
501
|
-
pkey_names = [column.name for column in primary_keys]
|
|
502
|
-
if len(pkey_names) == 1:
|
|
503
|
-
if isinstance(ids, list):
|
|
504
|
-
where_in_clause = table.c[pkey_names[0]].in_(ids)
|
|
505
|
-
elif isinstance(ids, (str, int)):
|
|
506
|
-
where_in_clause = table.c[pkey_names[0]].in_([ids])
|
|
507
|
-
else:
|
|
508
|
-
raise TypeError("'ids' is not a list/str/int")
|
|
509
|
-
|
|
510
|
-
with self.engine.connect() as conn:
|
|
511
|
-
with conn.begin():
|
|
512
|
-
delete_stmt = (
|
|
513
|
-
delete(table).with_hint(f"PARTITION({partition_name})")
|
|
514
|
-
if partition_name is not None and partition_name != ""
|
|
515
|
-
else delete(table)
|
|
516
|
-
)
|
|
517
|
-
if where_in_clause is None and where_clause is None:
|
|
518
|
-
conn.execute(delete_stmt)
|
|
519
|
-
elif where_in_clause is not None and where_clause is None:
|
|
520
|
-
conn.execute(delete_stmt.where(where_in_clause))
|
|
521
|
-
elif where_in_clause is None and where_clause is not None:
|
|
522
|
-
conn.execute(delete_stmt.where(*where_clause))
|
|
523
|
-
else:
|
|
524
|
-
conn.execute(
|
|
525
|
-
delete_stmt.where(and_(where_in_clause, *where_clause))
|
|
526
|
-
)
|
|
527
|
-
|
|
528
|
-
def get(
|
|
529
|
-
self,
|
|
530
|
-
table_name: str,
|
|
531
|
-
ids: Optional[Union[list, str, int]],
|
|
532
|
-
where_clause = None,
|
|
533
|
-
output_column_name: Optional[List[str]] = None,
|
|
534
|
-
partition_names: Optional[List[str]] = None,
|
|
535
|
-
n_limits: Optional[int] = None,
|
|
536
|
-
):
|
|
537
|
-
"""get records with specified primary field `ids`.
|
|
538
|
-
|
|
539
|
-
Args:
|
|
540
|
-
:param table_name (string) : table name
|
|
541
|
-
:param ids : specified primary field values
|
|
542
|
-
:param where_clause : SQL filter
|
|
543
|
-
:param output_column_name (Optional[List[str]]) : output fields name
|
|
544
|
-
:param partition_names (List[str]) : limit the query to certain partitions
|
|
545
|
-
"""
|
|
546
|
-
table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
|
|
547
|
-
if output_column_name is not None:
|
|
548
|
-
columns = [table.c[column_name] for column_name in output_column_name]
|
|
549
|
-
stmt = select(*columns)
|
|
550
|
-
else:
|
|
551
|
-
stmt = select(table)
|
|
552
|
-
primary_keys = table.primary_key
|
|
553
|
-
pkey_names = [column.name for column in primary_keys]
|
|
554
|
-
where_in_clause = None
|
|
555
|
-
if ids is not None and len(pkey_names) == 1:
|
|
556
|
-
if isinstance(ids, list):
|
|
557
|
-
where_in_clause = table.c[pkey_names[0]].in_(ids)
|
|
558
|
-
elif isinstance(ids, (str, int)):
|
|
559
|
-
where_in_clause = table.c[pkey_names[0]].in_([ids])
|
|
560
|
-
else:
|
|
561
|
-
raise TypeError("'ids' is not a list/str/int")
|
|
562
|
-
|
|
563
|
-
if where_in_clause is not None and where_clause is None:
|
|
564
|
-
stmt = stmt.where(where_in_clause)
|
|
565
|
-
elif where_in_clause is None and where_clause is not None:
|
|
566
|
-
stmt = stmt.where(*where_clause)
|
|
567
|
-
elif where_in_clause is not None and where_clause is not None:
|
|
568
|
-
stmt = stmt.where(and_(where_in_clause, *where_clause))
|
|
569
|
-
|
|
570
|
-
if n_limits is not None:
|
|
571
|
-
stmt = stmt.limit(n_limits)
|
|
572
|
-
|
|
573
|
-
with self.engine.connect() as conn:
|
|
574
|
-
with conn.begin():
|
|
575
|
-
if partition_names is None:
|
|
576
|
-
execute_res = conn.execute(stmt)
|
|
577
|
-
else:
|
|
578
|
-
stmt_str = str(stmt.compile(
|
|
579
|
-
dialect=self.engine.dialect,
|
|
580
|
-
compile_kwargs={"literal_binds": True}
|
|
581
|
-
))
|
|
582
|
-
stmt_str = self._insert_partition_hint_for_query_sql(
|
|
583
|
-
stmt_str, f"PARTITION({', '.join(partition_names)})"
|
|
584
|
-
)
|
|
585
|
-
logging.debug(stmt_str)
|
|
586
|
-
execute_res = conn.execute(text(stmt_str))
|
|
587
|
-
return execute_res
|
|
588
|
-
|
|
589
252
|
def set_ob_hnsw_ef_search(self, ob_hnsw_ef_search: int):
|
|
590
253
|
"""Set ob_hnsw_ef_search system variable."""
|
|
591
254
|
with self.engine.connect() as conn:
|
|
@@ -602,49 +265,74 @@ class ObVecClient:
|
|
|
602
265
|
def ann_search(
|
|
603
266
|
self,
|
|
604
267
|
table_name: str,
|
|
605
|
-
vec_data: list,
|
|
268
|
+
vec_data: Union[list, dict],
|
|
606
269
|
vec_column_name: str,
|
|
607
270
|
distance_func,
|
|
608
271
|
with_dist: bool = False,
|
|
609
272
|
topk: int = 10,
|
|
610
273
|
output_column_names: Optional[List[str]] = None,
|
|
274
|
+
output_columns: Optional[Union[List, tuple]] = None,
|
|
611
275
|
extra_output_cols: Optional[List] = None,
|
|
612
276
|
where_clause=None,
|
|
613
277
|
partition_names: Optional[List[str]] = None,
|
|
614
278
|
idx_name_hint: Optional[List[str]] = None,
|
|
279
|
+
distance_threshold: Optional[float] = None,
|
|
615
280
|
**kwargs,
|
|
616
|
-
):
|
|
617
|
-
"""
|
|
281
|
+
): # pylint: disable=unused-argument
|
|
282
|
+
"""Perform ann search.
|
|
618
283
|
|
|
619
284
|
Args:
|
|
620
|
-
table_name (string)
|
|
621
|
-
vec_data (list)
|
|
622
|
-
vec_column_name (string)
|
|
623
|
-
distance_func
|
|
624
|
-
with_dist (bool)
|
|
625
|
-
topk (int)
|
|
626
|
-
output_column_names (Optional[List[str]])
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
285
|
+
table_name (string): table name
|
|
286
|
+
vec_data (Union[list, dict]): the vector/sparse_vector data to search
|
|
287
|
+
vec_column_name (string): which vector field to search
|
|
288
|
+
distance_func: function to calculate distance between vectors
|
|
289
|
+
with_dist (bool): return result with distance
|
|
290
|
+
topk (int): top K
|
|
291
|
+
output_column_names (Optional[List[str]]): output fields
|
|
292
|
+
output_columns (Optional[Union[List, tuple]]): output columns as SQLAlchemy Column objects
|
|
293
|
+
or expressions. Similar to SQLAlchemy's select() function arguments.
|
|
294
|
+
If provided, takes precedence over output_column_names.
|
|
295
|
+
extra_output_cols (Optional[List]): additional output columns
|
|
296
|
+
where_clause: do ann search with filter
|
|
297
|
+
partition_names (Optional[List[str]]): limit the query to certain partitions
|
|
298
|
+
idx_name_hint (Optional[List[str]]): post-filtering enabled if vector index name is specified
|
|
299
|
+
Or pre-filtering enabled
|
|
300
|
+
distance_threshold (Optional[float]): filter results where distance <= threshold.
|
|
301
|
+
**kwargs: additional arguments
|
|
630
302
|
"""
|
|
303
|
+
if not (isinstance(vec_data, list) or isinstance(vec_data, dict)):
|
|
304
|
+
raise ValueError("'vec_data' type must be in 'list'/'dict'")
|
|
305
|
+
|
|
631
306
|
table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
|
|
632
307
|
|
|
633
|
-
|
|
308
|
+
columns = []
|
|
309
|
+
if output_columns:
|
|
310
|
+
if isinstance(output_columns, (list, tuple)):
|
|
311
|
+
columns = list(output_columns)
|
|
312
|
+
else:
|
|
313
|
+
columns = [output_columns]
|
|
314
|
+
elif output_column_names:
|
|
634
315
|
columns = [table.c[column_name] for column_name in output_column_names]
|
|
635
316
|
else:
|
|
636
317
|
columns = [table.c[column.name] for column in table.columns]
|
|
637
318
|
|
|
638
|
-
if extra_output_cols
|
|
319
|
+
if extra_output_cols:
|
|
639
320
|
columns.extend(extra_output_cols)
|
|
640
321
|
|
|
641
322
|
if with_dist:
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
323
|
+
if isinstance(vec_data, list):
|
|
324
|
+
columns.append(
|
|
325
|
+
distance_func(
|
|
326
|
+
table.c[vec_column_name],
|
|
327
|
+
"[" + ",".join([str(np.float32(v)) for v in vec_data]) + "]",
|
|
328
|
+
)
|
|
329
|
+
)
|
|
330
|
+
else:
|
|
331
|
+
columns.append(
|
|
332
|
+
distance_func(
|
|
333
|
+
table.c[vec_column_name], f"{vec_data}"
|
|
334
|
+
)
|
|
646
335
|
)
|
|
647
|
-
)
|
|
648
336
|
# if idx_name_hint is not None:
|
|
649
337
|
# stmt = select(*columns).with_hint(
|
|
650
338
|
# table,
|
|
@@ -657,12 +345,32 @@ class ObVecClient:
|
|
|
657
345
|
if where_clause is not None:
|
|
658
346
|
stmt = stmt.where(*where_clause)
|
|
659
347
|
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
348
|
+
# Add distance threshold filter in SQL WHERE clause
|
|
349
|
+
if distance_threshold is not None:
|
|
350
|
+
if isinstance(vec_data, list):
|
|
351
|
+
dist_expr = distance_func(
|
|
352
|
+
table.c[vec_column_name],
|
|
353
|
+
"[" + ",".join([str(np.float32(v)) for v in vec_data]) + "]",
|
|
354
|
+
)
|
|
355
|
+
else:
|
|
356
|
+
dist_expr = distance_func(
|
|
357
|
+
table.c[vec_column_name], f"{vec_data}"
|
|
358
|
+
)
|
|
359
|
+
stmt = stmt.where(dist_expr <= distance_threshold)
|
|
360
|
+
|
|
361
|
+
if isinstance(vec_data, list):
|
|
362
|
+
stmt = stmt.order_by(
|
|
363
|
+
distance_func(
|
|
364
|
+
table.c[vec_column_name],
|
|
365
|
+
"[" + ",".join([str(np.float32(v)) for v in vec_data]) + "]",
|
|
366
|
+
)
|
|
367
|
+
)
|
|
368
|
+
else:
|
|
369
|
+
stmt = stmt.order_by(
|
|
370
|
+
distance_func(
|
|
371
|
+
table.c[vec_column_name], f"{vec_data}"
|
|
372
|
+
)
|
|
664
373
|
)
|
|
665
|
-
)
|
|
666
374
|
stmt_str = (
|
|
667
375
|
str(stmt.compile(
|
|
668
376
|
dialect=self.engine.dialect,
|
|
@@ -697,18 +405,22 @@ class ObVecClient:
|
|
|
697
405
|
partition_names: Optional[List[str]] = None,
|
|
698
406
|
str_list: Optional[List[str]] = None,
|
|
699
407
|
**kwargs,
|
|
700
|
-
):
|
|
701
|
-
"""
|
|
408
|
+
): # pylint: disable=unused-argument
|
|
409
|
+
"""Perform post ann search.
|
|
702
410
|
|
|
703
411
|
Args:
|
|
704
|
-
table_name (string)
|
|
705
|
-
vec_data (list)
|
|
706
|
-
vec_column_name (string)
|
|
707
|
-
distance_func
|
|
708
|
-
with_dist (bool)
|
|
709
|
-
topk (int)
|
|
710
|
-
output_column_names (Optional[List[str]])
|
|
711
|
-
|
|
412
|
+
table_name (string): table name
|
|
413
|
+
vec_data (list): the vector data to search
|
|
414
|
+
vec_column_name (string): which vector field to search
|
|
415
|
+
distance_func: function to calculate distance between vectors
|
|
416
|
+
with_dist (bool): return result with distance
|
|
417
|
+
topk (int): top K
|
|
418
|
+
output_column_names (Optional[List[str]]): output fields
|
|
419
|
+
extra_output_cols (Optional[List]): additional output columns
|
|
420
|
+
where_clause: do ann search with filter
|
|
421
|
+
partition_names (Optional[List[str]]): limit the query to certain partitions
|
|
422
|
+
str_list (Optional[List[str]]): list to append SQL string to
|
|
423
|
+
**kwargs: additional arguments
|
|
712
424
|
"""
|
|
713
425
|
table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
|
|
714
426
|
|
|
@@ -770,17 +482,18 @@ class ObVecClient:
|
|
|
770
482
|
output_column_names: Optional[List[str]] = None,
|
|
771
483
|
where_clause=None,
|
|
772
484
|
**kwargs,
|
|
773
|
-
):
|
|
774
|
-
"""
|
|
485
|
+
): # pylint: disable=unused-argument
|
|
486
|
+
"""Perform precise vector search.
|
|
775
487
|
|
|
776
488
|
Args:
|
|
777
|
-
table_name (string)
|
|
778
|
-
vec_data (list)
|
|
779
|
-
vec_column_name (string)
|
|
780
|
-
distance_func
|
|
781
|
-
topk (int)
|
|
782
|
-
output_column_names (Optional[List[str]])
|
|
783
|
-
where_clause
|
|
489
|
+
table_name (string): table name
|
|
490
|
+
vec_data (list): the vector data to search
|
|
491
|
+
vec_column_name (string): which vector field to search
|
|
492
|
+
distance_func: function to calculate distance between vectors
|
|
493
|
+
topk (int): top K
|
|
494
|
+
output_column_names (Optional[List[str]]): output column names
|
|
495
|
+
where_clause: do ann search with filter
|
|
496
|
+
**kwargs: additional arguments
|
|
784
497
|
"""
|
|
785
498
|
table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
|
|
786
499
|
|
|
@@ -807,56 +520,3 @@ class ObVecClient:
|
|
|
807
520
|
with self.engine.connect() as conn:
|
|
808
521
|
with conn.begin():
|
|
809
522
|
return conn.execute(stmt)
|
|
810
|
-
|
|
811
|
-
def perform_raw_text_sql(
|
|
812
|
-
self,
|
|
813
|
-
text_sql: str,
|
|
814
|
-
):
|
|
815
|
-
"""Execute raw text SQL."""
|
|
816
|
-
with self.engine.connect() as conn:
|
|
817
|
-
with conn.begin():
|
|
818
|
-
return conn.execute(text(text_sql))
|
|
819
|
-
|
|
820
|
-
def add_columns(
|
|
821
|
-
self,
|
|
822
|
-
table_name: str,
|
|
823
|
-
columns: list[Column],
|
|
824
|
-
):
|
|
825
|
-
"""Add multiple columns to an existing table.
|
|
826
|
-
|
|
827
|
-
Args:
|
|
828
|
-
table_name (string): table name
|
|
829
|
-
columns (list[Column]): list of SQLAlchemy Column objects representing the new columns
|
|
830
|
-
"""
|
|
831
|
-
compiler = self.engine.dialect.ddl_compiler(self.engine.dialect, None)
|
|
832
|
-
column_specs = [compiler.get_column_specification(column) for column in columns]
|
|
833
|
-
columns_ddl = ", ".join(f"ADD COLUMN {spec}" for spec in column_specs)
|
|
834
|
-
|
|
835
|
-
with self.engine.connect() as conn:
|
|
836
|
-
with conn.begin():
|
|
837
|
-
conn.execute(
|
|
838
|
-
text(f"ALTER TABLE `{table_name}` {columns_ddl}")
|
|
839
|
-
)
|
|
840
|
-
|
|
841
|
-
self.refresh_metadata([table_name])
|
|
842
|
-
|
|
843
|
-
def drop_columns(
|
|
844
|
-
self,
|
|
845
|
-
table_name: str,
|
|
846
|
-
column_names: list[str],
|
|
847
|
-
):
|
|
848
|
-
"""Drop multiple columns from an existing table.
|
|
849
|
-
|
|
850
|
-
Args:
|
|
851
|
-
table_name (string): table name
|
|
852
|
-
column_names (list[str]): names of the columns to drop
|
|
853
|
-
"""
|
|
854
|
-
columns_ddl = ", ".join(f"DROP COLUMN `{name}`" for name in column_names)
|
|
855
|
-
|
|
856
|
-
with self.engine.connect() as conn:
|
|
857
|
-
with conn.begin():
|
|
858
|
-
conn.execute(
|
|
859
|
-
text(f"ALTER TABLE `{table_name}` {columns_ddl}")
|
|
860
|
-
)
|
|
861
|
-
|
|
862
|
-
self.refresh_metadata([table_name])
|