pixeltable 0.4.13__py3-none-any.whl → 0.4.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (55) hide show
  1. pixeltable/__init__.py +2 -1
  2. pixeltable/catalog/catalog.py +187 -63
  3. pixeltable/catalog/column.py +24 -20
  4. pixeltable/catalog/table.py +24 -8
  5. pixeltable/catalog/table_metadata.py +1 -0
  6. pixeltable/catalog/table_version.py +16 -34
  7. pixeltable/catalog/update_status.py +12 -0
  8. pixeltable/catalog/view.py +22 -22
  9. pixeltable/config.py +2 -0
  10. pixeltable/dataframe.py +4 -2
  11. pixeltable/env.py +46 -21
  12. pixeltable/exec/__init__.py +1 -0
  13. pixeltable/exec/aggregation_node.py +0 -1
  14. pixeltable/exec/cache_prefetch_node.py +74 -98
  15. pixeltable/exec/data_row_batch.py +2 -18
  16. pixeltable/exec/expr_eval/expr_eval_node.py +11 -0
  17. pixeltable/exec/in_memory_data_node.py +1 -1
  18. pixeltable/exec/object_store_save_node.py +299 -0
  19. pixeltable/exec/sql_node.py +28 -33
  20. pixeltable/exprs/data_row.py +31 -25
  21. pixeltable/exprs/json_path.py +6 -5
  22. pixeltable/exprs/row_builder.py +6 -12
  23. pixeltable/functions/gemini.py +1 -1
  24. pixeltable/functions/openai.py +1 -1
  25. pixeltable/functions/video.py +128 -15
  26. pixeltable/functions/whisperx.py +2 -0
  27. pixeltable/functions/yolox.py +2 -0
  28. pixeltable/globals.py +49 -30
  29. pixeltable/index/embedding_index.py +5 -8
  30. pixeltable/io/__init__.py +1 -0
  31. pixeltable/io/fiftyone.py +1 -1
  32. pixeltable/io/label_studio.py +4 -5
  33. pixeltable/iterators/__init__.py +1 -0
  34. pixeltable/iterators/audio.py +1 -1
  35. pixeltable/iterators/document.py +10 -12
  36. pixeltable/iterators/video.py +1 -1
  37. pixeltable/metadata/schema.py +7 -0
  38. pixeltable/plan.py +26 -1
  39. pixeltable/share/packager.py +8 -2
  40. pixeltable/share/publish.py +3 -10
  41. pixeltable/store.py +1 -1
  42. pixeltable/type_system.py +1 -3
  43. pixeltable/utils/dbms.py +31 -5
  44. pixeltable/utils/gcs_store.py +283 -0
  45. pixeltable/utils/local_store.py +316 -0
  46. pixeltable/utils/object_stores.py +497 -0
  47. pixeltable/utils/pytorch.py +5 -6
  48. pixeltable/utils/s3_store.py +354 -0
  49. {pixeltable-0.4.13.dist-info → pixeltable-0.4.15.dist-info}/METADATA +1 -1
  50. {pixeltable-0.4.13.dist-info → pixeltable-0.4.15.dist-info}/RECORD +53 -50
  51. pixeltable/utils/media_store.py +0 -248
  52. pixeltable/utils/s3.py +0 -17
  53. {pixeltable-0.4.13.dist-info → pixeltable-0.4.15.dist-info}/WHEEL +0 -0
  54. {pixeltable-0.4.13.dist-info → pixeltable-0.4.15.dist-info}/entry_points.txt +0 -0
  55. {pixeltable-0.4.13.dist-info → pixeltable-0.4.15.dist-info}/licenses/LICENSE +0 -0
@@ -138,15 +138,12 @@ class EmbeddingIndex(IndexBase):
138
138
 
139
139
  def create_index(self, index_name: str, index_value_col: catalog.Column) -> None:
140
140
  """Create the index on the index value column"""
141
- idx = sql.Index(
142
- index_name,
143
- index_value_col.sa_col,
144
- postgresql_using='hnsw',
145
- postgresql_with={'m': 16, 'ef_construction': 64},
146
- postgresql_ops={index_value_col.sa_col.name: self.PGVECTOR_OPS[self.metric]},
141
+ Env.get().dbms.create_vector_index(
142
+ index_name=index_name,
143
+ index_value_sa_col=index_value_col.sa_col,
144
+ conn=Env.get().conn,
145
+ metric=self.PGVECTOR_OPS[self.metric],
147
146
  )
148
- conn = Env.get().conn
149
- idx.create(bind=conn)
150
147
 
151
148
  def drop_index(self, index_name: str, index_value_col: catalog.Column) -> None:
152
149
  """Drop the index on the index value column"""
pixeltable/io/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
+ """Functions for importing and exporting Pixeltable data."""
1
2
  # ruff: noqa: F401
2
3
 
3
4
  from .datarows import import_json, import_rows
pixeltable/io/fiftyone.py CHANGED
@@ -9,7 +9,7 @@ import puremagic
9
9
  import pixeltable as pxt
10
10
  import pixeltable.exceptions as excs
11
11
  from pixeltable import exprs
12
- from pixeltable.utils.media_store import TempStore
12
+ from pixeltable.utils.local_store import TempStore
13
13
 
14
14
 
15
15
  class PxtImageDatasetImporter(foud.LabeledImageDatasetImporter):
@@ -19,7 +19,7 @@ from pixeltable.config import Config
19
19
  from pixeltable.exprs import ColumnRef, DataRow, Expr
20
20
  from pixeltable.io.external_store import Project
21
21
  from pixeltable.utils import coco
22
- from pixeltable.utils.media_store import TempStore
22
+ from pixeltable.utils.local_store import TempStore
23
23
 
24
24
  # label_studio_sdk>=1 and label_studio_sdk<1 are not compatible, so we need to try
25
25
  # the import two different ways to insure intercompatibility
@@ -46,6 +46,9 @@ class LabelStudioProject(Project):
46
46
  """
47
47
  An [`ExternalStore`][pixeltable.io.ExternalStore] that represents a Label Studio project, providing functionality
48
48
  for synchronizing between a Pixeltable table and a Label Studio project.
49
+
50
+ The constructor will NOT create a new Label Studio project; it is also used when loading
51
+ metadata for existing projects.
49
52
  """
50
53
 
51
54
  project_id: int # Label Studio project ID
@@ -60,10 +63,6 @@ class LabelStudioProject(Project):
60
63
  col_mapping: dict[ColumnHandle, str],
61
64
  stored_proxies: Optional[dict[ColumnHandle, ColumnHandle]] = None,
62
65
  ):
63
- """
64
- The constructor will NOT create a new Label Studio project; it is also used when loading
65
- metadata for existing projects.
66
- """
67
66
  self.project_id = project_id
68
67
  self.media_import_method = media_import_method
69
68
  self._project = None
@@ -1,3 +1,4 @@
1
+ """Iterators for splitting media and documents into components."""
1
2
  # ruff: noqa: F401
2
3
 
3
4
  from .audio import AudioSplitter
@@ -6,7 +6,7 @@ from typing import Any, ClassVar, Optional
6
6
  import av
7
7
 
8
8
  from pixeltable import exceptions as excs, type_system as ts
9
- from pixeltable.utils.media_store import TempStore
9
+ from pixeltable.utils.local_store import TempStore
10
10
 
11
11
  from .base import ComponentIterator
12
12
 
@@ -94,6 +94,16 @@ class DocumentSplitter(ComponentIterator):
94
94
  include additional metadata fields if specified in the `metadata` parameter, as explained below.
95
95
 
96
96
  Chunked text will be cleaned with `ftfy.fix_text` to fix up common problems with unicode sequences.
97
+
98
+ Args:
99
+ separators: separators to use to chunk the document. Options are:
100
+ `'heading'`, `'paragraph'`, `'sentence'`, `'token_limit'`, `'char_limit'`, `'page'`.
101
+ This may be a comma-separated string, e.g., `'heading,token_limit'`.
102
+ limit: the maximum number of tokens or characters in each chunk, if `'token_limit'`
103
+ or `'char_limit'` is specified.
104
+ metadata: additional metadata fields to include in the output. Options are:
105
+ `'title'`, `'heading'` (HTML and Markdown), `'sourceline'` (HTML), `'page'` (PDF), `'bounding_box'`
106
+ (PDF). The input may be a comma-separated string, e.g., `'title,heading,sourceline'`.
97
107
  """
98
108
 
99
109
  METADATA_COLUMN_TYPES: ClassVar[dict[ChunkMetadata, ColumnType]] = {
@@ -116,18 +126,6 @@ class DocumentSplitter(ComponentIterator):
116
126
  tiktoken_encoding: Optional[str] = 'cl100k_base',
117
127
  tiktoken_target_model: Optional[str] = None,
118
128
  ):
119
- """Init method for `DocumentSplitter` class.
120
-
121
- Args:
122
- separators: separators to use to chunk the document. Options are:
123
- `'heading'`, `'paragraph'`, `'sentence'`, `'token_limit'`, `'char_limit'`, `'page'`.
124
- This may be a comma-separated string, e.g., `'heading,token_limit'`.
125
- limit: the maximum number of tokens or characters in each chunk, if `'token_limit'`
126
- or `'char_limit'` is specified.
127
- metadata: additional metadata fields to include in the output. Options are:
128
- `'title'`, `'heading'` (HTML and Markdown), `'sourceline'` (HTML), `'page'` (PDF), `'bounding_box'`
129
- (PDF). The input may be a comma-separated string, e.g., `'title,heading,sourceline'`.
130
- """
131
129
  if html_skip_tags is None:
132
130
  html_skip_tags = ['nav']
133
131
  self._doc_handle = get_document_handle(document)
@@ -14,7 +14,7 @@ import pixeltable as pxt
14
14
  import pixeltable.exceptions as excs
15
15
  import pixeltable.type_system as ts
16
16
  import pixeltable.utils.av as av_utils
17
- from pixeltable.utils.media_store import TempStore
17
+ from pixeltable.utils.local_store import TempStore
18
18
 
19
19
  from .base import ComponentIterator
20
20
 
@@ -115,6 +115,9 @@ class ColumnMd:
115
115
  # if True, the column is present in the stored table
116
116
  stored: Optional[bool]
117
117
 
118
+ # If present, the URI for the destination for column values
119
+ destination: Optional[str] = None
120
+
118
121
 
119
122
  @dataclasses.dataclass
120
123
  class IndexMd:
@@ -244,6 +247,9 @@ class TableVersionMd:
244
247
  schema_version: int
245
248
  user: Optional[str] = None # User that created this version
246
249
  update_status: Optional[UpdateStatus] = None # UpdateStatus of the change that created this version
250
+ # A version fragment cannot be queried or instantiated via get_table(). A fragment represents a version of a
251
+ # replica table that has incomplete data, and exists only to provide base table support for a dependent view.
252
+ is_fragment: bool = False
247
253
  additional_md: dict[str, Any] = dataclasses.field(default_factory=dict)
248
254
 
249
255
 
@@ -353,6 +359,7 @@ class FullTableMd(NamedTuple):
353
359
  def is_pure_snapshot(self) -> bool:
354
360
  return (
355
361
  self.tbl_md.view_md is not None
362
+ and self.tbl_md.view_md.is_snapshot
356
363
  and self.tbl_md.view_md.predicate is None
357
364
  and len(self.schema_version_md.columns) == 0
358
365
  )
pixeltable/plan.py CHANGED
@@ -403,6 +403,8 @@ class Planner:
403
403
  ignore_errors=ignore_errors,
404
404
  )
405
405
  )
406
+ plan = cls._insert_save_node(tbl.id, row_builder.stored_media_cols, input_node=plan)
407
+
406
408
  return plan
407
409
 
408
410
  @classmethod
@@ -499,6 +501,9 @@ class Planner:
499
501
  for i, col in enumerate(all_base_cols):
500
502
  plan.row_builder.add_table_column(col, select_list[i].slot_idx)
501
503
  plan.ctx.num_computed_exprs = len(recomputed_exprs)
504
+
505
+ plan = cls._insert_save_node(tbl.tbl_version.id, plan.row_builder.stored_media_cols, input_node=plan)
506
+
502
507
  recomputed_user_cols = [c for c in recomputed_cols if c.name is not None]
503
508
  return plan, [f'{c.tbl.name}.{c.name}' for c in updated_cols + recomputed_user_cols], recomputed_user_cols
504
509
 
@@ -597,6 +602,7 @@ class Planner:
597
602
  # we're returning everything to the user, so we might as well do it in a single batch
598
603
  ctx.batch_size = 0
599
604
  plan.set_ctx(ctx)
605
+ plan = cls._insert_save_node(tbl.tbl_version.id, plan.row_builder.stored_media_cols, input_node=plan)
600
606
  recomputed_user_cols = [c for c in recomputed_cols if c.name is not None]
601
607
  return (
602
608
  plan,
@@ -650,6 +656,8 @@ class Planner:
650
656
  for i, col in enumerate(copied_cols + list(recomputed_cols)): # same order as select_list
651
657
  plan.row_builder.add_table_column(col, select_list[i].slot_idx)
652
658
  # TODO: avoid duplication with view_load_plan() logic (where does this belong?)
659
+ plan = cls._insert_save_node(view.tbl_version.id, plan.row_builder.stored_media_cols, input_node=plan)
660
+
653
661
  return plan
654
662
 
655
663
  @classmethod
@@ -718,6 +726,8 @@ class Planner:
718
726
 
719
727
  exec_ctx.ignore_errors = True
720
728
  plan.set_ctx(exec_ctx)
729
+ plan = cls._insert_save_node(view.tbl_version.id, plan.row_builder.stored_media_cols, input_node=plan)
730
+
721
731
  return plan, len(row_builder.default_eval_ctx.target_exprs)
722
732
 
723
733
  @classmethod
@@ -762,6 +772,17 @@ class Planner:
762
772
  combined_ordering = combined
763
773
  return combined_ordering
764
774
 
775
+ @classmethod
776
+ def _insert_save_node(
777
+ cls, tbl_id: UUID, stored_media_cols: list[exprs.ColumnSlotIdx], input_node: exec.ExecNode
778
+ ) -> exec.ExecNode:
779
+ """Return an ObjectStoreSaveNode if stored media columns are present, otherwise return input"""
780
+ if len(stored_media_cols) == 0:
781
+ return input_node
782
+ save_node = exec.ObjectStoreSaveNode(tbl_id, stored_media_cols, input_node)
783
+ save_node.set_ctx(input_node.ctx)
784
+ return save_node
785
+
765
786
  @classmethod
766
787
  def _is_contained_in(cls, l1: Iterable[exprs.Expr], l2: Iterable[exprs.Expr]) -> bool:
767
788
  """Returns True if l1 is contained in l2"""
@@ -771,7 +792,7 @@ class Planner:
771
792
  def _insert_prefetch_node(
772
793
  cls, tbl_id: UUID, expressions: Iterable[exprs.Expr], input_node: exec.ExecNode
773
794
  ) -> exec.ExecNode:
774
- """Return a CachePrefetchNode if needed, otherwise return input"""
795
+ """Return a node to prefetch data if needed, otherwise return input"""
775
796
  # we prefetch external files for all media ColumnRefs, even those that aren't part of the dependencies
776
797
  # of output_exprs: if unstored iterator columns are present, we might need to materialize ColumnRefs that
777
798
  # aren't explicitly captured as dependencies
@@ -989,6 +1010,7 @@ class Planner:
989
1010
  if not agg_output.issuperset(exprs.ExprSet(eval_ctx.target_exprs)):
990
1011
  # we need an ExprEvalNode to evaluate the remaining output exprs
991
1012
  plan = exec.ExprEvalNode(row_builder, eval_ctx.target_exprs, agg_output, input=plan)
1013
+ plan = cls._insert_save_node(tbl.tbl_version.id, row_builder.stored_media_cols, input_node=plan)
992
1014
  else:
993
1015
  if not exprs.ExprSet(sql_exprs).issuperset(exprs.ExprSet(eval_ctx.target_exprs)):
994
1016
  # we need an ExprEvalNode to evaluate the remaining output exprs
@@ -1034,10 +1056,13 @@ class Planner:
1034
1056
  plan = cls._create_query_plan(
1035
1057
  row_builder=row_builder, analyzer=analyzer, eval_ctx=row_builder.default_eval_ctx, with_pk=True
1036
1058
  )
1059
+
1037
1060
  plan.ctx.batch_size = 16
1038
1061
  plan.ctx.show_pbar = True
1039
1062
  plan.ctx.ignore_errors = True
1040
1063
  computed_exprs = row_builder.output_exprs - row_builder.input_exprs
1041
1064
  plan.ctx.num_computed_exprs = len(computed_exprs) # we are adding a computed column, so we need to evaluate it
1042
1065
 
1066
+ plan = cls._insert_save_node(tbl.tbl_version.id, row_builder.stored_media_cols, input_node=plan)
1067
+
1043
1068
  return plan
@@ -24,7 +24,8 @@ from pixeltable.env import Env
24
24
  from pixeltable.metadata import schema
25
25
  from pixeltable.utils import sha256sum
26
26
  from pixeltable.utils.formatter import Formatter
27
- from pixeltable.utils.media_store import MediaStore, TempStore
27
+ from pixeltable.utils.local_store import TempStore
28
+ from pixeltable.utils.object_stores import ObjectOps
28
29
 
29
30
  _logger = logging.getLogger('pixeltable')
30
31
 
@@ -362,6 +363,8 @@ class TableRestorer:
362
363
  for md in tbl_md:
363
364
  md.tbl_md.is_replica = True
364
365
 
366
+ assert not tbl_md[0].version_md.is_fragment # Top-level table cannot be a version fragment
367
+
365
368
  cat = catalog.Catalog.get()
366
369
 
367
370
  with cat.begin_xact(for_write=True):
@@ -369,6 +372,9 @@ class TableRestorer:
369
372
  # versions that have not been seen before.
370
373
  cat.create_replica(catalog.Path.parse(self.tbl_path), tbl_md)
371
374
 
375
+ _logger.debug(f'Now will import data for {len(tbl_md)} table(s):')
376
+ _logger.debug(repr([md.tbl_md.tbl_id for md in tbl_md[::-1]]))
377
+
372
378
  # Now we need to load data for replica_tbl and its ancestors, except that we skip
373
379
  # replica_tbl itself if it's a pure snapshot.
374
380
  for md in tbl_md[::-1]: # Base table first
@@ -619,7 +625,7 @@ class TableRestorer:
619
625
  # in self.media_files.
620
626
  src_path = self.tmp_dir / 'media' / parsed_url.netloc
621
627
  # Move the file to the media store and update the URL.
622
- self.media_files[url] = MediaStore.get().relocate_local_media_file(src_path, media_col)
628
+ self.media_files[url] = ObjectOps.put_file(media_col, src_path, relocate_or_delete=True)
623
629
  return self.media_files[url]
624
630
  # For any type of URL other than a local file, just return the URL as-is.
625
631
  return url
@@ -14,7 +14,7 @@ import pixeltable as pxt
14
14
  from pixeltable import exceptions as excs
15
15
  from pixeltable.env import Env
16
16
  from pixeltable.utils import sha256sum
17
- from pixeltable.utils.media_store import TempStore
17
+ from pixeltable.utils.local_store import TempStore
18
18
 
19
19
  from .packager import TablePackager, TableRestorer
20
20
 
@@ -79,16 +79,13 @@ def push_replica(
79
79
 
80
80
 
81
81
  def _upload_bundle_to_s3(bundle: Path, parsed_location: urllib.parse.ParseResult) -> None:
82
- from pixeltable.utils.s3 import get_client
83
-
84
82
  bucket = parsed_location.netloc
85
83
  remote_dir = Path(urllib.parse.unquote(urllib.request.url2pathname(parsed_location.path)))
86
84
  remote_path = str(remote_dir / bundle.name)[1:] # Remove initial /
87
85
 
88
86
  Env.get().console_logger.info(f'Uploading snapshot to: {bucket}:{remote_path}')
89
87
 
90
- boto_config = {'max_pool_connections': 5, 'connect_timeout': 15, 'retries': {'max_attempts': 3, 'mode': 'adaptive'}}
91
- s3_client = get_client(**boto_config)
88
+ s3_client = Env.get().get_client('s3')
92
89
 
93
90
  upload_args = {'ChecksumAlgorithm': 'SHA256'}
94
91
 
@@ -135,16 +132,13 @@ def pull_replica(dest_path: str, src_tbl_uri: str) -> pxt.Table:
135
132
 
136
133
 
137
134
  def _download_bundle_from_s3(parsed_location: urllib.parse.ParseResult, bundle_filename: str) -> Path:
138
- from pixeltable.utils.s3 import get_client
139
-
140
135
  bucket = parsed_location.netloc
141
136
  remote_dir = Path(urllib.parse.unquote(urllib.request.url2pathname(parsed_location.path)))
142
137
  remote_path = str(remote_dir / bundle_filename)[1:] # Remove initial /
143
138
 
144
139
  Env.get().console_logger.info(f'Downloading snapshot from: {bucket}:{remote_path}')
145
140
 
146
- boto_config = {'max_pool_connections': 5, 'connect_timeout': 15, 'retries': {'max_attempts': 3, 'mode': 'adaptive'}}
147
- s3_client = get_client(**boto_config)
141
+ s3_client = Env.get().get_client('s3')
148
142
 
149
143
  obj = s3_client.head_object(Bucket=bucket, Key=remote_path) # Check if the object exists
150
144
  bundle_size = obj['ContentLength']
@@ -260,7 +254,6 @@ def _download_from_presigned_url(
260
254
  session.close()
261
255
 
262
256
 
263
- # TODO: This will be replaced by drop_table with cloud table uri
264
257
  def delete_replica(dest_path: str) -> None:
265
258
  """Delete cloud replica"""
266
259
  delete_request_json = {'operation_type': 'delete_snapshot', 'table_uri': dest_path}
pixeltable/store.py CHANGED
@@ -274,7 +274,7 @@ class StoreBase:
274
274
  self.sa_md.remove(tmp_tbl)
275
275
  tmp_tbl.drop(bind=conn)
276
276
 
277
- run_cleanup(remove_tmp_tbl, raise_error=True)
277
+ run_cleanup(remove_tmp_tbl, raise_error=False)
278
278
 
279
279
  return num_excs
280
280
 
pixeltable/type_system.py CHANGED
@@ -1081,9 +1081,7 @@ class ImageType(ColumnType):
1081
1081
  mode: Optional[str] = None,
1082
1082
  nullable: bool = False,
1083
1083
  ):
1084
- """
1085
- TODO: does it make sense to specify only width or height?
1086
- """
1084
+ # TODO: does it make sense to specify only width or height?
1087
1085
  super().__init__(self.Type.IMAGE, nullable=nullable)
1088
1086
  assert not (width is not None and size is not None)
1089
1087
  assert not (height is not None and size is not None)
pixeltable/utils/dbms.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import abc
2
2
 
3
- from sqlalchemy import URL
3
+ import sqlalchemy as sql
4
4
 
5
5
 
6
6
  class Dbms(abc.ABC):
@@ -11,9 +11,9 @@ class Dbms(abc.ABC):
11
11
  name: str
12
12
  transaction_isolation_level: str
13
13
  version_index_type: str
14
- db_url: URL
14
+ db_url: sql.URL
15
15
 
16
- def __init__(self, name: str, transaction_isolation_level: str, version_index_type: str, db_url: URL) -> None:
16
+ def __init__(self, name: str, transaction_isolation_level: str, version_index_type: str, db_url: sql.URL) -> None:
17
17
  self.name = name
18
18
  self.transaction_isolation_level = transaction_isolation_level
19
19
  self.version_index_type = version_index_type
@@ -28,13 +28,18 @@ class Dbms(abc.ABC):
28
28
  @abc.abstractmethod
29
29
  def default_system_db_url(self) -> str: ...
30
30
 
31
+ @abc.abstractmethod
32
+ def create_vector_index(
33
+ self, index_name: str, index_value_sa_col: sql.schema.Column, conn: sql.Connection, metric: str
34
+ ) -> None: ...
35
+
31
36
 
32
37
  class PostgresqlDbms(Dbms):
33
38
  """
34
39
  Implements utilities to interact with Postgres database.
35
40
  """
36
41
 
37
- def __init__(self, db_url: URL):
42
+ def __init__(self, db_url: sql.URL):
38
43
  super().__init__('postgresql', 'SERIALIZABLE', 'brin', db_url)
39
44
 
40
45
  def drop_db_stmt(self, database: str) -> str:
@@ -47,13 +52,25 @@ class PostgresqlDbms(Dbms):
47
52
  a = self.db_url.set(database='postgres').render_as_string(hide_password=False)
48
53
  return a
49
54
 
55
+ def create_vector_index(
56
+ self, index_name: str, index_value_sa_col: sql.schema.Column, conn: sql.Connection, metric: str
57
+ ) -> None:
58
+ idx = sql.Index(
59
+ index_name,
60
+ index_value_sa_col,
61
+ postgresql_using='hnsw',
62
+ postgresql_with={'m': 16, 'ef_construction': 64},
63
+ postgresql_ops={index_value_sa_col.name: metric},
64
+ )
65
+ idx.create(bind=conn)
66
+
50
67
 
51
68
  class CockroachDbms(Dbms):
52
69
  """
53
70
  Implements utilities to interact with CockroachDb database.
54
71
  """
55
72
 
56
- def __init__(self, db_url: URL):
73
+ def __init__(self, db_url: sql.URL):
57
74
  super().__init__('cockroachdb', 'SERIALIZABLE', 'btree', db_url)
58
75
 
59
76
  def drop_db_stmt(self, database: str) -> str:
@@ -64,3 +81,12 @@ class CockroachDbms(Dbms):
64
81
 
65
82
  def default_system_db_url(self) -> str:
66
83
  return self.db_url.set(database='defaultdb').render_as_string(hide_password=False)
84
+
85
+ def create_vector_index(
86
+ self, index_name: str, index_value_sa_col: sql.schema.Column, conn: sql.Connection, metric: str
87
+ ) -> None:
88
+ create_index_sql = sql.text(
89
+ f"""CREATE VECTOR INDEX {index_name} ON {index_value_sa_col.table.name}
90
+ ({index_value_sa_col.name} {metric})"""
91
+ )
92
+ conn.execute(create_index_sql)