pixeltable 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (140) hide show
  1. pixeltable/__init__.py +21 -4
  2. pixeltable/catalog/__init__.py +13 -0
  3. pixeltable/catalog/catalog.py +159 -0
  4. pixeltable/catalog/column.py +200 -0
  5. pixeltable/catalog/dir.py +32 -0
  6. pixeltable/catalog/globals.py +33 -0
  7. pixeltable/catalog/insertable_table.py +191 -0
  8. pixeltable/catalog/named_function.py +36 -0
  9. pixeltable/catalog/path.py +58 -0
  10. pixeltable/catalog/path_dict.py +139 -0
  11. pixeltable/catalog/schema_object.py +39 -0
  12. pixeltable/catalog/table.py +581 -0
  13. pixeltable/catalog/table_version.py +749 -0
  14. pixeltable/catalog/table_version_path.py +133 -0
  15. pixeltable/catalog/view.py +203 -0
  16. pixeltable/client.py +520 -31
  17. pixeltable/dataframe.py +540 -349
  18. pixeltable/env.py +373 -48
  19. pixeltable/exceptions.py +12 -21
  20. pixeltable/exec/__init__.py +9 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +113 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +95 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +69 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +225 -0
  31. pixeltable/exprs/__init__.py +24 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +105 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +187 -0
  39. pixeltable/exprs/expr.py +586 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +380 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +115 -0
  44. pixeltable/exprs/image_similarity_predicate.py +58 -0
  45. pixeltable/exprs/inline_array.py +107 -0
  46. pixeltable/exprs/inline_dict.py +101 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +54 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +355 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/type_cast.py +53 -0
  56. pixeltable/exprs/variable.py +45 -0
  57. pixeltable/func/__init__.py +9 -0
  58. pixeltable/func/aggregate_function.py +194 -0
  59. pixeltable/func/batched_function.py +53 -0
  60. pixeltable/func/callable_function.py +69 -0
  61. pixeltable/func/expr_template_function.py +82 -0
  62. pixeltable/func/function.py +110 -0
  63. pixeltable/func/function_registry.py +227 -0
  64. pixeltable/func/globals.py +36 -0
  65. pixeltable/func/nos_function.py +202 -0
  66. pixeltable/func/signature.py +166 -0
  67. pixeltable/func/udf.py +163 -0
  68. pixeltable/functions/__init__.py +52 -103
  69. pixeltable/functions/eval.py +216 -0
  70. pixeltable/functions/fireworks.py +61 -0
  71. pixeltable/functions/huggingface.py +120 -0
  72. pixeltable/functions/image.py +16 -0
  73. pixeltable/functions/openai.py +88 -0
  74. pixeltable/functions/pil/image.py +148 -7
  75. pixeltable/functions/string.py +13 -0
  76. pixeltable/functions/together.py +27 -0
  77. pixeltable/functions/util.py +41 -0
  78. pixeltable/functions/video.py +62 -0
  79. pixeltable/iterators/__init__.py +3 -0
  80. pixeltable/iterators/base.py +48 -0
  81. pixeltable/iterators/document.py +311 -0
  82. pixeltable/iterators/video.py +89 -0
  83. pixeltable/metadata/__init__.py +54 -0
  84. pixeltable/metadata/converters/convert_10.py +18 -0
  85. pixeltable/metadata/schema.py +211 -0
  86. pixeltable/plan.py +656 -0
  87. pixeltable/store.py +413 -182
  88. pixeltable/tests/conftest.py +143 -86
  89. pixeltable/tests/test_audio.py +65 -0
  90. pixeltable/tests/test_catalog.py +27 -0
  91. pixeltable/tests/test_client.py +14 -14
  92. pixeltable/tests/test_component_view.py +372 -0
  93. pixeltable/tests/test_dataframe.py +433 -0
  94. pixeltable/tests/test_dirs.py +78 -62
  95. pixeltable/tests/test_document.py +117 -0
  96. pixeltable/tests/test_exprs.py +591 -135
  97. pixeltable/tests/test_function.py +297 -67
  98. pixeltable/tests/test_functions.py +283 -1
  99. pixeltable/tests/test_migration.py +43 -0
  100. pixeltable/tests/test_nos.py +54 -0
  101. pixeltable/tests/test_snapshot.py +208 -0
  102. pixeltable/tests/test_table.py +1086 -258
  103. pixeltable/tests/test_transactional_directory.py +42 -0
  104. pixeltable/tests/test_types.py +5 -11
  105. pixeltable/tests/test_video.py +149 -34
  106. pixeltable/tests/test_view.py +530 -0
  107. pixeltable/tests/utils.py +186 -45
  108. pixeltable/tool/create_test_db_dump.py +149 -0
  109. pixeltable/type_system.py +490 -133
  110. pixeltable/utils/__init__.py +17 -46
  111. pixeltable/utils/clip.py +12 -15
  112. pixeltable/utils/coco.py +136 -0
  113. pixeltable/utils/documents.py +39 -0
  114. pixeltable/utils/filecache.py +195 -0
  115. pixeltable/utils/help.py +11 -0
  116. pixeltable/utils/media_store.py +76 -0
  117. pixeltable/utils/parquet.py +126 -0
  118. pixeltable/utils/pytorch.py +172 -0
  119. pixeltable/utils/s3.py +13 -0
  120. pixeltable/utils/sql.py +17 -0
  121. pixeltable/utils/transactional_directory.py +35 -0
  122. pixeltable-0.2.1.dist-info/LICENSE +18 -0
  123. pixeltable-0.2.1.dist-info/METADATA +119 -0
  124. pixeltable-0.2.1.dist-info/RECORD +125 -0
  125. {pixeltable-0.1.2.dist-info → pixeltable-0.2.1.dist-info}/WHEEL +1 -1
  126. pixeltable/catalog.py +0 -1421
  127. pixeltable/exprs.py +0 -1745
  128. pixeltable/function.py +0 -269
  129. pixeltable/functions/clip.py +0 -10
  130. pixeltable/functions/pil/__init__.py +0 -23
  131. pixeltable/functions/tf.py +0 -21
  132. pixeltable/index.py +0 -57
  133. pixeltable/tests/test_dict.py +0 -24
  134. pixeltable/tests/test_tf.py +0 -69
  135. pixeltable/tf.py +0 -33
  136. pixeltable/utils/tf.py +0 -33
  137. pixeltable/utils/video.py +0 -32
  138. pixeltable-0.1.2.dist-info/LICENSE +0 -201
  139. pixeltable-0.1.2.dist-info/METADATA +0 -89
  140. pixeltable-0.1.2.dist-info/RECORD +0 -37
pixeltable/store.py CHANGED
@@ -1,191 +1,422 @@
1
- import enum
2
- import platform
3
-
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ import warnings
6
+ from typing import Optional, Dict, Any, List, Tuple, Set
7
+ import logging
8
+ import urllib
4
9
  import sqlalchemy as sql
5
- from sqlalchemy import Integer, String, Enum, Boolean, TIMESTAMP, BigInteger, LargeBinary, JSON
6
- from sqlalchemy import ForeignKey, UniqueConstraint, ForeignKeyConstraint
7
- from sqlalchemy.orm import declarative_base
8
-
9
- from pixeltable import type_system as pt_types
10
-
11
- Base = declarative_base()
12
-
13
-
14
- class Db(Base):
15
- __tablename__ = 'dbs'
16
-
17
- id = sql.Column(Integer, primary_key=True, autoincrement=True, nullable=False)
18
- name = sql.Column(String, nullable=False)
19
-
20
-
21
- class Dir(Base):
22
- __tablename__ = 'dirs'
23
-
24
- id = sql.Column(Integer, primary_key=True, autoincrement=True, nullable=False)
25
- db_id = sql.Column(Integer, ForeignKey('dbs.id'), nullable=False)
26
- path = sql.Column(String, nullable=False)
27
- is_snapshot = sql.Column(Boolean, nullable=False)
28
-
29
-
30
- class Table(Base):
31
- __tablename__ = 'tables'
32
-
33
- MAX_VERSION = 9223372036854775807 # 2^63 - 1
34
-
35
- id = sql.Column(Integer, primary_key=True, autoincrement=True, nullable=False)
36
- db_id = sql.Column(Integer, ForeignKey('dbs.id'), nullable=False)
37
- dir_id = sql.Column(Integer, ForeignKey('dirs.id'), nullable=False)
38
- name = sql.Column(String, nullable=False)
39
- parameters = sql.Column(JSON, nullable=False)
10
+ from tqdm import tqdm, TqdmWarning
11
+ import abc
40
12
 
41
- # monotonically increasing w/in Table for both data and schema changes, starting at 0
42
- current_version = sql.Column(BigInteger, nullable=False)
43
- # each version has a corresponding schema version (current_version >= current_schema_version)
44
- current_schema_version = sql.Column(BigInteger, nullable=False)
13
+ import pixeltable.catalog as catalog
14
+ from pixeltable.metadata import schema
15
+ from pixeltable.type_system import StringType
16
+ from pixeltable.exec import ExecNode
17
+ from pixeltable import exprs
18
+ from pixeltable.utils.sql import log_stmt, log_explain
19
+ import pixeltable.env as env
20
+ from pixeltable.utils.media_store import MediaStore
45
21
 
46
- # if False, can't apply schema or data changes to this table
47
- # (table got dropped, but we need to keep a record of it for snapshots)
48
- is_mutable = sql.Column(Boolean, nullable=False)
49
22
 
50
- next_col_id = sql.Column(Integer, nullable=False) # used to assign Column.id
23
+ _logger = logging.getLogger('pixeltable')
51
24
 
52
- # - used to assign the rowid column in the storage table
53
- # - every row is assigned a unique and immutable rowid on insertion
54
- next_row_id = sql.Column(BigInteger, nullable=False)
55
25
 
56
- __table_args__ = (
57
- #ForeignKeyConstraint(
58
- #['id', 'current_schema_version'], ['tableschemaversions.tbl_id', 'tableschemaversions.schema_version']),
59
- )
26
+ class StoreBase:
27
+ """Base class for stored tables
60
28
 
61
- def storage_name(self) -> str:
62
- return f'tbl_{self.db_id}_{self.id}'
63
-
64
-
65
- # versioning: each table schema change results in a new record
66
- class TableSchemaVersion(Base):
67
- __tablename__ = 'tableschemaversions'
68
-
69
- tbl_id = sql.Column(Integer, ForeignKey('tables.id'), primary_key=True, nullable=False)
70
- schema_version = sql.Column(BigInteger, primary_key=True, nullable=False)
71
- preceding_schema_version = sql.Column(BigInteger, nullable=False)
72
-
73
-
74
- # - records the physical schema of a table, ie, the columns that are actually stored
75
- # - versioning information is needed to GC unreachable storage columns
76
- # - one record per column (across all schema versions)
77
- class StorageColumn(Base):
78
- __tablename__ = 'storagecolumns'
79
-
80
- tbl_id = sql.Column(Integer, ForeignKey('tables.id'), primary_key=True, nullable=False)
81
- # immutable and monotonically increasing from 0 w/in Table
82
- col_id = sql.Column(Integer, primary_key=True, nullable=False)
83
- # table schema version when col was added
84
- schema_version_add = sql.Column(BigInteger, nullable=False)
85
- # table schema version when col was dropped
86
- schema_version_drop = sql.Column(BigInteger, nullable=True)
87
-
88
- __table_args__ = (
89
- ForeignKeyConstraint(
90
- ['tbl_id', 'schema_version_add'], ['tableschemaversions.tbl_id', 'tableschemaversions.schema_version']),
91
- ForeignKeyConstraint(
92
- ['tbl_id', 'schema_version_drop'], ['tableschemaversions.tbl_id', 'tableschemaversions.schema_version'])
93
- )
94
-
95
-
96
- # - records the logical (user-visible) schema of a table
97
- # - contains the full set of columns for each new schema version: one record per column x schema version
98
- class SchemaColumn(Base):
99
- __tablename__ = 'schemacolumns'
100
-
101
- tbl_id = sql.Column(Integer, ForeignKey('tables.id'), primary_key=True, nullable=False)
102
- schema_version = sql.Column(BigInteger, primary_key=True, nullable=False)
103
- # immutable and monotonically increasing from 0 w/in Table
104
- col_id = sql.Column(Integer, primary_key=True, nullable=False)
105
- pos = sql.Column(Integer, nullable=False) # position in table, starting at 0
106
- name = sql.Column(String, nullable=False)
107
- col_type = sql.Column(String, nullable=False) # json
108
- is_nullable = sql.Column(Boolean, nullable=False)
109
- is_pk = sql.Column(Boolean, nullable=False)
110
- value_expr = sql.Column(String, nullable=True) # json
111
- # if True, creates vector index for this column
112
- is_indexed = sql.Column(Boolean, nullable=False)
113
-
114
- __table_args__ = (
115
- UniqueConstraint('tbl_id', 'schema_version', 'pos'),
116
- UniqueConstraint('tbl_id', 'schema_version', 'name'),
117
- ForeignKeyConstraint(['tbl_id', 'col_id'], ['storagecolumns.tbl_id', 'storagecolumns.col_id']),
118
- ForeignKeyConstraint(
119
- ['tbl_id', 'schema_version'], ['tableschemaversions.tbl_id', 'tableschemaversions.schema_version'])
120
- )
121
-
122
-
123
- class TableSnapshot(Base):
124
- __tablename__ = 'tablesnapshots'
125
-
126
- id = sql.Column(Integer, primary_key=True, autoincrement=True, nullable=False)
127
- db_id = sql.Column(Integer, ForeignKey('dbs.id'), nullable=False)
128
- dir_id = sql.Column(Integer, ForeignKey('dirs.id'), nullable=False)
129
- name = sql.Column(String, nullable=False)
130
- tbl_id = sql.Column(Integer, nullable=False)
131
- tbl_version = sql.Column(BigInteger, nullable=False)
132
- tbl_schema_version = sql.Column(BigInteger, nullable=False)
133
-
134
- __table_args__ = (
135
- ForeignKeyConstraint(['tbl_id'], ['tables.id']),
136
- ForeignKeyConstraint(
137
- ['tbl_id', 'tbl_schema_version'], ['tableschemaversions.tbl_id', 'tableschemaversions.schema_version']),
138
- )
139
-
140
-
141
- class Function(Base):
29
+ Each row has the following system columns:
30
+ - rowid columns: one or more columns that identify a user-visible row across all versions
31
+ - v_min: version at which the row was created
32
+ - v_max: version at which the row was deleted (or MAX_VERSION if it's still live)
142
33
  """
143
- User-defined functions that are not library functions (ie, aren't available at runtime as a symbol in a known
144
- module).
145
- Functions without a name are anonymous functions used in the definition of a computed column.
146
- Functions that have names are also assigned to a database and directory.
147
- We store the Python version under which a Function was created (and the callable pickled) in order to warn
148
- against version mismatches.
34
+
35
+ def __init__(self, tbl_version: catalog.TableVersion):
36
+ self.tbl_version = tbl_version
37
+ self.sa_md = sql.MetaData()
38
+ self.sa_tbl: Optional[sql.Table] = None
39
+ self._create_sa_tbl()
40
+
41
+ def pk_columns(self) -> List[sql.Column]:
42
+ return self._pk_columns
43
+
44
+ def rowid_columns(self) -> List[sql.Column]:
45
+ return self._pk_columns[:-1]
46
+
47
+ @abc.abstractmethod
48
+ def _create_rowid_columns(self) -> List[sql.Column]:
49
+ """Create and return rowid columns"""
50
+ pass
51
+
52
+ @abc.abstractmethod
53
+ def _create_system_columns(self) -> List[sql.Column]:
54
+ """Create and return system columns"""
55
+ rowid_cols = self._create_rowid_columns()
56
+ self.v_min_col = sql.Column('v_min', sql.BigInteger, nullable=False)
57
+ self.v_max_col = \
58
+ sql.Column('v_max', sql.BigInteger, nullable=False, server_default=str(schema.Table.MAX_VERSION))
59
+ self._pk_columns = [*rowid_cols, self.v_min_col]
60
+ return [*rowid_cols, self.v_min_col, self.v_max_col]
61
+
62
+
63
+ def _create_sa_tbl(self) -> None:
64
+ """Create self.sa_tbl from self.tbl_version."""
65
+ system_cols = self._create_system_columns()
66
+ all_cols = system_cols.copy()
67
+ idxs: List[sql.Index] = []
68
+ for col in [c for c in self.tbl_version.cols if c.is_stored]:
69
+ # re-create sql.Column for each column, regardless of whether it already has sa_col set: it was bound
70
+ # to the last sql.Table version we created and cannot be reused
71
+ col.create_sa_cols()
72
+ all_cols.append(col.sa_col)
73
+ if col.records_errors:
74
+ all_cols.append(col.sa_errormsg_col)
75
+ all_cols.append(col.sa_errortype_col)
76
+
77
+ if col.is_indexed:
78
+ all_cols.append(col.sa_idx_col)
79
+
80
+ # we create an index for:
81
+ # - scalar columns (except for strings, because long strings can't be used for B-tree indices)
82
+ # - non-computed video and image columns (they will contain external paths/urls that users might want to
83
+ # filter on)
84
+ if (col.col_type.is_scalar_type() and not col.col_type.is_string_type()) \
85
+ or (col.col_type.is_media_type() and not col.is_computed):
86
+ # index names need to be unique within the Postgres instance
87
+ idx_name = f'idx_{col.id}_{self.tbl_version.id.hex}'
88
+ idxs.append(sql.Index(idx_name, col.sa_col))
89
+
90
+ if self.sa_tbl is not None:
91
+ # if we're called in response to a schema change, we need to remove the old table first
92
+ self.sa_md.remove(self.sa_tbl)
93
+
94
+ # index for all system columns:
95
+ # - base x view joins can be executed as merge joins
96
+ # - speeds up ORDER BY rowid DESC
97
+ # - allows filtering for a particular table version in index scan
98
+ idx_name = f'sys_cols_idx_{self.tbl_version.id.hex}'
99
+ idxs.append(sql.Index(idx_name, *system_cols))
100
+ # v_min/v_max indices: speeds up base table scans needed to propagate a base table insert or delete
101
+ idx_name = f'vmin_idx_{self.tbl_version.id.hex}'
102
+ idxs.append(sql.Index(idx_name, self.v_min_col, postgresql_using='brin'))
103
+ idx_name = f'vmax_idx_{self.tbl_version.id.hex}'
104
+ idxs.append(sql.Index(idx_name, self.v_max_col, postgresql_using='brin'))
105
+
106
+ self.sa_tbl = sql.Table(self._storage_name(), self.sa_md, *all_cols, *idxs)
107
+
108
+ @abc.abstractmethod
109
+ def _rowid_join_predicate(self) -> sql.ClauseElement:
110
+ """Return predicate for rowid joins to all bases"""
111
+ pass
112
+
113
+ @abc.abstractmethod
114
+ def _storage_name(self) -> str:
115
+ """Return the name of the data store table"""
116
+ pass
117
+
118
+ def _move_tmp_media_file(self, file_url: Optional[str], col: catalog.Column, v_min: int) -> str:
119
+ """Move tmp media file with given url to Env.media_dir and return new url, or given url if not a tmp_dir file"""
120
+ pxt_tmp_dir = str(env.Env.get().tmp_dir)
121
+ if file_url is None:
122
+ return None
123
+ parsed = urllib.parse.urlparse(file_url)
124
+ if parsed.scheme != '' and parsed.scheme != 'file':
125
+ # remote url
126
+ return file_url
127
+ file_path = urllib.parse.unquote(parsed.path)
128
+ if not file_path.startswith(pxt_tmp_dir):
129
+ # not a tmp file
130
+ return file_url
131
+ _, ext = os.path.splitext(file_path)
132
+ new_path = str(MediaStore.prepare_media_path(self.tbl_version.id, col.id, v_min, ext=ext))
133
+ os.rename(file_path, new_path)
134
+ new_file_url = urllib.parse.urljoin('file:', urllib.request.pathname2url(new_path))
135
+ return new_file_url
136
+
137
+ def _move_tmp_media_files(
138
+ self, table_rows: List[Dict[str, Any]], media_cols: List[catalog.Column], v_min: int
139
+ ) -> None:
140
+ """Move tmp media files that we generated to a permanent location"""
141
+ for c in media_cols:
142
+ for table_row in table_rows:
143
+ file_url = table_row[c.storage_name()]
144
+ table_row[c.storage_name()] = self._move_tmp_media_file(file_url, c, v_min)
145
+
146
+ def _create_table_row(
147
+ self, input_row: exprs.DataRow, row_builder: exprs.RowBuilder, media_cols: List[catalog.Column],
148
+ exc_col_ids: Set[int], v_min: int
149
+ ) -> Tuple[Dict[str, Any], int]:
150
+ """Return Tuple[complete table row, # of exceptions] for insert()
151
+ Creates a row that includes the PK columns, with the values from input_row.pk.
152
+ Returns:
153
+ Tuple[complete table row, # of exceptions]
154
+ """
155
+ table_row, num_excs = row_builder.create_table_row(input_row, exc_col_ids)
156
+
157
+ assert input_row.pk is not None and len(input_row.pk) == len(self._pk_columns)
158
+ for pk_col, pk_val in zip(self._pk_columns, input_row.pk):
159
+ if pk_col == self.v_min_col:
160
+ table_row[pk_col.name] = v_min
161
+ else:
162
+ table_row[pk_col.name] = pk_val
163
+
164
+ return table_row, num_excs
165
+
166
+ def count(self) -> None:
167
+ """Return the number of rows visible in self.tbl_version"""
168
+ stmt = sql.select(sql.func.count('*'))\
169
+ .select_from(self.sa_tbl)\
170
+ .where(self.v_min_col <= self.tbl_version.version)\
171
+ .where(self.v_max_col > self.tbl_version.version)
172
+ with env.Env.get().engine.begin() as conn:
173
+ result = conn.execute(stmt).scalar_one()
174
+ assert isinstance(result, int)
175
+ return result
176
+
177
+ def create(self, conn: sql.engine.Connection) -> None:
178
+ self.sa_md.create_all(bind=conn)
179
+
180
+ def drop(self, conn: sql.engine.Connection) -> None:
181
+ """Drop store table"""
182
+ self.sa_md.drop_all(bind=conn)
183
+
184
+ def add_column(self, col: catalog.Column, conn: sql.engine.Connection) -> None:
185
+ """Add column(s) to the store-resident table based on a catalog column
186
+
187
+ Note that a computed catalog column will require two extra columns (for the computed value and for the error
188
+ message).
189
+ """
190
+ assert col.is_stored
191
+ stmt = sql.text(f'ALTER TABLE {self._storage_name()} ADD COLUMN {col.storage_name()} {col.col_type.to_sql()}')
192
+ log_stmt(_logger, stmt)
193
+ conn.execute(stmt)
194
+ added_storage_cols = [col.storage_name()]
195
+ if col.records_errors:
196
+ # we also need to create the errormsg and errortype storage cols
197
+ stmt = (f'ALTER TABLE {self._storage_name()} '
198
+ f'ADD COLUMN {col.errormsg_storage_name()} {StringType().to_sql()} DEFAULT NULL')
199
+ conn.execute(sql.text(stmt))
200
+ stmt = (f'ALTER TABLE {self._storage_name()} '
201
+ f'ADD COLUMN {col.errortype_storage_name()} {StringType().to_sql()} DEFAULT NULL')
202
+ conn.execute(sql.text(stmt))
203
+ added_storage_cols.extend([col.errormsg_storage_name(), col.errortype_storage_name()])
204
+ self._create_sa_tbl()
205
+ _logger.info(f'Added columns {added_storage_cols} to storage table {self._storage_name()}')
206
+
207
+ def drop_column(self, col: Optional[catalog.Column] = None, conn: Optional[sql.engine.Connection] = None) -> None:
208
+ """Re-create self.sa_tbl and drop column, if one is given"""
209
+ if col is not None:
210
+ assert conn is not None
211
+ stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.storage_name()}'
212
+ conn.execute(sql.text(stmt))
213
+ if col.records_errors:
214
+ stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.errormsg_storage_name()}'
215
+ conn.execute(sql.text(stmt))
216
+ stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.errortype_storage_name()}'
217
+ conn.execute(sql.text(stmt))
218
+ self._create_sa_tbl()
219
+
220
+ def load_column(
221
+ self, col: catalog.Column, exec_plan: ExecNode, value_expr_slot_idx: int, embedding_slot_idx: int,
222
+ conn: sql.engine.Connection
223
+ ) -> int:
224
+ """Update store column of a computed column with values produced by an execution plan
225
+
226
+ Returns:
227
+ number of rows with exceptions
228
+ Raises:
229
+ sql.exc.DBAPIError if there was an error during SQL execution
230
+ """
231
+ num_excs = 0
232
+ num_rows = 0
233
+ for row_batch in exec_plan:
234
+ num_rows += len(row_batch)
235
+ for result_row in row_batch:
236
+ values_dict: Dict[sql.Column, Any] = {}
237
+
238
+ if col.is_computed:
239
+ if result_row.has_exc(value_expr_slot_idx):
240
+ num_excs += 1
241
+ value_exc = result_row.get_exc(value_expr_slot_idx)
242
+ # we store a NULL value and record the exception/exc type
243
+ error_type = type(value_exc).__name__
244
+ error_msg = str(value_exc)
245
+ values_dict = {
246
+ col.sa_col: None,
247
+ col.sa_errortype_col: error_type,
248
+ col.sa_errormsg_col: error_msg
249
+ }
250
+ else:
251
+ val = result_row.get_stored_val(value_expr_slot_idx)
252
+ if col.col_type.is_media_type():
253
+ val = self._move_tmp_media_file(val, col, result_row.pk[-1])
254
+ values_dict = {col.sa_col: val}
255
+
256
+ if col.is_indexed:
257
+ # TODO: deal with exceptions
258
+ assert not result_row.has_exc(embedding_slot_idx)
259
+ # don't use get_stored_val() here, we need to pass the ndarray
260
+ embedding = result_row[embedding_slot_idx]
261
+ values_dict[col.sa_index_col] = embedding
262
+
263
+ update_stmt = sql.update(self.sa_tbl).values(values_dict)
264
+ for pk_col, pk_val in zip(self.pk_columns(), result_row.pk):
265
+ update_stmt = update_stmt.where(pk_col == pk_val)
266
+ log_stmt(_logger, update_stmt)
267
+ conn.execute(update_stmt)
268
+
269
+ return num_excs
270
+
271
+ def insert_rows(
272
+ self, exec_plan: ExecNode, conn: sql.engine.Connection, v_min: Optional[int] = None
273
+ ) -> Tuple[int, int, Set[int]]:
274
+ """Insert rows into the store table and update the catalog table's md
275
+ Returns:
276
+ number of inserted rows, number of exceptions, set of column ids that have exceptions
277
+ """
278
+ assert v_min is not None
279
+ exec_plan.ctx.conn = conn
280
+ batch_size = 16 # TODO: is this a good batch size?
281
+ # TODO: total?
282
+ num_excs = 0
283
+ num_rows = 0
284
+ cols_with_excs: Set[int] = set()
285
+ progress_bar: Optional[tqdm] = None # create this only after we started executing
286
+ row_builder = exec_plan.row_builder
287
+ media_cols = [info.col for info in row_builder.table_columns if info.col.col_type.is_media_type()]
288
+ try:
289
+ exec_plan.open()
290
+ for row_batch in exec_plan:
291
+ num_rows += len(row_batch)
292
+ for batch_start_idx in range(0, len(row_batch), batch_size):
293
+ # compute batch of rows and convert them into table rows
294
+ table_rows: List[Dict[str, Any]] = []
295
+ for row_idx in range(batch_start_idx, min(batch_start_idx + batch_size, len(row_batch))):
296
+ row = row_batch[row_idx]
297
+ table_row, num_row_exc = \
298
+ self._create_table_row(row, row_builder, media_cols, cols_with_excs, v_min=v_min)
299
+ num_excs += num_row_exc
300
+ table_rows.append(table_row)
301
+ if progress_bar is None:
302
+ warnings.simplefilter("ignore", category=TqdmWarning)
303
+ progress_bar = tqdm(
304
+ desc=f'Inserting rows into `{self.tbl_version.name}`',
305
+ unit=' rows',
306
+ ncols=100,
307
+ file=sys.stdout
308
+ )
309
+ progress_bar.update(1)
310
+ self._move_tmp_media_files(table_rows, media_cols, v_min)
311
+ conn.execute(sql.insert(self.sa_tbl), table_rows)
312
+ if progress_bar is not None:
313
+ progress_bar.close()
314
+ return num_rows, num_excs, cols_with_excs
315
+ finally:
316
+ exec_plan.close()
317
+
318
+ def _versions_clause(self, versions: List[Optional[int]], match_on_vmin: bool) -> sql.ClauseElement:
319
+ """Return filter for base versions"""
320
+ v = versions[0]
321
+ if v is None:
322
+ # we're looking at live rows
323
+ clause = sql.and_(self.v_min_col <= self.tbl_version.version, self.v_max_col == schema.Table.MAX_VERSION)
324
+ else:
325
+ # we're looking at a specific version
326
+ clause = self.v_min_col == v if match_on_vmin else self.v_max_col == v
327
+ if len(versions) == 1:
328
+ return clause
329
+ return sql.and_(clause, self.base._versions_clause(versions[1:], match_on_vmin))
330
+
331
+ def delete_rows(
332
+ self, current_version: int, base_versions: List[Optional[int]], match_on_vmin: bool,
333
+ where_clause: Optional[sql.ClauseElement], conn: sql.engine.Connection) -> int:
334
+ """Mark rows as deleted that are live and were created prior to current_version.
335
+ Args:
336
+ base_versions: if non-None, join only to base rows that were created at that version,
337
+ otherwise join to rows that are live in the base's current version (which is distinct from the
338
+ current_version parameter)
339
+ match_on_vmin: if True, match exact versions on v_min; if False, match on v_max
340
+ where_clause: if not None, also apply where_clause
341
+ Returns:
342
+ number of deleted rows
343
+ """
344
+ where_clause = sql.true() if where_clause is None else where_clause
345
+ where_clause = sql.and_(
346
+ self.v_min_col < current_version,
347
+ self.v_max_col == schema.Table.MAX_VERSION,
348
+ where_clause)
349
+ rowid_join_clause = self._rowid_join_predicate()
350
+ base_versions_clause = sql.true() if len(base_versions) == 0 \
351
+ else self.base._versions_clause(base_versions, match_on_vmin)
352
+ stmt = sql.update(self.sa_tbl) \
353
+ .values({self.v_max_col: current_version}) \
354
+ .where(where_clause) \
355
+ .where(rowid_join_clause) \
356
+ .where(base_versions_clause)
357
+ log_explain(_logger, stmt, conn)
358
+ status = conn.execute(stmt)
359
+ return status.rowcount
360
+
361
+
362
+ class StoreTable(StoreBase):
363
+ def __init__(self, tbl_version: catalog.TableVersion):
364
+ assert not tbl_version.is_view()
365
+ super().__init__(tbl_version)
366
+
367
+ def _create_rowid_columns(self) -> List[sql.Column]:
368
+ self.rowid_col = sql.Column('rowid', sql.BigInteger, nullable=False)
369
+ return [self.rowid_col]
370
+
371
+ def _storage_name(self) -> str:
372
+ return f'tbl_{self.tbl_version.id.hex}'
373
+
374
+ def _rowid_join_predicate(self) -> sql.ClauseElement:
375
+ return sql.true()
376
+
377
+
378
+ class StoreView(StoreBase):
379
+ def __init__(self, catalog_view: catalog.TableVersion):
380
+ assert catalog_view.is_view()
381
+ self.base = catalog_view.base.store_tbl
382
+ super().__init__(catalog_view)
383
+
384
+ def _create_rowid_columns(self) -> List[sql.Column]:
385
+ # a view row corresponds directly to a single base row, which means it needs to duplicate its rowid columns
386
+ self.rowid_cols = [sql.Column(c.name, c.type) for c in self.base.rowid_columns()]
387
+ return self.rowid_cols
388
+
389
+ def _storage_name(self) -> str:
390
+ return f'view_{self.tbl_version.id.hex}'
391
+
392
+ def _rowid_join_predicate(self) -> sql.ClauseElement:
393
+ return sql.and_(
394
+ self.base._rowid_join_predicate(),
395
+ *[c1 == c2 for c1, c2 in zip(self.rowid_columns(), self.base.rowid_columns())])
396
+
397
+ class StoreComponentView(StoreView):
398
+ """A view that stores components of its base, as produced by a ComponentIterator
399
+
400
+ PK: now also includes pos, the position returned by the ComponentIterator for the base row identified by base_rowid
149
401
  """
150
- __tablename__ = 'functions'
151
-
152
- id = sql.Column(Integer, primary_key=True, autoincrement=True, nullable=False)
153
- db_id = sql.Column(Integer, ForeignKey('dbs.id'), nullable=True)
154
- dir_id = sql.Column(Integer, ForeignKey('dirs.id'), nullable=True)
155
- name = sql.Column(String, nullable=True)
156
- return_type = sql.Column(String, nullable=False) # json
157
- param_types = sql.Column(String, nullable=False) # json
158
- eval_obj = sql.Column(LargeBinary, nullable=True) # Function.eval_fn
159
- init_obj = sql.Column(LargeBinary, nullable=True) # AggregateFunction.init_fn
160
- update_obj = sql.Column(LargeBinary, nullable=True) # AggregateFunction.update_fn
161
- value_obj = sql.Column(LargeBinary, nullable=True) # AggregateFunction.value_fn
162
- python_version = sql.Column(
163
- String, nullable=False, default=platform.python_version, onupdate=platform.python_version)
164
-
165
-
166
- class OpCodes(enum.Enum):
167
- CREATE_DB = 1
168
- RENAME_DB = 1
169
- DROP_DB = 1
170
- CREATE_TABLE = 1
171
- RENAME_TABLE = 1
172
- DROP_TABLE = 1
173
- ADD_COLUMN = 1
174
- RENAME_COLUMN = 1
175
- DROP_COLUMN = 1
176
- CREATE_DIR = 1
177
- DROP_DIR = 1
178
- CREATE_SNAPSHOT = 1
179
- DROP_SNAPSHOT = 1
180
-
181
-
182
- class Operation(Base):
183
- __tablename__ = 'oplog'
184
-
185
- id = sql.Column(Integer, primary_key=True, autoincrement=True, nullable=False)
186
- ts = sql.Column(TIMESTAMP, nullable=False)
187
- # encodes db_id, schema object type (table, view, table/view snapshot), table/view/... id
188
- schema_object = sql.Column(BigInteger, nullable=False)
189
- opcode = sql.Column(Enum(OpCodes), nullable=False)
190
- # operation-specific details; json
191
- details = sql.Column(String, nullable=False)
402
+ def __init__(self, catalog_view: catalog.TableVersion):
403
+ super().__init__(catalog_view)
404
+
405
+ def _create_rowid_columns(self) -> List[sql.Column]:
406
+ # each base row is expanded into n view rows
407
+ self.rowid_cols = [sql.Column(c.name, c.type) for c in self.base.rowid_columns()]
408
+ # name of pos column: avoid collisions with bases' pos columns
409
+ self.pos_col = sql.Column(f'pos_{len(self.rowid_cols) - 1}', sql.BigInteger, nullable=False)
410
+ self.pos_col_idx = len(self.rowid_cols)
411
+ self.rowid_cols.append(self.pos_col)
412
+ return self.rowid_cols
413
+
414
+ def _create_sa_tbl(self) -> None:
415
+ super()._create_sa_tbl()
416
+ # we need to fix up the 'pos' column in TableVersion
417
+ self.tbl_version.cols_by_name['pos'].sa_col = self.pos_col
418
+
419
+ def _rowid_join_predicate(self) -> sql.ClauseElement:
420
+ return sql.and_(
421
+ self.base._rowid_join_predicate(),
422
+ *[c1 == c2 for c1, c2 in zip(self.rowid_columns()[:-1], self.base.rowid_columns())])