pyobvector 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyobvector/__init__.py +6 -5
- pyobvector/client/__init__.py +5 -4
- pyobvector/client/collection_schema.py +5 -1
- pyobvector/client/enum.py +1 -1
- pyobvector/client/exceptions.py +9 -7
- pyobvector/client/fts_index_param.py +8 -4
- pyobvector/client/hybrid_search.py +14 -4
- pyobvector/client/index_param.py +56 -41
- pyobvector/client/milvus_like_client.py +71 -54
- pyobvector/client/ob_client.py +20 -16
- pyobvector/client/ob_vec_client.py +45 -41
- pyobvector/client/ob_vec_json_table_client.py +366 -274
- pyobvector/client/partitions.py +81 -39
- pyobvector/client/schema_type.py +3 -1
- pyobvector/json_table/__init__.py +4 -3
- pyobvector/json_table/json_value_returning_func.py +12 -10
- pyobvector/json_table/oceanbase_dialect.py +15 -8
- pyobvector/json_table/virtual_data_type.py +47 -28
- pyobvector/schema/__init__.py +7 -1
- pyobvector/schema/array.py +6 -2
- pyobvector/schema/dialect.py +4 -0
- pyobvector/schema/full_text_index.py +8 -3
- pyobvector/schema/geo_srid_point.py +5 -2
- pyobvector/schema/gis_func.py +23 -11
- pyobvector/schema/match_against_func.py +10 -5
- pyobvector/schema/ob_table.py +2 -0
- pyobvector/schema/reflection.py +25 -8
- pyobvector/schema/replace_stmt.py +4 -0
- pyobvector/schema/sparse_vector.py +7 -4
- pyobvector/schema/vec_dist_func.py +22 -9
- pyobvector/schema/vector.py +3 -1
- pyobvector/schema/vector_index.py +7 -3
- pyobvector/util/__init__.py +1 -0
- pyobvector/util/ob_version.py +2 -0
- pyobvector/util/sparse_vector.py +9 -6
- pyobvector/util/vector.py +2 -0
- {pyobvector-0.2.21.dist-info → pyobvector-0.2.23.dist-info}/METADATA +13 -14
- pyobvector-0.2.23.dist-info/RECORD +40 -0
- {pyobvector-0.2.21.dist-info → pyobvector-0.2.23.dist-info}/licenses/LICENSE +1 -1
- pyobvector-0.2.21.dist-info/RECORD +0 -40
- {pyobvector-0.2.21.dist-info → pyobvector-0.2.23.dist-info}/WHEEL +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Milvus Like Client."""
|
|
2
|
+
|
|
2
3
|
import logging
|
|
3
4
|
import json
|
|
4
5
|
from typing import Optional, Union
|
|
@@ -62,11 +63,11 @@ class MilvusLikeClient(Client):
|
|
|
62
63
|
index_params: Optional[IndexParams] = None, # Used for custom setup
|
|
63
64
|
max_length: int = 16384,
|
|
64
65
|
**kwargs,
|
|
65
|
-
):
|
|
66
|
-
"""Create a collection.
|
|
66
|
+
): # pylint: disable=unused-argument
|
|
67
|
+
"""Create a collection.
|
|
67
68
|
If `schema` is not None, `dimension`, `primary_field_name`, `id_type`, `vector_field_name`,
|
|
68
69
|
`metric_type`, `auto_id` will be ignored.
|
|
69
|
-
|
|
70
|
+
|
|
70
71
|
Args:
|
|
71
72
|
collection_name (string): collection name
|
|
72
73
|
dimension (Optional[int]): vector data dimension
|
|
@@ -146,10 +147,12 @@ class MilvusLikeClient(Client):
|
|
|
146
147
|
)
|
|
147
148
|
|
|
148
149
|
def get_collection_stats(
|
|
149
|
-
self,
|
|
150
|
+
self,
|
|
151
|
+
collection_name: str,
|
|
152
|
+
timeout: Optional[float] = None, # pylint: disable=unused-argument
|
|
150
153
|
) -> dict:
|
|
151
154
|
"""Get collection row count.
|
|
152
|
-
|
|
155
|
+
|
|
153
156
|
Args:
|
|
154
157
|
collection_name (string): collection name
|
|
155
158
|
timeout (Optional[float]): not used in OceanBase
|
|
@@ -166,8 +169,10 @@ class MilvusLikeClient(Client):
|
|
|
166
169
|
return {"row_count": cnt}
|
|
167
170
|
|
|
168
171
|
def has_collection(
|
|
169
|
-
self,
|
|
170
|
-
|
|
172
|
+
self,
|
|
173
|
+
collection_name: str,
|
|
174
|
+
timeout: Optional[float] = None, # pylint: disable=unused-argument
|
|
175
|
+
) -> bool: # pylint: disable=unused-argument
|
|
171
176
|
"""Check if collection exists.
|
|
172
177
|
|
|
173
178
|
Args:
|
|
@@ -181,17 +186,20 @@ class MilvusLikeClient(Client):
|
|
|
181
186
|
|
|
182
187
|
def drop_collection(self, collection_name: str) -> None:
|
|
183
188
|
"""Drop collection if exists.
|
|
184
|
-
|
|
189
|
+
|
|
185
190
|
Args:
|
|
186
191
|
collection_name (string): collection name
|
|
187
192
|
"""
|
|
188
193
|
self.drop_table_if_exist(collection_name)
|
|
189
194
|
|
|
190
195
|
def rename_collection(
|
|
191
|
-
self,
|
|
196
|
+
self,
|
|
197
|
+
old_name: str,
|
|
198
|
+
new_name: str,
|
|
199
|
+
timeout: Optional[float] = None, # pylint: disable=unused-argument
|
|
192
200
|
) -> None:
|
|
193
201
|
"""Rename collection.
|
|
194
|
-
|
|
202
|
+
|
|
195
203
|
Args:
|
|
196
204
|
old_name (string): old collection name
|
|
197
205
|
new_name (string): new collection name
|
|
@@ -206,7 +214,7 @@ class MilvusLikeClient(Client):
|
|
|
206
214
|
collection_name: str,
|
|
207
215
|
):
|
|
208
216
|
"""Load table into SQLAlchemy metadata.
|
|
209
|
-
|
|
217
|
+
|
|
210
218
|
Args:
|
|
211
219
|
collection_name (string): which collection to load
|
|
212
220
|
|
|
@@ -230,9 +238,9 @@ class MilvusLikeClient(Client):
|
|
|
230
238
|
index_params: IndexParams,
|
|
231
239
|
timeout: Optional[float] = None,
|
|
232
240
|
**kwargs,
|
|
233
|
-
):
|
|
241
|
+
): # pylint: disable=unused-argument
|
|
234
242
|
"""Create vector index with index params.
|
|
235
|
-
|
|
243
|
+
|
|
236
244
|
Args:
|
|
237
245
|
collection_name (string): which collection to create vector index
|
|
238
246
|
index_params (IndexParams): the vector index parameters
|
|
@@ -263,9 +271,9 @@ class MilvusLikeClient(Client):
|
|
|
263
271
|
index_name: str,
|
|
264
272
|
timeout: Optional[float] = None,
|
|
265
273
|
**kwargs,
|
|
266
|
-
):
|
|
274
|
+
): # pylint: disable=unused-argument
|
|
267
275
|
"""Drop index on specified collection.
|
|
268
|
-
|
|
276
|
+
|
|
269
277
|
If the index not exists, SQL ERROR 1091 will raise.
|
|
270
278
|
|
|
271
279
|
Args:
|
|
@@ -283,7 +291,7 @@ class MilvusLikeClient(Client):
|
|
|
283
291
|
trigger_threshold: int = 10000,
|
|
284
292
|
):
|
|
285
293
|
"""Refresh vector index for performance.
|
|
286
|
-
|
|
294
|
+
|
|
287
295
|
Args:
|
|
288
296
|
collection_name (string): collection name
|
|
289
297
|
index_name (string): vector index name
|
|
@@ -303,7 +311,7 @@ class MilvusLikeClient(Client):
|
|
|
303
311
|
trigger_threshold: float = 0.2,
|
|
304
312
|
):
|
|
305
313
|
"""Rebuild vector index for performance.
|
|
306
|
-
|
|
314
|
+
|
|
307
315
|
Args:
|
|
308
316
|
collection_name (string): collection name
|
|
309
317
|
index_name (string): vector index name
|
|
@@ -356,13 +364,13 @@ class MilvusLikeClient(Client):
|
|
|
356
364
|
limit: int = 10,
|
|
357
365
|
output_fields: Optional[list[str]] = None,
|
|
358
366
|
search_params: Optional[dict] = None,
|
|
359
|
-
timeout: Optional[float] = None,
|
|
367
|
+
timeout: Optional[float] = None, # pylint: disable=unused-argument
|
|
360
368
|
partition_names: Optional[list[str]] = None,
|
|
361
|
-
**kwargs,
|
|
369
|
+
**kwargs, # pylint: disable=unused-argument
|
|
362
370
|
) -> list[dict]:
|
|
363
371
|
"""Perform ann search.
|
|
364
372
|
Note: OceanBase does not support batch search now. `data` & the return value is not a batch.
|
|
365
|
-
|
|
373
|
+
|
|
366
374
|
Args:
|
|
367
375
|
collection_name (string): collection name
|
|
368
376
|
data (list): the vector/sparse_vector data to search
|
|
@@ -392,9 +400,7 @@ class MilvusLikeClient(Client):
|
|
|
392
400
|
message=ExceptionsMessage.MetricTypeParamTypeInvalid,
|
|
393
401
|
)
|
|
394
402
|
lower_metric_type_str = search_params["metric_type"].lower()
|
|
395
|
-
if lower_metric_type_str not in (
|
|
396
|
-
"l2", "neg_ip", "cosine", "ip"
|
|
397
|
-
):
|
|
403
|
+
if lower_metric_type_str not in ("l2", "neg_ip", "cosine", "ip"):
|
|
398
404
|
raise VectorMetricTypeException(
|
|
399
405
|
code=ErrorCode.INVALID_ARGUMENT,
|
|
400
406
|
message=ExceptionsMessage.MetricTypeValueInvalid,
|
|
@@ -416,25 +422,34 @@ class MilvusLikeClient(Client):
|
|
|
416
422
|
|
|
417
423
|
if with_dist:
|
|
418
424
|
if isinstance(data, list):
|
|
419
|
-
columns.append(
|
|
420
|
-
|
|
425
|
+
columns.append(
|
|
426
|
+
distance_func(
|
|
427
|
+
table.c[anns_field],
|
|
428
|
+
"[" + ",".join([str(np.float32(v)) for v in data]) + "]",
|
|
429
|
+
)
|
|
430
|
+
)
|
|
421
431
|
else:
|
|
422
432
|
columns.append(distance_func(table.c[anns_field], f"{data}"))
|
|
423
433
|
stmt = select(*columns)
|
|
424
434
|
|
|
425
435
|
if flter is not None:
|
|
426
436
|
stmt = stmt.where(*flter)
|
|
427
|
-
|
|
437
|
+
|
|
428
438
|
if isinstance(data, list):
|
|
429
|
-
stmt = stmt.order_by(
|
|
430
|
-
|
|
439
|
+
stmt = stmt.order_by(
|
|
440
|
+
distance_func(
|
|
441
|
+
table.c[anns_field],
|
|
442
|
+
"[" + ",".join([str(np.float32(v)) for v in data]) + "]",
|
|
443
|
+
)
|
|
444
|
+
)
|
|
431
445
|
else:
|
|
432
446
|
stmt = stmt.order_by(distance_func(table.c[anns_field], f"{data}"))
|
|
433
447
|
stmt_str = (
|
|
434
|
-
str(
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
448
|
+
str(
|
|
449
|
+
stmt.compile(
|
|
450
|
+
dialect=self.engine.dialect, compile_kwargs={"literal_binds": True}
|
|
451
|
+
)
|
|
452
|
+
)
|
|
438
453
|
+ f" APPROXIMATE limit {limit}"
|
|
439
454
|
)
|
|
440
455
|
|
|
@@ -468,12 +483,12 @@ class MilvusLikeClient(Client):
|
|
|
468
483
|
collection_name: str,
|
|
469
484
|
flter=None,
|
|
470
485
|
output_fields: Optional[list[str]] = None,
|
|
471
|
-
timeout: Optional[float] = None,
|
|
486
|
+
timeout: Optional[float] = None, # pylint: disable=unused-argument
|
|
472
487
|
partition_names: Optional[list[str]] = None,
|
|
473
|
-
**kwargs,
|
|
488
|
+
**kwargs, # pylint: disable=unused-argument
|
|
474
489
|
) -> list[dict]:
|
|
475
490
|
"""Query records.
|
|
476
|
-
|
|
491
|
+
|
|
477
492
|
Args:
|
|
478
493
|
collection_name (string): collection name
|
|
479
494
|
flter: do ann search with filter (note: parameter name is intentionally 'flter' to distinguish it from the built-in function)
|
|
@@ -508,10 +523,12 @@ class MilvusLikeClient(Client):
|
|
|
508
523
|
if partition_names is None:
|
|
509
524
|
execute_res = conn.execute(stmt)
|
|
510
525
|
else:
|
|
511
|
-
stmt_str = str(
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
526
|
+
stmt_str = str(
|
|
527
|
+
stmt.compile(
|
|
528
|
+
dialect=self.engine.dialect,
|
|
529
|
+
compile_kwargs={"literal_binds": True},
|
|
530
|
+
)
|
|
531
|
+
)
|
|
515
532
|
stmt_str = self._insert_partition_hint_for_query_sql(
|
|
516
533
|
stmt_str, f"PARTITION({', '.join(partition_names)})"
|
|
517
534
|
)
|
|
@@ -534,12 +551,12 @@ class MilvusLikeClient(Client):
|
|
|
534
551
|
collection_name: str,
|
|
535
552
|
ids: Union[list, str, int] = None,
|
|
536
553
|
output_fields: Optional[list[str]] = None,
|
|
537
|
-
timeout: Optional[float] = None,
|
|
554
|
+
timeout: Optional[float] = None, # pylint: disable=unused-argument
|
|
538
555
|
partition_names: Optional[list[str]] = None,
|
|
539
|
-
**kwargs,
|
|
556
|
+
**kwargs, # pylint: disable=unused-argument
|
|
540
557
|
) -> list[dict]:
|
|
541
558
|
"""Get records with specified primary field `ids`.
|
|
542
|
-
|
|
559
|
+
|
|
543
560
|
Args:
|
|
544
561
|
collection_name (string): collection name
|
|
545
562
|
ids (Union[list, str, int]): specified primary field values
|
|
@@ -586,10 +603,12 @@ class MilvusLikeClient(Client):
|
|
|
586
603
|
if partition_names is None:
|
|
587
604
|
execute_res = conn.execute(stmt)
|
|
588
605
|
else:
|
|
589
|
-
stmt_str = str(
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
606
|
+
stmt_str = str(
|
|
607
|
+
stmt.compile(
|
|
608
|
+
dialect=self.engine.dialect,
|
|
609
|
+
compile_kwargs={"literal_binds": True},
|
|
610
|
+
)
|
|
611
|
+
)
|
|
593
612
|
stmt_str = self._insert_partition_hint_for_query_sql(
|
|
594
613
|
stmt_str, f"PARTITION({', '.join(partition_names)})"
|
|
595
614
|
)
|
|
@@ -611,10 +630,10 @@ class MilvusLikeClient(Client):
|
|
|
611
630
|
self,
|
|
612
631
|
collection_name: str,
|
|
613
632
|
ids: Optional[Union[list, str, int]] = None,
|
|
614
|
-
timeout: Optional[float] = None,
|
|
633
|
+
timeout: Optional[float] = None, # pylint: disable=unused-argument
|
|
615
634
|
flter=None,
|
|
616
635
|
partition_name: Optional[str] = "",
|
|
617
|
-
**kwargs,
|
|
636
|
+
**kwargs, # pylint: disable=unused-argument
|
|
618
637
|
) -> dict:
|
|
619
638
|
"""Delete data in collection.
|
|
620
639
|
|
|
@@ -675,11 +694,9 @@ class MilvusLikeClient(Client):
|
|
|
675
694
|
data: Union[dict, list[dict]],
|
|
676
695
|
timeout: Optional[float] = None,
|
|
677
696
|
partition_name: Optional[str] = "",
|
|
678
|
-
) ->
|
|
679
|
-
None
|
|
680
|
-
): # pylint: disable=unused-argument
|
|
697
|
+
) -> None: # pylint: disable=unused-argument
|
|
681
698
|
"""Insert data into collection.
|
|
682
|
-
|
|
699
|
+
|
|
683
700
|
Args:
|
|
684
701
|
collection_name (string): collection name
|
|
685
702
|
data (Union[Dict, List[Dict]]): data that will be inserted
|
|
@@ -701,11 +718,11 @@ class MilvusLikeClient(Client):
|
|
|
701
718
|
self,
|
|
702
719
|
collection_name: str,
|
|
703
720
|
data: Union[dict, list[dict]],
|
|
704
|
-
timeout: Optional[float] = None,
|
|
721
|
+
timeout: Optional[float] = None, # pylint: disable=unused-argument
|
|
705
722
|
partition_name: Optional[str] = "",
|
|
706
723
|
) -> list[Union[str, int]]:
|
|
707
724
|
"""Update data in table. If primary key is duplicated, replace it.
|
|
708
|
-
|
|
725
|
+
|
|
709
726
|
Args:
|
|
710
727
|
collection_name (string): collection name
|
|
711
728
|
data (Union[Dict, List[Dict]]): data that will be upserted
|
pyobvector/client/ob_client.py
CHANGED
|
@@ -51,7 +51,9 @@ class ObClient:
|
|
|
51
51
|
db_name: str = "test",
|
|
52
52
|
**kwargs,
|
|
53
53
|
):
|
|
54
|
-
registry.register(
|
|
54
|
+
registry.register(
|
|
55
|
+
"mysql.oceanbase", "pyobvector.schema.dialect", "OceanBaseDialect"
|
|
56
|
+
)
|
|
55
57
|
|
|
56
58
|
setattr(func_mod, "l2_distance", l2_distance)
|
|
57
59
|
setattr(func_mod, "cosine_distance", cosine_distance)
|
|
@@ -88,27 +90,31 @@ class ObClient:
|
|
|
88
90
|
for table_name in tables:
|
|
89
91
|
if table_name in self.metadata_obj.tables:
|
|
90
92
|
self.metadata_obj.remove(Table(table_name, self.metadata_obj))
|
|
91
|
-
self.metadata_obj.reflect(
|
|
93
|
+
self.metadata_obj.reflect(
|
|
94
|
+
bind=self.engine, only=tables, extend_existing=True
|
|
95
|
+
)
|
|
92
96
|
else:
|
|
93
97
|
self.metadata_obj.clear()
|
|
94
98
|
self.metadata_obj.reflect(bind=self.engine, extend_existing=True)
|
|
95
99
|
|
|
96
100
|
def _is_seekdb(self) -> bool:
|
|
97
101
|
"""Check if the database is SeekDB by querying version.
|
|
98
|
-
|
|
102
|
+
|
|
99
103
|
Returns:
|
|
100
104
|
bool: True if database is SeekDB, False otherwise
|
|
101
105
|
"""
|
|
102
106
|
is_seekdb = False
|
|
103
107
|
try:
|
|
104
|
-
if hasattr(self,
|
|
108
|
+
if hasattr(self, "_is_seekdb_cached"):
|
|
105
109
|
return self._is_seekdb_cached
|
|
106
110
|
with self.engine.connect() as conn:
|
|
107
111
|
result = conn.execute(text("SELECT VERSION()"))
|
|
108
112
|
version_str = [r[0] for r in result][0]
|
|
109
|
-
is_seekdb = "
|
|
113
|
+
is_seekdb = "seekdb" in version_str.lower()
|
|
110
114
|
self._is_seekdb_cached = is_seekdb
|
|
111
|
-
logger.debug(
|
|
115
|
+
logger.debug(
|
|
116
|
+
f"Version query result: {version_str}, is_seekdb: {is_seekdb}"
|
|
117
|
+
)
|
|
112
118
|
except Exception as e:
|
|
113
119
|
logger.warning(f"Failed to query version: {e}")
|
|
114
120
|
return is_seekdb
|
|
@@ -414,10 +420,12 @@ class ObClient:
|
|
|
414
420
|
if partition_names is None:
|
|
415
421
|
execute_res = conn.execute(stmt)
|
|
416
422
|
else:
|
|
417
|
-
stmt_str = str(
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
423
|
+
stmt_str = str(
|
|
424
|
+
stmt.compile(
|
|
425
|
+
dialect=self.engine.dialect,
|
|
426
|
+
compile_kwargs={"literal_binds": True},
|
|
427
|
+
)
|
|
428
|
+
)
|
|
421
429
|
stmt_str = self._insert_partition_hint_for_query_sql(
|
|
422
430
|
stmt_str, f"PARTITION({', '.join(partition_names)})"
|
|
423
431
|
)
|
|
@@ -451,9 +459,7 @@ class ObClient:
|
|
|
451
459
|
|
|
452
460
|
with self.engine.connect() as conn:
|
|
453
461
|
with conn.begin():
|
|
454
|
-
conn.execute(
|
|
455
|
-
text(f"ALTER TABLE `{table_name}` {columns_ddl}")
|
|
456
|
-
)
|
|
462
|
+
conn.execute(text(f"ALTER TABLE `{table_name}` {columns_ddl}"))
|
|
457
463
|
|
|
458
464
|
self.refresh_metadata([table_name])
|
|
459
465
|
|
|
@@ -472,8 +478,6 @@ class ObClient:
|
|
|
472
478
|
|
|
473
479
|
with self.engine.connect() as conn:
|
|
474
480
|
with conn.begin():
|
|
475
|
-
conn.execute(
|
|
476
|
-
text(f"ALTER TABLE `{table_name}` {columns_ddl}")
|
|
477
|
-
)
|
|
481
|
+
conn.execute(text(f"ALTER TABLE `{table_name}` {columns_ddl}"))
|
|
478
482
|
|
|
479
483
|
self.refresh_metadata([table_name])
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""OceanBase Vector Store Client."""
|
|
2
|
+
|
|
2
3
|
import logging
|
|
3
4
|
from typing import Optional, Union
|
|
4
5
|
|
|
@@ -44,18 +45,14 @@ class ObVecClient(ObClient):
|
|
|
44
45
|
if self.ob_version < ObVersion.from_db_version_nums(4, 3, 3, 0):
|
|
45
46
|
raise ClusterVersionException(
|
|
46
47
|
code=ErrorCode.NOT_SUPPORTED,
|
|
47
|
-
message=ExceptionsMessage.ClusterVersionIsLow
|
|
48
|
+
message=ExceptionsMessage.ClusterVersionIsLow
|
|
49
|
+
% ("Vector Store", "4.3.3.0"),
|
|
48
50
|
)
|
|
49
51
|
|
|
50
|
-
def _get_sparse_vector_index_params(
|
|
51
|
-
self, vidxs: Optional[IndexParams]
|
|
52
|
-
):
|
|
52
|
+
def _get_sparse_vector_index_params(self, vidxs: Optional[IndexParams]):
|
|
53
53
|
if vidxs is None:
|
|
54
54
|
return None
|
|
55
|
-
return [
|
|
56
|
-
vidx for vidx in vidxs
|
|
57
|
-
if vidx.is_index_type_sparse_vector()
|
|
58
|
-
]
|
|
55
|
+
return [vidx for vidx in vidxs if vidx.is_index_type_sparse_vector()]
|
|
59
56
|
|
|
60
57
|
def create_table_with_index_params(
|
|
61
58
|
self,
|
|
@@ -65,6 +62,7 @@ class ObVecClient(ObClient):
|
|
|
65
62
|
vidxs: Optional[IndexParams] = None,
|
|
66
63
|
fts_idxs: Optional[list[FtsIndexParam]] = None,
|
|
67
64
|
partitions: Optional[ObPartition] = None,
|
|
65
|
+
**kwargs,
|
|
68
66
|
):
|
|
69
67
|
"""Create table with optional index_params.
|
|
70
68
|
|
|
@@ -75,8 +73,10 @@ class ObVecClient(ObClient):
|
|
|
75
73
|
vidxs (Optional[IndexParams]): optional vector index schema
|
|
76
74
|
fts_idxs (Optional[List[FtsIndexParam]]): optional full-text search index schema
|
|
77
75
|
partitions (Optional[ObPartition]): optional partition strategy
|
|
76
|
+
**kwargs: additional keyword arguments (e.g., mysql_organization='heap')
|
|
78
77
|
"""
|
|
79
78
|
sparse_vidxs = self._get_sparse_vector_index_params(vidxs)
|
|
79
|
+
kwargs.setdefault("extend_existing", True)
|
|
80
80
|
with self.engine.connect() as conn:
|
|
81
81
|
with conn.begin():
|
|
82
82
|
# create table with common index
|
|
@@ -86,21 +86,21 @@ class ObVecClient(ObClient):
|
|
|
86
86
|
self.metadata_obj,
|
|
87
87
|
*columns,
|
|
88
88
|
*indexes,
|
|
89
|
-
|
|
89
|
+
**kwargs,
|
|
90
90
|
)
|
|
91
91
|
else:
|
|
92
92
|
table = ObTable(
|
|
93
93
|
table_name,
|
|
94
94
|
self.metadata_obj,
|
|
95
95
|
*columns,
|
|
96
|
-
|
|
96
|
+
**kwargs,
|
|
97
97
|
)
|
|
98
98
|
if sparse_vidxs is not None and len(sparse_vidxs) > 0:
|
|
99
99
|
create_table_sql = str(CreateTable(table).compile(self.engine))
|
|
100
|
-
new_sql = create_table_sql[:create_table_sql.rfind(
|
|
100
|
+
new_sql = create_table_sql[: create_table_sql.rfind(")")]
|
|
101
101
|
for sparse_vidx in sparse_vidxs:
|
|
102
102
|
sparse_params = sparse_vidx._parse_kwargs()
|
|
103
|
-
if
|
|
103
|
+
if "type" in sparse_params:
|
|
104
104
|
new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (type={sparse_params['type']}, distance=inner_product)"
|
|
105
105
|
else:
|
|
106
106
|
new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (distance=inner_product)"
|
|
@@ -127,7 +127,9 @@ class ObVecClient(ObClient):
|
|
|
127
127
|
# create fts indexes
|
|
128
128
|
if fts_idxs is not None:
|
|
129
129
|
for fts_idx in fts_idxs:
|
|
130
|
-
idx_cols = [
|
|
130
|
+
idx_cols = [
|
|
131
|
+
table.c[field_name] for field_name in fts_idx.field_names
|
|
132
|
+
]
|
|
131
133
|
fts_idx = FtsIndex(
|
|
132
134
|
fts_idx.index_name,
|
|
133
135
|
fts_idx.param_str(),
|
|
@@ -192,7 +194,7 @@ class ObVecClient(ObClient):
|
|
|
192
194
|
fts_idx_param: FtsIndexParam,
|
|
193
195
|
):
|
|
194
196
|
"""Create fts index with fts index parameter.
|
|
195
|
-
|
|
197
|
+
|
|
196
198
|
Args:
|
|
197
199
|
table_name (string): table name
|
|
198
200
|
fts_idx_param (FtsIndexParam): fts index parameter
|
|
@@ -200,7 +202,9 @@ class ObVecClient(ObClient):
|
|
|
200
202
|
table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
|
|
201
203
|
with self.engine.connect() as conn:
|
|
202
204
|
with conn.begin():
|
|
203
|
-
idx_cols = [
|
|
205
|
+
idx_cols = [
|
|
206
|
+
table.c[field_name] for field_name in fts_idx_param.field_names
|
|
207
|
+
]
|
|
204
208
|
fts_idx = FtsIndex(
|
|
205
209
|
fts_idx_param.index_name,
|
|
206
210
|
fts_idx_param.param_str(),
|
|
@@ -332,11 +336,7 @@ class ObVecClient(ObClient):
|
|
|
332
336
|
)
|
|
333
337
|
)
|
|
334
338
|
else:
|
|
335
|
-
columns.append(
|
|
336
|
-
distance_func(
|
|
337
|
-
table.c[vec_column_name], f"{vec_data}"
|
|
338
|
-
)
|
|
339
|
-
)
|
|
339
|
+
columns.append(distance_func(table.c[vec_column_name], f"{vec_data}"))
|
|
340
340
|
# if idx_name_hint is not None:
|
|
341
341
|
# stmt = select(*columns).with_hint(
|
|
342
342
|
# table,
|
|
@@ -357,9 +357,7 @@ class ObVecClient(ObClient):
|
|
|
357
357
|
"[" + ",".join([str(np.float32(v)) for v in vec_data]) + "]",
|
|
358
358
|
)
|
|
359
359
|
else:
|
|
360
|
-
dist_expr = distance_func(
|
|
361
|
-
table.c[vec_column_name], f"{vec_data}"
|
|
362
|
-
)
|
|
360
|
+
dist_expr = distance_func(table.c[vec_column_name], f"{vec_data}")
|
|
363
361
|
stmt = stmt.where(dist_expr <= distance_threshold)
|
|
364
362
|
|
|
365
363
|
if isinstance(vec_data, list):
|
|
@@ -370,23 +368,23 @@ class ObVecClient(ObClient):
|
|
|
370
368
|
)
|
|
371
369
|
)
|
|
372
370
|
else:
|
|
373
|
-
stmt = stmt.order_by(
|
|
374
|
-
|
|
375
|
-
|
|
371
|
+
stmt = stmt.order_by(distance_func(table.c[vec_column_name], f"{vec_data}"))
|
|
372
|
+
stmt_str = (
|
|
373
|
+
str(
|
|
374
|
+
stmt.compile(
|
|
375
|
+
dialect=self.engine.dialect, compile_kwargs={"literal_binds": True}
|
|
376
376
|
)
|
|
377
377
|
)
|
|
378
|
-
stmt_str = (
|
|
379
|
-
str(stmt.compile(
|
|
380
|
-
dialect=self.engine.dialect,
|
|
381
|
-
compile_kwargs={"literal_binds": True}
|
|
382
|
-
))
|
|
383
378
|
+ f" APPROXIMATE limit {topk}"
|
|
384
379
|
)
|
|
385
380
|
with self.engine.connect() as conn:
|
|
386
381
|
with conn.begin():
|
|
387
382
|
if idx_name_hint is not None:
|
|
388
383
|
idx = stmt_str.find("SELECT ")
|
|
389
|
-
stmt_str =
|
|
384
|
+
stmt_str = (
|
|
385
|
+
f"SELECT /*+ index({table_name} {idx_name_hint}) */ "
|
|
386
|
+
+ stmt_str[idx + len("SELECT ") :]
|
|
387
|
+
)
|
|
390
388
|
|
|
391
389
|
if partition_names is None:
|
|
392
390
|
return conn.execute(text(stmt_str))
|
|
@@ -430,7 +428,9 @@ class ObVecClient(ObClient):
|
|
|
430
428
|
|
|
431
429
|
columns = []
|
|
432
430
|
if output_column_names is not None:
|
|
433
|
-
columns.extend(
|
|
431
|
+
columns.extend(
|
|
432
|
+
[table.c[column_name] for column_name in output_column_names]
|
|
433
|
+
)
|
|
434
434
|
else:
|
|
435
435
|
columns.extend([table.c[column.name] for column in table.columns])
|
|
436
436
|
if extra_output_cols is not None:
|
|
@@ -459,16 +459,20 @@ class ObVecClient(ObClient):
|
|
|
459
459
|
if partition_names is None:
|
|
460
460
|
if str_list is not None:
|
|
461
461
|
str_list.append(
|
|
462
|
-
str(
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
462
|
+
str(
|
|
463
|
+
stmt.compile(
|
|
464
|
+
dialect=self.engine.dialect,
|
|
465
|
+
compile_kwargs={"literal_binds": True},
|
|
466
|
+
)
|
|
467
|
+
)
|
|
466
468
|
)
|
|
467
469
|
return conn.execute(stmt)
|
|
468
|
-
stmt_str = str(
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
470
|
+
stmt_str = str(
|
|
471
|
+
stmt.compile(
|
|
472
|
+
dialect=self.engine.dialect,
|
|
473
|
+
compile_kwargs={"literal_binds": True},
|
|
474
|
+
)
|
|
475
|
+
)
|
|
472
476
|
stmt_str = self._insert_partition_hint_for_query_sql(
|
|
473
477
|
stmt_str, f"PARTITION({', '.join(partition_names)})"
|
|
474
478
|
)
|