pyobvector 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,52 +1,35 @@
1
1
  """OceanBase Vector Store Client."""
2
-
3
2
  import logging
4
- from typing import List, Optional, Dict, Union
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
5
6
  from sqlalchemy import (
6
- create_engine,
7
- MetaData,
8
7
  Table,
9
8
  Column,
10
9
  Index,
11
10
  select,
12
- delete,
13
- update,
14
- insert,
15
11
  text,
16
- inspect,
17
- and_,
18
12
  )
19
- from sqlalchemy.exc import NoSuchTableError
20
- from sqlalchemy.dialects import registry
21
- import sqlalchemy.sql.functions as func_mod
22
- import numpy as np
23
- from urllib.parse import quote
24
- from .index_param import IndexParams, IndexParam
13
+ from sqlalchemy.schema import CreateTable
14
+
15
+ from .exceptions import ClusterVersionException, ErrorCode, ExceptionsMessage
25
16
  from .fts_index_param import FtsIndexParam
17
+ from .index_param import IndexParams, IndexParam
18
+ from .ob_client import ObClient
19
+ from .partitions import ObPartition
26
20
  from ..schema import (
27
21
  ObTable,
28
22
  VectorIndex,
29
- l2_distance,
30
- cosine_distance,
31
- inner_product,
32
- negative_inner_product,
33
- ST_GeomFromText,
34
- st_distance,
35
- st_dwithin,
36
- st_astext,
37
- ReplaceStmt,
38
23
  FtsIndex,
39
24
  )
40
25
  from ..util import ObVersion
41
- from .partitions import *
42
- from .exceptions import *
43
26
 
44
27
  logger = logging.getLogger(__name__)
45
28
  logger.setLevel(logging.DEBUG)
46
29
 
47
30
 
48
- class ObVecClient:
49
- """The OceanBase Client"""
31
+ class ObVecClient(ObClient):
32
+ """The OceanBase Vector Client"""
50
33
 
51
34
  def __init__(
52
35
  self,
@@ -56,119 +39,23 @@ class ObVecClient:
56
39
  db_name: str = "test",
57
40
  **kwargs,
58
41
  ):
59
- registry.register("mysql.oceanbase", "pyobvector.schema.dialect", "OceanBaseDialect")
60
-
61
- # ischema_names["VECTOR"] = VECTOR
62
- setattr(func_mod, "l2_distance", l2_distance)
63
- setattr(func_mod, "cosine_distance", cosine_distance)
64
- setattr(func_mod, "inner_product", inner_product)
65
- setattr(func_mod, "negative_inner_product", negative_inner_product)
66
- setattr(func_mod, "ST_GeomFromText", ST_GeomFromText)
67
- setattr(func_mod, "st_distance", st_distance)
68
- setattr(func_mod, "st_dwithin", st_dwithin)
69
- setattr(func_mod, "st_astext", st_astext)
70
-
71
- user = quote(user, safe="")
72
- password = quote(password, safe="")
73
-
74
- connection_str = (
75
- f"mysql+oceanbase://{user}:{password}@{uri}/{db_name}?charset=utf8mb4"
76
- )
77
- self.engine = create_engine(connection_str, **kwargs)
78
- self.metadata_obj = MetaData()
79
- self.metadata_obj.reflect(bind=self.engine)
80
-
81
- with self.engine.connect() as conn:
82
- with conn.begin():
83
- res = conn.execute(text("SELECT OB_VERSION() FROM DUAL"))
84
- version = [r[0] for r in res][0]
85
- ob_version = ObVersion.from_db_version_string(version)
86
- if ob_version < ObVersion.from_db_version_nums(4, 3, 3, 0):
87
- raise ClusterVersionException(
88
- code=ErrorCode.NOT_SUPPORTED,
89
- message=ExceptionsMessage.ClusterVersionIsLow,
90
- )
91
-
92
- def refresh_metadata(self, tables: Optional[list[str]] = None):
93
- """Reload metadata from the database.
94
-
95
- Args:
96
- tables (Optional[list[str]]): names of the tables to refresh. If None, refresh all tables.
97
- """
98
- if tables is not None:
99
- for table_name in tables:
100
- if table_name in self.metadata_obj.tables:
101
- self.metadata_obj.remove(Table(table_name, self.metadata_obj))
102
- self.metadata_obj.reflect(bind=self.engine, only=tables, extend_existing=True)
103
- else:
104
- self.metadata_obj.clear()
105
- self.metadata_obj.reflect(bind=self.engine, extend_existing=True)
106
-
107
- def _insert_partition_hint_for_query_sql(self, sql: str, partition_hint: str):
108
- from_index = sql.find("FROM")
109
- assert from_index != -1
110
- first_space_after_from = sql.find(" ", from_index + len("FROM") + 1)
111
- if first_space_after_from == -1:
112
- return sql + " " + partition_hint
113
- return (
114
- sql[:first_space_after_from]
115
- + " "
116
- + partition_hint
117
- + sql[first_space_after_from:]
118
- )
119
-
120
- def check_table_exists(self, table_name: str):
121
- """check if table exists.
42
+ super().__init__(uri, user, password, db_name, **kwargs)
122
43
 
123
- Args:
124
- table_name (string) : table name
125
- """
126
- inspector = inspect(self.engine)
127
- return inspector.has_table(table_name)
44
+ if self.ob_version < ObVersion.from_db_version_nums(4, 3, 3, 0):
45
+ raise ClusterVersionException(
46
+ code=ErrorCode.NOT_SUPPORTED,
47
+ message=ExceptionsMessage.ClusterVersionIsLow,
48
+ )
128
49
 
129
- def create_table(
130
- self,
131
- table_name: str,
132
- columns: List[Column],
133
- indexes: Optional[List[Index]] = None,
134
- partitions: Optional[ObPartition] = None,
50
+ def _get_sparse_vector_index_params(
51
+ self, vidxs: Optional[IndexParams]
135
52
  ):
136
- """Create a table.
137
-
138
- Args:
139
- table_name (string) : table name
140
- columns (List[Column]) : column schema
141
- indexes (Optional[List[Index]]) : optional index schema
142
- partitions (Optional[ObPartition]) : optional partition strategy
143
- """
144
- with self.engine.connect() as conn:
145
- with conn.begin():
146
- if indexes is not None:
147
- table = ObTable(
148
- table_name,
149
- self.metadata_obj,
150
- *columns,
151
- *indexes,
152
- extend_existing=True,
153
- )
154
- else:
155
- table = ObTable(
156
- table_name,
157
- self.metadata_obj,
158
- *columns,
159
- extend_existing=True,
160
- )
161
- table.create(self.engine, checkfirst=True)
162
- # do partition
163
- if partitions is not None:
164
- conn.execute(
165
- text(f"ALTER TABLE `{table_name}` {partitions.do_compile()}")
166
- )
167
-
168
- @classmethod
169
- def prepare_index_params(cls):
170
- """Create `IndexParams` to hold index configuration."""
171
- return IndexParams()
53
+ if vidxs is None:
54
+ return None
55
+ return [
56
+ vidx for vidx in vidxs
57
+ if vidx.is_index_type_sparse_vector()
58
+ ]
172
59
 
173
60
  def create_table_with_index_params(
174
61
  self,
@@ -182,12 +69,14 @@ class ObVecClient:
182
69
  """Create table with optional index_params.
183
70
 
184
71
  Args:
185
- table_name (string) : table name
186
- columns (List[Column]) : column schema
187
- indexes (Optional[List[Index]]) : optional common index schema
188
- vids (Optional[IndexParams]) : optional vector index schema
189
- partitions (Optional[ObPartition]) : optional partition strategy
72
+ table_name (string): table name
73
+ columns (List[Column]): column schema
74
+ indexes (Optional[List[Index]]): optional common index schema
75
+ vidxs (Optional[IndexParams]): optional vector index schema
76
+ fts_idxs (Optional[List[FtsIndexParam]]): optional full-text search index schema
77
+ partitions (Optional[ObPartition]): optional partition strategy
190
78
  """
79
+ sparse_vidxs = self._get_sparse_vector_index_params(vidxs)
191
80
  with self.engine.connect() as conn:
192
81
  with conn.begin():
193
82
  # create table with common index
@@ -206,7 +95,15 @@ class ObVecClient:
206
95
  *columns,
207
96
  extend_existing=True,
208
97
  )
209
- table.create(self.engine, checkfirst=True)
98
+ if sparse_vidxs is not None and len(sparse_vidxs) > 0:
99
+ create_table_sql = str(CreateTable(table).compile(self.engine))
100
+ new_sql = create_table_sql[:create_table_sql.rfind(')')]
101
+ for sparse_vidx in sparse_vidxs:
102
+ new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (distance=inner_product)"
103
+ new_sql += "\n)"
104
+ conn.execute(text(new_sql))
105
+ else:
106
+ table.create(self.engine, checkfirst=True)
210
107
  # do partition
211
108
  if partitions is not None:
212
109
  conn.execute(
@@ -215,6 +112,8 @@ class ObVecClient:
215
112
  # create vector indexes
216
113
  if vidxs is not None:
217
114
  for vidx in vidxs:
115
+ if vidx.is_index_type_sparse_vector():
116
+ continue
218
117
  vidx = VectorIndex(
219
118
  vidx.index_name,
220
119
  table.c[vidx.field_name],
@@ -244,12 +143,12 @@ class ObVecClient:
244
143
  """Create common index or vector index.
245
144
 
246
145
  Args:
247
- table_name (string) : table name
248
- is_vec_index (bool) : common index or vector index
249
- index_name (string) : index name
250
- column_names (List[string]) : create index on which columns
251
- vidx_params (Optional[str]) :
252
- vector index params, for example 'distance=l2, type=hnsw, lib=vsag'
146
+ table_name (string): table name
147
+ is_vec_index (bool): common index or vector index
148
+ index_name (string): index name
149
+ column_names (List[string]): create index on which columns
150
+ vidx_params (Optional[str]): vector index params, for example 'distance=l2, type=hnsw, lib=vsag'
151
+ **kw: additional keyword arguments
253
152
  """
254
153
  table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
255
154
  columns = [table.c[column_name] for column_name in column_names]
@@ -270,8 +169,8 @@ class ObVecClient:
270
169
  """Create vector index with vector index parameter.
271
170
 
272
171
  Args:
273
- table_name (string) : table name
274
- vidx_param (IndexParam) : vector index parameter
172
+ table_name (string): table name
173
+ vidx_param (IndexParam): vector index parameter
275
174
  """
276
175
  table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
277
176
  with self.engine.connect() as conn:
@@ -291,8 +190,8 @@ class ObVecClient:
291
190
  """Create fts index with fts index parameter.
292
191
 
293
192
  Args:
294
- table_name (string) : table name
295
- fts_idx_param (FtsIndexParam) : fts index parameter
193
+ table_name (string): table name
194
+ fts_idx_param (FtsIndexParam): fts index parameter
296
195
  """
297
196
  table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
298
197
  with self.engine.connect() as conn:
@@ -305,26 +204,6 @@ class ObVecClient:
305
204
  )
306
205
  fts_idx.create(self.engine, checkfirst=True)
307
206
 
308
- def drop_table_if_exist(self, table_name: str):
309
- """Drop table if exists."""
310
- try:
311
- table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
312
- except NoSuchTableError:
313
- return
314
- with self.engine.connect() as conn:
315
- with conn.begin():
316
- table.drop(self.engine, checkfirst=True)
317
- self.metadata_obj.remove(table)
318
-
319
- def drop_index(self, table_name: str, index_name: str):
320
- """drop index on specified table.
321
-
322
- If the index not exists, SQL ERROR 1091 will raise.
323
- """
324
- with self.engine.connect() as conn:
325
- with conn.begin():
326
- conn.execute(text(f"DROP INDEX `{index_name}` ON `{table_name}`"))
327
-
328
207
  def refresh_index(
329
208
  self,
330
209
  table_name: str,
@@ -334,10 +213,9 @@ class ObVecClient:
334
213
  """Refresh vector index for performance.
335
214
 
336
215
  Args:
337
- :param table_name (string) : table name
338
- :param index_name (string) : vector index name
339
- :param trigger_threshold (int) :
340
- If delta_buffer_table row count is greater than `trigger_threshold`,
216
+ table_name (string): table name
217
+ index_name (string): vector index name
218
+ trigger_threshold (int): If delta_buffer_table row count is greater than `trigger_threshold`,
341
219
  refreshing is actually triggered.
342
220
  """
343
221
  with self.engine.connect() as conn:
@@ -358,9 +236,9 @@ class ObVecClient:
358
236
  """Rebuild vector index for performance.
359
237
 
360
238
  Args:
361
- :param table_name (string) : table name
362
- :param index_name (string) : vector index name
363
- :param trigger_threshold (float)
239
+ table_name (string): table name
240
+ index_name (string): vector index name
241
+ trigger_threshold (float): threshold value for rebuilding index
364
242
  """
365
243
  with self.engine.connect() as conn:
366
244
  with conn.begin():
@@ -371,221 +249,6 @@ class ObVecClient:
371
249
  )
372
250
  )
373
251
 
374
- def insert(
375
- self,
376
- table_name: str,
377
- data: Union[Dict, List[Dict]],
378
- partition_name: Optional[str] = "",
379
- ):
380
- """Insert data into table.
381
-
382
- Args:
383
- table_name (string) : table name
384
- data (Union[Dict, List[Dict]]) : data that will be inserted
385
- partition_names (Optional[str]) : limit the query to certain partition
386
- """
387
- if isinstance(data, Dict):
388
- data = [data]
389
-
390
- if len(data) == 0:
391
- return
392
-
393
- table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
394
-
395
- with self.engine.connect() as conn:
396
- with conn.begin():
397
- if partition_name is None or partition_name == "":
398
- conn.execute(insert(table).values(data))
399
- else:
400
- conn.execute(
401
- insert(table)
402
- .with_hint(f"PARTITION({partition_name})")
403
- .values(data)
404
- )
405
-
406
- def upsert(
407
- self,
408
- table_name: str,
409
- data: Union[Dict, List[Dict]],
410
- partition_name: Optional[str] = "",
411
- ):
412
- """Update data in table. If primary key is duplicated, replace it.
413
-
414
- Args:
415
- table_name (string) : table name
416
- data (Union[Dict, List[Dict]]) : data that will be upserted
417
- partition_names (Optional[str]) : limit the query to certain partition
418
- """
419
- if isinstance(data, Dict):
420
- data = [data]
421
-
422
- if len(data) == 0:
423
- return
424
-
425
- table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
426
-
427
- with self.engine.connect() as conn:
428
- with conn.begin():
429
- upsert_stmt = (
430
- ReplaceStmt(table).with_hint(f"PARTITION({partition_name})")
431
- if partition_name is not None and partition_name != ""
432
- else ReplaceStmt(table)
433
- )
434
- upsert_stmt = upsert_stmt.values(data)
435
- conn.execute(upsert_stmt)
436
-
437
- def update(
438
- self,
439
- table_name: str,
440
- values_clause,
441
- where_clause=None,
442
- partition_name: Optional[str] = "",
443
- ):
444
- """Update data in table.
445
-
446
- Args:
447
- table_name (string) : table name
448
- values_clause: update values clause
449
- where_clause: update with filter
450
- partition_name (Optional[str]) : limit the query to certain partition
451
-
452
- Example:
453
- .. code-block:: python
454
-
455
- data = [
456
- {"id": 112, "embedding": [1, 2, 3], "meta": {'doc':'hhh1'}},
457
- {"id": 190, "embedding": [0.13, 0.123, 1.213], "meta": {'doc':'hhh2'}},
458
- ]
459
- client.insert(collection_name=test_collection_name, data=data)
460
- client.update(
461
- table_name=test_collection_name,
462
- values_clause=[{'meta':{'doc':'HHH'}}],
463
- where_clause=[text("id=112")]
464
- )
465
- """
466
- table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
467
-
468
- with self.engine.connect() as conn:
469
- with conn.begin():
470
- update_stmt = (
471
- update(table).with_hint(f"PARTITION({partition_name})")
472
- if partition_name is not None and partition_name != ""
473
- else update(table)
474
- )
475
- if where_clause is not None:
476
- update_stmt = update_stmt.where(*where_clause).values(
477
- *values_clause
478
- )
479
- else:
480
- update_stmt = update_stmt.values(*values_clause)
481
- conn.execute(update_stmt)
482
-
483
- def delete(
484
- self,
485
- table_name: str,
486
- ids: Optional[Union[list, str, int]] = None,
487
- where_clause=None,
488
- partition_name: Optional[str] = "",
489
- ):
490
- """Delete data in table.
491
-
492
- Args:
493
- table_name (string) : table name
494
- where_clause : delete with filter
495
- partition_names (Optional[str]) : limit the query to certain partition
496
- """
497
- table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
498
- where_in_clause = None
499
- if ids is not None:
500
- primary_keys = table.primary_key
501
- pkey_names = [column.name for column in primary_keys]
502
- if len(pkey_names) == 1:
503
- if isinstance(ids, list):
504
- where_in_clause = table.c[pkey_names[0]].in_(ids)
505
- elif isinstance(ids, (str, int)):
506
- where_in_clause = table.c[pkey_names[0]].in_([ids])
507
- else:
508
- raise TypeError("'ids' is not a list/str/int")
509
-
510
- with self.engine.connect() as conn:
511
- with conn.begin():
512
- delete_stmt = (
513
- delete(table).with_hint(f"PARTITION({partition_name})")
514
- if partition_name is not None and partition_name != ""
515
- else delete(table)
516
- )
517
- if where_in_clause is None and where_clause is None:
518
- conn.execute(delete_stmt)
519
- elif where_in_clause is not None and where_clause is None:
520
- conn.execute(delete_stmt.where(where_in_clause))
521
- elif where_in_clause is None and where_clause is not None:
522
- conn.execute(delete_stmt.where(*where_clause))
523
- else:
524
- conn.execute(
525
- delete_stmt.where(and_(where_in_clause, *where_clause))
526
- )
527
-
528
- def get(
529
- self,
530
- table_name: str,
531
- ids: Optional[Union[list, str, int]],
532
- where_clause = None,
533
- output_column_name: Optional[List[str]] = None,
534
- partition_names: Optional[List[str]] = None,
535
- n_limits: Optional[int] = None,
536
- ):
537
- """get records with specified primary field `ids`.
538
-
539
- Args:
540
- :param table_name (string) : table name
541
- :param ids : specified primary field values
542
- :param where_clause : SQL filter
543
- :param output_column_name (Optional[List[str]]) : output fields name
544
- :param partition_names (List[str]) : limit the query to certain partitions
545
- """
546
- table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
547
- if output_column_name is not None:
548
- columns = [table.c[column_name] for column_name in output_column_name]
549
- stmt = select(*columns)
550
- else:
551
- stmt = select(table)
552
- primary_keys = table.primary_key
553
- pkey_names = [column.name for column in primary_keys]
554
- where_in_clause = None
555
- if ids is not None and len(pkey_names) == 1:
556
- if isinstance(ids, list):
557
- where_in_clause = table.c[pkey_names[0]].in_(ids)
558
- elif isinstance(ids, (str, int)):
559
- where_in_clause = table.c[pkey_names[0]].in_([ids])
560
- else:
561
- raise TypeError("'ids' is not a list/str/int")
562
-
563
- if where_in_clause is not None and where_clause is None:
564
- stmt = stmt.where(where_in_clause)
565
- elif where_in_clause is None and where_clause is not None:
566
- stmt = stmt.where(*where_clause)
567
- elif where_in_clause is not None and where_clause is not None:
568
- stmt = stmt.where(and_(where_in_clause, *where_clause))
569
-
570
- if n_limits is not None:
571
- stmt = stmt.limit(n_limits)
572
-
573
- with self.engine.connect() as conn:
574
- with conn.begin():
575
- if partition_names is None:
576
- execute_res = conn.execute(stmt)
577
- else:
578
- stmt_str = str(stmt.compile(
579
- dialect=self.engine.dialect,
580
- compile_kwargs={"literal_binds": True}
581
- ))
582
- stmt_str = self._insert_partition_hint_for_query_sql(
583
- stmt_str, f"PARTITION({', '.join(partition_names)})"
584
- )
585
- logging.debug(stmt_str)
586
- execute_res = conn.execute(text(stmt_str))
587
- return execute_res
588
-
589
252
  def set_ob_hnsw_ef_search(self, ob_hnsw_ef_search: int):
590
253
  """Set ob_hnsw_ef_search system variable."""
591
254
  with self.engine.connect() as conn:
@@ -602,49 +265,74 @@ class ObVecClient:
602
265
  def ann_search(
603
266
  self,
604
267
  table_name: str,
605
- vec_data: list,
268
+ vec_data: Union[list, dict],
606
269
  vec_column_name: str,
607
270
  distance_func,
608
271
  with_dist: bool = False,
609
272
  topk: int = 10,
610
273
  output_column_names: Optional[List[str]] = None,
274
+ output_columns: Optional[Union[List, tuple]] = None,
611
275
  extra_output_cols: Optional[List] = None,
612
276
  where_clause=None,
613
277
  partition_names: Optional[List[str]] = None,
614
278
  idx_name_hint: Optional[List[str]] = None,
279
+ distance_threshold: Optional[float] = None,
615
280
  **kwargs,
616
- ): # pylint: disable=unused-argument
617
- """perform ann search.
281
+ ): # pylint: disable=unused-argument
282
+ """Perform ann search.
618
283
 
619
284
  Args:
620
- table_name (string) : table name
621
- vec_data (list) : the vector data to search
622
- vec_column_name (string) : which vector field to search
623
- distance_func : function to calculate distance between vectors
624
- with_dist (bool) : return result with distance
625
- topk (int) : top K
626
- output_column_names (Optional[List[str]]) : output fields
627
- where_clause : do ann search with filter
628
- idx_name_hint : post-filtering enabled if vector index name is specified
629
- Or pre-filtering enabled
285
+ table_name (string): table name
286
+ vec_data (Union[list, dict]): the vector/sparse_vector data to search
287
+ vec_column_name (string): which vector field to search
288
+ distance_func: function to calculate distance between vectors
289
+ with_dist (bool): return result with distance
290
+ topk (int): top K
291
+ output_column_names (Optional[List[str]]): output fields
292
+ output_columns (Optional[Union[List, tuple]]): output columns as SQLAlchemy Column objects
293
+ or expressions. Similar to SQLAlchemy's select() function arguments.
294
+ If provided, takes precedence over output_column_names.
295
+ extra_output_cols (Optional[List]): additional output columns
296
+ where_clause: do ann search with filter
297
+ partition_names (Optional[List[str]]): limit the query to certain partitions
298
+ idx_name_hint (Optional[List[str]]): post-filtering enabled if vector index name is specified
299
+ Or pre-filtering enabled
300
+ distance_threshold (Optional[float]): filter results where distance <= threshold.
301
+ **kwargs: additional arguments
630
302
  """
303
+ if not (isinstance(vec_data, list) or isinstance(vec_data, dict)):
304
+ raise ValueError("'vec_data' type must be in 'list'/'dict'")
305
+
631
306
  table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
632
307
 
633
- if output_column_names is not None:
308
+ columns = []
309
+ if output_columns:
310
+ if isinstance(output_columns, (list, tuple)):
311
+ columns = list(output_columns)
312
+ else:
313
+ columns = [output_columns]
314
+ elif output_column_names:
634
315
  columns = [table.c[column_name] for column_name in output_column_names]
635
316
  else:
636
317
  columns = [table.c[column.name] for column in table.columns]
637
318
 
638
- if extra_output_cols is not None:
319
+ if extra_output_cols:
639
320
  columns.extend(extra_output_cols)
640
321
 
641
322
  if with_dist:
642
- columns.append(
643
- distance_func(
644
- table.c[vec_column_name],
645
- "[" + ",".join([str(np.float32(v)) for v in vec_data]) + "]",
323
+ if isinstance(vec_data, list):
324
+ columns.append(
325
+ distance_func(
326
+ table.c[vec_column_name],
327
+ "[" + ",".join([str(np.float32(v)) for v in vec_data]) + "]",
328
+ )
329
+ )
330
+ else:
331
+ columns.append(
332
+ distance_func(
333
+ table.c[vec_column_name], f"{vec_data}"
334
+ )
646
335
  )
647
- )
648
336
  # if idx_name_hint is not None:
649
337
  # stmt = select(*columns).with_hint(
650
338
  # table,
@@ -657,12 +345,32 @@ class ObVecClient:
657
345
  if where_clause is not None:
658
346
  stmt = stmt.where(*where_clause)
659
347
 
660
- stmt = stmt.order_by(
661
- distance_func(
662
- table.c[vec_column_name],
663
- "[" + ",".join([str(np.float32(v)) for v in vec_data]) + "]",
348
+ # Add distance threshold filter in SQL WHERE clause
349
+ if distance_threshold is not None:
350
+ if isinstance(vec_data, list):
351
+ dist_expr = distance_func(
352
+ table.c[vec_column_name],
353
+ "[" + ",".join([str(np.float32(v)) for v in vec_data]) + "]",
354
+ )
355
+ else:
356
+ dist_expr = distance_func(
357
+ table.c[vec_column_name], f"{vec_data}"
358
+ )
359
+ stmt = stmt.where(dist_expr <= distance_threshold)
360
+
361
+ if isinstance(vec_data, list):
362
+ stmt = stmt.order_by(
363
+ distance_func(
364
+ table.c[vec_column_name],
365
+ "[" + ",".join([str(np.float32(v)) for v in vec_data]) + "]",
366
+ )
367
+ )
368
+ else:
369
+ stmt = stmt.order_by(
370
+ distance_func(
371
+ table.c[vec_column_name], f"{vec_data}"
372
+ )
664
373
  )
665
- )
666
374
  stmt_str = (
667
375
  str(stmt.compile(
668
376
  dialect=self.engine.dialect,
@@ -697,18 +405,22 @@ class ObVecClient:
697
405
  partition_names: Optional[List[str]] = None,
698
406
  str_list: Optional[List[str]] = None,
699
407
  **kwargs,
700
- ): # pylint: disable=unused-argument
701
- """perform post ann search.
408
+ ): # pylint: disable=unused-argument
409
+ """Perform post ann search.
702
410
 
703
411
  Args:
704
- table_name (string) : table name
705
- vec_data (list) : the vector data to search
706
- vec_column_name (string) : which vector field to search
707
- distance_func : function to calculate distance between vectors
708
- with_dist (bool) : return result with distance
709
- topk (int) : top K
710
- output_column_names (Optional[List[str]]) : output fields
711
- where_clause : do ann search with filter
412
+ table_name (string): table name
413
+ vec_data (list): the vector data to search
414
+ vec_column_name (string): which vector field to search
415
+ distance_func: function to calculate distance between vectors
416
+ with_dist (bool): return result with distance
417
+ topk (int): top K
418
+ output_column_names (Optional[List[str]]): output fields
419
+ extra_output_cols (Optional[List]): additional output columns
420
+ where_clause: do ann search with filter
421
+ partition_names (Optional[List[str]]): limit the query to certain partitions
422
+ str_list (Optional[List[str]]): list to append SQL string to
423
+ **kwargs: additional arguments
712
424
  """
713
425
  table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
714
426
 
@@ -770,17 +482,18 @@ class ObVecClient:
770
482
  output_column_names: Optional[List[str]] = None,
771
483
  where_clause=None,
772
484
  **kwargs,
773
- ): # pylint: disable=unused-argument
774
- """perform precise vector search.
485
+ ): # pylint: disable=unused-argument
486
+ """Perform precise vector search.
775
487
 
776
488
  Args:
777
- table_name (string) : table name
778
- vec_data (list) : the vector data to search
779
- vec_column_name (string) : which vector field to search
780
- distance_func : function to calculate distance between vectors
781
- topk (int) : top K
782
- output_column_names (Optional[List[str]]) : output column names
783
- where_clause : do ann search with filter
489
+ table_name (string): table name
490
+ vec_data (list): the vector data to search
491
+ vec_column_name (string): which vector field to search
492
+ distance_func: function to calculate distance between vectors
493
+ topk (int): top K
494
+ output_column_names (Optional[List[str]]): output column names
495
+ where_clause: do ann search with filter
496
+ **kwargs: additional arguments
784
497
  """
785
498
  table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
786
499
 
@@ -807,56 +520,3 @@ class ObVecClient:
807
520
  with self.engine.connect() as conn:
808
521
  with conn.begin():
809
522
  return conn.execute(stmt)
810
-
811
- def perform_raw_text_sql(
812
- self,
813
- text_sql: str,
814
- ):
815
- """Execute raw text SQL."""
816
- with self.engine.connect() as conn:
817
- with conn.begin():
818
- return conn.execute(text(text_sql))
819
-
820
- def add_columns(
821
- self,
822
- table_name: str,
823
- columns: list[Column],
824
- ):
825
- """Add multiple columns to an existing table.
826
-
827
- Args:
828
- table_name (string): table name
829
- columns (list[Column]): list of SQLAlchemy Column objects representing the new columns
830
- """
831
- compiler = self.engine.dialect.ddl_compiler(self.engine.dialect, None)
832
- column_specs = [compiler.get_column_specification(column) for column in columns]
833
- columns_ddl = ", ".join(f"ADD COLUMN {spec}" for spec in column_specs)
834
-
835
- with self.engine.connect() as conn:
836
- with conn.begin():
837
- conn.execute(
838
- text(f"ALTER TABLE `{table_name}` {columns_ddl}")
839
- )
840
-
841
- self.refresh_metadata([table_name])
842
-
843
- def drop_columns(
844
- self,
845
- table_name: str,
846
- column_names: list[str],
847
- ):
848
- """Drop multiple columns from an existing table.
849
-
850
- Args:
851
- table_name (string): table name
852
- column_names (list[str]): names of the columns to drop
853
- """
854
- columns_ddl = ", ".join(f"DROP COLUMN `{name}`" for name in column_names)
855
-
856
- with self.engine.connect() as conn:
857
- with conn.begin():
858
- conn.execute(
859
- text(f"ALTER TABLE `{table_name}` {columns_ddl}")
860
- )
861
-
862
- self.refresh_metadata([table_name])