pixeltable 0.2.21__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (82) hide show
  1. pixeltable/__version__.py +2 -2
  2. pixeltable/catalog/__init__.py +1 -1
  3. pixeltable/catalog/column.py +37 -11
  4. pixeltable/catalog/globals.py +18 -0
  5. pixeltable/catalog/insertable_table.py +6 -4
  6. pixeltable/catalog/table.py +19 -3
  7. pixeltable/catalog/table_version.py +34 -14
  8. pixeltable/catalog/view.py +16 -17
  9. pixeltable/dataframe.py +7 -8
  10. pixeltable/env.py +5 -0
  11. pixeltable/exec/__init__.py +0 -1
  12. pixeltable/exec/aggregation_node.py +6 -3
  13. pixeltable/exec/cache_prefetch_node.py +1 -1
  14. pixeltable/exec/data_row_batch.py +2 -19
  15. pixeltable/exec/exec_node.py +2 -1
  16. pixeltable/exec/expr_eval_node.py +17 -10
  17. pixeltable/exec/in_memory_data_node.py +6 -3
  18. pixeltable/exec/sql_node.py +24 -25
  19. pixeltable/exprs/arithmetic_expr.py +3 -1
  20. pixeltable/exprs/array_slice.py +7 -7
  21. pixeltable/exprs/column_property_ref.py +37 -10
  22. pixeltable/exprs/column_ref.py +93 -14
  23. pixeltable/exprs/comparison.py +5 -5
  24. pixeltable/exprs/compound_predicate.py +8 -7
  25. pixeltable/exprs/data_row.py +27 -18
  26. pixeltable/exprs/expr.py +53 -52
  27. pixeltable/exprs/expr_set.py +5 -0
  28. pixeltable/exprs/function_call.py +32 -16
  29. pixeltable/exprs/globals.py +4 -1
  30. pixeltable/exprs/in_predicate.py +8 -7
  31. pixeltable/exprs/inline_expr.py +4 -4
  32. pixeltable/exprs/is_null.py +4 -4
  33. pixeltable/exprs/json_mapper.py +11 -12
  34. pixeltable/exprs/json_path.py +5 -10
  35. pixeltable/exprs/literal.py +5 -5
  36. pixeltable/exprs/method_ref.py +5 -4
  37. pixeltable/exprs/object_ref.py +2 -1
  38. pixeltable/exprs/row_builder.py +88 -36
  39. pixeltable/exprs/rowid_ref.py +12 -11
  40. pixeltable/exprs/similarity_expr.py +12 -7
  41. pixeltable/exprs/sql_element_cache.py +7 -5
  42. pixeltable/exprs/type_cast.py +8 -6
  43. pixeltable/exprs/variable.py +5 -4
  44. pixeltable/func/aggregate_function.py +1 -1
  45. pixeltable/func/function.py +11 -10
  46. pixeltable/functions/__init__.py +2 -2
  47. pixeltable/functions/globals.py +5 -7
  48. pixeltable/functions/huggingface.py +19 -20
  49. pixeltable/functions/llama_cpp.py +106 -0
  50. pixeltable/functions/ollama.py +147 -0
  51. pixeltable/functions/replicate.py +72 -0
  52. pixeltable/functions/string.py +9 -0
  53. pixeltable/globals.py +12 -20
  54. pixeltable/index/btree.py +16 -3
  55. pixeltable/index/embedding_index.py +4 -4
  56. pixeltable/io/__init__.py +1 -2
  57. pixeltable/io/fiftyone.py +178 -0
  58. pixeltable/io/globals.py +96 -2
  59. pixeltable/iterators/base.py +3 -2
  60. pixeltable/iterators/document.py +1 -1
  61. pixeltable/iterators/video.py +120 -63
  62. pixeltable/metadata/__init__.py +1 -1
  63. pixeltable/metadata/converters/convert_21.py +34 -0
  64. pixeltable/metadata/converters/util.py +45 -4
  65. pixeltable/metadata/notes.py +1 -0
  66. pixeltable/metadata/schema.py +8 -0
  67. pixeltable/plan.py +16 -14
  68. pixeltable/py.typed +0 -0
  69. pixeltable/store.py +7 -2
  70. pixeltable/tool/create_test_video.py +1 -1
  71. pixeltable/tool/embed_udf.py +1 -1
  72. pixeltable/tool/mypy_plugin.py +28 -5
  73. pixeltable/type_system.py +17 -1
  74. pixeltable/utils/documents.py +15 -1
  75. pixeltable/utils/formatter.py +9 -10
  76. {pixeltable-0.2.21.dist-info → pixeltable-0.2.22.dist-info}/METADATA +46 -10
  77. pixeltable-0.2.22.dist-info/RECORD +153 -0
  78. pixeltable/exec/media_validation_node.py +0 -43
  79. pixeltable-0.2.21.dist-info/RECORD +0 -148
  80. {pixeltable-0.2.21.dist-info → pixeltable-0.2.22.dist-info}/LICENSE +0 -0
  81. {pixeltable-0.2.21.dist-info → pixeltable-0.2.22.dist-info}/WHEEL +0 -0
  82. {pixeltable-0.2.21.dist-info → pixeltable-0.2.22.dist-info}/entry_points.txt +0 -0
@@ -1,13 +1,15 @@
1
1
  import logging
2
2
  import math
3
+ from fractions import Fraction
3
4
  from pathlib import Path
4
5
  from typing import Any, Optional, Sequence
5
6
 
6
- import cv2
7
+ import av # type: ignore[import-untyped]
8
+ import pandas as pd
7
9
  import PIL.Image
8
10
 
9
- from pixeltable.exceptions import Error
10
- from pixeltable.type_system import ColumnType, FloatType, ImageType, IntType, VideoType
11
+ import pixeltable.exceptions as excs
12
+ import pixeltable.type_system as ts
11
13
 
12
14
  from .base import ComponentIterator
13
15
 
@@ -30,108 +32,163 @@ class FrameIterator(ComponentIterator):
30
32
  `num_frames` is greater than the number of frames in the video, all frames will be extracted.
31
33
  """
32
34
 
35
+ # Input parameters
33
36
  video_path: Path
34
- video_reader: cv2.VideoCapture
35
37
  fps: Optional[float]
36
38
  num_frames: Optional[int]
37
- frames_to_extract: Sequence[int]
38
- frames_set: set[int]
39
- next_frame_idx: int
39
+
40
+ # Video info
41
+ container: av.container.input.InputContainer
42
+ video_framerate: Fraction
43
+ video_time_base: Fraction
44
+ video_frame_count: int
45
+ video_start_time: int
46
+
47
+ # List of frame indices to be extracted, or None to extract all frames
48
+ frames_to_extract: Optional[list[int]]
49
+
50
+ # Next frame to extract, as an iterator `pos` index. If `frames_to_extract` is None, this is the same as the
51
+ # frame index in the video. Otherwise, the corresponding video index is `frames_to_extract[next_pos]`.
52
+ next_pos: int
40
53
 
41
54
  def __init__(self, video: str, *, fps: Optional[float] = None, num_frames: Optional[int] = None):
42
55
  if fps is not None and num_frames is not None:
43
- raise Error('At most one of `fps` or `num_frames` may be specified')
56
+ raise excs.Error('At most one of `fps` or `num_frames` may be specified')
44
57
 
45
58
  video_path = Path(video)
46
59
  assert video_path.exists() and video_path.is_file()
47
60
  self.video_path = video_path
48
- self.video_reader = cv2.VideoCapture(str(video_path))
61
+ self.container = av.open(str(video_path))
49
62
  self.fps = fps
50
63
  self.num_frames = num_frames
51
- if not self.video_reader.isOpened():
52
- raise Error(f'Failed to open video: {video}')
53
64
 
54
- video_fps = int(self.video_reader.get(cv2.CAP_PROP_FPS))
55
- if fps is not None and fps > video_fps:
56
- raise Error(f'Video {video}: requested fps ({fps}) exceeds that of the video ({video_fps})')
57
- num_video_frames = int(self.video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
58
- if num_video_frames == 0:
59
- raise Error(f'Video {video}: failed to get number of frames')
65
+ self.video_framerate = self.container.streams.video[0].average_rate
66
+ self.video_time_base = self.container.streams.video[0].time_base
67
+ self.video_start_time = self.container.streams.video[0].start_time or 0
68
+
69
+ # Determine the number of frames in the video
70
+ self.video_frame_count = self.container.streams.video[0].frames
71
+ if self.video_frame_count == 0:
72
+ # The video codec does not provide a frame count in the standard `frames` field. Try some other methods.
73
+ metadata: dict = self.container.streams.video[0].metadata
74
+ if 'NUMBER_OF_FRAMES' in metadata:
75
+ self.video_frame_count = int(metadata['NUMBER_OF_FRAMES'])
76
+ elif 'DURATION' in metadata:
77
+ # As a last resort, calculate the frame count from the stream duration.
78
+ duration = metadata['DURATION']
79
+ assert isinstance(duration, str)
80
+ seconds = pd.to_timedelta(duration).total_seconds()
81
+ # Usually the duration and framerate are precise enough for this calculation to be accurate, but if
82
+ # we encounter a case where it's off by one due to a rounding error, that's ok; we only use this
83
+ # to determine the positions of the sampled frames when `fps` or `num_frames` is specified.
84
+ self.video_frame_count = round(seconds * self.video_framerate)
85
+ else:
86
+ raise excs.Error(f'Video {video}: failed to get number of frames')
60
87
 
61
88
  if num_frames is not None:
62
89
  # specific number of frames
63
- if num_frames > num_video_frames:
90
+ if num_frames > self.video_frame_count:
64
91
  # Extract all frames
65
- self.frames_to_extract = range(num_video_frames)
92
+ self.frames_to_extract = None
66
93
  else:
67
- spacing = float(num_video_frames) / float(num_frames)
94
+ spacing = float(self.video_frame_count) / float(num_frames)
68
95
  self.frames_to_extract = list(round(i * spacing) for i in range(num_frames))
69
96
  assert len(self.frames_to_extract) == num_frames
70
97
  else:
71
98
  if fps is None or fps == 0.0:
72
99
  # Extract all frames
73
- self.frames_to_extract = range(num_video_frames)
100
+ self.frames_to_extract = None
101
+ elif fps > float(self.video_framerate):
102
+ raise excs.Error(
103
+ f'Video {video}: requested fps ({fps}) exceeds that of the video ({float(self.video_framerate)})'
104
+ )
74
105
  else:
75
106
  # Extract frames at the implied frequency
76
- freq = fps / video_fps
77
- n = math.ceil(num_video_frames * freq) # number of frames to extract
107
+ freq = fps / float(self.video_framerate)
108
+ n = math.ceil(self.video_frame_count * freq) # number of frames to extract
78
109
  self.frames_to_extract = list(round(i / freq) for i in range(n))
79
110
 
80
- # We need the list of frames as both a list (for set_pos) and a set (for fast lookups when
81
- # there are lots of frames)
82
- self.frames_set = set(self.frames_to_extract)
83
111
  _logger.debug(f'FrameIterator: path={self.video_path} fps={self.fps} num_frames={self.num_frames}')
84
- self.next_frame_idx = 0
112
+ self.next_pos = 0
85
113
 
86
114
  @classmethod
87
- def input_schema(cls) -> dict[str, ColumnType]:
115
+ def input_schema(cls) -> dict[str, ts.ColumnType]:
88
116
  return {
89
- 'video': VideoType(nullable=False),
90
- 'fps': FloatType(nullable=True),
91
- 'num_frames': IntType(nullable=True),
117
+ 'video': ts.VideoType(nullable=False),
118
+ 'fps': ts.FloatType(nullable=True),
119
+ 'num_frames': ts.IntType(nullable=True),
92
120
  }
93
121
 
94
122
  @classmethod
95
- def output_schema(cls, *args: Any, **kwargs: Any) -> tuple[dict[str, ColumnType], list[str]]:
123
+ def output_schema(cls, *args: Any, **kwargs: Any) -> tuple[dict[str, ts.ColumnType], list[str]]:
96
124
  return {
97
- 'frame_idx': IntType(),
98
- 'pos_msec': FloatType(),
99
- 'pos_frame': FloatType(),
100
- 'frame': ImageType(),
125
+ 'frame_idx': ts.IntType(),
126
+ 'pos_msec': ts.FloatType(),
127
+ 'pos_frame': ts.IntType(),
128
+ 'frame': ts.ImageType(),
101
129
  }, ['frame']
102
130
 
103
131
  def __next__(self) -> dict[str, Any]:
104
- # jumping to the target frame here with video_reader.set() is far slower than just
105
- # skipping the unwanted frames
132
+ # Determine the frame index in the video corresponding to the iterator index `next_pos`;
133
+ # the frame at this index is the one we want to extract next
134
+ if self.frames_to_extract is None:
135
+ next_video_idx = self.next_pos # we're extracting all frames
136
+ elif self.next_pos >= len(self.frames_to_extract):
137
+ raise StopIteration
138
+ else:
139
+ next_video_idx = self.frames_to_extract[self.next_pos]
140
+
141
+ # We are searching for the frame at the index implied by `next_pos`. Step through the video until we
142
+ # find it. There are two reasons why it might not be the immediate next frame in the video:
143
+ # (1) `fps` or `num_frames` was specified as an iterator argument; or
144
+ # (2) we just did a seek, and the desired frame is not a keyframe.
145
+ # TODO: In case (1) it will usually be fastest to step through the frames until we find the one we're
146
+ # looking for. But in some cases it may be faster to do a seek; for example, when `fps` is very
147
+ # low and there are multiple keyframes in between each frame we want to extract (imagine extracting
148
+ # 10 frames from an hourlong video).
106
149
  while True:
107
- pos_msec = self.video_reader.get(cv2.CAP_PROP_POS_MSEC)
108
- pos_frame = self.video_reader.get(cv2.CAP_PROP_POS_FRAMES)
109
- status, img = self.video_reader.read()
110
- if not status:
111
- _logger.debug(f'releasing video reader for {self.video_path}')
112
- self.video_reader.release()
113
- self.video_reader = None
150
+ try:
151
+ frame = next(self.container.decode(video=0))
152
+ except EOFError:
114
153
  raise StopIteration
115
- if pos_frame in self.frames_set:
116
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
117
- result = {
118
- 'frame_idx': self.next_frame_idx,
119
- 'pos_msec': pos_msec,
120
- 'pos_frame': pos_frame,
121
- 'frame': PIL.Image.fromarray(img),
122
- }
123
- self.next_frame_idx += 1
124
- return result
154
+ # Compute the index of the current frame in the video based on the presentation timestamp (pts);
155
+ # this ensures we have a canonical understanding of frame index, regardless of how we got here
156
+ # (seek or iteration)
157
+ pts = frame.pts - self.video_start_time
158
+ video_idx = round(pts * self.video_time_base * self.video_framerate)
159
+ assert isinstance(video_idx, int)
160
+ if video_idx < next_video_idx:
161
+ # We haven't reached the desired frame yet
162
+ continue
163
+
164
+ # Sanity check that we're at the right frame.
165
+ if video_idx != next_video_idx:
166
+ raise excs.Error(f'Frame {next_video_idx} is missing from the video (video file is corrupt)')
167
+ img = frame.to_image()
168
+ assert isinstance(img, PIL.Image.Image)
169
+ pos_msec = float(pts * self.video_time_base * 1000)
170
+ result = {
171
+ 'frame_idx': self.next_pos,
172
+ 'pos_msec': pos_msec,
173
+ 'pos_frame': video_idx,
174
+ 'frame': img,
175
+ }
176
+ self.next_pos += 1
177
+ return result
125
178
 
126
179
  def close(self) -> None:
127
- if self.video_reader is not None:
128
- self.video_reader.release()
129
- self.video_reader = None
180
+ self.container.close()
130
181
 
131
182
  def set_pos(self, pos: int) -> None:
132
183
  """Seek to frame idx"""
133
- if pos == self.next_frame_idx:
134
- return
135
- _logger.debug(f'seeking to frame {pos}')
136
- self.video_reader.set(cv2.CAP_PROP_POS_FRAMES, self.frames_to_extract[pos])
137
- self.next_frame_idx = pos
184
+ if pos == self.next_pos:
185
+ return # already there
186
+
187
+ video_idx = pos if self.frames_to_extract is None else self.frames_to_extract[pos]
188
+ _logger.debug(f'seeking to frame number {video_idx} (at iterator index {pos})')
189
+ # compute the frame position in time_base units
190
+ seek_pos = int(video_idx / self.video_framerate / self.video_time_base + self.video_start_time)
191
+ # This will seek to the nearest keyframe before the desired frame. If the frame being sought is not a keyframe,
192
+ # then the iterator will step forward to the desired frame on the subsequent call to next().
193
+ self.container.seek(seek_pos, backward=True, stream=self.container.streams.video[0])
194
+ self.next_pos = pos
@@ -10,7 +10,7 @@ import sqlalchemy.orm as orm
10
10
  from .schema import SystemInfo, SystemInfoMd
11
11
 
12
12
  # current version of the metadata; this is incremented whenever the metadata schema changes
13
- VERSION = 21
13
+ VERSION = 22
14
14
 
15
15
 
16
16
  def create_system_info(engine: sql.engine.Engine) -> None:
@@ -0,0 +1,34 @@
1
+ from typing import Any, Optional
2
+ import sqlalchemy as sql
3
+
4
+ from pixeltable.metadata import register_converter
5
+ from pixeltable.metadata.converters.util import convert_table_schema_version_md, convert_table_md
6
+
7
+
8
+ @register_converter(version=21)
9
+ def _(engine: sql.engine.Engine) -> None:
10
+ convert_table_schema_version_md(
11
+ engine,
12
+ table_schema_version_md_updater=__update_table_schema_version,
13
+ schema_column_updater=__update_schema_column
14
+ )
15
+ convert_table_md(
16
+ engine,
17
+ substitution_fn=__substitute_md
18
+ )
19
+
20
+
21
+ def __update_table_schema_version(table_schema_version_md: dict) -> None:
22
+ table_schema_version_md['media_validation'] = 'on_write' # MediaValidation.ON_WRITE
23
+
24
+
25
+ def __update_schema_column(schema_column: dict) -> None:
26
+ schema_column['media_validation'] = None
27
+
28
+
29
+ def __substitute_md(k: Optional[str], v: Any) -> Optional[tuple[Optional[str], Any]]:
30
+ if isinstance(v, dict) and '_classname' in v and v['_classname'] == 'ColumnRef':
31
+ if 'perform_validation' not in v:
32
+ v['perform_validation'] = False
33
+ return k, v
34
+ return None
@@ -4,7 +4,7 @@ from typing import Any, Callable, Optional
4
4
 
5
5
  import sqlalchemy as sql
6
6
 
7
- from pixeltable.metadata.schema import Table
7
+ from pixeltable.metadata.schema import Table, TableSchemaVersion
8
8
 
9
9
  __logger = logging.getLogger('pixeltable')
10
10
 
@@ -17,12 +17,12 @@ def convert_table_md(
17
17
  substitution_fn: Optional[Callable[[Optional[str], Any], Optional[tuple[Optional[str], Any]]]] = None
18
18
  ) -> None:
19
19
  """
20
- Converts table metadata based on the specified conversion functions.
20
+ Converts schema.TableMd dicts based on the specified conversion functions.
21
21
 
22
22
  Args:
23
23
  engine: The SQLAlchemy engine.
24
- table_md_updater: A function that updates the table metadata in place.
25
- column_md_updater: A function that updates the column metadata in place.
24
+ table_md_updater: A function that updates schema.TableMd dicts in place.
25
+ column_md_updater: A function that updates schema.ColumnMd dicts in place.
26
26
  external_store_md_updater: A function that updates the external store metadata in place.
27
27
  substitution_fn: A function that substitutes metadata values. If specified, all metadata will be traversed
28
28
  recursively, and `substitution_fn` will be called once for each metadata entry. If the entry appears in
@@ -90,3 +90,44 @@ def __substitute_md_rec(
90
90
  return updated_list
91
91
  else:
92
92
  return md
93
+
94
+
95
+ def convert_table_schema_version_md(
96
+ engine: sql.engine.Engine,
97
+ table_schema_version_md_updater: Optional[Callable[[dict], None]] = None,
98
+ schema_column_updater: Optional[Callable[[dict], None]] = None
99
+ ) -> None:
100
+ """
101
+ Converts schema.TableSchemaVersionMd dicts based on the specified conversion functions.
102
+
103
+ Args:
104
+ engine: The SQLAlchemy engine.
105
+ table_schema_version_md_updater: A function that updates schema.TableSchemaVersionMd dicts in place.
106
+ schema_column_updater: A function that updates schema.SchemaColumn dicts in place.
107
+ """
108
+ with engine.begin() as conn:
109
+ stmt = sql.select(TableSchemaVersion.tbl_id, TableSchemaVersion.schema_version, TableSchemaVersion.md)
110
+ for row in conn.execute(stmt):
111
+ tbl_id, schema_version, md = row[0], row[1], row[2]
112
+ assert isinstance(md, dict)
113
+ updated_md = copy.deepcopy(md)
114
+ if table_schema_version_md_updater is not None:
115
+ table_schema_version_md_updater(updated_md)
116
+ if schema_column_updater is not None:
117
+ __update_schema_column(updated_md, schema_column_updater)
118
+ if updated_md != md:
119
+ __logger.info(f'Updating TableSchemaVersion(tbl_id={tbl_id}, schema_version={schema_version})')
120
+ update_stmt = (
121
+ sql.update(TableSchemaVersion)
122
+ .where(TableSchemaVersion.tbl_id == tbl_id)
123
+ .where(TableSchemaVersion.schema_version == schema_version)
124
+ .values(md=updated_md)
125
+ )
126
+ conn.execute(update_stmt)
127
+
128
+
129
+ def __update_schema_column(table_schema_version_md: dict, schema_column_updater: Callable[[dict], None]) -> None:
130
+ cols = table_schema_version_md['columns']
131
+ assert isinstance(cols, dict)
132
+ for schema_col in cols.values():
133
+ schema_column_updater(schema_col)
@@ -2,6 +2,7 @@
2
2
  # rather than as a comment, so that the existence of a description can be enforced by
3
3
  # the unit tests when new versions are added.
4
4
  VERSION_NOTES = {
5
+ 22: 'TableMd/ColumnMd.media_validation',
5
6
  21: 'Separate InlineArray and InlineList',
6
7
  20: 'Store DB timestamps in UTC',
7
8
  19: 'UDF renames; ImageMemberAccess removal',
@@ -202,6 +202,10 @@ class SchemaColumn:
202
202
  pos: int
203
203
  name: str
204
204
 
205
+ # media validation strategy of this particular media column; if not set, TableMd.media_validation applies
206
+ # stores column.MediaValiation.name.lower()
207
+ media_validation: Optional[str]
208
+
205
209
 
206
210
  @dataclasses.dataclass
207
211
  class TableSchemaVersionMd:
@@ -214,6 +218,10 @@ class TableSchemaVersionMd:
214
218
  num_retained_versions: int
215
219
  comment: str
216
220
 
221
+ # default validation strategy for any media column of this table
222
+ # stores column.MediaValiation.name.lower()
223
+ media_validation: str
224
+
217
225
 
218
226
  # versioning: each table schema change results in a new record
219
227
  class TableSchemaVersion(Base):
pixeltable/plan.py CHANGED
@@ -225,27 +225,28 @@ class Planner:
225
225
  assert not tbl.is_view()
226
226
  # stored_cols: all cols we need to store, incl computed cols (and indices)
227
227
  stored_cols = [c for c in tbl.cols if c.is_stored]
228
- assert len(stored_cols) > 0
229
-
228
+ assert len(stored_cols) > 0 # there needs to be something to store
230
229
  row_builder = exprs.RowBuilder([], stored_cols, [])
231
230
 
232
231
  # create InMemoryDataNode for 'rows'
233
- stored_col_info = row_builder.output_slot_idxs()
234
- stored_img_col_info = [info for info in stored_col_info if info.col.col_type.is_image_type()]
235
- input_col_info = [info for info in stored_col_info if not info.col.is_computed]
236
232
  plan: exec.ExecNode = exec.InMemoryDataNode(tbl, rows, row_builder, tbl.next_rowid)
237
233
 
238
- media_input_cols = [info for info in input_col_info if info.col.col_type.is_media_type()]
239
- if len(media_input_cols) > 0:
240
- # prefetch external files for all input column refs for validation
241
- plan = exec.CachePrefetchNode(tbl.id, media_input_cols, input=plan)
242
- plan = exec.MediaValidationNode(row_builder, media_input_cols, input=plan)
234
+ media_input_col_info = [
235
+ exprs.ColumnSlotIdx(col_ref.col, col_ref.slot_idx)
236
+ for col_ref in row_builder.input_exprs
237
+ if isinstance(col_ref, exprs.ColumnRef) and col_ref.col_type.is_media_type()
238
+ ]
239
+ if len(media_input_col_info) > 0:
240
+ # prefetch external files for all input column refs
241
+ plan = exec.CachePrefetchNode(tbl.id, media_input_col_info, input=plan)
243
242
 
244
- computed_exprs = [e for e in row_builder.default_eval_ctx.target_exprs if not isinstance(e, exprs.ColumnRef)]
243
+ computed_exprs = row_builder.output_exprs - row_builder.input_exprs
245
244
  if len(computed_exprs) > 0:
246
245
  # add an ExprEvalNode when there are exprs to compute
247
246
  plan = exec.ExprEvalNode(row_builder, computed_exprs, plan.output_exprs, input=plan)
248
247
 
248
+ stored_col_info = row_builder.output_slot_idxs()
249
+ stored_img_col_info = [info for info in stored_col_info if info.col.col_type.is_image_type()]
249
250
  plan.set_stored_img_cols(stored_img_col_info)
250
251
  plan.set_ctx(
251
252
  exec.ExecContext(
@@ -621,7 +622,8 @@ class Planner:
621
622
  assert isinstance(tbl, catalog.TableVersionPath)
622
623
  sql_elements = analyzer.sql_elements
623
624
  is_python_agg = (
624
- not sql_elements.contains(analyzer.agg_fn_calls) or not sql_elements.contains(analyzer.window_fn_calls)
625
+ not sql_elements.contains_all(analyzer.agg_fn_calls)
626
+ or not sql_elements.contains_all(analyzer.window_fn_calls)
625
627
  )
626
628
  ctx = exec.ExecContext(row_builder)
627
629
  cls._verify_ordering(analyzer, verify_agg=is_python_agg)
@@ -671,8 +673,8 @@ class Planner:
671
673
  ctx.batch_size = 16
672
674
 
673
675
  # do aggregation in SQL if all agg exprs can be translated
674
- if (sql_elements.contains(analyzer.select_list)
675
- and sql_elements.contains(analyzer.grouping_exprs)
676
+ if (sql_elements.contains_all(analyzer.select_list)
677
+ and sql_elements.contains_all(analyzer.grouping_exprs)
676
678
  and isinstance(plan, exec.SqlNode)
677
679
  and plan.to_cte() is not None):
678
680
  plan = exec.SqlAggregationNode(
pixeltable/py.typed ADDED
File without changes
pixeltable/store.py CHANGED
@@ -303,7 +303,7 @@ class StoreBase:
303
303
 
304
304
  def insert_rows(
305
305
  self, exec_plan: ExecNode, conn: sql.engine.Connection, v_min: Optional[int] = None,
306
- show_progress: bool = True, rowids: Optional[Iterator[int]] = None
306
+ show_progress: bool = True, rowids: Optional[Iterator[int]] = None, abort_on_exc: bool = False
307
307
  ) -> tuple[int, int, set[int]]:
308
308
  """Insert rows into the store table and update the catalog table's md
309
309
  Returns:
@@ -325,8 +325,13 @@ class StoreBase:
325
325
  for batch_start_idx in range(0, len(row_batch), self.__INSERT_BATCH_SIZE):
326
326
  # compute batch of rows and convert them into table rows
327
327
  table_rows: list[dict[str, Any]] = []
328
- for row_idx in range(batch_start_idx, min(batch_start_idx + self.__INSERT_BATCH_SIZE, len(row_batch))):
328
+ batch_stop_idx = min(batch_start_idx + self.__INSERT_BATCH_SIZE, len(row_batch))
329
+ for row_idx in range(batch_start_idx, batch_stop_idx):
329
330
  row = row_batch[row_idx]
331
+ # if abort_on_exc == True, we need to check for media validation exceptions
332
+ if abort_on_exc and row.has_exc():
333
+ exc = row.get_first_exc()
334
+ raise exc
330
335
 
331
336
  rowid = (next(rowids),) if rowids is not None else row.pk[:-1]
332
337
  pk = rowid + (v_min,)
@@ -1,4 +1,4 @@
1
- import av
1
+ import av # type: ignore[import-untyped]
2
2
  import PIL.Image
3
3
  import PIL.ImageDraw
4
4
  import PIL.ImageFont
@@ -6,4 +6,4 @@ import pixeltable as pxt
6
6
  # TODO This can go away once we have the ability to inline expr_udf's
7
7
  @pxt.expr_udf
8
8
  def clip_text_embed(txt: str) -> np.ndarray:
9
- return pxt.functions.huggingface.clip_text(txt, model_id='openai/clip-vit-base-patch32')
9
+ return pxt.functions.huggingface.clip_text(txt, model_id='openai/clip-vit-base-patch32') # type: ignore[return-value]
@@ -1,12 +1,15 @@
1
1
  from typing import Callable, Optional
2
2
 
3
- from mypy.plugin import AnalyzeTypeContext, Plugin
4
- from mypy.types import Type
3
+ from mypy import nodes
4
+ from mypy.plugin import AnalyzeTypeContext, ClassDefContext, Plugin
5
+ from mypy.plugins.common import add_method_to_class
6
+ from mypy.types import AnyType, Type, TypeOfAny
5
7
 
6
8
  import pixeltable as pxt
7
9
 
8
10
 
9
11
  class PxtPlugin(Plugin):
12
+ __UDA_FULLNAME = f'{pxt.uda.__module__}.{pxt.uda.__name__}'
10
13
  __TYPE_MAP = {
11
14
  pxt.Json: 'typing.Any',
12
15
  pxt.Array: 'numpy.ndarray',
@@ -20,13 +23,33 @@ class PxtPlugin(Plugin):
20
23
  for k, v in __TYPE_MAP.items()
21
24
  }
22
25
 
23
- def get_type_analyze_hook(self, fullname: str) -> Optional[Callable[[AnalyzeTypeContext], type]]:
26
+ def get_type_analyze_hook(self, fullname: str) -> Optional[Callable[[AnalyzeTypeContext], Type]]:
24
27
  if fullname in self.__FULLNAME_MAP:
25
28
  subst_name = self.__FULLNAME_MAP[fullname]
26
29
  return lambda ctx: pxt_hook(ctx, subst_name)
30
+ return None
27
31
 
28
- def plugin(version: str):
32
+ def get_class_decorator_hook_2(self, fullname: str) -> Optional[Callable[[ClassDefContext], bool]]:
33
+ if fullname == self.__UDA_FULLNAME:
34
+ return pxt_decorator_hook
35
+ return None
36
+
37
+ def plugin(version: str) -> type:
29
38
  return PxtPlugin
30
39
 
31
40
  def pxt_hook(ctx: AnalyzeTypeContext, subst_name: str) -> Type:
32
- return ctx.api.named_type(subst_name)
41
+ if subst_name == 'typing.Any':
42
+ return AnyType(TypeOfAny.special_form)
43
+ return ctx.api.named_type(subst_name, [])
44
+
45
+ def pxt_decorator_hook(ctx: ClassDefContext) -> bool:
46
+ arg = nodes.Argument(nodes.Var('fn'), AnyType(TypeOfAny.special_form), None, nodes.ARG_POS)
47
+ add_method_to_class(
48
+ ctx.api,
49
+ ctx.cls,
50
+ "to_sql",
51
+ args=[arg],
52
+ return_type=AnyType(TypeOfAny.special_form),
53
+ is_staticmethod=True,
54
+ )
55
+ return True
pixeltable/type_system.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import abc
4
4
  import datetime
5
5
  import enum
6
+ import io
6
7
  import json
7
8
  import typing
8
9
  import urllib.parse
@@ -10,9 +11,9 @@ import urllib.request
10
11
  from pathlib import Path
11
12
  from typing import Any, Iterable, Mapping, Optional, Sequence, Union
12
13
 
14
+ import PIL.Image
13
15
  import av # type: ignore
14
16
  import numpy as np
15
- import PIL.Image
16
17
  import sqlalchemy as sql
17
18
  from typing import _GenericAlias # type: ignore[attr-defined]
18
19
  from typing_extensions import _AnnotatedAlias
@@ -798,6 +799,20 @@ class ImageType(ColumnType):
798
799
  def to_sa_type(self) -> sql.types.TypeEngine:
799
800
  return sql.String()
800
801
 
802
+ def _create_literal(self, val: Any) -> Any:
803
+ if isinstance(val, str) and val.startswith('data:'):
804
+ # try parsing this as a `data:` URL, and if successful, decode the image immediately
805
+ try:
806
+ with urllib.request.urlopen(val) as response:
807
+ b = response.read()
808
+ img = PIL.Image.open(io.BytesIO(b))
809
+ img.load()
810
+ return img
811
+ except Exception as exc:
812
+ errormsg_val = val if len(val) < 50 else val[:50] + '...'
813
+ raise excs.Error(f'data URL could not be decoded into a valid image: {errormsg_val}') from exc
814
+ return val
815
+
801
816
  def _validate_literal(self, val: Any) -> None:
802
817
  if isinstance(val, PIL.Image.Image):
803
818
  return
@@ -876,6 +891,7 @@ class DocumentType(ColumnType):
876
891
  HTML = 0
877
892
  MD = 1
878
893
  PDF = 2
894
+ XML = 3
879
895
 
880
896
  def __init__(self, nullable: bool = False, doc_formats: Optional[str] = None):
881
897
  super().__init__(self.Type.DOCUMENT, nullable=nullable)
@@ -35,6 +35,11 @@ def get_document_handle(path: str) -> Optional[DocumentHandle]:
35
35
  if md_ast is not None:
36
36
  return DocumentHandle(format=ts.DocumentType.DocumentFormat.MD, md_ast=md_ast)
37
37
 
38
+ if doc_format == '.xml':
39
+ bs_doc = get_xml_handle(path)
40
+ if bs_doc is not None:
41
+ return DocumentHandle(format=ts.DocumentType.DocumentFormat.XML, bs_doc=bs_doc)
42
+
38
43
  return None
39
44
 
40
45
 
@@ -54,7 +59,16 @@ def get_pdf_handle(path: str) -> Optional[fitz.Document]:
54
59
  def get_html_handle(path: str) -> Optional[bs4.BeautifulSoup]:
55
60
  try:
56
61
  with open(path, 'r', encoding='utf8') as fp:
57
- doc = bs4.BeautifulSoup(fp, 'html.parser')
62
+ doc = bs4.BeautifulSoup(fp, 'lxml')
63
+ return doc if doc.find() is not None else None
64
+ except Exception:
65
+ return None
66
+
67
+
68
+ def get_xml_handle(path: str) -> Optional[bs4.BeautifulSoup]:
69
+ try:
70
+ with open(path, 'r', encoding='utf8') as fp:
71
+ doc = bs4.BeautifulSoup(fp, 'xml')
58
72
  return doc if doc.find() is not None else None
59
73
  except Exception:
60
74
  return None