pixeltable 0.1.0__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (147) hide show
  1. pixeltable/__init__.py +34 -6
  2. pixeltable/catalog/__init__.py +13 -0
  3. pixeltable/catalog/catalog.py +159 -0
  4. pixeltable/catalog/column.py +200 -0
  5. pixeltable/catalog/dir.py +32 -0
  6. pixeltable/catalog/globals.py +33 -0
  7. pixeltable/catalog/insertable_table.py +191 -0
  8. pixeltable/catalog/named_function.py +36 -0
  9. pixeltable/catalog/path.py +58 -0
  10. pixeltable/catalog/path_dict.py +139 -0
  11. pixeltable/catalog/schema_object.py +39 -0
  12. pixeltable/catalog/table.py +581 -0
  13. pixeltable/catalog/table_version.py +749 -0
  14. pixeltable/catalog/table_version_path.py +133 -0
  15. pixeltable/catalog/view.py +203 -0
  16. pixeltable/client.py +590 -30
  17. pixeltable/dataframe.py +540 -349
  18. pixeltable/env.py +359 -45
  19. pixeltable/exceptions.py +12 -21
  20. pixeltable/exec/__init__.py +9 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +116 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +95 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +69 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +225 -0
  31. pixeltable/exprs/__init__.py +24 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +105 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +195 -0
  39. pixeltable/exprs/expr.py +586 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +380 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +115 -0
  44. pixeltable/exprs/image_similarity_predicate.py +58 -0
  45. pixeltable/exprs/inline_array.py +107 -0
  46. pixeltable/exprs/inline_dict.py +101 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +54 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +355 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/type_cast.py +53 -0
  56. pixeltable/exprs/variable.py +45 -0
  57. pixeltable/func/__init__.py +9 -0
  58. pixeltable/func/aggregate_function.py +194 -0
  59. pixeltable/func/batched_function.py +53 -0
  60. pixeltable/func/callable_function.py +69 -0
  61. pixeltable/func/expr_template_function.py +82 -0
  62. pixeltable/func/function.py +110 -0
  63. pixeltable/func/function_registry.py +227 -0
  64. pixeltable/func/globals.py +36 -0
  65. pixeltable/func/nos_function.py +202 -0
  66. pixeltable/func/signature.py +166 -0
  67. pixeltable/func/udf.py +163 -0
  68. pixeltable/functions/__init__.py +52 -103
  69. pixeltable/functions/eval.py +216 -0
  70. pixeltable/functions/fireworks.py +34 -0
  71. pixeltable/functions/huggingface.py +120 -0
  72. pixeltable/functions/image.py +16 -0
  73. pixeltable/functions/openai.py +256 -0
  74. pixeltable/functions/pil/image.py +148 -7
  75. pixeltable/functions/string.py +13 -0
  76. pixeltable/functions/together.py +122 -0
  77. pixeltable/functions/util.py +41 -0
  78. pixeltable/functions/video.py +62 -0
  79. pixeltable/iterators/__init__.py +3 -0
  80. pixeltable/iterators/base.py +48 -0
  81. pixeltable/iterators/document.py +311 -0
  82. pixeltable/iterators/video.py +89 -0
  83. pixeltable/metadata/__init__.py +54 -0
  84. pixeltable/metadata/converters/convert_10.py +18 -0
  85. pixeltable/metadata/schema.py +211 -0
  86. pixeltable/plan.py +656 -0
  87. pixeltable/store.py +418 -182
  88. pixeltable/tests/conftest.py +146 -88
  89. pixeltable/tests/functions/test_fireworks.py +42 -0
  90. pixeltable/tests/functions/test_functions.py +60 -0
  91. pixeltable/tests/functions/test_huggingface.py +158 -0
  92. pixeltable/tests/functions/test_openai.py +152 -0
  93. pixeltable/tests/functions/test_together.py +111 -0
  94. pixeltable/tests/test_audio.py +65 -0
  95. pixeltable/tests/test_catalog.py +27 -0
  96. pixeltable/tests/test_client.py +14 -14
  97. pixeltable/tests/test_component_view.py +370 -0
  98. pixeltable/tests/test_dataframe.py +439 -0
  99. pixeltable/tests/test_dirs.py +78 -62
  100. pixeltable/tests/test_document.py +120 -0
  101. pixeltable/tests/test_exprs.py +592 -135
  102. pixeltable/tests/test_function.py +297 -67
  103. pixeltable/tests/test_migration.py +43 -0
  104. pixeltable/tests/test_nos.py +54 -0
  105. pixeltable/tests/test_snapshot.py +208 -0
  106. pixeltable/tests/test_table.py +1195 -263
  107. pixeltable/tests/test_transactional_directory.py +42 -0
  108. pixeltable/tests/test_types.py +5 -11
  109. pixeltable/tests/test_video.py +151 -34
  110. pixeltable/tests/test_view.py +530 -0
  111. pixeltable/tests/utils.py +320 -45
  112. pixeltable/tool/create_test_db_dump.py +149 -0
  113. pixeltable/tool/create_test_video.py +81 -0
  114. pixeltable/type_system.py +445 -124
  115. pixeltable/utils/__init__.py +17 -46
  116. pixeltable/utils/arrow.py +98 -0
  117. pixeltable/utils/clip.py +12 -15
  118. pixeltable/utils/coco.py +136 -0
  119. pixeltable/utils/documents.py +39 -0
  120. pixeltable/utils/filecache.py +195 -0
  121. pixeltable/utils/help.py +11 -0
  122. pixeltable/utils/hf_datasets.py +157 -0
  123. pixeltable/utils/media_store.py +76 -0
  124. pixeltable/utils/parquet.py +167 -0
  125. pixeltable/utils/pytorch.py +91 -0
  126. pixeltable/utils/s3.py +13 -0
  127. pixeltable/utils/sql.py +17 -0
  128. pixeltable/utils/transactional_directory.py +35 -0
  129. pixeltable-0.2.4.dist-info/LICENSE +18 -0
  130. pixeltable-0.2.4.dist-info/METADATA +127 -0
  131. pixeltable-0.2.4.dist-info/RECORD +132 -0
  132. {pixeltable-0.1.0.dist-info → pixeltable-0.2.4.dist-info}/WHEEL +1 -1
  133. pixeltable/catalog.py +0 -1421
  134. pixeltable/exprs.py +0 -1745
  135. pixeltable/function.py +0 -269
  136. pixeltable/functions/clip.py +0 -10
  137. pixeltable/functions/pil/__init__.py +0 -23
  138. pixeltable/functions/tf.py +0 -21
  139. pixeltable/index.py +0 -57
  140. pixeltable/tests/test_dict.py +0 -24
  141. pixeltable/tests/test_functions.py +0 -11
  142. pixeltable/tests/test_tf.py +0 -69
  143. pixeltable/tf.py +0 -33
  144. pixeltable/utils/tf.py +0 -33
  145. pixeltable/utils/video.py +0 -32
  146. pixeltable-0.1.0.dist-info/METADATA +0 -34
  147. pixeltable-0.1.0.dist-info/RECORD +0 -36
pixeltable/store.py CHANGED
@@ -1,191 +1,427 @@
1
- import enum
2
- import platform
1
+ from __future__ import annotations
3
2
 
4
- import sqlalchemy as sql
5
- from sqlalchemy import Integer, String, Enum, Boolean, TIMESTAMP, BigInteger, LargeBinary, JSON
6
- from sqlalchemy import ForeignKey, UniqueConstraint, ForeignKeyConstraint
7
- from sqlalchemy.orm import declarative_base
8
-
9
- from pixeltable import type_system as pt_types
10
-
11
- Base = declarative_base()
12
-
13
-
14
- class Db(Base):
15
- __tablename__ = 'dbs'
16
-
17
- id = sql.Column(Integer, primary_key=True, autoincrement=True, nullable=False)
18
- name = sql.Column(String, nullable=False)
19
-
20
-
21
- class Dir(Base):
22
- __tablename__ = 'dirs'
23
-
24
- id = sql.Column(Integer, primary_key=True, autoincrement=True, nullable=False)
25
- db_id = sql.Column(Integer, ForeignKey('dbs.id'), nullable=False)
26
- path = sql.Column(String, nullable=False)
27
- is_snapshot = sql.Column(Boolean, nullable=False)
28
-
29
-
30
- class Table(Base):
31
- __tablename__ = 'tables'
32
-
33
- MAX_VERSION = 9223372036854775807 # 2^63 - 1
34
-
35
- id = sql.Column(Integer, primary_key=True, autoincrement=True, nullable=False)
36
- db_id = sql.Column(Integer, ForeignKey('dbs.id'), nullable=False)
37
- dir_id = sql.Column(Integer, ForeignKey('dirs.id'), nullable=False)
38
- name = sql.Column(String, nullable=False)
39
- parameters = sql.Column(JSON, nullable=False)
3
+ import abc
4
+ import logging
5
+ import os
6
+ import sys
7
+ import urllib.parse
8
+ import urllib.request
9
+ import warnings
10
+ from typing import Optional, Dict, Any, List, Tuple, Set
40
11
 
41
- # monotonically increasing w/in Table for both data and schema changes, starting at 0
42
- current_version = sql.Column(BigInteger, nullable=False)
43
- # each version has a corresponding schema version (current_version >= current_schema_version)
44
- current_schema_version = sql.Column(BigInteger, nullable=False)
45
-
46
- # if False, can't apply schema or data changes to this table
47
- # (table got dropped, but we need to keep a record of it for snapshots)
48
- is_mutable = sql.Column(Boolean, nullable=False)
49
-
50
- next_col_id = sql.Column(Integer, nullable=False) # used to assign Column.id
51
-
52
- # - used to assign the rowid column in the storage table
53
- # - every row is assigned a unique and immutable rowid on insertion
54
- next_row_id = sql.Column(BigInteger, nullable=False)
55
-
56
- __table_args__ = (
57
- #ForeignKeyConstraint(
58
- #['id', 'current_schema_version'], ['tableschemaversions.tbl_id', 'tableschemaversions.schema_version']),
59
- )
60
-
61
- def storage_name(self) -> str:
62
- return f'tbl_{self.db_id}_{self.id}'
63
-
64
-
65
- # versioning: each table schema change results in a new record
66
- class TableSchemaVersion(Base):
67
- __tablename__ = 'tableschemaversions'
68
-
69
- tbl_id = sql.Column(Integer, ForeignKey('tables.id'), primary_key=True, nullable=False)
70
- schema_version = sql.Column(BigInteger, primary_key=True, nullable=False)
71
- preceding_schema_version = sql.Column(BigInteger, nullable=False)
72
-
73
-
74
- # - records the physical schema of a table, ie, the columns that are actually stored
75
- # - versioning information is needed to GC unreachable storage columns
76
- # - one record per column (across all schema versions)
77
- class StorageColumn(Base):
78
- __tablename__ = 'storagecolumns'
79
-
80
- tbl_id = sql.Column(Integer, ForeignKey('tables.id'), primary_key=True, nullable=False)
81
- # immutable and monotonically increasing from 0 w/in Table
82
- col_id = sql.Column(Integer, primary_key=True, nullable=False)
83
- # table schema version when col was added
84
- schema_version_add = sql.Column(BigInteger, nullable=False)
85
- # table schema version when col was dropped
86
- schema_version_drop = sql.Column(BigInteger, nullable=True)
12
+ import sqlalchemy as sql
13
+ from tqdm import tqdm, TqdmWarning
87
14
 
88
- __table_args__ = (
89
- ForeignKeyConstraint(
90
- ['tbl_id', 'schema_version_add'], ['tableschemaversions.tbl_id', 'tableschemaversions.schema_version']),
91
- ForeignKeyConstraint(
92
- ['tbl_id', 'schema_version_drop'], ['tableschemaversions.tbl_id', 'tableschemaversions.schema_version'])
93
- )
15
+ import pixeltable.catalog as catalog
16
+ import pixeltable.env as env
17
+ from pixeltable import exprs
18
+ import pixeltable.exceptions as excs
19
+ from pixeltable.exec import ExecNode
20
+ from pixeltable.metadata import schema
21
+ from pixeltable.type_system import StringType
22
+ from pixeltable.utils.media_store import MediaStore
23
+ from pixeltable.utils.sql import log_stmt, log_explain
94
24
 
25
+ _logger = logging.getLogger('pixeltable')
95
26
 
96
- # - records the logical (user-visible) schema of a table
97
- # - contains the full set of columns for each new schema version: one record per column x schema version
98
- class SchemaColumn(Base):
99
- __tablename__ = 'schemacolumns'
100
27
 
101
- tbl_id = sql.Column(Integer, ForeignKey('tables.id'), primary_key=True, nullable=False)
102
- schema_version = sql.Column(BigInteger, primary_key=True, nullable=False)
103
- # immutable and monotonically increasing from 0 w/in Table
104
- col_id = sql.Column(Integer, primary_key=True, nullable=False)
105
- pos = sql.Column(Integer, nullable=False) # position in table, starting at 0
106
- name = sql.Column(String, nullable=False)
107
- col_type = sql.Column(String, nullable=False) # json
108
- is_nullable = sql.Column(Boolean, nullable=False)
109
- is_pk = sql.Column(Boolean, nullable=False)
110
- value_expr = sql.Column(String, nullable=True) # json
111
- # if True, creates vector index for this column
112
- is_indexed = sql.Column(Boolean, nullable=False)
28
+ class StoreBase:
29
+ """Base class for stored tables
113
30
 
114
- __table_args__ = (
115
- UniqueConstraint('tbl_id', 'schema_version', 'pos'),
116
- UniqueConstraint('tbl_id', 'schema_version', 'name'),
117
- ForeignKeyConstraint(['tbl_id', 'col_id'], ['storagecolumns.tbl_id', 'storagecolumns.col_id']),
118
- ForeignKeyConstraint(
119
- ['tbl_id', 'schema_version'], ['tableschemaversions.tbl_id', 'tableschemaversions.schema_version'])
120
- )
121
-
122
-
123
- class TableSnapshot(Base):
124
- __tablename__ = 'tablesnapshots'
125
-
126
- id = sql.Column(Integer, primary_key=True, autoincrement=True, nullable=False)
127
- db_id = sql.Column(Integer, ForeignKey('dbs.id'), nullable=False)
128
- dir_id = sql.Column(Integer, ForeignKey('dirs.id'), nullable=False)
129
- name = sql.Column(String, nullable=False)
130
- tbl_id = sql.Column(Integer, nullable=False)
131
- tbl_version = sql.Column(BigInteger, nullable=False)
132
- tbl_schema_version = sql.Column(BigInteger, nullable=False)
133
-
134
- __table_args__ = (
135
- ForeignKeyConstraint(['tbl_id'], ['tables.id']),
136
- ForeignKeyConstraint(
137
- ['tbl_id', 'tbl_schema_version'], ['tableschemaversions.tbl_id', 'tableschemaversions.schema_version']),
138
- )
139
-
140
-
141
- class Function(Base):
31
+ Each row has the following system columns:
32
+ - rowid columns: one or more columns that identify a user-visible row across all versions
33
+ - v_min: version at which the row was created
34
+ - v_max: version at which the row was deleted (or MAX_VERSION if it's still live)
142
35
  """
143
- User-defined functions that are not library functions (ie, aren't available at runtime as a symbol in a known
144
- module).
145
- Functions without a name are anonymous functions used in the definition of a computed column.
146
- Functions that have names are also assigned to a database and directory.
147
- We store the Python version under which a Function was created (and the callable pickled) in order to warn
148
- against version mismatches.
36
+
37
+ def __init__(self, tbl_version: catalog.TableVersion):
38
+ self.tbl_version = tbl_version
39
+ self.sa_md = sql.MetaData()
40
+ self.sa_tbl: Optional[sql.Table] = None
41
+ self._create_sa_tbl()
42
+
43
+ def pk_columns(self) -> List[sql.Column]:
44
+ return self._pk_columns
45
+
46
+ def rowid_columns(self) -> List[sql.Column]:
47
+ return self._pk_columns[:-1]
48
+
49
+ @abc.abstractmethod
50
+ def _create_rowid_columns(self) -> List[sql.Column]:
51
+ """Create and return rowid columns"""
52
+ pass
53
+
54
+ @abc.abstractmethod
55
+ def _create_system_columns(self) -> List[sql.Column]:
56
+ """Create and return system columns"""
57
+ rowid_cols = self._create_rowid_columns()
58
+ self.v_min_col = sql.Column('v_min', sql.BigInteger, nullable=False)
59
+ self.v_max_col = \
60
+ sql.Column('v_max', sql.BigInteger, nullable=False, server_default=str(schema.Table.MAX_VERSION))
61
+ self._pk_columns = [*rowid_cols, self.v_min_col]
62
+ return [*rowid_cols, self.v_min_col, self.v_max_col]
63
+
64
+
65
+ def _create_sa_tbl(self) -> None:
66
+ """Create self.sa_tbl from self.tbl_version."""
67
+ system_cols = self._create_system_columns()
68
+ all_cols = system_cols.copy()
69
+ idxs: List[sql.Index] = []
70
+ for col in [c for c in self.tbl_version.cols if c.is_stored]:
71
+ # re-create sql.Column for each column, regardless of whether it already has sa_col set: it was bound
72
+ # to the last sql.Table version we created and cannot be reused
73
+ col.create_sa_cols()
74
+ all_cols.append(col.sa_col)
75
+ if col.records_errors:
76
+ all_cols.append(col.sa_errormsg_col)
77
+ all_cols.append(col.sa_errortype_col)
78
+
79
+ if col.is_indexed:
80
+ all_cols.append(col.sa_idx_col)
81
+
82
+ # we create an index for:
83
+ # - scalar columns (except for strings, because long strings can't be used for B-tree indices)
84
+ # - non-computed video and image columns (they will contain external paths/urls that users might want to
85
+ # filter on)
86
+ if (col.col_type.is_scalar_type() and not col.col_type.is_string_type()) \
87
+ or (col.col_type.is_media_type() and not col.is_computed):
88
+ # index names need to be unique within the Postgres instance
89
+ idx_name = f'idx_{col.id}_{self.tbl_version.id.hex}'
90
+ idxs.append(sql.Index(idx_name, col.sa_col))
91
+
92
+ if self.sa_tbl is not None:
93
+ # if we're called in response to a schema change, we need to remove the old table first
94
+ self.sa_md.remove(self.sa_tbl)
95
+
96
+ # index for all system columns:
97
+ # - base x view joins can be executed as merge joins
98
+ # - speeds up ORDER BY rowid DESC
99
+ # - allows filtering for a particular table version in index scan
100
+ idx_name = f'sys_cols_idx_{self.tbl_version.id.hex}'
101
+ idxs.append(sql.Index(idx_name, *system_cols))
102
+ # v_min/v_max indices: speeds up base table scans needed to propagate a base table insert or delete
103
+ idx_name = f'vmin_idx_{self.tbl_version.id.hex}'
104
+ idxs.append(sql.Index(idx_name, self.v_min_col, postgresql_using='brin'))
105
+ idx_name = f'vmax_idx_{self.tbl_version.id.hex}'
106
+ idxs.append(sql.Index(idx_name, self.v_max_col, postgresql_using='brin'))
107
+
108
+ self.sa_tbl = sql.Table(self._storage_name(), self.sa_md, *all_cols, *idxs)
109
+
110
+ @abc.abstractmethod
111
+ def _rowid_join_predicate(self) -> sql.ClauseElement:
112
+ """Return predicate for rowid joins to all bases"""
113
+ pass
114
+
115
+ @abc.abstractmethod
116
+ def _storage_name(self) -> str:
117
+ """Return the name of the data store table"""
118
+ pass
119
+
120
+ def _move_tmp_media_file(self, file_url: Optional[str], col: catalog.Column, v_min: int) -> str:
121
+ """Move tmp media file with given url to Env.media_dir and return new url, or given url if not a tmp_dir file"""
122
+ pxt_tmp_dir = str(env.Env.get().tmp_dir)
123
+ if file_url is None:
124
+ return None
125
+ parsed = urllib.parse.urlparse(file_url)
126
+ # We should never be passed a local file path here. The "len > 1" ensures that Windows
127
+ # file paths aren't mistaken for URLs with a single-character scheme.
128
+ assert len(parsed.scheme) > 1
129
+ if parsed.scheme != 'file':
130
+ # remote url
131
+ return file_url
132
+ file_path = urllib.parse.unquote(urllib.request.url2pathname(parsed.path))
133
+ if not file_path.startswith(pxt_tmp_dir):
134
+ # not a tmp file
135
+ return file_url
136
+ _, ext = os.path.splitext(file_path)
137
+ new_path = str(MediaStore.prepare_media_path(self.tbl_version.id, col.id, v_min, ext=ext))
138
+ os.rename(file_path, new_path)
139
+ new_file_url = urllib.parse.urljoin('file:', urllib.request.pathname2url(new_path))
140
+ return new_file_url
141
+
142
+ def _move_tmp_media_files(
143
+ self, table_rows: List[Dict[str, Any]], media_cols: List[catalog.Column], v_min: int
144
+ ) -> None:
145
+ """Move tmp media files that we generated to a permanent location"""
146
+ for c in media_cols:
147
+ for table_row in table_rows:
148
+ file_url = table_row[c.storage_name()]
149
+ table_row[c.storage_name()] = self._move_tmp_media_file(file_url, c, v_min)
150
+
151
+ def _create_table_row(
152
+ self, input_row: exprs.DataRow, row_builder: exprs.RowBuilder, media_cols: List[catalog.Column],
153
+ exc_col_ids: Set[int], v_min: int
154
+ ) -> Tuple[Dict[str, Any], int]:
155
+ """Return Tuple[complete table row, # of exceptions] for insert()
156
+ Creates a row that includes the PK columns, with the values from input_row.pk.
157
+ Returns:
158
+ Tuple[complete table row, # of exceptions]
159
+ """
160
+ table_row, num_excs = row_builder.create_table_row(input_row, exc_col_ids)
161
+
162
+ assert input_row.pk is not None and len(input_row.pk) == len(self._pk_columns)
163
+ for pk_col, pk_val in zip(self._pk_columns, input_row.pk):
164
+ if pk_col == self.v_min_col:
165
+ table_row[pk_col.name] = v_min
166
+ else:
167
+ table_row[pk_col.name] = pk_val
168
+
169
+ return table_row, num_excs
170
+
171
+ def count(self) -> None:
172
+ """Return the number of rows visible in self.tbl_version"""
173
+ stmt = sql.select(sql.func.count('*'))\
174
+ .select_from(self.sa_tbl)\
175
+ .where(self.v_min_col <= self.tbl_version.version)\
176
+ .where(self.v_max_col > self.tbl_version.version)
177
+ with env.Env.get().engine.begin() as conn:
178
+ result = conn.execute(stmt).scalar_one()
179
+ assert isinstance(result, int)
180
+ return result
181
+
182
+ def create(self, conn: sql.engine.Connection) -> None:
183
+ self.sa_md.create_all(bind=conn)
184
+
185
+ def drop(self, conn: sql.engine.Connection) -> None:
186
+ """Drop store table"""
187
+ self.sa_md.drop_all(bind=conn)
188
+
189
+ def add_column(self, col: catalog.Column, conn: sql.engine.Connection) -> None:
190
+ """Add column(s) to the store-resident table based on a catalog column
191
+
192
+ Note that a computed catalog column will require two extra columns (for the computed value and for the error
193
+ message).
194
+ """
195
+ assert col.is_stored
196
+ stmt = sql.text(f'ALTER TABLE {self._storage_name()} ADD COLUMN {col.storage_name()} {col.col_type.to_sql()}')
197
+ log_stmt(_logger, stmt)
198
+ conn.execute(stmt)
199
+ added_storage_cols = [col.storage_name()]
200
+ if col.records_errors:
201
+ # we also need to create the errormsg and errortype storage cols
202
+ stmt = (f'ALTER TABLE {self._storage_name()} '
203
+ f'ADD COLUMN {col.errormsg_storage_name()} {StringType().to_sql()} DEFAULT NULL')
204
+ conn.execute(sql.text(stmt))
205
+ stmt = (f'ALTER TABLE {self._storage_name()} '
206
+ f'ADD COLUMN {col.errortype_storage_name()} {StringType().to_sql()} DEFAULT NULL')
207
+ conn.execute(sql.text(stmt))
208
+ added_storage_cols.extend([col.errormsg_storage_name(), col.errortype_storage_name()])
209
+ self._create_sa_tbl()
210
+ _logger.info(f'Added columns {added_storage_cols} to storage table {self._storage_name()}')
211
+
212
+ def drop_column(self, col: Optional[catalog.Column] = None, conn: Optional[sql.engine.Connection] = None) -> None:
213
+ """Re-create self.sa_tbl and drop column, if one is given"""
214
+ if col is not None:
215
+ assert conn is not None
216
+ stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.storage_name()}'
217
+ conn.execute(sql.text(stmt))
218
+ if col.records_errors:
219
+ stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.errormsg_storage_name()}'
220
+ conn.execute(sql.text(stmt))
221
+ stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.errortype_storage_name()}'
222
+ conn.execute(sql.text(stmt))
223
+ self._create_sa_tbl()
224
+
225
+ def load_column(
226
+ self, col: catalog.Column, exec_plan: ExecNode, value_expr_slot_idx: int, embedding_slot_idx: int,
227
+ conn: sql.engine.Connection
228
+ ) -> int:
229
+ """Update store column of a computed column with values produced by an execution plan
230
+
231
+ Returns:
232
+ number of rows with exceptions
233
+ Raises:
234
+ sql.exc.DBAPIError if there was an error during SQL execution
235
+ """
236
+ num_excs = 0
237
+ num_rows = 0
238
+ for row_batch in exec_plan:
239
+ num_rows += len(row_batch)
240
+ for result_row in row_batch:
241
+ values_dict: Dict[sql.Column, Any] = {}
242
+
243
+ if col.is_computed:
244
+ if result_row.has_exc(value_expr_slot_idx):
245
+ num_excs += 1
246
+ value_exc = result_row.get_exc(value_expr_slot_idx)
247
+ # we store a NULL value and record the exception/exc type
248
+ error_type = type(value_exc).__name__
249
+ error_msg = str(value_exc)
250
+ values_dict = {
251
+ col.sa_col: None,
252
+ col.sa_errortype_col: error_type,
253
+ col.sa_errormsg_col: error_msg
254
+ }
255
+ else:
256
+ val = result_row.get_stored_val(value_expr_slot_idx)
257
+ if col.col_type.is_media_type():
258
+ val = self._move_tmp_media_file(val, col, result_row.pk[-1])
259
+ values_dict = {col.sa_col: val}
260
+
261
+ if col.is_indexed:
262
+ # TODO: deal with exceptions
263
+ assert not result_row.has_exc(embedding_slot_idx)
264
+ # don't use get_stored_val() here, we need to pass the ndarray
265
+ embedding = result_row[embedding_slot_idx]
266
+ values_dict[col.sa_index_col] = embedding
267
+
268
+ update_stmt = sql.update(self.sa_tbl).values(values_dict)
269
+ for pk_col, pk_val in zip(self.pk_columns(), result_row.pk):
270
+ update_stmt = update_stmt.where(pk_col == pk_val)
271
+ log_stmt(_logger, update_stmt)
272
+ conn.execute(update_stmt)
273
+
274
+ return num_excs
275
+
276
+ def insert_rows(
277
+ self, exec_plan: ExecNode, conn: sql.engine.Connection, v_min: Optional[int] = None
278
+ ) -> Tuple[int, int, Set[int]]:
279
+ """Insert rows into the store table and update the catalog table's md
280
+ Returns:
281
+ number of inserted rows, number of exceptions, set of column ids that have exceptions
282
+ """
283
+ assert v_min is not None
284
+ exec_plan.ctx.conn = conn
285
+ batch_size = 16 # TODO: is this a good batch size?
286
+ # TODO: total?
287
+ num_excs = 0
288
+ num_rows = 0
289
+ cols_with_excs: Set[int] = set()
290
+ progress_bar: Optional[tqdm] = None # create this only after we started executing
291
+ row_builder = exec_plan.row_builder
292
+ media_cols = [info.col for info in row_builder.table_columns if info.col.col_type.is_media_type()]
293
+ try:
294
+ exec_plan.open()
295
+ for row_batch in exec_plan:
296
+ num_rows += len(row_batch)
297
+ for batch_start_idx in range(0, len(row_batch), batch_size):
298
+ # compute batch of rows and convert them into table rows
299
+ table_rows: List[Dict[str, Any]] = []
300
+ for row_idx in range(batch_start_idx, min(batch_start_idx + batch_size, len(row_batch))):
301
+ row = row_batch[row_idx]
302
+ table_row, num_row_exc = \
303
+ self._create_table_row(row, row_builder, media_cols, cols_with_excs, v_min=v_min)
304
+ num_excs += num_row_exc
305
+ table_rows.append(table_row)
306
+ if progress_bar is None:
307
+ warnings.simplefilter("ignore", category=TqdmWarning)
308
+ progress_bar = tqdm(
309
+ desc=f'Inserting rows into `{self.tbl_version.name}`',
310
+ unit=' rows',
311
+ ncols=100,
312
+ file=sys.stdout
313
+ )
314
+ progress_bar.update(1)
315
+ self._move_tmp_media_files(table_rows, media_cols, v_min)
316
+ conn.execute(sql.insert(self.sa_tbl), table_rows)
317
+ if progress_bar is not None:
318
+ progress_bar.close()
319
+ return num_rows, num_excs, cols_with_excs
320
+ finally:
321
+ exec_plan.close()
322
+
323
+ def _versions_clause(self, versions: List[Optional[int]], match_on_vmin: bool) -> sql.ClauseElement:
324
+ """Return filter for base versions"""
325
+ v = versions[0]
326
+ if v is None:
327
+ # we're looking at live rows
328
+ clause = sql.and_(self.v_min_col <= self.tbl_version.version, self.v_max_col == schema.Table.MAX_VERSION)
329
+ else:
330
+ # we're looking at a specific version
331
+ clause = self.v_min_col == v if match_on_vmin else self.v_max_col == v
332
+ if len(versions) == 1:
333
+ return clause
334
+ return sql.and_(clause, self.base._versions_clause(versions[1:], match_on_vmin))
335
+
336
+ def delete_rows(
337
+ self, current_version: int, base_versions: List[Optional[int]], match_on_vmin: bool,
338
+ where_clause: Optional[sql.ClauseElement], conn: sql.engine.Connection) -> int:
339
+ """Mark rows as deleted that are live and were created prior to current_version.
340
+ Args:
341
+ base_versions: if non-None, join only to base rows that were created at that version,
342
+ otherwise join to rows that are live in the base's current version (which is distinct from the
343
+ current_version parameter)
344
+ match_on_vmin: if True, match exact versions on v_min; if False, match on v_max
345
+ where_clause: if not None, also apply where_clause
346
+ Returns:
347
+ number of deleted rows
348
+ """
349
+ where_clause = sql.true() if where_clause is None else where_clause
350
+ where_clause = sql.and_(
351
+ self.v_min_col < current_version,
352
+ self.v_max_col == schema.Table.MAX_VERSION,
353
+ where_clause)
354
+ rowid_join_clause = self._rowid_join_predicate()
355
+ base_versions_clause = sql.true() if len(base_versions) == 0 \
356
+ else self.base._versions_clause(base_versions, match_on_vmin)
357
+ stmt = sql.update(self.sa_tbl) \
358
+ .values({self.v_max_col: current_version}) \
359
+ .where(where_clause) \
360
+ .where(rowid_join_clause) \
361
+ .where(base_versions_clause)
362
+ log_explain(_logger, stmt, conn)
363
+ status = conn.execute(stmt)
364
+ return status.rowcount
365
+
366
+
367
+ class StoreTable(StoreBase):
368
+ def __init__(self, tbl_version: catalog.TableVersion):
369
+ assert not tbl_version.is_view()
370
+ super().__init__(tbl_version)
371
+
372
+ def _create_rowid_columns(self) -> List[sql.Column]:
373
+ self.rowid_col = sql.Column('rowid', sql.BigInteger, nullable=False)
374
+ return [self.rowid_col]
375
+
376
+ def _storage_name(self) -> str:
377
+ return f'tbl_{self.tbl_version.id.hex}'
378
+
379
+ def _rowid_join_predicate(self) -> sql.ClauseElement:
380
+ return sql.true()
381
+
382
+
383
+ class StoreView(StoreBase):
384
+ def __init__(self, catalog_view: catalog.TableVersion):
385
+ assert catalog_view.is_view()
386
+ self.base = catalog_view.base.store_tbl
387
+ super().__init__(catalog_view)
388
+
389
+ def _create_rowid_columns(self) -> List[sql.Column]:
390
+ # a view row corresponds directly to a single base row, which means it needs to duplicate its rowid columns
391
+ self.rowid_cols = [sql.Column(c.name, c.type) for c in self.base.rowid_columns()]
392
+ return self.rowid_cols
393
+
394
+ def _storage_name(self) -> str:
395
+ return f'view_{self.tbl_version.id.hex}'
396
+
397
+ def _rowid_join_predicate(self) -> sql.ClauseElement:
398
+ return sql.and_(
399
+ self.base._rowid_join_predicate(),
400
+ *[c1 == c2 for c1, c2 in zip(self.rowid_columns(), self.base.rowid_columns())])
401
+
402
+ class StoreComponentView(StoreView):
403
+ """A view that stores components of its base, as produced by a ComponentIterator
404
+
405
+ PK: now also includes pos, the position returned by the ComponentIterator for the base row identified by base_rowid
149
406
  """
150
- __tablename__ = 'functions'
151
-
152
- id = sql.Column(Integer, primary_key=True, autoincrement=True, nullable=False)
153
- db_id = sql.Column(Integer, ForeignKey('dbs.id'), nullable=True)
154
- dir_id = sql.Column(Integer, ForeignKey('dirs.id'), nullable=True)
155
- name = sql.Column(String, nullable=True)
156
- return_type = sql.Column(String, nullable=False) # json
157
- param_types = sql.Column(String, nullable=False) # json
158
- eval_obj = sql.Column(LargeBinary, nullable=True) # Function.eval_fn
159
- init_obj = sql.Column(LargeBinary, nullable=True) # AggregateFunction.init_fn
160
- update_obj = sql.Column(LargeBinary, nullable=True) # AggregateFunction.update_fn
161
- value_obj = sql.Column(LargeBinary, nullable=True) # AggregateFunction.value_fn
162
- python_version = sql.Column(
163
- String, nullable=False, default=platform.python_version, onupdate=platform.python_version)
164
-
165
-
166
- class OpCodes(enum.Enum):
167
- CREATE_DB = 1
168
- RENAME_DB = 1
169
- DROP_DB = 1
170
- CREATE_TABLE = 1
171
- RENAME_TABLE = 1
172
- DROP_TABLE = 1
173
- ADD_COLUMN = 1
174
- RENAME_COLUMN = 1
175
- DROP_COLUMN = 1
176
- CREATE_DIR = 1
177
- DROP_DIR = 1
178
- CREATE_SNAPSHOT = 1
179
- DROP_SNAPSHOT = 1
180
-
181
-
182
- class Operation(Base):
183
- __tablename__ = 'oplog'
184
-
185
- id = sql.Column(Integer, primary_key=True, autoincrement=True, nullable=False)
186
- ts = sql.Column(TIMESTAMP, nullable=False)
187
- # encodes db_id, schema object type (table, view, table/view snapshot), table/view/... id
188
- schema_object = sql.Column(BigInteger, nullable=False)
189
- opcode = sql.Column(Enum(OpCodes), nullable=False)
190
- # operation-specific details; json
191
- details = sql.Column(String, nullable=False)
407
+ def __init__(self, catalog_view: catalog.TableVersion):
408
+ super().__init__(catalog_view)
409
+
410
+ def _create_rowid_columns(self) -> List[sql.Column]:
411
+ # each base row is expanded into n view rows
412
+ self.rowid_cols = [sql.Column(c.name, c.type) for c in self.base.rowid_columns()]
413
+ # name of pos column: avoid collisions with bases' pos columns
414
+ self.pos_col = sql.Column(f'pos_{len(self.rowid_cols) - 1}', sql.BigInteger, nullable=False)
415
+ self.pos_col_idx = len(self.rowid_cols)
416
+ self.rowid_cols.append(self.pos_col)
417
+ return self.rowid_cols
418
+
419
+ def _create_sa_tbl(self) -> None:
420
+ super()._create_sa_tbl()
421
+ # we need to fix up the 'pos' column in TableVersion
422
+ self.tbl_version.cols_by_name['pos'].sa_col = self.pos_col
423
+
424
+ def _rowid_join_predicate(self) -> sql.ClauseElement:
425
+ return sql.and_(
426
+ self.base._rowid_join_predicate(),
427
+ *[c1 == c2 for c1, c2 in zip(self.rowid_columns()[:-1], self.base.rowid_columns())])