pixeltable 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. pixeltable/__init__.py +2 -2
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +1 -1
  4. pixeltable/catalog/column.py +41 -29
  5. pixeltable/catalog/globals.py +18 -0
  6. pixeltable/catalog/insertable_table.py +30 -10
  7. pixeltable/catalog/table.py +198 -86
  8. pixeltable/catalog/table_version.py +47 -53
  9. pixeltable/catalog/table_version_path.py +2 -2
  10. pixeltable/catalog/view.py +17 -18
  11. pixeltable/dataframe.py +27 -36
  12. pixeltable/env.py +7 -0
  13. pixeltable/exec/__init__.py +0 -1
  14. pixeltable/exec/aggregation_node.py +6 -3
  15. pixeltable/exec/cache_prefetch_node.py +189 -43
  16. pixeltable/exec/data_row_batch.py +5 -22
  17. pixeltable/exec/exec_context.py +2 -2
  18. pixeltable/exec/exec_node.py +3 -2
  19. pixeltable/exec/expr_eval_node.py +23 -16
  20. pixeltable/exec/in_memory_data_node.py +6 -3
  21. pixeltable/exec/sql_node.py +24 -25
  22. pixeltable/exprs/arithmetic_expr.py +12 -5
  23. pixeltable/exprs/array_slice.py +7 -7
  24. pixeltable/exprs/column_property_ref.py +37 -10
  25. pixeltable/exprs/column_ref.py +97 -14
  26. pixeltable/exprs/comparison.py +10 -5
  27. pixeltable/exprs/compound_predicate.py +8 -7
  28. pixeltable/exprs/data_row.py +27 -18
  29. pixeltable/exprs/expr.py +53 -52
  30. pixeltable/exprs/expr_set.py +5 -0
  31. pixeltable/exprs/function_call.py +32 -16
  32. pixeltable/exprs/globals.py +4 -1
  33. pixeltable/exprs/in_predicate.py +8 -7
  34. pixeltable/exprs/inline_expr.py +4 -4
  35. pixeltable/exprs/is_null.py +4 -4
  36. pixeltable/exprs/json_mapper.py +11 -12
  37. pixeltable/exprs/json_path.py +6 -11
  38. pixeltable/exprs/literal.py +5 -5
  39. pixeltable/exprs/method_ref.py +5 -4
  40. pixeltable/exprs/object_ref.py +2 -1
  41. pixeltable/exprs/row_builder.py +88 -36
  42. pixeltable/exprs/rowid_ref.py +12 -11
  43. pixeltable/exprs/similarity_expr.py +12 -7
  44. pixeltable/exprs/sql_element_cache.py +7 -5
  45. pixeltable/exprs/type_cast.py +8 -6
  46. pixeltable/exprs/variable.py +5 -4
  47. pixeltable/func/aggregate_function.py +9 -9
  48. pixeltable/func/expr_template_function.py +6 -5
  49. pixeltable/func/function.py +11 -10
  50. pixeltable/func/udf.py +6 -11
  51. pixeltable/functions/__init__.py +2 -2
  52. pixeltable/functions/globals.py +5 -7
  53. pixeltable/functions/huggingface.py +155 -45
  54. pixeltable/functions/llama_cpp.py +107 -0
  55. pixeltable/functions/mistralai.py +1 -1
  56. pixeltable/functions/ollama.py +147 -0
  57. pixeltable/functions/openai.py +1 -1
  58. pixeltable/functions/replicate.py +72 -0
  59. pixeltable/functions/string.py +9 -0
  60. pixeltable/functions/together.py +1 -1
  61. pixeltable/functions/util.py +5 -2
  62. pixeltable/globals.py +67 -26
  63. pixeltable/index/btree.py +16 -3
  64. pixeltable/index/embedding_index.py +4 -4
  65. pixeltable/io/__init__.py +1 -2
  66. pixeltable/io/fiftyone.py +178 -0
  67. pixeltable/io/globals.py +96 -2
  68. pixeltable/iterators/base.py +3 -2
  69. pixeltable/iterators/document.py +1 -1
  70. pixeltable/iterators/video.py +120 -63
  71. pixeltable/metadata/__init__.py +1 -1
  72. pixeltable/metadata/converters/convert_21.py +34 -0
  73. pixeltable/metadata/converters/util.py +45 -4
  74. pixeltable/metadata/notes.py +1 -0
  75. pixeltable/metadata/schema.py +8 -0
  76. pixeltable/plan.py +17 -15
  77. pixeltable/py.typed +0 -0
  78. pixeltable/store.py +7 -2
  79. pixeltable/tool/create_test_db_dump.py +1 -1
  80. pixeltable/tool/create_test_video.py +1 -1
  81. pixeltable/tool/embed_udf.py +1 -1
  82. pixeltable/tool/mypy_plugin.py +28 -5
  83. pixeltable/type_system.py +100 -36
  84. pixeltable/utils/coco.py +5 -5
  85. pixeltable/utils/documents.py +15 -1
  86. pixeltable/utils/formatter.py +12 -13
  87. pixeltable/utils/s3.py +6 -3
  88. {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/METADATA +158 -49
  89. pixeltable-0.2.23.dist-info/RECORD +153 -0
  90. pixeltable/exec/media_validation_node.py +0 -43
  91. pixeltable-0.2.21.dist-info/RECORD +0 -148
  92. {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/LICENSE +0 -0
  93. {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/WHEEL +0 -0
  94. {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/entry_points.txt +0 -0
@@ -1,87 +1,226 @@
1
1
  from __future__ import annotations
2
2
 
3
- import concurrent.futures
3
+ import dataclasses
4
+ import itertools
4
5
  import logging
5
6
  import threading
6
7
  import urllib.parse
7
8
  import urllib.request
8
- from collections import defaultdict
9
+ from collections import deque
10
+ from concurrent import futures
9
11
  from pathlib import Path
10
- from typing import List, Optional, Any, Tuple, Dict
12
+ from typing import Optional, Any, Iterator
11
13
  from uuid import UUID
12
14
 
13
15
  import pixeltable.env as env
14
16
  import pixeltable.exceptions as excs
15
17
  import pixeltable.exprs as exprs
18
+ from pixeltable import catalog
16
19
  from pixeltable.utils.filecache import FileCache
20
+
17
21
  from .data_row_batch import DataRowBatch
18
22
  from .exec_node import ExecNode
19
23
 
20
24
  _logger = logging.getLogger('pixeltable')
21
25
 
26
+
22
27
  class CachePrefetchNode(ExecNode):
23
28
  """Brings files with external URLs into the cache
24
29
 
25
30
  TODO:
26
- - maintain a queue of row batches, in order to overlap download and evaluation
27
31
  - adapting the number of download threads at runtime to maximize throughput
28
32
  """
29
- def __init__(self, tbl_id: UUID, file_col_info: List[exprs.ColumnSlotIdx], input: ExecNode):
30
- # []: we don't have anything to evaluate
33
+ BATCH_SIZE = 16
34
+ NUM_EXECUTOR_THREADS = 16
35
+
36
+ retain_input_order: bool # if True, return rows in the exact order they were received
37
+ file_col_info: list[exprs.ColumnSlotIdx]
38
+ boto_client: Optional[Any]
39
+ boto_client_lock: threading.Lock
40
+
41
+ # execution state
42
+ batch_tbl_version: Optional[catalog.TableVersion] # needed to construct output batches
43
+ num_returned_rows: int
44
+
45
+ # ready_rows: rows that are ready to be returned, ordered by row idx;
46
+ # the implied row idx of ready_rows[0] is num_returned_rows
47
+ ready_rows: deque[Optional[exprs.DataRow]]
48
+
49
+ in_flight_rows: dict[int, CachePrefetchNode.RowState] # rows with in-flight urls; id(row) -> RowState
50
+ in_flight_requests: dict[futures.Future, str] # in-flight requests for urls; future -> URL
51
+ in_flight_urls: dict[str, list[tuple[exprs.DataRow, exprs.ColumnSlotIdx]]] # URL -> [(row, info)]
52
+ input_finished: bool
53
+ row_idx: Iterator[Optional[int]]
54
+
55
+ @dataclasses.dataclass
56
+ class RowState:
57
+ row: exprs.DataRow
58
+ idx: Optional[int] # position in input stream; None if we don't retain input order
59
+ num_missing: int # number of missing URLs in this row
60
+
61
+ def __init__(
62
+ self, tbl_id: UUID, file_col_info: list[exprs.ColumnSlotIdx], input: ExecNode,
63
+ retain_input_order: bool = True):
64
+ # input_/output_exprs=[]: we don't have anything to evaluate
31
65
  super().__init__(input.row_builder, [], [], input)
32
- self.tbl_id = tbl_id
66
+ self.retain_input_order = retain_input_order
33
67
  self.file_col_info = file_col_info
34
68
 
35
69
  # clients for specific services are constructed as needed, because it's time-consuming
36
- self.boto_client: Optional[Any] = None
70
+ self.boto_client = None
37
71
  self.boto_client_lock = threading.Lock()
38
72
 
39
- def __next__(self) -> DataRowBatch:
40
- input_batch = next(self.input)
73
+ self.batch_tbl_version = None
74
+ self.num_returned_rows = 0
75
+ self.ready_rows = deque()
76
+ self.in_flight_rows = {}
77
+ self.in_flight_requests = {}
78
+ self.in_flight_urls = {}
79
+ self.input_finished = False
80
+ self.row_idx = itertools.count() if retain_input_order else itertools.repeat(None)
81
+
82
+ def __iter__(self) -> Iterator[DataRowBatch]:
83
+ input_iter = iter(self.input)
84
+ with futures.ThreadPoolExecutor(max_workers=self.NUM_EXECUTOR_THREADS) as executor:
85
+ # we create enough in-flight requests to fill the first batch
86
+ while not self.input_finished and self.__num_pending_rows() < self.BATCH_SIZE:
87
+ self.__submit_input_batch(input_iter, executor)
88
+
89
+ while True:
90
+ # try to assemble a full batch of output rows
91
+ if not self.__has_ready_batch() and len(self.in_flight_requests) > 0:
92
+ self.__wait_for_requests()
93
+
94
+ # try to create enough in-flight requests to fill the next batch
95
+ while not self.input_finished and self.__num_pending_rows() < self.BATCH_SIZE:
96
+ self.__submit_input_batch(input_iter, executor)
97
+
98
+ if len(self.ready_rows) > 0:
99
+ # create DataRowBatch from the first BATCH_SIZE ready rows
100
+ batch = DataRowBatch(self.batch_tbl_version, self.row_builder)
101
+ rows = [self.ready_rows.popleft() for _ in range(min(self.BATCH_SIZE, len(self.ready_rows)))]
102
+ for row in rows:
103
+ assert row is not None
104
+ batch.add_row(row)
105
+ self.num_returned_rows += len(rows)
106
+ _logger.debug(f'returning {len(rows)} rows')
107
+ yield batch
108
+
109
+ if self.input_finished and self.__num_pending_rows() == 0:
110
+ return
111
+
112
+ def __num_pending_rows(self) -> int:
113
+ return len(self.in_flight_rows) + len(self.ready_rows)
114
+
115
+ def __has_ready_batch(self) -> bool:
116
+ """True if there are >= BATCH_SIZES entries in ready_rows and the first BATCH_SIZE ones are all non-None"""
117
+ return (
118
+ sum(int(row is not None) for row in itertools.islice(self.ready_rows, self.BATCH_SIZE)) == self.BATCH_SIZE
119
+ )
120
+
121
+ def __ready_prefix_len(self) -> int:
122
+ """Length of the non-None prefix of ready_rows (= what we can return right now)"""
123
+ return sum(1 for _ in itertools.takewhile(lambda x: x is not None, self.ready_rows))
124
+
125
+ def __add_ready_row(self, row: exprs.DataRow, row_idx: Optional[int]) -> None:
126
+ if row_idx is None:
127
+ self.ready_rows.append(row)
128
+ else:
129
+ # extend ready_rows to accommodate row_idx
130
+ idx = row_idx - self.num_returned_rows
131
+ if idx >= len(self.ready_rows):
132
+ self.ready_rows.extend([None] * (idx - len(self.ready_rows) + 1))
133
+ self.ready_rows[idx] = row
134
+
135
+ def __wait_for_requests(self) -> None:
136
+ """Wait for in-flight requests to complete until we have a full batch of rows"""
137
+ file_cache = FileCache.get()
138
+ _logger.debug(f'waiting for requests; ready_batch_size={self.__ready_prefix_len()}')
139
+ while not self.__has_ready_batch() and len(self.in_flight_requests) > 0:
140
+ done, _ = futures.wait(self.in_flight_requests, return_when=futures.FIRST_COMPLETED)
141
+ for f in done:
142
+ url = self.in_flight_requests.pop(f)
143
+ tmp_path, exc = f.result()
144
+ local_path: Optional[Path] = None
145
+ if tmp_path is not None:
146
+ # register the file with the cache for the first column in which it's missing
147
+ assert url in self.in_flight_urls
148
+ _, info = self.in_flight_urls[url][0]
149
+ local_path = file_cache.add(info.col.tbl.id, info.col.id, url, tmp_path)
150
+ _logger.debug(f'cached {url} as {local_path}')
151
+
152
+ # add the local path/exception to the slots that reference the url
153
+ for row, info in self.in_flight_urls.pop(url):
154
+ if exc is not None:
155
+ self.row_builder.set_exc(row, info.slot_idx, exc)
156
+ else:
157
+ assert local_path is not None
158
+ row.set_file_path(info.slot_idx, str(local_path))
159
+ state = self.in_flight_rows[id(row)]
160
+ state.num_missing -= 1
161
+ if state.num_missing == 0:
162
+ del self.in_flight_rows[id(row)]
163
+ self.__add_ready_row(row, state.idx)
164
+ _logger.debug(f'row {state.idx} is ready (ready_batch_size={self.__ready_prefix_len()})')
165
+
166
+ def __submit_input_batch(self, input: Iterator[DataRowBatch], executor: futures.ThreadPoolExecutor) -> None:
167
+ assert not self.input_finished
168
+ input_batch = next(input, None)
169
+ if input_batch is None:
170
+ self.input_finished = True
171
+ return
172
+ if self.batch_tbl_version is None:
173
+ self.batch_tbl_version = input_batch.tbl
41
174
 
42
- # collect external URLs that aren't already cached, and set DataRow.file_paths for those that are
43
175
  file_cache = FileCache.get()
44
- cache_misses: List[Tuple[exprs.DataRow, exprs.ColumnSlotIdx]] = []
45
- missing_url_rows: Dict[str, List[exprs.DataRow]] = defaultdict(list) # URL -> rows in which it's missing
176
+
177
+ # URLs from this input batch that aren't already in the file cache;
178
+ # we use a list to make sure we submit urls in the order in which they appear in the output, which minimizes
179
+ # the time it takes to get the next batch together
180
+ cache_misses: list[str] = []
181
+
182
+ url_pos: dict[str, int] = {} # url -> row_idx; used for logging
46
183
  for row in input_batch:
184
+ # identify missing local files in input batch, or fill in their paths if they're already cached
185
+ num_missing = 0
186
+ row_idx = next(self.row_idx)
187
+
47
188
  for info in self.file_col_info:
48
189
  url = row.file_urls[info.slot_idx]
49
190
  if url is None or row.file_paths[info.slot_idx] is not None:
50
191
  # nothing to do
51
192
  continue
52
- if url in missing_url_rows:
53
- missing_url_rows[url].append(row)
193
+ locations = self.in_flight_urls.get(url)
194
+ if locations is not None:
195
+ # we've already seen this
196
+ locations.append((row, info))
197
+ num_missing += 1
54
198
  continue
199
+
55
200
  local_path = file_cache.lookup(url)
56
201
  if local_path is None:
57
- cache_misses.append((row, info))
58
- missing_url_rows[url].append(row)
202
+ cache_misses.append(url)
203
+ self.in_flight_urls[url] = [(row, info)]
204
+ num_missing += 1
205
+ if url not in url_pos:
206
+ url_pos[url] = row_idx
59
207
  else:
60
208
  row.set_file_path(info.slot_idx, str(local_path))
61
209
 
62
- # download the cache misses in parallel
63
- # TODO: set max_workers to maximize throughput
64
- futures: Dict[concurrent.futures.Future, Tuple[exprs.DataRow, exprs.ColumnSlotIdx]] = {}
65
- with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
66
- for row, info in cache_misses:
67
- futures[executor.submit(self._fetch_url, row, info.slot_idx)] = (row, info)
68
- for future in concurrent.futures.as_completed(futures):
69
- # TODO: does this need to deal with recoverable errors (such as retry after throttling)?
70
- tmp_path = future.result()
71
- if tmp_path is None:
72
- continue
73
- row, info = futures[future]
74
- url = row.file_urls[info.slot_idx]
75
- local_path = file_cache.add(self.tbl_id, info.col.id, url, tmp_path)
76
- _logger.debug(f'PrefetchNode: cached {url} as {local_path}')
77
- for row in missing_url_rows[url]:
78
- row.set_file_path(info.slot_idx, str(local_path))
210
+ if num_missing > 0:
211
+ self.in_flight_rows[id(row)] = self.RowState(row, row_idx, num_missing)
212
+ else:
213
+ self.__add_ready_row(row, row_idx)
79
214
 
80
- return input_batch
215
+ _logger.debug(f'submitting {len(cache_misses)} urls')
216
+ for url in cache_misses:
217
+ f = executor.submit(self.__fetch_url, url)
218
+ _logger.debug(f'submitted {url} for idx {url_pos[url]}')
219
+ self.in_flight_requests[f] = url
81
220
 
82
- def _fetch_url(self, row: exprs.DataRow, slot_idx: int) -> Optional[str]:
221
+ def __fetch_url(self, url: str) -> tuple[Optional[Path], Optional[Exception]]:
83
222
  """Fetches a remote URL into Env.tmp_dir and returns its path"""
84
- url = row.file_urls[slot_idx]
223
+ _logger.debug(f'fetching url={url} thread_name={threading.current_thread().name}')
85
224
  parsed = urllib.parse.urlparse(url)
86
225
  # Use len(parsed.scheme) > 1 here to ensure we're not being passed
87
226
  # a Windows filename
@@ -93,24 +232,31 @@ class CachePrefetchNode(ExecNode):
93
232
  extension = p.suffix
94
233
  tmp_path = env.Env.get().create_tmp_path(extension=extension)
95
234
  try:
235
+ _logger.debug(f'Downloading {url} to {tmp_path}')
96
236
  if parsed.scheme == 's3':
97
237
  from pixeltable.utils.s3 import get_client
98
238
  with self.boto_client_lock:
99
239
  if self.boto_client is None:
100
- self.boto_client = get_client()
101
- self.boto_client.download_file(parsed.netloc, parsed.path.lstrip('/'), str(tmp_path))
240
+ config = {
241
+ 'max_pool_connections': self.NUM_EXECUTOR_THREADS + 4, # +4: leave some headroom
242
+ 'connect_timeout': 5,
243
+ 'read_timeout': 30,
244
+ 'retries': {'max_attempts': 3, 'mode': 'adaptive'},
245
+ }
246
+ self.boto_client = get_client(**config)
247
+ self.boto_client.download_file(parsed.netloc, parsed.path.lstrip('/'), str(tmp_path))
102
248
  elif parsed.scheme == 'http' or parsed.scheme == 'https':
103
249
  with urllib.request.urlopen(url) as resp, open(tmp_path, 'wb') as f:
104
250
  data = resp.read()
105
251
  f.write(data)
106
252
  else:
107
253
  assert False, f'Unsupported URL scheme: {parsed.scheme}'
108
- return tmp_path
254
+ _logger.debug(f'Downloaded {url} to {tmp_path}')
255
+ return tmp_path, None
109
256
  except Exception as e:
110
257
  # we want to add the file url to the exception message
111
258
  exc = excs.Error(f'Failed to download {url}: {e}')
112
- self.row_builder.set_exc(row, slot_idx, exc)
259
+ _logger.debug(f'Failed to download {url}: {e}', exc_info=e)
113
260
  if not self.ctx.ignore_errors:
114
261
  raise exc from None # suppress original exception
115
- return None
116
-
262
+ return None, exc
@@ -1,5 +1,5 @@
1
1
  from __future__ import annotations
2
- from typing import List, Iterator, Optional
2
+ from typing import Iterator, Optional
3
3
  import logging
4
4
 
5
5
  import pixeltable.exprs as exprs
@@ -49,12 +49,12 @@ class DataRowBatch:
49
49
  def __len__(self) -> int:
50
50
  return len(self.rows)
51
51
 
52
- def __getitem__(self, index: object) -> exprs.DataRow:
52
+ def __getitem__(self, index: int) -> exprs.DataRow:
53
53
  return self.rows[index]
54
54
 
55
55
  def flush_imgs(
56
- self, idx_range: Optional[slice] = None, stored_img_info: Optional[List[exprs.ColumnSlotIdx]] = None,
57
- flushed_slot_idxs: Optional[List[int]] = None
56
+ self, idx_range: Optional[slice] = None, stored_img_info: Optional[list[exprs.ColumnSlotIdx]] = None,
57
+ flushed_slot_idxs: Optional[list[int]] = None
58
58
  ) -> None:
59
59
  """Flushes images in the given range of rows."""
60
60
  assert self.tbl is not None
@@ -74,21 +74,4 @@ class DataRowBatch:
74
74
  row.flush_img(slot_idx)
75
75
 
76
76
  def __iter__(self) -> Iterator[exprs.DataRow]:
77
- return DataRowBatchIterator(self)
78
-
79
-
80
- class DataRowBatchIterator:
81
- """
82
- Iterator over a DataRowBatch.
83
- """
84
- def __init__(self, batch: DataRowBatch):
85
- self.row_batch = batch
86
- self.index = 0
87
-
88
- def __next__(self) -> exprs.DataRow:
89
- if self.index >= len(self.row_batch.rows):
90
- raise StopIteration
91
- row = self.row_batch.rows[self.index]
92
- self.index += 1
93
- return row
94
-
77
+ return iter(self.rows)
@@ -1,4 +1,4 @@
1
- from typing import Optional, List
1
+ from typing import Optional
2
2
 
3
3
  import sqlalchemy as sql
4
4
 
@@ -8,7 +8,7 @@ class ExecContext:
8
8
  """Class for execution runtime constants"""
9
9
  def __init__(
10
10
  self, row_builder: exprs.RowBuilder, *, show_pbar: bool = False, batch_size: int = 0,
11
- pk_clause: Optional[List[sql.ClauseElement]] = None, num_computed_exprs: int = 0,
11
+ pk_clause: Optional[list[sql.ClauseElement]] = None, num_computed_exprs: int = 0,
12
12
  ignore_errors: bool = False
13
13
  ):
14
14
  self.show_pbar = show_pbar
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import abc
4
- from typing import Iterable, Optional, List, TYPE_CHECKING, Iterator
4
+ from typing import TYPE_CHECKING, Iterable, Iterator, Optional
5
5
 
6
6
  import pixeltable.exprs as exprs
7
+
7
8
  from .data_row_batch import DataRowBatch
8
9
  from .exec_context import ExecContext
9
10
 
@@ -42,7 +43,7 @@ class ExecNode(abc.ABC):
42
43
  if self.input is not None:
43
44
  self.input.set_ctx(ctx)
44
45
 
45
- def set_stored_img_cols(self, stored_img_cols: List[exprs.ColumnSlotIdx]) -> None:
46
+ def set_stored_img_cols(self, stored_img_cols: list[exprs.ColumnSlotIdx]) -> None:
46
47
  self.stored_img_cols = stored_img_cols
47
48
  # propagate batch size to the source
48
49
  if self.input is not None:
@@ -3,11 +3,11 @@ import sys
3
3
  import time
4
4
  import warnings
5
5
  from dataclasses import dataclass
6
- from typing import Iterable, List, Optional
6
+ from typing import Iterable, Optional
7
7
 
8
8
  from tqdm import TqdmWarning, tqdm
9
9
 
10
- import pixeltable.exprs as exprs
10
+ from pixeltable import exprs
11
11
  from pixeltable.func import CallableFunction
12
12
 
13
13
  from .data_row_batch import DataRowBatch
@@ -22,10 +22,10 @@ class ExprEvalNode(ExecNode):
22
22
  @dataclass
23
23
  class Cohort:
24
24
  """List of exprs that form an evaluation context and contain calls to at most one external function"""
25
- exprs: List[exprs.Expr]
25
+ exprs_: list[exprs.Expr]
26
26
  batched_fn: Optional[CallableFunction]
27
- segment_ctxs: List['exprs.RowBuilder.EvalCtx']
28
- target_slot_idxs: List[int]
27
+ segment_ctxs: list['exprs.RowBuilder.EvalCtx']
28
+ target_slot_idxs: list[int]
29
29
  batch_size: int = 8
30
30
 
31
31
  def __init__(
@@ -38,7 +38,7 @@ class ExprEvalNode(ExecNode):
38
38
  # we're only materializing exprs that are not already in the input
39
39
  self.target_exprs = [e for e in output_exprs if e.slot_idx not in input_slot_idxs]
40
40
  self.pbar: Optional[tqdm] = None
41
- self.cohorts: List[List[ExprEvalNode.Cohort]] = []
41
+ self.cohorts: list[ExprEvalNode.Cohort] = []
42
42
  self._create_cohorts()
43
43
 
44
44
  def __next__(self) -> DataRowBatch:
@@ -83,11 +83,13 @@ class ExprEvalNode(ExecNode):
83
83
  all_exprs = self.row_builder.get_dependencies(self.target_exprs)
84
84
  # break up all_exprs into cohorts such that each cohort contains calls to at most one external function;
85
85
  # seed the cohorts with only the ext fn calls
86
- cohorts: List[List[exprs.Expr]] = []
86
+ cohorts: list[list[exprs.Expr]] = []
87
87
  current_batched_fn: Optional[CallableFunction] = None
88
88
  for e in all_exprs:
89
89
  if not self._is_batched_fn_call(e):
90
90
  continue
91
+ assert isinstance(e, exprs.FunctionCall)
92
+ assert isinstance(e.fn, CallableFunction)
91
93
  if current_batched_fn is None or current_batched_fn != e.fn:
92
94
  # create a new cohort
93
95
  cohorts.append([])
@@ -96,9 +98,9 @@ class ExprEvalNode(ExecNode):
96
98
 
97
99
  # expand the cohorts to include all exprs that are in the same evaluation context as the external calls;
98
100
  # cohorts are evaluated in order, so we can exclude the target slots from preceding cohorts and input slots
99
- exclude = set([e.slot_idx for e in self.input_exprs])
100
- all_target_slot_idxs = set([e.slot_idx for e in self.target_exprs])
101
- target_slot_idxs: List[List[int]] = [] # the ones materialized by each cohort
101
+ exclude = set(e.slot_idx for e in self.input_exprs)
102
+ all_target_slot_idxs = set(e.slot_idx for e in self.target_exprs)
103
+ target_slot_idxs: list[list[int]] = [] # the ones materialized by each cohort
102
104
  for i in range(len(cohorts)):
103
105
  cohorts[i] = self.row_builder.get_dependencies(
104
106
  cohorts[i], exclude=[self.row_builder.unique_exprs[slot_idx] for slot_idx in exclude])
@@ -106,7 +108,7 @@ class ExprEvalNode(ExecNode):
106
108
  [e.slot_idx for e in cohorts[i] if e.slot_idx in all_target_slot_idxs])
107
109
  exclude.update(target_slot_idxs[-1])
108
110
 
109
- all_cohort_slot_idxs = set([e.slot_idx for cohort in cohorts for e in cohort])
111
+ all_cohort_slot_idxs = set(e.slot_idx for cohort in cohorts for e in cohort)
110
112
  remaining_slot_idxs = set(all_target_slot_idxs) - all_cohort_slot_idxs
111
113
  if len(remaining_slot_idxs) > 0:
112
114
  cohorts.append(self.row_builder.get_dependencies(
@@ -164,11 +166,12 @@ class ExprEvalNode(ExecNode):
164
166
  rows[row_idx], segment_ctx, self.ctx.profile, ignore_errors=self.ctx.ignore_errors)
165
167
  else:
166
168
  fn_call = segment_ctx.exprs[0]
169
+ assert isinstance(fn_call, exprs.FunctionCall)
167
170
  # make a batched external function call
168
- arg_batches = [[] for _ in range(len(fn_call.args))]
169
- kwarg_batches = {k: [] for k in fn_call.kwargs.keys()}
171
+ arg_batches: list[list[exprs.Expr]] = [[] for _ in range(len(fn_call.args))]
172
+ kwarg_batches: dict[str, list[exprs.Expr]] = {k: [] for k in fn_call.kwargs.keys()}
170
173
 
171
- valid_batch_idxs: List[int] = [] # rows with exceptions are not valid
174
+ valid_batch_idxs: list[int] = [] # rows with exceptions are not valid
172
175
  for row_idx in range(batch_start_idx, batch_start_idx + num_batch_rows):
173
176
  row = rows[row_idx]
174
177
  if row.has_exc(fn_call.slot_idx):
@@ -176,12 +179,15 @@ class ExprEvalNode(ExecNode):
176
179
  continue
177
180
  valid_batch_idxs.append(row_idx)
178
181
  args, kwargs = fn_call._make_args(row)
179
- [arg_batches[i].append(args[i]) for i in range(len(args))]
180
- [kwarg_batches[k].append(kwargs[k]) for k in kwargs.keys()]
182
+ for i in range(len(args)):
183
+ arg_batches[i].append(args[i])
184
+ for k in kwargs.keys():
185
+ kwarg_batches[k].append(kwargs[k])
181
186
  num_valid_batch_rows = len(valid_batch_idxs)
182
187
 
183
188
  if ext_batch_size is None:
184
189
  # we need to choose a batch size based on the args
190
+ assert isinstance(fn_call.fn, CallableFunction)
185
191
  sample_args = [arg_batches[i][0] for i in range(len(arg_batches))]
186
192
  ext_batch_size = fn_call.fn.get_batch_size(*sample_args)
187
193
 
@@ -201,6 +207,7 @@ class ExprEvalNode(ExecNode):
201
207
  for k in kwarg_batches.keys()
202
208
  }
203
209
  start_ts = time.perf_counter()
210
+ assert isinstance(fn_call.fn, CallableFunction)
204
211
  result_batch = fn_call.fn.exec_batch(*call_args, **call_kwargs)
205
212
  self.ctx.profile.eval_time[fn_call.slot_idx] += time.perf_counter() - start_ts
206
213
  self.ctx.profile.eval_count[fn_call.slot_idx] += num_ext_batch_rows
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Any, Optional, Iterator
2
+ from typing import Any, Iterator, Optional
3
3
 
4
4
  import pixeltable.catalog as catalog
5
5
  import pixeltable.exprs as exprs
@@ -23,12 +23,15 @@ class InMemoryDataNode(ExecNode):
23
23
  start_row_id: int
24
24
  output_rows: Optional[DataRowBatch]
25
25
 
26
+ # output_exprs is declared in the superclass, but we redeclare it here with a more specific type
27
+ output_exprs: list[exprs.ColumnRef]
28
+
26
29
  def __init__(
27
30
  self, tbl: catalog.TableVersion, rows: list[dict[str, Any]],
28
31
  row_builder: exprs.RowBuilder, start_row_id: int,
29
32
  ):
30
- # we materialize all output slots
31
- output_exprs = [e for e in row_builder.get_output_exprs() if isinstance(e, exprs.ColumnRef)]
33
+ # we materialize the input slots
34
+ output_exprs = list(row_builder.input_exprs)
32
35
  super().__init__(row_builder, output_exprs, [], None)
33
36
  assert tbl.is_insertable()
34
37
  self.tbl = tbl
@@ -1,13 +1,14 @@
1
1
  import logging
2
2
  import warnings
3
3
  from decimal import Decimal
4
- from typing import Optional, Iterable, Iterator, NamedTuple
4
+ from typing import Iterable, Iterator, NamedTuple, Optional
5
5
  from uuid import UUID
6
6
 
7
7
  import sqlalchemy as sql
8
8
 
9
9
  import pixeltable.catalog as catalog
10
10
  import pixeltable.exprs as exprs
11
+
11
12
  from .data_row_batch import DataRowBatch
12
13
  from .exec_node import ExecNode
13
14
 
@@ -100,7 +101,7 @@ class SqlNode(ExecNode):
100
101
  # minimize the number of tables that need to be joined to the target table
101
102
  self.retarget_rowid_refs(tbl, self.select_list)
102
103
 
103
- assert self.sql_elements.contains(self.select_list)
104
+ assert self.sql_elements.contains_all(self.select_list)
104
105
  self.set_pk = set_pk
105
106
  self.num_pk_cols = 0
106
107
  if set_pk:
@@ -120,13 +121,13 @@ class SqlNode(ExecNode):
120
121
  def _create_stmt(self) -> sql.Select:
121
122
  """Create Select from local state"""
122
123
 
123
- assert self.sql_elements.contains(self.select_list)
124
+ assert self.sql_elements.contains_all(self.select_list)
124
125
  sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
125
126
  if self.set_pk:
126
127
  sql_select_list += self.tbl.tbl_version.store_tbl.pk_columns()
127
128
  stmt = sql.select(*sql_select_list)
128
129
 
129
- order_by_clause: list[sql.ClauseElement] = []
130
+ order_by_clause: list[sql.ColumnElement] = []
130
131
  for e, asc in self.order_by_clause:
131
132
  if isinstance(e, exprs.SimilarityExpr):
132
133
  order_by_clause.append(e.as_order_by_clause(asc))
@@ -141,7 +142,7 @@ class SqlNode(ExecNode):
141
142
  return stmt
142
143
 
143
144
  def _ordering_tbl_ids(self) -> set[UUID]:
144
- return exprs.Expr.list_tbl_ids(e for e, _ in self.order_by_clause)
145
+ return exprs.Expr.all_tbl_ids(e for e, _ in self.order_by_clause)
145
146
 
146
147
  def to_cte(self) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
147
148
  """
@@ -182,9 +183,9 @@ class SqlNode(ExecNode):
182
183
  """
183
184
  # we need to include at least the root
184
185
  if refd_tbl_ids is None:
185
- refd_tbl_ids = {}
186
+ refd_tbl_ids = set()
186
187
  if exact_version_only is None:
187
- exact_version_only = {}
188
+ exact_version_only = set()
188
189
  candidates = tbl.get_tbl_versions()
189
190
  assert len(candidates) > 0
190
191
  joined_tbls: list[catalog.TableVersion] = [candidates[0]]
@@ -193,6 +194,7 @@ class SqlNode(ExecNode):
193
194
  joined_tbls.append(tbl)
194
195
 
195
196
  first = True
197
+ prev_tbl: catalog.TableVersion
196
198
  for tbl in joined_tbls[::-1]:
197
199
  if first:
198
200
  stmt = stmt.select_from(tbl.store_tbl.sa_tbl)
@@ -239,22 +241,19 @@ class SqlNode(ExecNode):
239
241
  def __iter__(self) -> Iterator[DataRowBatch]:
240
242
  # run the query; do this here rather than in _open(), exceptions are only expected during iteration
241
243
  assert self.ctx.conn is not None
242
- try:
243
- with warnings.catch_warnings(record=True) as w:
244
- stmt = self._create_stmt()
245
- try:
246
- # log stmt, if possible
247
- stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
248
- _logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
249
- except Exception as e:
250
- pass
251
- self._log_explain(stmt)
252
-
253
- result_cursor = self.ctx.conn.execute(stmt)
254
- for warning in w:
255
- pass
256
- except Exception as e:
257
- raise e
244
+ with warnings.catch_warnings(record=True) as w:
245
+ stmt = self._create_stmt()
246
+ try:
247
+ # log stmt, if possible
248
+ stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
249
+ _logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
250
+ except Exception:
251
+ pass
252
+ self._log_explain(stmt)
253
+
254
+ result_cursor = self.ctx.conn.execute(stmt)
255
+ for warning in w:
256
+ pass
258
257
 
259
258
  tbl_version = self.tbl.tbl_version if self.tbl is not None else None
260
259
  output_batch = DataRowBatch(tbl_version, self.row_builder)
@@ -350,7 +349,7 @@ class SqlScanNode(SqlNode):
350
349
  def _create_stmt(self) -> sql.Select:
351
350
  stmt = super()._create_stmt()
352
351
  where_clause_tbl_ids = self.where_clause.tbl_ids() if self.where_clause is not None else set()
353
- refd_tbl_ids = exprs.Expr.list_tbl_ids(self.select_list) | where_clause_tbl_ids | self._ordering_tbl_ids()
352
+ refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | where_clause_tbl_ids | self._ordering_tbl_ids()
354
353
  stmt = self.create_from_clause(
355
354
  self.tbl, stmt, refd_tbl_ids, exact_version_only={t.id for t in self.exact_version_only})
356
355
 
@@ -386,7 +385,7 @@ class SqlLookupNode(SqlNode):
386
385
 
387
386
  def _create_stmt(self) -> sql.Select:
388
387
  stmt = super()._create_stmt()
389
- refd_tbl_ids = exprs.Expr.list_tbl_ids(self.select_list) | self._ordering_tbl_ids()
388
+ refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | self._ordering_tbl_ids()
390
389
  stmt = self.create_from_clause(self.tbl, stmt, refd_tbl_ids)
391
390
  stmt = stmt.where(self.where_clause)
392
391
  return stmt