pixeltable 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (120) hide show
  1. pixeltable/__init__.py +7 -19
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +7 -7
  4. pixeltable/catalog/column.py +37 -11
  5. pixeltable/catalog/globals.py +21 -0
  6. pixeltable/catalog/insertable_table.py +6 -4
  7. pixeltable/catalog/table.py +227 -148
  8. pixeltable/catalog/table_version.py +66 -28
  9. pixeltable/catalog/table_version_path.py +0 -8
  10. pixeltable/catalog/view.py +18 -19
  11. pixeltable/dataframe.py +16 -32
  12. pixeltable/env.py +6 -1
  13. pixeltable/exec/__init__.py +1 -2
  14. pixeltable/exec/aggregation_node.py +27 -17
  15. pixeltable/exec/cache_prefetch_node.py +1 -1
  16. pixeltable/exec/data_row_batch.py +9 -26
  17. pixeltable/exec/exec_node.py +36 -7
  18. pixeltable/exec/expr_eval_node.py +19 -11
  19. pixeltable/exec/in_memory_data_node.py +14 -11
  20. pixeltable/exec/sql_node.py +266 -138
  21. pixeltable/exprs/__init__.py +1 -0
  22. pixeltable/exprs/arithmetic_expr.py +3 -1
  23. pixeltable/exprs/array_slice.py +7 -7
  24. pixeltable/exprs/column_property_ref.py +37 -10
  25. pixeltable/exprs/column_ref.py +93 -14
  26. pixeltable/exprs/comparison.py +5 -5
  27. pixeltable/exprs/compound_predicate.py +8 -7
  28. pixeltable/exprs/data_row.py +56 -36
  29. pixeltable/exprs/expr.py +65 -63
  30. pixeltable/exprs/expr_dict.py +55 -0
  31. pixeltable/exprs/expr_set.py +26 -15
  32. pixeltable/exprs/function_call.py +53 -24
  33. pixeltable/exprs/globals.py +4 -1
  34. pixeltable/exprs/in_predicate.py +8 -7
  35. pixeltable/exprs/inline_expr.py +4 -4
  36. pixeltable/exprs/is_null.py +4 -4
  37. pixeltable/exprs/json_mapper.py +11 -12
  38. pixeltable/exprs/json_path.py +5 -10
  39. pixeltable/exprs/literal.py +5 -5
  40. pixeltable/exprs/method_ref.py +5 -4
  41. pixeltable/exprs/object_ref.py +2 -1
  42. pixeltable/exprs/row_builder.py +88 -36
  43. pixeltable/exprs/rowid_ref.py +14 -13
  44. pixeltable/exprs/similarity_expr.py +12 -7
  45. pixeltable/exprs/sql_element_cache.py +12 -6
  46. pixeltable/exprs/type_cast.py +8 -6
  47. pixeltable/exprs/variable.py +5 -4
  48. pixeltable/ext/functions/whisperx.py +7 -2
  49. pixeltable/func/aggregate_function.py +1 -1
  50. pixeltable/func/callable_function.py +2 -2
  51. pixeltable/func/function.py +11 -10
  52. pixeltable/func/function_registry.py +6 -7
  53. pixeltable/func/query_template_function.py +11 -12
  54. pixeltable/func/signature.py +17 -15
  55. pixeltable/func/udf.py +0 -4
  56. pixeltable/functions/__init__.py +2 -2
  57. pixeltable/functions/audio.py +4 -6
  58. pixeltable/functions/globals.py +84 -42
  59. pixeltable/functions/huggingface.py +31 -34
  60. pixeltable/functions/image.py +59 -45
  61. pixeltable/functions/json.py +0 -1
  62. pixeltable/functions/llama_cpp.py +106 -0
  63. pixeltable/functions/mistralai.py +2 -2
  64. pixeltable/functions/ollama.py +147 -0
  65. pixeltable/functions/openai.py +22 -25
  66. pixeltable/functions/replicate.py +72 -0
  67. pixeltable/functions/string.py +59 -50
  68. pixeltable/functions/timestamp.py +20 -20
  69. pixeltable/functions/together.py +2 -2
  70. pixeltable/functions/video.py +11 -20
  71. pixeltable/functions/whisper.py +2 -20
  72. pixeltable/globals.py +65 -74
  73. pixeltable/index/base.py +2 -2
  74. pixeltable/index/btree.py +20 -7
  75. pixeltable/index/embedding_index.py +12 -14
  76. pixeltable/io/__init__.py +1 -2
  77. pixeltable/io/external_store.py +11 -5
  78. pixeltable/io/fiftyone.py +178 -0
  79. pixeltable/io/globals.py +98 -2
  80. pixeltable/io/hf_datasets.py +1 -1
  81. pixeltable/io/label_studio.py +6 -6
  82. pixeltable/io/parquet.py +14 -13
  83. pixeltable/iterators/base.py +3 -2
  84. pixeltable/iterators/document.py +10 -8
  85. pixeltable/iterators/video.py +126 -60
  86. pixeltable/metadata/__init__.py +4 -3
  87. pixeltable/metadata/converters/convert_14.py +4 -2
  88. pixeltable/metadata/converters/convert_15.py +1 -1
  89. pixeltable/metadata/converters/convert_19.py +1 -0
  90. pixeltable/metadata/converters/convert_20.py +1 -1
  91. pixeltable/metadata/converters/convert_21.py +34 -0
  92. pixeltable/metadata/converters/util.py +54 -12
  93. pixeltable/metadata/notes.py +1 -0
  94. pixeltable/metadata/schema.py +40 -21
  95. pixeltable/plan.py +149 -165
  96. pixeltable/py.typed +0 -0
  97. pixeltable/store.py +57 -37
  98. pixeltable/tool/create_test_db_dump.py +6 -6
  99. pixeltable/tool/create_test_video.py +1 -1
  100. pixeltable/tool/doc_plugins/griffe.py +3 -34
  101. pixeltable/tool/embed_udf.py +1 -1
  102. pixeltable/tool/mypy_plugin.py +55 -0
  103. pixeltable/type_system.py +260 -61
  104. pixeltable/utils/arrow.py +10 -9
  105. pixeltable/utils/coco.py +4 -4
  106. pixeltable/utils/documents.py +16 -2
  107. pixeltable/utils/filecache.py +9 -9
  108. pixeltable/utils/formatter.py +10 -11
  109. pixeltable/utils/http_server.py +2 -5
  110. pixeltable/utils/media_store.py +6 -6
  111. pixeltable/utils/pytorch.py +10 -11
  112. pixeltable/utils/sql.py +2 -1
  113. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/METADATA +50 -13
  114. pixeltable-0.2.22.dist-info/RECORD +153 -0
  115. pixeltable/exec/media_validation_node.py +0 -43
  116. pixeltable/utils/help.py +0 -11
  117. pixeltable-0.2.20.dist-info/RECORD +0 -147
  118. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/LICENSE +0 -0
  119. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/WHEEL +0 -0
  120. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/entry_points.txt +0 -0
@@ -1,13 +1,15 @@
1
1
  import logging
2
2
  import math
3
+ from fractions import Fraction
3
4
  from pathlib import Path
4
- from typing import Any, Optional
5
+ from typing import Any, Optional, Sequence
5
6
 
6
- import cv2
7
+ import av # type: ignore[import-untyped]
8
+ import pandas as pd
7
9
  import PIL.Image
8
10
 
9
- from pixeltable.exceptions import Error
10
- from pixeltable.type_system import ColumnType, FloatType, ImageType, IntType, VideoType
11
+ import pixeltable.exceptions as excs
12
+ import pixeltable.type_system as ts
11
13
 
12
14
  from .base import ComponentIterator
13
15
 
@@ -29,100 +31,164 @@ class FrameIterator(ComponentIterator):
29
31
  num_frames: Exact number of frames to extract. The frames will be spaced as evenly as possible. If
30
32
  `num_frames` is greater than the number of frames in the video, all frames will be extracted.
31
33
  """
34
+
35
+ # Input parameters
36
+ video_path: Path
37
+ fps: Optional[float]
38
+ num_frames: Optional[int]
39
+
40
+ # Video info
41
+ container: av.container.input.InputContainer
42
+ video_framerate: Fraction
43
+ video_time_base: Fraction
44
+ video_frame_count: int
45
+ video_start_time: int
46
+
47
+ # List of frame indices to be extracted, or None to extract all frames
48
+ frames_to_extract: Optional[list[int]]
49
+
50
+ # Next frame to extract, as an iterator `pos` index. If `frames_to_extract` is None, this is the same as the
51
+ # frame index in the video. Otherwise, the corresponding video index is `frames_to_extract[next_pos]`.
52
+ next_pos: int
53
+
32
54
  def __init__(self, video: str, *, fps: Optional[float] = None, num_frames: Optional[int] = None):
33
55
  if fps is not None and num_frames is not None:
34
- raise Error('At most one of `fps` or `num_frames` may be specified')
56
+ raise excs.Error('At most one of `fps` or `num_frames` may be specified')
35
57
 
36
58
  video_path = Path(video)
37
59
  assert video_path.exists() and video_path.is_file()
38
60
  self.video_path = video_path
39
- self.video_reader = cv2.VideoCapture(str(video_path))
61
+ self.container = av.open(str(video_path))
40
62
  self.fps = fps
41
63
  self.num_frames = num_frames
42
- if not self.video_reader.isOpened():
43
- raise Error(f'Failed to open video: {video}')
44
64
 
45
- video_fps = int(self.video_reader.get(cv2.CAP_PROP_FPS))
46
- if fps is not None and fps > video_fps:
47
- raise Error(f'Video {video}: requested fps ({fps}) exceeds that of the video ({video_fps})')
48
- num_video_frames = int(self.video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
49
- if num_video_frames == 0:
50
- raise Error(f'Video {video}: failed to get number of frames')
65
+ self.video_framerate = self.container.streams.video[0].average_rate
66
+ self.video_time_base = self.container.streams.video[0].time_base
67
+ self.video_start_time = self.container.streams.video[0].start_time or 0
68
+
69
+ # Determine the number of frames in the video
70
+ self.video_frame_count = self.container.streams.video[0].frames
71
+ if self.video_frame_count == 0:
72
+ # The video codec does not provide a frame count in the standard `frames` field. Try some other methods.
73
+ metadata: dict = self.container.streams.video[0].metadata
74
+ if 'NUMBER_OF_FRAMES' in metadata:
75
+ self.video_frame_count = int(metadata['NUMBER_OF_FRAMES'])
76
+ elif 'DURATION' in metadata:
77
+ # As a last resort, calculate the frame count from the stream duration.
78
+ duration = metadata['DURATION']
79
+ assert isinstance(duration, str)
80
+ seconds = pd.to_timedelta(duration).total_seconds()
81
+ # Usually the duration and framerate are precise enough for this calculation to be accurate, but if
82
+ # we encounter a case where it's off by one due to a rounding error, that's ok; we only use this
83
+ # to determine the positions of the sampled frames when `fps` or `num_frames` is specified.
84
+ self.video_frame_count = round(seconds * self.video_framerate)
85
+ else:
86
+ raise excs.Error(f'Video {video}: failed to get number of frames')
51
87
 
52
88
  if num_frames is not None:
53
89
  # specific number of frames
54
- if num_frames > num_video_frames:
90
+ if num_frames > self.video_frame_count:
55
91
  # Extract all frames
56
- self.frames_to_extract = range(num_video_frames)
92
+ self.frames_to_extract = None
57
93
  else:
58
- spacing = float(num_video_frames) / float(num_frames)
94
+ spacing = float(self.video_frame_count) / float(num_frames)
59
95
  self.frames_to_extract = list(round(i * spacing) for i in range(num_frames))
60
96
  assert len(self.frames_to_extract) == num_frames
61
97
  else:
62
98
  if fps is None or fps == 0.0:
63
99
  # Extract all frames
64
- self.frames_to_extract = range(num_video_frames)
100
+ self.frames_to_extract = None
101
+ elif fps > float(self.video_framerate):
102
+ raise excs.Error(
103
+ f'Video {video}: requested fps ({fps}) exceeds that of the video ({float(self.video_framerate)})'
104
+ )
65
105
  else:
66
106
  # Extract frames at the implied frequency
67
- freq = fps / video_fps
68
- n = math.ceil(num_video_frames * freq) # number of frames to extract
107
+ freq = fps / float(self.video_framerate)
108
+ n = math.ceil(self.video_frame_count * freq) # number of frames to extract
69
109
  self.frames_to_extract = list(round(i / freq) for i in range(n))
70
110
 
71
- # We need the list of frames as both a list (for set_pos) and a set (for fast lookups when
72
- # there are lots of frames)
73
- self.frames_set = set(self.frames_to_extract)
74
111
  _logger.debug(f'FrameIterator: path={self.video_path} fps={self.fps} num_frames={self.num_frames}')
75
- self.next_frame_idx = 0
112
+ self.next_pos = 0
76
113
 
77
114
  @classmethod
78
- def input_schema(cls) -> dict[str, ColumnType]:
115
+ def input_schema(cls) -> dict[str, ts.ColumnType]:
79
116
  return {
80
- 'video': VideoType(nullable=False),
81
- 'fps': FloatType(nullable=True),
82
- 'num_frames': IntType(nullable=True),
117
+ 'video': ts.VideoType(nullable=False),
118
+ 'fps': ts.FloatType(nullable=True),
119
+ 'num_frames': ts.IntType(nullable=True),
83
120
  }
84
121
 
85
122
  @classmethod
86
- def output_schema(cls, *args: Any, **kwargs: Any) -> tuple[dict[str, ColumnType], list[str]]:
123
+ def output_schema(cls, *args: Any, **kwargs: Any) -> tuple[dict[str, ts.ColumnType], list[str]]:
87
124
  return {
88
- 'frame_idx': IntType(),
89
- 'pos_msec': FloatType(),
90
- 'pos_frame': FloatType(),
91
- 'frame': ImageType(),
125
+ 'frame_idx': ts.IntType(),
126
+ 'pos_msec': ts.FloatType(),
127
+ 'pos_frame': ts.IntType(),
128
+ 'frame': ts.ImageType(),
92
129
  }, ['frame']
93
130
 
94
131
  def __next__(self) -> dict[str, Any]:
95
- # jumping to the target frame here with video_reader.set() is far slower than just
96
- # skipping the unwanted frames
132
+ # Determine the frame index in the video corresponding to the iterator index `next_pos`;
133
+ # the frame at this index is the one we want to extract next
134
+ if self.frames_to_extract is None:
135
+ next_video_idx = self.next_pos # we're extracting all frames
136
+ elif self.next_pos >= len(self.frames_to_extract):
137
+ raise StopIteration
138
+ else:
139
+ next_video_idx = self.frames_to_extract[self.next_pos]
140
+
141
+ # We are searching for the frame at the index implied by `next_pos`. Step through the video until we
142
+ # find it. There are two reasons why it might not be the immediate next frame in the video:
143
+ # (1) `fps` or `num_frames` was specified as an iterator argument; or
144
+ # (2) we just did a seek, and the desired frame is not a keyframe.
145
+ # TODO: In case (1) it will usually be fastest to step through the frames until we find the one we're
146
+ # looking for. But in some cases it may be faster to do a seek; for example, when `fps` is very
147
+ # low and there are multiple keyframes in between each frame we want to extract (imagine extracting
148
+ # 10 frames from an hourlong video).
97
149
  while True:
98
- pos_msec = self.video_reader.get(cv2.CAP_PROP_POS_MSEC)
99
- pos_frame = self.video_reader.get(cv2.CAP_PROP_POS_FRAMES)
100
- status, img = self.video_reader.read()
101
- if not status:
102
- _logger.debug(f'releasing video reader for {self.video_path}')
103
- self.video_reader.release()
104
- self.video_reader = None
150
+ try:
151
+ frame = next(self.container.decode(video=0))
152
+ except EOFError:
105
153
  raise StopIteration
106
- if pos_frame in self.frames_set:
107
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
108
- result = {
109
- 'frame_idx': self.next_frame_idx,
110
- 'pos_msec': pos_msec,
111
- 'pos_frame': pos_frame,
112
- 'frame': PIL.Image.fromarray(img),
113
- }
114
- self.next_frame_idx += 1
115
- return result
154
+ # Compute the index of the current frame in the video based on the presentation timestamp (pts);
155
+ # this ensures we have a canonical understanding of frame index, regardless of how we got here
156
+ # (seek or iteration)
157
+ pts = frame.pts - self.video_start_time
158
+ video_idx = round(pts * self.video_time_base * self.video_framerate)
159
+ assert isinstance(video_idx, int)
160
+ if video_idx < next_video_idx:
161
+ # We haven't reached the desired frame yet
162
+ continue
163
+
164
+ # Sanity check that we're at the right frame.
165
+ if video_idx != next_video_idx:
166
+ raise excs.Error(f'Frame {next_video_idx} is missing from the video (video file is corrupt)')
167
+ img = frame.to_image()
168
+ assert isinstance(img, PIL.Image.Image)
169
+ pos_msec = float(pts * self.video_time_base * 1000)
170
+ result = {
171
+ 'frame_idx': self.next_pos,
172
+ 'pos_msec': pos_msec,
173
+ 'pos_frame': video_idx,
174
+ 'frame': img,
175
+ }
176
+ self.next_pos += 1
177
+ return result
116
178
 
117
179
  def close(self) -> None:
118
- if self.video_reader is not None:
119
- self.video_reader.release()
120
- self.video_reader = None
180
+ self.container.close()
121
181
 
122
182
  def set_pos(self, pos: int) -> None:
123
183
  """Seek to frame idx"""
124
- if pos == self.next_frame_idx:
125
- return
126
- _logger.debug(f'seeking to frame {pos}')
127
- self.video_reader.set(cv2.CAP_PROP_POS_FRAMES, self.frames_to_extract[pos])
128
- self.next_frame_idx = pos
184
+ if pos == self.next_pos:
185
+ return # already there
186
+
187
+ video_idx = pos if self.frames_to_extract is None else self.frames_to_extract[pos]
188
+ _logger.debug(f'seeking to frame number {video_idx} (at iterator index {pos})')
189
+ # compute the frame position in time_base units
190
+ seek_pos = int(video_idx / self.video_framerate / self.video_time_base + self.video_start_time)
191
+ # This will seek to the nearest keyframe before the desired frame. If the frame being sought is not a keyframe,
192
+ # then the iterator will step forward to the desired frame on the subsequent call to next().
193
+ self.container.seek(seek_pos, backward=True, stream=self.container.streams.video[0])
194
+ self.next_pos = pos
@@ -2,7 +2,7 @@ import dataclasses
2
2
  import importlib
3
3
  import os
4
4
  import pkgutil
5
- from typing import Callable, Dict
5
+ from typing import Callable
6
6
 
7
7
  import sqlalchemy as sql
8
8
  import sqlalchemy.orm as orm
@@ -10,7 +10,7 @@ import sqlalchemy.orm as orm
10
10
  from .schema import SystemInfo, SystemInfoMd
11
11
 
12
12
  # current version of the metadata; this is incremented whenever the metadata schema changes
13
- VERSION = 21
13
+ VERSION = 22
14
14
 
15
15
 
16
16
  def create_system_info(engine: sql.engine.Engine) -> None:
@@ -24,7 +24,7 @@ def create_system_info(engine: sql.engine.Engine) -> None:
24
24
 
25
25
  # conversion functions for upgrading the metadata schema from one version to the following
26
26
  # key: old schema version
27
- converter_cbs: Dict[int, Callable[[sql.engine.Engine], None]] = {}
27
+ converter_cbs: dict[int, Callable[[sql.engine.Engine], None]] = {}
28
28
 
29
29
  def register_converter(version: int) -> Callable[[Callable[[sql.engine.Engine], None]], None]:
30
30
  def decorator(fn: Callable[[sql.engine.Engine], None]) -> None:
@@ -41,6 +41,7 @@ def upgrade_md(engine: sql.engine.Engine) -> None:
41
41
  with orm.Session(engine) as session:
42
42
  system_info = session.query(SystemInfo).one().md
43
43
  md_version = system_info['schema_version']
44
+ assert isinstance(md_version, int)
44
45
  if md_version == VERSION:
45
46
  return
46
47
  while md_version < VERSION:
@@ -1,11 +1,13 @@
1
+ from typing import Any
2
+
1
3
  import sqlalchemy as sql
2
4
 
3
- from pixeltable.metadata.schema import Table
4
5
  from pixeltable.metadata import register_converter
6
+ from pixeltable.metadata.schema import Table
5
7
 
6
8
 
7
9
  @register_converter(version=14)
8
10
  def _(engine: sql.engine.Engine) -> None:
9
- default_remotes = {'remotes': []}
11
+ default_remotes: dict[str, Any] = {'remotes': []}
10
12
  with engine.begin() as conn:
11
13
  conn.execute(sql.update(Table).where(Table.md['remotes'] == None).values(md=Table.md.concat(default_remotes)))
@@ -3,7 +3,7 @@ import inspect
3
3
  import logging
4
4
  from typing import Any
5
5
 
6
- import cloudpickle
6
+ import cloudpickle # type: ignore[import-untyped]
7
7
  import sqlalchemy as sql
8
8
 
9
9
  import pixeltable.func as func
@@ -44,3 +44,4 @@ def __update_timestamp_literals(k: Any, v: Any) -> Optional[tuple[Any, Any]]:
44
44
  dt_utc = dt.astimezone(datetime.timezone.utc)
45
45
  v['val'] = dt_utc.isoformat()
46
46
  return k, v
47
+ return None
@@ -35,7 +35,7 @@ def __substitute_md(k: Optional[str], v: Any) -> Optional[tuple[Optional[str], A
35
35
  # but it might actually be transformed into an InlineList when it is instantiated
36
36
  # (unfortunately, there is no way to disambiguate at this stage; see comments in
37
37
  # InlineArray._from_dict() for more details).
38
- updated_v = {'_classname': 'InlineList' if v.get('is_json') else 'InlineArray'}
38
+ updated_v: dict[str, Any] = {'_classname': 'InlineList' if v.get('is_json') else 'InlineArray'}
39
39
  if len(updated_components) > 0:
40
40
  updated_v['components'] = updated_components
41
41
  return k, updated_v
@@ -0,0 +1,34 @@
1
+ from typing import Any, Optional
2
+ import sqlalchemy as sql
3
+
4
+ from pixeltable.metadata import register_converter
5
+ from pixeltable.metadata.converters.util import convert_table_schema_version_md, convert_table_md
6
+
7
+
8
+ @register_converter(version=21)
9
+ def _(engine: sql.engine.Engine) -> None:
10
+ convert_table_schema_version_md(
11
+ engine,
12
+ table_schema_version_md_updater=__update_table_schema_version,
13
+ schema_column_updater=__update_schema_column
14
+ )
15
+ convert_table_md(
16
+ engine,
17
+ substitution_fn=__substitute_md
18
+ )
19
+
20
+
21
+ def __update_table_schema_version(table_schema_version_md: dict) -> None:
22
+ table_schema_version_md['media_validation'] = 'on_write' # MediaValidation.ON_WRITE
23
+
24
+
25
+ def __update_schema_column(schema_column: dict) -> None:
26
+ schema_column['media_validation'] = None
27
+
28
+
29
+ def __substitute_md(k: Optional[str], v: Any) -> Optional[tuple[Optional[str], Any]]:
30
+ if isinstance(v, dict) and '_classname' in v and v['_classname'] == 'ColumnRef':
31
+ if 'perform_validation' not in v:
32
+ v['perform_validation'] = False
33
+ return k, v
34
+ return None
@@ -4,7 +4,7 @@ from typing import Any, Callable, Optional
4
4
 
5
5
  import sqlalchemy as sql
6
6
 
7
- from pixeltable.metadata.schema import Table
7
+ from pixeltable.metadata.schema import Table, TableSchemaVersion
8
8
 
9
9
  __logger = logging.getLogger('pixeltable')
10
10
 
@@ -17,12 +17,12 @@ def convert_table_md(
17
17
  substitution_fn: Optional[Callable[[Optional[str], Any], Optional[tuple[Optional[str], Any]]]] = None
18
18
  ) -> None:
19
19
  """
20
- Converts table metadata based on the specified conversion functions.
20
+ Converts schema.TableMd dicts based on the specified conversion functions.
21
21
 
22
22
  Args:
23
23
  engine: The SQLAlchemy engine.
24
- table_md_updater: A function that updates the table metadata in place.
25
- column_md_updater: A function that updates the column metadata in place.
24
+ table_md_updater: A function that updates schema.TableMd dicts in place.
25
+ column_md_updater: A function that updates schema.ColumnMd dicts in place.
26
26
  external_store_md_updater: A function that updates the external store metadata in place.
27
27
  substitution_fn: A function that substitutes metadata values. If specified, all metadata will be traversed
28
28
  recursively, and `substitution_fn` will be called once for each metadata entry. If the entry appears in
@@ -68,24 +68,66 @@ def __substitute_md_rec(
68
68
  substitution_fn: Callable[[Optional[str], Any], Optional[tuple[Optional[str], Any]]]
69
69
  ) -> Any:
70
70
  if isinstance(md, dict):
71
- updated_md = {}
71
+ updated_dict: dict[str, Any] = {}
72
72
  for k, v in md.items():
73
+ assert isinstance(k, str)
73
74
  substitute = substitution_fn(k, v)
74
75
  if substitute is not None:
75
76
  updated_k, updated_v = substitute
76
- updated_md[updated_k] = __substitute_md_rec(updated_v, substitution_fn)
77
+ updated_dict[updated_k] = __substitute_md_rec(updated_v, substitution_fn)
77
78
  else:
78
- updated_md[k] = __substitute_md_rec(v, substitution_fn)
79
- return updated_md
79
+ updated_dict[k] = __substitute_md_rec(v, substitution_fn)
80
+ return updated_dict
80
81
  elif isinstance(md, list):
81
- updated_md = []
82
+ updated_list: list[Any] = []
82
83
  for v in md:
83
84
  substitute = substitution_fn(None, v)
84
85
  if substitute is not None:
85
86
  _, updated_v = substitute
86
- updated_md.append(__substitute_md_rec(updated_v, substitution_fn))
87
+ updated_list.append(__substitute_md_rec(updated_v, substitution_fn))
87
88
  else:
88
- updated_md.append(__substitute_md_rec(v, substitution_fn))
89
- return updated_md
89
+ updated_list.append(__substitute_md_rec(v, substitution_fn))
90
+ return updated_list
90
91
  else:
91
92
  return md
93
+
94
+
95
+ def convert_table_schema_version_md(
96
+ engine: sql.engine.Engine,
97
+ table_schema_version_md_updater: Optional[Callable[[dict], None]] = None,
98
+ schema_column_updater: Optional[Callable[[dict], None]] = None
99
+ ) -> None:
100
+ """
101
+ Converts schema.TableSchemaVersionMd dicts based on the specified conversion functions.
102
+
103
+ Args:
104
+ engine: The SQLAlchemy engine.
105
+ table_schema_version_md_updater: A function that updates schema.TableSchemaVersionMd dicts in place.
106
+ schema_column_updater: A function that updates schema.SchemaColumn dicts in place.
107
+ """
108
+ with engine.begin() as conn:
109
+ stmt = sql.select(TableSchemaVersion.tbl_id, TableSchemaVersion.schema_version, TableSchemaVersion.md)
110
+ for row in conn.execute(stmt):
111
+ tbl_id, schema_version, md = row[0], row[1], row[2]
112
+ assert isinstance(md, dict)
113
+ updated_md = copy.deepcopy(md)
114
+ if table_schema_version_md_updater is not None:
115
+ table_schema_version_md_updater(updated_md)
116
+ if schema_column_updater is not None:
117
+ __update_schema_column(updated_md, schema_column_updater)
118
+ if updated_md != md:
119
+ __logger.info(f'Updating TableSchemaVersion(tbl_id={tbl_id}, schema_version={schema_version})')
120
+ update_stmt = (
121
+ sql.update(TableSchemaVersion)
122
+ .where(TableSchemaVersion.tbl_id == tbl_id)
123
+ .where(TableSchemaVersion.schema_version == schema_version)
124
+ .values(md=updated_md)
125
+ )
126
+ conn.execute(update_stmt)
127
+
128
+
129
+ def __update_schema_column(table_schema_version_md: dict, schema_column_updater: Callable[[dict], None]) -> None:
130
+ cols = table_schema_version_md['columns']
131
+ assert isinstance(cols, dict)
132
+ for schema_col in cols.values():
133
+ schema_column_updater(schema_col)
@@ -2,6 +2,7 @@
2
2
  # rather than as a comment, so that the existence of a description can be enforced by
3
3
  # the unit tests when new versions are added.
4
4
  VERSION_NOTES = {
5
+ 22: 'TableMd/ColumnMd.media_validation',
5
6
  21: 'Separate InlineArray and InlineList',
6
7
  20: 'Store DB timestamps in UTC',
7
8
  19: 'UDF renames; ImageMemberAccess removal',
@@ -1,37 +1,48 @@
1
1
  import dataclasses
2
+ import typing
2
3
  import uuid
3
- from typing import Optional, List, get_type_hints, Type, Any, TypeVar, Tuple, Union
4
+ from typing import Any, Optional, TypeVar, Union, get_type_hints
4
5
 
5
6
  import sqlalchemy as sql
6
7
  import sqlalchemy.orm as orm
7
- from sqlalchemy import ForeignKey
8
- from sqlalchemy import Integer, BigInteger, LargeBinary
9
- from sqlalchemy.dialects.postgresql import UUID, JSONB
8
+ from sqlalchemy import BigInteger, ForeignKey, Integer, LargeBinary
9
+ from sqlalchemy.dialects.postgresql import JSONB, UUID
10
10
  from sqlalchemy.orm import declarative_base
11
+ from sqlalchemy.orm.decl_api import DeclarativeMeta
11
12
 
12
- Base = declarative_base()
13
+ # Base has to be marked explicitly as a type, in order to be used elsewhere as a type hint. But in addition to being
14
+ # a type, it's also a `DeclarativeMeta`. The following pattern enables us to expose both `Base` and `Base.metadata`
15
+ # outside of the module in a typesafe way.
16
+ Base: type = declarative_base()
17
+ assert isinstance(Base, DeclarativeMeta)
18
+ base_metadata = Base.metadata
13
19
 
14
20
  T = TypeVar('T')
15
21
 
16
- def md_from_dict(data_class_type: Type[T], data: Any) -> T:
22
+ def md_from_dict(data_class_type: type[T], data: Any) -> T:
17
23
  """Re-instantiate a dataclass instance that contains nested dataclasses from a dict."""
18
24
  if dataclasses.is_dataclass(data_class_type):
19
25
  fieldtypes = {f: t for f, t in get_type_hints(data_class_type).items()}
20
- return data_class_type(**{f: md_from_dict(fieldtypes[f], data[f]) for f in data})
21
- elif hasattr(data_class_type, '__origin__'):
22
- if data_class_type.__origin__ is Union and type(None) in data_class_type.__args__:
26
+ return data_class_type(**{f: md_from_dict(fieldtypes[f], data[f]) for f in data}) # type: ignore[return-value]
27
+
28
+ origin = typing.get_origin(data_class_type)
29
+ if origin is not None:
30
+ type_args = typing.get_args(data_class_type)
31
+ if origin is Union and type(None) in type_args:
23
32
  # Handling Optional types
24
- non_none_args = [arg for arg in data_class_type.__args__ if arg is not type(None)]
25
- if len(non_none_args) == 1:
26
- return md_from_dict(non_none_args[0], data) if data is not None else None
27
- elif data_class_type.__origin__ is list:
28
- return [md_from_dict(data_class_type.__args__[0], elem) for elem in data]
29
- elif data_class_type.__origin__ is dict:
30
- key_type = data_class_type.__args__[0]
31
- val_type = data_class_type.__args__[1]
32
- return {key_type(key): md_from_dict(val_type, val) for key, val in data.items()}
33
- elif data_class_type.__origin__ is tuple:
34
- return tuple(md_from_dict(arg_type, elem) for arg_type, elem in zip(data_class_type.__args__, data))
33
+ non_none_args = [arg for arg in type_args if arg is not type(None)]
34
+ assert len(non_none_args) == 1
35
+ return md_from_dict(non_none_args[0], data) if data is not None else None
36
+ elif origin is list:
37
+ return [md_from_dict(type_args[0], elem) for elem in data] # type: ignore[return-value]
38
+ elif origin is dict:
39
+ key_type = type_args[0]
40
+ val_type = type_args[1]
41
+ return {key_type(key): md_from_dict(val_type, val) for key, val in data.items()} # type: ignore[return-value]
42
+ elif origin is tuple:
43
+ return tuple(md_from_dict(arg_type, elem) for arg_type, elem in zip(type_args, data)) # type: ignore[return-value]
44
+ else:
45
+ assert False
35
46
  else:
36
47
  return data
37
48
 
@@ -115,7 +126,7 @@ class ViewMd:
115
126
  is_snapshot: bool
116
127
 
117
128
  # (table id, version); for mutable views, all versions are None
118
- base_versions: List[Tuple[str, Optional[int]]]
129
+ base_versions: list[tuple[str, Optional[int]]]
119
130
 
120
131
  # filter predicate applied to the base table; view-only
121
132
  predicate: Optional[dict[str, Any]]
@@ -191,6 +202,10 @@ class SchemaColumn:
191
202
  pos: int
192
203
  name: str
193
204
 
205
+ # media validation strategy of this particular media column; if not set, TableMd.media_validation applies
206
+ # stores column.MediaValiation.name.lower()
207
+ media_validation: Optional[str]
208
+
194
209
 
195
210
  @dataclasses.dataclass
196
211
  class TableSchemaVersionMd:
@@ -203,6 +218,10 @@ class TableSchemaVersionMd:
203
218
  num_retained_versions: int
204
219
  comment: str
205
220
 
221
+ # default validation strategy for any media column of this table
222
+ # stores column.MediaValiation.name.lower()
223
+ media_validation: str
224
+
206
225
 
207
226
  # versioning: each table schema change results in a new record
208
227
  class TableSchemaVersion(Base):