pyobvector 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. pyobvector/__init__.py +6 -5
  2. pyobvector/client/__init__.py +5 -4
  3. pyobvector/client/collection_schema.py +5 -1
  4. pyobvector/client/enum.py +1 -1
  5. pyobvector/client/exceptions.py +9 -7
  6. pyobvector/client/fts_index_param.py +8 -4
  7. pyobvector/client/hybrid_search.py +14 -4
  8. pyobvector/client/index_param.py +56 -41
  9. pyobvector/client/milvus_like_client.py +71 -54
  10. pyobvector/client/ob_client.py +20 -16
  11. pyobvector/client/ob_vec_client.py +45 -41
  12. pyobvector/client/ob_vec_json_table_client.py +366 -274
  13. pyobvector/client/partitions.py +81 -39
  14. pyobvector/client/schema_type.py +3 -1
  15. pyobvector/json_table/__init__.py +4 -3
  16. pyobvector/json_table/json_value_returning_func.py +12 -10
  17. pyobvector/json_table/oceanbase_dialect.py +15 -8
  18. pyobvector/json_table/virtual_data_type.py +47 -28
  19. pyobvector/schema/__init__.py +7 -1
  20. pyobvector/schema/array.py +6 -2
  21. pyobvector/schema/dialect.py +4 -0
  22. pyobvector/schema/full_text_index.py +8 -3
  23. pyobvector/schema/geo_srid_point.py +5 -2
  24. pyobvector/schema/gis_func.py +23 -11
  25. pyobvector/schema/match_against_func.py +10 -5
  26. pyobvector/schema/ob_table.py +2 -0
  27. pyobvector/schema/reflection.py +25 -8
  28. pyobvector/schema/replace_stmt.py +4 -0
  29. pyobvector/schema/sparse_vector.py +7 -4
  30. pyobvector/schema/vec_dist_func.py +22 -9
  31. pyobvector/schema/vector.py +3 -1
  32. pyobvector/schema/vector_index.py +7 -3
  33. pyobvector/util/__init__.py +1 -0
  34. pyobvector/util/ob_version.py +2 -0
  35. pyobvector/util/sparse_vector.py +9 -6
  36. pyobvector/util/vector.py +2 -0
  37. {pyobvector-0.2.21.dist-info → pyobvector-0.2.23.dist-info}/METADATA +13 -14
  38. pyobvector-0.2.23.dist-info/RECORD +40 -0
  39. {pyobvector-0.2.21.dist-info → pyobvector-0.2.23.dist-info}/licenses/LICENSE +1 -1
  40. pyobvector-0.2.21.dist-info/RECORD +0 -40
  41. {pyobvector-0.2.21.dist-info → pyobvector-0.2.23.dist-info}/WHEEL +0 -0
@@ -1,4 +1,5 @@
1
1
  """Milvus Like Client."""
2
+
2
3
  import logging
3
4
  import json
4
5
  from typing import Optional, Union
@@ -62,11 +63,11 @@ class MilvusLikeClient(Client):
62
63
  index_params: Optional[IndexParams] = None, # Used for custom setup
63
64
  max_length: int = 16384,
64
65
  **kwargs,
65
- ): # pylint: disable=unused-argument
66
- """Create a collection.
66
+ ): # pylint: disable=unused-argument
67
+ """Create a collection.
67
68
  If `schema` is not None, `dimension`, `primary_field_name`, `id_type`, `vector_field_name`,
68
69
  `metric_type`, `auto_id` will be ignored.
69
-
70
+
70
71
  Args:
71
72
  collection_name (string): collection name
72
73
  dimension (Optional[int]): vector data dimension
@@ -146,10 +147,12 @@ class MilvusLikeClient(Client):
146
147
  )
147
148
 
148
149
  def get_collection_stats(
149
- self, collection_name: str, timeout: Optional[float] = None # pylint: disable=unused-argument
150
+ self,
151
+ collection_name: str,
152
+ timeout: Optional[float] = None, # pylint: disable=unused-argument
150
153
  ) -> dict:
151
154
  """Get collection row count.
152
-
155
+
153
156
  Args:
154
157
  collection_name (string): collection name
155
158
  timeout (Optional[float]): not used in OceanBase
@@ -166,8 +169,10 @@ class MilvusLikeClient(Client):
166
169
  return {"row_count": cnt}
167
170
 
168
171
  def has_collection(
169
- self, collection_name: str, timeout: Optional[float] = None # pylint: disable=unused-argument
170
- ) -> bool: # pylint: disable=unused-argument
172
+ self,
173
+ collection_name: str,
174
+ timeout: Optional[float] = None, # pylint: disable=unused-argument
175
+ ) -> bool: # pylint: disable=unused-argument
171
176
  """Check if collection exists.
172
177
 
173
178
  Args:
@@ -181,17 +186,20 @@ class MilvusLikeClient(Client):
181
186
 
182
187
  def drop_collection(self, collection_name: str) -> None:
183
188
  """Drop collection if exists.
184
-
189
+
185
190
  Args:
186
191
  collection_name (string): collection name
187
192
  """
188
193
  self.drop_table_if_exist(collection_name)
189
194
 
190
195
  def rename_collection(
191
- self, old_name: str, new_name: str, timeout: Optional[float] = None # pylint: disable=unused-argument
196
+ self,
197
+ old_name: str,
198
+ new_name: str,
199
+ timeout: Optional[float] = None, # pylint: disable=unused-argument
192
200
  ) -> None:
193
201
  """Rename collection.
194
-
202
+
195
203
  Args:
196
204
  old_name (string): old collection name
197
205
  new_name (string): new collection name
@@ -206,7 +214,7 @@ class MilvusLikeClient(Client):
206
214
  collection_name: str,
207
215
  ):
208
216
  """Load table into SQLAlchemy metadata.
209
-
217
+
210
218
  Args:
211
219
  collection_name (string): which collection to load
212
220
 
@@ -230,9 +238,9 @@ class MilvusLikeClient(Client):
230
238
  index_params: IndexParams,
231
239
  timeout: Optional[float] = None,
232
240
  **kwargs,
233
- ): # pylint: disable=unused-argument
241
+ ): # pylint: disable=unused-argument
234
242
  """Create vector index with index params.
235
-
243
+
236
244
  Args:
237
245
  collection_name (string): which collection to create vector index
238
246
  index_params (IndexParams): the vector index parameters
@@ -263,9 +271,9 @@ class MilvusLikeClient(Client):
263
271
  index_name: str,
264
272
  timeout: Optional[float] = None,
265
273
  **kwargs,
266
- ): # pylint: disable=unused-argument
274
+ ): # pylint: disable=unused-argument
267
275
  """Drop index on specified collection.
268
-
276
+
269
277
  If the index not exists, SQL ERROR 1091 will raise.
270
278
 
271
279
  Args:
@@ -283,7 +291,7 @@ class MilvusLikeClient(Client):
283
291
  trigger_threshold: int = 10000,
284
292
  ):
285
293
  """Refresh vector index for performance.
286
-
294
+
287
295
  Args:
288
296
  collection_name (string): collection name
289
297
  index_name (string): vector index name
@@ -303,7 +311,7 @@ class MilvusLikeClient(Client):
303
311
  trigger_threshold: float = 0.2,
304
312
  ):
305
313
  """Rebuild vector index for performance.
306
-
314
+
307
315
  Args:
308
316
  collection_name (string): collection name
309
317
  index_name (string): vector index name
@@ -356,13 +364,13 @@ class MilvusLikeClient(Client):
356
364
  limit: int = 10,
357
365
  output_fields: Optional[list[str]] = None,
358
366
  search_params: Optional[dict] = None,
359
- timeout: Optional[float] = None, # pylint: disable=unused-argument
367
+ timeout: Optional[float] = None, # pylint: disable=unused-argument
360
368
  partition_names: Optional[list[str]] = None,
361
- **kwargs, # pylint: disable=unused-argument
369
+ **kwargs, # pylint: disable=unused-argument
362
370
  ) -> list[dict]:
363
371
  """Perform ann search.
364
372
  Note: OceanBase does not support batch search now. `data` & the return value is not a batch.
365
-
373
+
366
374
  Args:
367
375
  collection_name (string): collection name
368
376
  data (list): the vector/sparse_vector data to search
@@ -392,9 +400,7 @@ class MilvusLikeClient(Client):
392
400
  message=ExceptionsMessage.MetricTypeParamTypeInvalid,
393
401
  )
394
402
  lower_metric_type_str = search_params["metric_type"].lower()
395
- if lower_metric_type_str not in (
396
- "l2", "neg_ip", "cosine", "ip"
397
- ):
403
+ if lower_metric_type_str not in ("l2", "neg_ip", "cosine", "ip"):
398
404
  raise VectorMetricTypeException(
399
405
  code=ErrorCode.INVALID_ARGUMENT,
400
406
  message=ExceptionsMessage.MetricTypeValueInvalid,
@@ -416,25 +422,34 @@ class MilvusLikeClient(Client):
416
422
 
417
423
  if with_dist:
418
424
  if isinstance(data, list):
419
- columns.append(distance_func(table.c[anns_field],
420
- "[" + ",".join([str(np.float32(v)) for v in data]) + "]"))
425
+ columns.append(
426
+ distance_func(
427
+ table.c[anns_field],
428
+ "[" + ",".join([str(np.float32(v)) for v in data]) + "]",
429
+ )
430
+ )
421
431
  else:
422
432
  columns.append(distance_func(table.c[anns_field], f"{data}"))
423
433
  stmt = select(*columns)
424
434
 
425
435
  if flter is not None:
426
436
  stmt = stmt.where(*flter)
427
-
437
+
428
438
  if isinstance(data, list):
429
- stmt = stmt.order_by(distance_func(table.c[anns_field],
430
- "[" + ",".join([str(np.float32(v)) for v in data]) + "]"))
439
+ stmt = stmt.order_by(
440
+ distance_func(
441
+ table.c[anns_field],
442
+ "[" + ",".join([str(np.float32(v)) for v in data]) + "]",
443
+ )
444
+ )
431
445
  else:
432
446
  stmt = stmt.order_by(distance_func(table.c[anns_field], f"{data}"))
433
447
  stmt_str = (
434
- str(stmt.compile(
435
- dialect=self.engine.dialect,
436
- compile_kwargs={"literal_binds": True}
437
- ))
448
+ str(
449
+ stmt.compile(
450
+ dialect=self.engine.dialect, compile_kwargs={"literal_binds": True}
451
+ )
452
+ )
438
453
  + f" APPROXIMATE limit {limit}"
439
454
  )
440
455
 
@@ -468,12 +483,12 @@ class MilvusLikeClient(Client):
468
483
  collection_name: str,
469
484
  flter=None,
470
485
  output_fields: Optional[list[str]] = None,
471
- timeout: Optional[float] = None, # pylint: disable=unused-argument
486
+ timeout: Optional[float] = None, # pylint: disable=unused-argument
472
487
  partition_names: Optional[list[str]] = None,
473
- **kwargs, # pylint: disable=unused-argument
488
+ **kwargs, # pylint: disable=unused-argument
474
489
  ) -> list[dict]:
475
490
  """Query records.
476
-
491
+
477
492
  Args:
478
493
  collection_name (string): collection name
479
494
  flter: do ann search with filter (note: parameter name is intentionally 'flter' to distinguish it from the built-in function)
@@ -508,10 +523,12 @@ class MilvusLikeClient(Client):
508
523
  if partition_names is None:
509
524
  execute_res = conn.execute(stmt)
510
525
  else:
511
- stmt_str = str(stmt.compile(
512
- dialect=self.engine.dialect,
513
- compile_kwargs={"literal_binds": True}
514
- ))
526
+ stmt_str = str(
527
+ stmt.compile(
528
+ dialect=self.engine.dialect,
529
+ compile_kwargs={"literal_binds": True},
530
+ )
531
+ )
515
532
  stmt_str = self._insert_partition_hint_for_query_sql(
516
533
  stmt_str, f"PARTITION({', '.join(partition_names)})"
517
534
  )
@@ -534,12 +551,12 @@ class MilvusLikeClient(Client):
534
551
  collection_name: str,
535
552
  ids: Union[list, str, int] = None,
536
553
  output_fields: Optional[list[str]] = None,
537
- timeout: Optional[float] = None, # pylint: disable=unused-argument
554
+ timeout: Optional[float] = None, # pylint: disable=unused-argument
538
555
  partition_names: Optional[list[str]] = None,
539
- **kwargs, # pylint: disable=unused-argument
556
+ **kwargs, # pylint: disable=unused-argument
540
557
  ) -> list[dict]:
541
558
  """Get records with specified primary field `ids`.
542
-
559
+
543
560
  Args:
544
561
  collection_name (string): collection name
545
562
  ids (Union[list, str, int]): specified primary field values
@@ -586,10 +603,12 @@ class MilvusLikeClient(Client):
586
603
  if partition_names is None:
587
604
  execute_res = conn.execute(stmt)
588
605
  else:
589
- stmt_str = str(stmt.compile(
590
- dialect=self.engine.dialect,
591
- compile_kwargs={"literal_binds": True}
592
- ))
606
+ stmt_str = str(
607
+ stmt.compile(
608
+ dialect=self.engine.dialect,
609
+ compile_kwargs={"literal_binds": True},
610
+ )
611
+ )
593
612
  stmt_str = self._insert_partition_hint_for_query_sql(
594
613
  stmt_str, f"PARTITION({', '.join(partition_names)})"
595
614
  )
@@ -611,10 +630,10 @@ class MilvusLikeClient(Client):
611
630
  self,
612
631
  collection_name: str,
613
632
  ids: Optional[Union[list, str, int]] = None,
614
- timeout: Optional[float] = None, # pylint: disable=unused-argument
633
+ timeout: Optional[float] = None, # pylint: disable=unused-argument
615
634
  flter=None,
616
635
  partition_name: Optional[str] = "",
617
- **kwargs, # pylint: disable=unused-argument
636
+ **kwargs, # pylint: disable=unused-argument
618
637
  ) -> dict:
619
638
  """Delete data in collection.
620
639
 
@@ -675,11 +694,9 @@ class MilvusLikeClient(Client):
675
694
  data: Union[dict, list[dict]],
676
695
  timeout: Optional[float] = None,
677
696
  partition_name: Optional[str] = "",
678
- ) -> (
679
- None
680
- ): # pylint: disable=unused-argument
697
+ ) -> None: # pylint: disable=unused-argument
681
698
  """Insert data into collection.
682
-
699
+
683
700
  Args:
684
701
  collection_name (string): collection name
685
702
  data (Union[Dict, List[Dict]]): data that will be inserted
@@ -701,11 +718,11 @@ class MilvusLikeClient(Client):
701
718
  self,
702
719
  collection_name: str,
703
720
  data: Union[dict, list[dict]],
704
- timeout: Optional[float] = None, # pylint: disable=unused-argument
721
+ timeout: Optional[float] = None, # pylint: disable=unused-argument
705
722
  partition_name: Optional[str] = "",
706
723
  ) -> list[Union[str, int]]:
707
724
  """Update data in table. If primary key is duplicated, replace it.
708
-
725
+
709
726
  Args:
710
727
  collection_name (string): collection name
711
728
  data (Union[Dict, List[Dict]]): data that will be upserted
@@ -51,7 +51,9 @@ class ObClient:
51
51
  db_name: str = "test",
52
52
  **kwargs,
53
53
  ):
54
- registry.register("mysql.oceanbase", "pyobvector.schema.dialect", "OceanBaseDialect")
54
+ registry.register(
55
+ "mysql.oceanbase", "pyobvector.schema.dialect", "OceanBaseDialect"
56
+ )
55
57
 
56
58
  setattr(func_mod, "l2_distance", l2_distance)
57
59
  setattr(func_mod, "cosine_distance", cosine_distance)
@@ -88,27 +90,31 @@ class ObClient:
88
90
  for table_name in tables:
89
91
  if table_name in self.metadata_obj.tables:
90
92
  self.metadata_obj.remove(Table(table_name, self.metadata_obj))
91
- self.metadata_obj.reflect(bind=self.engine, only=tables, extend_existing=True)
93
+ self.metadata_obj.reflect(
94
+ bind=self.engine, only=tables, extend_existing=True
95
+ )
92
96
  else:
93
97
  self.metadata_obj.clear()
94
98
  self.metadata_obj.reflect(bind=self.engine, extend_existing=True)
95
99
 
96
100
  def _is_seekdb(self) -> bool:
97
101
  """Check if the database is SeekDB by querying version.
98
-
102
+
99
103
  Returns:
100
104
  bool: True if database is SeekDB, False otherwise
101
105
  """
102
106
  is_seekdb = False
103
107
  try:
104
- if hasattr(self, '_is_seekdb_cached'):
108
+ if hasattr(self, "_is_seekdb_cached"):
105
109
  return self._is_seekdb_cached
106
110
  with self.engine.connect() as conn:
107
111
  result = conn.execute(text("SELECT VERSION()"))
108
112
  version_str = [r[0] for r in result][0]
109
- is_seekdb = "SeekDB" in version_str
113
+ is_seekdb = "seekdb" in version_str.lower()
110
114
  self._is_seekdb_cached = is_seekdb
111
- logger.debug(f"Version query result: {version_str}, is_seekdb: {is_seekdb}")
115
+ logger.debug(
116
+ f"Version query result: {version_str}, is_seekdb: {is_seekdb}"
117
+ )
112
118
  except Exception as e:
113
119
  logger.warning(f"Failed to query version: {e}")
114
120
  return is_seekdb
@@ -414,10 +420,12 @@ class ObClient:
414
420
  if partition_names is None:
415
421
  execute_res = conn.execute(stmt)
416
422
  else:
417
- stmt_str = str(stmt.compile(
418
- dialect=self.engine.dialect,
419
- compile_kwargs={"literal_binds": True}
420
- ))
423
+ stmt_str = str(
424
+ stmt.compile(
425
+ dialect=self.engine.dialect,
426
+ compile_kwargs={"literal_binds": True},
427
+ )
428
+ )
421
429
  stmt_str = self._insert_partition_hint_for_query_sql(
422
430
  stmt_str, f"PARTITION({', '.join(partition_names)})"
423
431
  )
@@ -451,9 +459,7 @@ class ObClient:
451
459
 
452
460
  with self.engine.connect() as conn:
453
461
  with conn.begin():
454
- conn.execute(
455
- text(f"ALTER TABLE `{table_name}` {columns_ddl}")
456
- )
462
+ conn.execute(text(f"ALTER TABLE `{table_name}` {columns_ddl}"))
457
463
 
458
464
  self.refresh_metadata([table_name])
459
465
 
@@ -472,8 +478,6 @@ class ObClient:
472
478
 
473
479
  with self.engine.connect() as conn:
474
480
  with conn.begin():
475
- conn.execute(
476
- text(f"ALTER TABLE `{table_name}` {columns_ddl}")
477
- )
481
+ conn.execute(text(f"ALTER TABLE `{table_name}` {columns_ddl}"))
478
482
 
479
483
  self.refresh_metadata([table_name])
@@ -1,4 +1,5 @@
1
1
  """OceanBase Vector Store Client."""
2
+
2
3
  import logging
3
4
  from typing import Optional, Union
4
5
 
@@ -44,18 +45,14 @@ class ObVecClient(ObClient):
44
45
  if self.ob_version < ObVersion.from_db_version_nums(4, 3, 3, 0):
45
46
  raise ClusterVersionException(
46
47
  code=ErrorCode.NOT_SUPPORTED,
47
- message=ExceptionsMessage.ClusterVersionIsLow % ("Vector Store", "4.3.3.0"),
48
+ message=ExceptionsMessage.ClusterVersionIsLow
49
+ % ("Vector Store", "4.3.3.0"),
48
50
  )
49
51
 
50
- def _get_sparse_vector_index_params(
51
- self, vidxs: Optional[IndexParams]
52
- ):
52
+ def _get_sparse_vector_index_params(self, vidxs: Optional[IndexParams]):
53
53
  if vidxs is None:
54
54
  return None
55
- return [
56
- vidx for vidx in vidxs
57
- if vidx.is_index_type_sparse_vector()
58
- ]
55
+ return [vidx for vidx in vidxs if vidx.is_index_type_sparse_vector()]
59
56
 
60
57
  def create_table_with_index_params(
61
58
  self,
@@ -65,6 +62,7 @@ class ObVecClient(ObClient):
65
62
  vidxs: Optional[IndexParams] = None,
66
63
  fts_idxs: Optional[list[FtsIndexParam]] = None,
67
64
  partitions: Optional[ObPartition] = None,
65
+ **kwargs,
68
66
  ):
69
67
  """Create table with optional index_params.
70
68
 
@@ -75,8 +73,10 @@ class ObVecClient(ObClient):
75
73
  vidxs (Optional[IndexParams]): optional vector index schema
76
74
  fts_idxs (Optional[List[FtsIndexParam]]): optional full-text search index schema
77
75
  partitions (Optional[ObPartition]): optional partition strategy
76
+ **kwargs: additional keyword arguments (e.g., mysql_organization='heap')
78
77
  """
79
78
  sparse_vidxs = self._get_sparse_vector_index_params(vidxs)
79
+ kwargs.setdefault("extend_existing", True)
80
80
  with self.engine.connect() as conn:
81
81
  with conn.begin():
82
82
  # create table with common index
@@ -86,21 +86,21 @@ class ObVecClient(ObClient):
86
86
  self.metadata_obj,
87
87
  *columns,
88
88
  *indexes,
89
- extend_existing=True,
89
+ **kwargs,
90
90
  )
91
91
  else:
92
92
  table = ObTable(
93
93
  table_name,
94
94
  self.metadata_obj,
95
95
  *columns,
96
- extend_existing=True,
96
+ **kwargs,
97
97
  )
98
98
  if sparse_vidxs is not None and len(sparse_vidxs) > 0:
99
99
  create_table_sql = str(CreateTable(table).compile(self.engine))
100
- new_sql = create_table_sql[:create_table_sql.rfind(')')]
100
+ new_sql = create_table_sql[: create_table_sql.rfind(")")]
101
101
  for sparse_vidx in sparse_vidxs:
102
102
  sparse_params = sparse_vidx._parse_kwargs()
103
- if 'type' in sparse_params:
103
+ if "type" in sparse_params:
104
104
  new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (type={sparse_params['type']}, distance=inner_product)"
105
105
  else:
106
106
  new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (distance=inner_product)"
@@ -127,7 +127,9 @@ class ObVecClient(ObClient):
127
127
  # create fts indexes
128
128
  if fts_idxs is not None:
129
129
  for fts_idx in fts_idxs:
130
- idx_cols = [table.c[field_name] for field_name in fts_idx.field_names]
130
+ idx_cols = [
131
+ table.c[field_name] for field_name in fts_idx.field_names
132
+ ]
131
133
  fts_idx = FtsIndex(
132
134
  fts_idx.index_name,
133
135
  fts_idx.param_str(),
@@ -192,7 +194,7 @@ class ObVecClient(ObClient):
192
194
  fts_idx_param: FtsIndexParam,
193
195
  ):
194
196
  """Create fts index with fts index parameter.
195
-
197
+
196
198
  Args:
197
199
  table_name (string): table name
198
200
  fts_idx_param (FtsIndexParam): fts index parameter
@@ -200,7 +202,9 @@ class ObVecClient(ObClient):
200
202
  table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
201
203
  with self.engine.connect() as conn:
202
204
  with conn.begin():
203
- idx_cols = [table.c[field_name] for field_name in fts_idx_param.field_names]
205
+ idx_cols = [
206
+ table.c[field_name] for field_name in fts_idx_param.field_names
207
+ ]
204
208
  fts_idx = FtsIndex(
205
209
  fts_idx_param.index_name,
206
210
  fts_idx_param.param_str(),
@@ -332,11 +336,7 @@ class ObVecClient(ObClient):
332
336
  )
333
337
  )
334
338
  else:
335
- columns.append(
336
- distance_func(
337
- table.c[vec_column_name], f"{vec_data}"
338
- )
339
- )
339
+ columns.append(distance_func(table.c[vec_column_name], f"{vec_data}"))
340
340
  # if idx_name_hint is not None:
341
341
  # stmt = select(*columns).with_hint(
342
342
  # table,
@@ -357,9 +357,7 @@ class ObVecClient(ObClient):
357
357
  "[" + ",".join([str(np.float32(v)) for v in vec_data]) + "]",
358
358
  )
359
359
  else:
360
- dist_expr = distance_func(
361
- table.c[vec_column_name], f"{vec_data}"
362
- )
360
+ dist_expr = distance_func(table.c[vec_column_name], f"{vec_data}")
363
361
  stmt = stmt.where(dist_expr <= distance_threshold)
364
362
 
365
363
  if isinstance(vec_data, list):
@@ -370,23 +368,23 @@ class ObVecClient(ObClient):
370
368
  )
371
369
  )
372
370
  else:
373
- stmt = stmt.order_by(
374
- distance_func(
375
- table.c[vec_column_name], f"{vec_data}"
371
+ stmt = stmt.order_by(distance_func(table.c[vec_column_name], f"{vec_data}"))
372
+ stmt_str = (
373
+ str(
374
+ stmt.compile(
375
+ dialect=self.engine.dialect, compile_kwargs={"literal_binds": True}
376
376
  )
377
377
  )
378
- stmt_str = (
379
- str(stmt.compile(
380
- dialect=self.engine.dialect,
381
- compile_kwargs={"literal_binds": True}
382
- ))
383
378
  + f" APPROXIMATE limit {topk}"
384
379
  )
385
380
  with self.engine.connect() as conn:
386
381
  with conn.begin():
387
382
  if idx_name_hint is not None:
388
383
  idx = stmt_str.find("SELECT ")
389
- stmt_str = f"SELECT /*+ index({table_name} {idx_name_hint}) */ " + stmt_str[idx + len("SELECT "):]
384
+ stmt_str = (
385
+ f"SELECT /*+ index({table_name} {idx_name_hint}) */ "
386
+ + stmt_str[idx + len("SELECT ") :]
387
+ )
390
388
 
391
389
  if partition_names is None:
392
390
  return conn.execute(text(stmt_str))
@@ -430,7 +428,9 @@ class ObVecClient(ObClient):
430
428
 
431
429
  columns = []
432
430
  if output_column_names is not None:
433
- columns.extend([table.c[column_name] for column_name in output_column_names])
431
+ columns.extend(
432
+ [table.c[column_name] for column_name in output_column_names]
433
+ )
434
434
  else:
435
435
  columns.extend([table.c[column.name] for column in table.columns])
436
436
  if extra_output_cols is not None:
@@ -459,16 +459,20 @@ class ObVecClient(ObClient):
459
459
  if partition_names is None:
460
460
  if str_list is not None:
461
461
  str_list.append(
462
- str(stmt.compile(
463
- dialect=self.engine.dialect,
464
- compile_kwargs={"literal_binds": True}
465
- ))
462
+ str(
463
+ stmt.compile(
464
+ dialect=self.engine.dialect,
465
+ compile_kwargs={"literal_binds": True},
466
+ )
467
+ )
466
468
  )
467
469
  return conn.execute(stmt)
468
- stmt_str = str(stmt.compile(
469
- dialect=self.engine.dialect,
470
- compile_kwargs={"literal_binds": True}
471
- ))
470
+ stmt_str = str(
471
+ stmt.compile(
472
+ dialect=self.engine.dialect,
473
+ compile_kwargs={"literal_binds": True},
474
+ )
475
+ )
472
476
  stmt_str = self._insert_partition_hint_for_query_sql(
473
477
  stmt_str, f"PARTITION({', '.join(partition_names)})"
474
478
  )