pixeltable 0.1.0__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (147) hide show
  1. pixeltable/__init__.py +34 -6
  2. pixeltable/catalog/__init__.py +13 -0
  3. pixeltable/catalog/catalog.py +159 -0
  4. pixeltable/catalog/column.py +200 -0
  5. pixeltable/catalog/dir.py +32 -0
  6. pixeltable/catalog/globals.py +33 -0
  7. pixeltable/catalog/insertable_table.py +191 -0
  8. pixeltable/catalog/named_function.py +36 -0
  9. pixeltable/catalog/path.py +58 -0
  10. pixeltable/catalog/path_dict.py +139 -0
  11. pixeltable/catalog/schema_object.py +39 -0
  12. pixeltable/catalog/table.py +581 -0
  13. pixeltable/catalog/table_version.py +749 -0
  14. pixeltable/catalog/table_version_path.py +133 -0
  15. pixeltable/catalog/view.py +203 -0
  16. pixeltable/client.py +590 -30
  17. pixeltable/dataframe.py +540 -349
  18. pixeltable/env.py +359 -45
  19. pixeltable/exceptions.py +12 -21
  20. pixeltable/exec/__init__.py +9 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +116 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +95 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +69 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +225 -0
  31. pixeltable/exprs/__init__.py +24 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +105 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +195 -0
  39. pixeltable/exprs/expr.py +586 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +380 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +115 -0
  44. pixeltable/exprs/image_similarity_predicate.py +58 -0
  45. pixeltable/exprs/inline_array.py +107 -0
  46. pixeltable/exprs/inline_dict.py +101 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +54 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +355 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/type_cast.py +53 -0
  56. pixeltable/exprs/variable.py +45 -0
  57. pixeltable/func/__init__.py +9 -0
  58. pixeltable/func/aggregate_function.py +194 -0
  59. pixeltable/func/batched_function.py +53 -0
  60. pixeltable/func/callable_function.py +69 -0
  61. pixeltable/func/expr_template_function.py +82 -0
  62. pixeltable/func/function.py +110 -0
  63. pixeltable/func/function_registry.py +227 -0
  64. pixeltable/func/globals.py +36 -0
  65. pixeltable/func/nos_function.py +202 -0
  66. pixeltable/func/signature.py +166 -0
  67. pixeltable/func/udf.py +163 -0
  68. pixeltable/functions/__init__.py +52 -103
  69. pixeltable/functions/eval.py +216 -0
  70. pixeltable/functions/fireworks.py +34 -0
  71. pixeltable/functions/huggingface.py +120 -0
  72. pixeltable/functions/image.py +16 -0
  73. pixeltable/functions/openai.py +256 -0
  74. pixeltable/functions/pil/image.py +148 -7
  75. pixeltable/functions/string.py +13 -0
  76. pixeltable/functions/together.py +122 -0
  77. pixeltable/functions/util.py +41 -0
  78. pixeltable/functions/video.py +62 -0
  79. pixeltable/iterators/__init__.py +3 -0
  80. pixeltable/iterators/base.py +48 -0
  81. pixeltable/iterators/document.py +311 -0
  82. pixeltable/iterators/video.py +89 -0
  83. pixeltable/metadata/__init__.py +54 -0
  84. pixeltable/metadata/converters/convert_10.py +18 -0
  85. pixeltable/metadata/schema.py +211 -0
  86. pixeltable/plan.py +656 -0
  87. pixeltable/store.py +418 -182
  88. pixeltable/tests/conftest.py +146 -88
  89. pixeltable/tests/functions/test_fireworks.py +42 -0
  90. pixeltable/tests/functions/test_functions.py +60 -0
  91. pixeltable/tests/functions/test_huggingface.py +158 -0
  92. pixeltable/tests/functions/test_openai.py +152 -0
  93. pixeltable/tests/functions/test_together.py +111 -0
  94. pixeltable/tests/test_audio.py +65 -0
  95. pixeltable/tests/test_catalog.py +27 -0
  96. pixeltable/tests/test_client.py +14 -14
  97. pixeltable/tests/test_component_view.py +370 -0
  98. pixeltable/tests/test_dataframe.py +439 -0
  99. pixeltable/tests/test_dirs.py +78 -62
  100. pixeltable/tests/test_document.py +120 -0
  101. pixeltable/tests/test_exprs.py +592 -135
  102. pixeltable/tests/test_function.py +297 -67
  103. pixeltable/tests/test_migration.py +43 -0
  104. pixeltable/tests/test_nos.py +54 -0
  105. pixeltable/tests/test_snapshot.py +208 -0
  106. pixeltable/tests/test_table.py +1195 -263
  107. pixeltable/tests/test_transactional_directory.py +42 -0
  108. pixeltable/tests/test_types.py +5 -11
  109. pixeltable/tests/test_video.py +151 -34
  110. pixeltable/tests/test_view.py +530 -0
  111. pixeltable/tests/utils.py +320 -45
  112. pixeltable/tool/create_test_db_dump.py +149 -0
  113. pixeltable/tool/create_test_video.py +81 -0
  114. pixeltable/type_system.py +445 -124
  115. pixeltable/utils/__init__.py +17 -46
  116. pixeltable/utils/arrow.py +98 -0
  117. pixeltable/utils/clip.py +12 -15
  118. pixeltable/utils/coco.py +136 -0
  119. pixeltable/utils/documents.py +39 -0
  120. pixeltable/utils/filecache.py +195 -0
  121. pixeltable/utils/help.py +11 -0
  122. pixeltable/utils/hf_datasets.py +157 -0
  123. pixeltable/utils/media_store.py +76 -0
  124. pixeltable/utils/parquet.py +167 -0
  125. pixeltable/utils/pytorch.py +91 -0
  126. pixeltable/utils/s3.py +13 -0
  127. pixeltable/utils/sql.py +17 -0
  128. pixeltable/utils/transactional_directory.py +35 -0
  129. pixeltable-0.2.4.dist-info/LICENSE +18 -0
  130. pixeltable-0.2.4.dist-info/METADATA +127 -0
  131. pixeltable-0.2.4.dist-info/RECORD +132 -0
  132. {pixeltable-0.1.0.dist-info → pixeltable-0.2.4.dist-info}/WHEEL +1 -1
  133. pixeltable/catalog.py +0 -1421
  134. pixeltable/exprs.py +0 -1745
  135. pixeltable/function.py +0 -269
  136. pixeltable/functions/clip.py +0 -10
  137. pixeltable/functions/pil/__init__.py +0 -23
  138. pixeltable/functions/tf.py +0 -21
  139. pixeltable/index.py +0 -57
  140. pixeltable/tests/test_dict.py +0 -24
  141. pixeltable/tests/test_functions.py +0 -11
  142. pixeltable/tests/test_tf.py +0 -69
  143. pixeltable/tf.py +0 -33
  144. pixeltable/utils/tf.py +0 -33
  145. pixeltable/utils/video.py +0 -32
  146. pixeltable-0.1.0.dist-info/METADATA +0 -34
  147. pixeltable-0.1.0.dist-info/RECORD +0 -36
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ import concurrent.futures
4
+ import logging
5
+ import threading
6
+ import urllib.parse
7
+ import urllib.request
8
+ from collections import defaultdict
9
+ from pathlib import Path
10
+ from typing import List, Optional, Any, Tuple, Dict
11
+ from uuid import UUID
12
+
13
+ import pixeltable.env as env
14
+ import pixeltable.exceptions as excs
15
+ import pixeltable.exprs as exprs
16
+ from pixeltable.utils.filecache import FileCache
17
+ from .data_row_batch import DataRowBatch
18
+ from .exec_node import ExecNode
19
+
20
+ _logger = logging.getLogger('pixeltable')
21
+
22
+ class CachePrefetchNode(ExecNode):
23
+ """Brings files with external URLs into the cache
24
+
25
+ TODO:
26
+ - maintain a queue of row batches, in order to overlap download and evaluation
27
+ - adapting the number of download threads at runtime to maximize throughput
28
+ """
29
+ def __init__(self, tbl_id: UUID, file_col_info: List[exprs.ColumnSlotIdx], input: ExecNode):
30
+ # []: we don't have anything to evaluate
31
+ super().__init__(input.row_builder, [], [], input)
32
+ self.tbl_id = tbl_id
33
+ self.file_col_info = file_col_info
34
+
35
+ # clients for specific services are constructed as needed, because it's time-consuming
36
+ self.boto_client: Optional[Any] = None
37
+ self.boto_client_lock = threading.Lock()
38
+
39
+ def __next__(self) -> DataRowBatch:
40
+ input_batch = next(self.input)
41
+
42
+ # collect external URLs that aren't already cached, and set DataRow.file_paths for those that are
43
+ file_cache = FileCache.get()
44
+ cache_misses: List[Tuple[exprs.DataRow, exprs.ColumnSlotIdx]] = []
45
+ missing_url_rows: Dict[str, List[exprs.DataRow]] = defaultdict(list) # URL -> rows in which it's missing
46
+ for row in input_batch:
47
+ for info in self.file_col_info:
48
+ url = row.file_urls[info.slot_idx]
49
+ if url is None or row.file_paths[info.slot_idx] is not None:
50
+ # nothing to do
51
+ continue
52
+ if url in missing_url_rows:
53
+ missing_url_rows[url].append(row)
54
+ continue
55
+ local_path = file_cache.lookup(url)
56
+ if local_path is None:
57
+ cache_misses.append((row, info))
58
+ missing_url_rows[url].append(row)
59
+ else:
60
+ row.set_file_path(info.slot_idx, str(local_path))
61
+
62
+ # download the cache misses in parallel
63
+ # TODO: set max_workers to maximize throughput
64
+ futures: Dict[concurrent.futures.Future, Tuple[exprs.DataRow, exprs.ColumnSlotIdx]] = {}
65
+ with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
66
+ for row, info in cache_misses:
67
+ futures[executor.submit(self._fetch_url, row, info.slot_idx)] = (row, info)
68
+ for future in concurrent.futures.as_completed(futures):
69
+ # TODO: does this need to deal with recoverable errors (such as retry after throttling)?
70
+ tmp_path = future.result()
71
+ if tmp_path is None:
72
+ continue
73
+ row, info = futures[future]
74
+ url = row.file_urls[info.slot_idx]
75
+ local_path = file_cache.add(self.tbl_id, info.col.id, url, tmp_path)
76
+ _logger.debug(f'PrefetchNode: cached {url} as {local_path}')
77
+ for row in missing_url_rows[url]:
78
+ row.set_file_path(info.slot_idx, str(local_path))
79
+
80
+ return input_batch
81
+
82
+ def _fetch_url(self, row: exprs.DataRow, slot_idx: int) -> Optional[str]:
83
+ """Fetches a remote URL into Env.tmp_dir and returns its path"""
84
+ url = row.file_urls[slot_idx]
85
+ parsed = urllib.parse.urlparse(url)
86
+ # Use len(parsed.scheme) > 1 here to ensure we're not being passed
87
+ # a Windows filename
88
+ assert len(parsed.scheme) > 1 and parsed.scheme != 'file'
89
+ # preserve the file extension, if there is one
90
+ extension = ''
91
+ if parsed.path != '':
92
+ p = Path(urllib.parse.unquote(parsed.path))
93
+ extension = p.suffix
94
+ tmp_path = env.Env.get().create_tmp_path(extension=extension)
95
+ try:
96
+ if parsed.scheme == 's3':
97
+ from pixeltable.utils.s3 import get_client
98
+ with self.boto_client_lock:
99
+ if self.boto_client is None:
100
+ self.boto_client = get_client()
101
+ self.boto_client.download_file(parsed.netloc, parsed.path.lstrip('/'), str(tmp_path))
102
+ elif parsed.scheme == 'http' or parsed.scheme == 'https':
103
+ with urllib.request.urlopen(url) as resp, open(tmp_path, 'wb') as f:
104
+ data = resp.read()
105
+ f.write(data)
106
+ else:
107
+ assert False, f'Unsupported URL scheme: {parsed.scheme}'
108
+ return tmp_path
109
+ except Exception as e:
110
+ # we want to add the file url to the exception message
111
+ exc = excs.Error(f'Failed to download {url}: {e}')
112
+ self.row_builder.set_exc(row, slot_idx, exc)
113
+ if not self.ctx.ignore_errors:
114
+ raise exc from None # suppress original exception
115
+ return None
116
+
@@ -0,0 +1,79 @@
1
+ from typing import Generator, Optional
2
+
3
+ from .data_row_batch import DataRowBatch
4
+ from .exec_node import ExecNode
5
+ import pixeltable.catalog as catalog
6
+ import pixeltable.exprs as exprs
7
+ import pixeltable.exceptions as excs
8
+
9
+
10
+ class ComponentIterationNode(ExecNode):
11
+ """Expands each row from a base table into one row per component returned by an iterator
12
+
13
+ Returns row batches of OUTPUT_BATCH_SIZE size.
14
+ """
15
+ OUTPUT_BATCH_SIZE = 1024
16
+
17
+ def __init__(self, view: catalog.TableVersion, input: ExecNode):
18
+ assert view.is_component_view()
19
+ super().__init__(input.row_builder, [], [], input)
20
+ self.view = view
21
+ iterator_args = [view.iterator_args.copy()]
22
+ self.row_builder.substitute_exprs(iterator_args)
23
+ self.iterator_args = iterator_args[0]
24
+ assert isinstance(self.iterator_args, exprs.InlineDict)
25
+ self.iterator_args_ctx = self.row_builder.create_eval_ctx([self.iterator_args])
26
+ self.iterator_output_schema, self.unstored_column_names = \
27
+ self.view.iterator_cls.output_schema(**self.iterator_args.to_dict())
28
+ self.iterator_output_fields = list(self.iterator_output_schema.keys())
29
+ self.iterator_output_cols = \
30
+ {field_name: self.view.cols_by_name[field_name] for field_name in self.iterator_output_fields}
31
+ # referenced iterator output fields
32
+ self.refd_output_slot_idxs = {
33
+ e.col.name: e.slot_idx for e in self.row_builder.unique_exprs
34
+ if isinstance(e, exprs.ColumnRef) and e.col.name in self.iterator_output_fields
35
+ }
36
+ self._output: Optional[Generator[DataRowBatch, None, None]] = None
37
+
38
+ def _output_batches(self) -> Generator[DataRowBatch, None, None]:
39
+ output_batch = DataRowBatch(self.view, self.row_builder)
40
+ for input_batch in self.input:
41
+ for input_row in input_batch:
42
+ self.row_builder.eval(input_row, self.iterator_args_ctx)
43
+ iterator_args = input_row[self.iterator_args.slot_idx]
44
+ iterator = self.view.iterator_cls(**iterator_args)
45
+ for pos, component_dict in enumerate(iterator):
46
+ output_row = output_batch.add_row()
47
+ input_row.copy(output_row)
48
+ # we're expanding the input and need to add the iterator position to the pk
49
+ pk = output_row.pk[:-1] + (pos,) + output_row.pk[-1:]
50
+ output_row.set_pk(pk)
51
+
52
+ # verify and copy component_dict fields to their respective slots in output_row
53
+ for field_name, field_val in component_dict.items():
54
+ if field_name not in self.iterator_output_fields:
55
+ raise excs.Error(
56
+ f'Invalid field name {field_name} in output of {self.view.iterator_cls.__name__}')
57
+ if field_name not in self.refd_output_slot_idxs:
58
+ # we can ignore this
59
+ continue
60
+ output_col = self.iterator_output_cols[field_name]
61
+ output_col.col_type.validate_literal(field_val)
62
+ output_row[self.refd_output_slot_idxs[field_name]] = field_val
63
+ if len(component_dict) != len(self.iterator_output_fields):
64
+ missing_fields = set(self.refd_output_slot_idxs.keys()) - set(component_dict.keys())
65
+ raise excs.Error(
66
+ f'Invalid output of {self.view.iterator_cls.__name__}: '
67
+ f'missing fields {", ".join(missing_fields)}')
68
+
69
+ if len(output_batch) == self.OUTPUT_BATCH_SIZE:
70
+ yield output_batch
71
+ output_batch = DataRowBatch(self.view, self.row_builder)
72
+
73
+ if len(output_batch) > 0:
74
+ yield output_batch
75
+
76
+ def __next__(self) -> DataRowBatch:
77
+ if self._output is None:
78
+ self._output = self._output_batches()
79
+ return next(self._output)
@@ -0,0 +1,95 @@
1
+ from __future__ import annotations
2
+ from typing import List, Iterator, Optional
3
+ import logging
4
+
5
+ import pixeltable.exprs as exprs
6
+ import pixeltable.catalog as catalog
7
+ from pixeltable.utils.media_store import MediaStore
8
+
9
+
10
+ _logger = logging.getLogger('pixeltable')
11
+
12
+ class DataRowBatch:
13
+ """Set of DataRows, indexed by rowid.
14
+
15
+ Contains the metadata needed to initialize DataRows.
16
+ """
17
+ def __init__(self, tbl: catalog.TableVersion, row_builder: exprs.RowBuilder, len: int = 0):
18
+ self.tbl_id = tbl.id
19
+ self.tbl_version = tbl.version
20
+ self.row_builder = row_builder
21
+ self.img_slot_idxs = [e.slot_idx for e in row_builder.unique_exprs if e.col_type.is_image_type()]
22
+ # non-image media slots
23
+ self.media_slot_idxs = [
24
+ e.slot_idx for e in row_builder.unique_exprs
25
+ if e.col_type.is_media_type() and not e.col_type.is_image_type()
26
+ ]
27
+ self.array_slot_idxs = [e.slot_idx for e in row_builder.unique_exprs if e.col_type.is_array_type()]
28
+ self.rows = [
29
+ exprs.DataRow(row_builder.num_materialized, self.img_slot_idxs, self.media_slot_idxs, self.array_slot_idxs)
30
+ for _ in range(len)
31
+ ]
32
+
33
+ def add_row(self, row: Optional[exprs.DataRow] = None) -> exprs.DataRow:
34
+ if row is None:
35
+ row = exprs.DataRow(
36
+ self.row_builder.num_materialized, self.img_slot_idxs, self.media_slot_idxs, self.array_slot_idxs)
37
+ self.rows.append(row)
38
+ return row
39
+
40
+ def pop_row(self) -> exprs.DataRow:
41
+ return self.rows.pop()
42
+
43
+ def set_row_ids(self, row_ids: List[int]) -> None:
44
+ """Sets pks for rows in batch"""
45
+ assert len(row_ids) == len(self.rows)
46
+ for row, row_id in zip(self.rows, row_ids):
47
+ row.set_pk((row_id, self.tbl_version))
48
+
49
+ def __len__(self) -> int:
50
+ return len(self.rows)
51
+
52
+ def __getitem__(self, index: object) -> exprs.DataRow:
53
+ return self.rows[index]
54
+
55
+ def flush_imgs(
56
+ self, idx_range: Optional[slice] = None, stored_img_info: Optional[List[exprs.ColumnSlotIdx]] = None,
57
+ flushed_slot_idxs: Optional[List[int]] = None
58
+ ) -> None:
59
+ """Flushes images in the given range of rows."""
60
+ if stored_img_info is None:
61
+ stored_img_info = []
62
+ if flushed_slot_idxs is None:
63
+ flushed_slot_idxs = []
64
+ if len(stored_img_info) == 0 and len(flushed_slot_idxs) == 0:
65
+ return
66
+ if idx_range is None:
67
+ idx_range = slice(0, len(self.rows))
68
+ for row in self.rows[idx_range]:
69
+ for info in stored_img_info:
70
+ filepath = str(MediaStore.prepare_media_path(self.tbl_id, info.col.id, self.tbl_version))
71
+ row.flush_img(info.slot_idx, filepath)
72
+ for slot_idx in flushed_slot_idxs:
73
+ row.flush_img(slot_idx)
74
+ #_logger.debug(
75
+ #f'flushed images in range {idx_range}: slot_idxs={flushed_slot_idxs} stored_img_info={stored_img_info}')
76
+
77
+ def __iter__(self) -> Iterator[exprs.DataRow]:
78
+ return DataRowBatchIterator(self)
79
+
80
+
81
+ class DataRowBatchIterator:
82
+ """
83
+ Iterator over a DataRowBatch.
84
+ """
85
+ def __init__(self, batch: DataRowBatch):
86
+ self.row_batch = batch
87
+ self.index = 0
88
+
89
+ def __next__(self) -> exprs.DataRow:
90
+ if self.index >= len(self.row_batch.rows):
91
+ raise StopIteration
92
+ row = self.row_batch.rows[self.index]
93
+ self.index += 1
94
+ return row
95
+
@@ -0,0 +1,22 @@
1
+ from typing import Optional, List
2
+
3
+ import sqlalchemy as sql
4
+
5
+ import pixeltable.exprs as exprs
6
+
7
+ class ExecContext:
8
+ """Class for execution runtime constants"""
9
+ def __init__(
10
+ self, row_builder: exprs.RowBuilder, *, show_pbar: bool = False, batch_size: int = 0,
11
+ pk_clause: Optional[List[sql.ClauseElement]] = None, num_computed_exprs: int = 0,
12
+ ignore_errors: bool = False
13
+ ):
14
+ self.show_pbar = show_pbar
15
+ self.batch_size = batch_size
16
+ self.profile = exprs.ExecProfile(row_builder)
17
+ # num_rows is used to compute the total number of computed cells used for the progress bar
18
+ self.num_rows: Optional[int] = None
19
+ self.conn: Optional[sql.engine.Connection] = None # if present, use this to execute SQL queries
20
+ self.pk_clause = pk_clause
21
+ self.num_computed_exprs = num_computed_exprs
22
+ self.ignore_errors = ignore_errors
@@ -0,0 +1,61 @@
1
+ from __future__ import annotations
2
+ from typing import Iterable, Optional, List
3
+ import abc
4
+
5
+ from .data_row_batch import DataRowBatch
6
+ from .exec_context import ExecContext
7
+ import pixeltable.exprs as exprs
8
+
9
+ class ExecNode(abc.ABC):
10
+ """Base class of all execution nodes"""
11
+ def __init__(
12
+ self, row_builder: exprs.RowBuilder, output_exprs: Iterable[exprs.Expr],
13
+ input_exprs: Iterable[exprs.Expr], input: Optional[ExecNode] = None):
14
+ self.row_builder = row_builder
15
+ self.input = input
16
+ # we flush all image slots that aren't part of our output but are needed to create our output
17
+ output_slot_idxs = {e.slot_idx for e in output_exprs}
18
+ output_dependencies = row_builder.get_dependencies(output_exprs, exclude=input_exprs)
19
+ self.flushed_img_slots = [
20
+ e.slot_idx for e in output_dependencies
21
+ if e.col_type.is_image_type() and e.slot_idx not in output_slot_idxs
22
+ ]
23
+ self.stored_img_cols: List[exprs.ColumnSlotIdx] = []
24
+ self.ctx: Optional[ExecContext] = None # all nodes of a tree share the same context
25
+
26
+ def set_ctx(self, ctx: ExecContext) -> None:
27
+ self.ctx = ctx
28
+ if self.input is not None:
29
+ self.input.set_ctx(ctx)
30
+
31
+ def set_stored_img_cols(self, stored_img_cols: List[exprs.ColumnSlotIdx]) -> None:
32
+ self.stored_img_cols = stored_img_cols
33
+ # propagate batch size to the source
34
+ if self.input is not None:
35
+ self.input.set_stored_img_cols(stored_img_cols)
36
+
37
+ def __iter__(self):
38
+ return self
39
+
40
+ @abc.abstractmethod
41
+ def __next__(self) -> DataRowBatch:
42
+ pass
43
+
44
+ def open(self) -> None:
45
+ """Bottom-up initialization of nodes for execution. Must be called before __next__."""
46
+ if self.input is not None:
47
+ self.input.open()
48
+ self._open()
49
+
50
+ def close(self) -> None:
51
+ """Frees node resources top-down after execution. Must be called after final __next__."""
52
+ self._close()
53
+ if self.input is not None:
54
+ self.input.close()
55
+
56
+ def _open(self) -> None:
57
+ pass
58
+
59
+ def _close(self) -> None:
60
+ pass
61
+
@@ -0,0 +1,217 @@
1
+ import sys
2
+ import warnings
3
+ from typing import List, Optional, Tuple
4
+ from dataclasses import dataclass, field
5
+ import logging
6
+ import time
7
+
8
+ from tqdm import tqdm, TqdmWarning
9
+
10
+ from .data_row_batch import DataRowBatch
11
+ from .exec_node import ExecNode
12
+ import pixeltable.exprs as exprs
13
+ import pixeltable.func as func
14
+
15
+
16
+ _logger = logging.getLogger('pixeltable')
17
+
18
+ class ExprEvalNode(ExecNode):
19
+ """Materializes expressions
20
+ """
21
+ @dataclass
22
+ class Cohort:
23
+ """List of exprs that form an evaluation context and contain calls to at most one external function"""
24
+ exprs: List[exprs.Expr]
25
+ ext_function: Optional[func.BatchedFunction]
26
+ segment_ctxs: List[exprs.RowBuilder.EvalCtx]
27
+ target_slot_idxs: List[int]
28
+ batch_size: int = 8
29
+
30
+ def __init__(
31
+ self, row_builder: exprs.RowBuilder, output_exprs: List[exprs.Expr], input_exprs: List[exprs.Expr],
32
+ input: ExecNode
33
+ ):
34
+ super().__init__(row_builder, output_exprs, input_exprs, input)
35
+ self.input_exprs = input_exprs
36
+ input_slot_idxs = {e.slot_idx for e in input_exprs}
37
+ # we're only materializing exprs that are not already in the input
38
+ self.target_exprs = [e for e in output_exprs if e.slot_idx not in input_slot_idxs]
39
+ self.pbar: Optional[tqdm] = None
40
+ self.cohorts: List[List[ExprEvalNode.Cohort]] = []
41
+ self._create_cohorts()
42
+
43
+ def __next__(self) -> DataRowBatch:
44
+ input_batch = next(self.input)
45
+ # compute target exprs
46
+ for cohort in self.cohorts:
47
+ self._exec_cohort(cohort, input_batch)
48
+ _logger.debug(f'ExprEvalNode: returning {len(input_batch)} rows')
49
+ return input_batch
50
+
51
+ def _open(self) -> None:
52
+ warnings.simplefilter("ignore", category=TqdmWarning)
53
+ if self.ctx.show_pbar:
54
+ self.pbar = tqdm(
55
+ total=len(self.target_exprs) * self.ctx.num_rows,
56
+ desc='Computing cells',
57
+ unit=' cells',
58
+ ncols=100,
59
+ file=sys.stdout
60
+ )
61
+
62
+ def _close(self) -> None:
63
+ if self.pbar is not None:
64
+ self.pbar.close()
65
+
66
+ def _get_batched_fn(self, expr: exprs.Expr) -> Optional[func.BatchedFunction]:
67
+ if not isinstance(expr, exprs.FunctionCall):
68
+ return None
69
+ return expr.fn if isinstance(expr.fn, func.BatchedFunction) else None
70
+
71
+ def _is_ext_call(self, expr: exprs.Expr) -> bool:
72
+ return self._get_batched_fn(expr) is not None
73
+
74
+ def _create_cohorts(self) -> None:
75
+ all_exprs = self.row_builder.get_dependencies(self.target_exprs)
76
+ # break up all_exprs into cohorts such that each cohort contains calls to at most one external function;
77
+ # seed the cohorts with only the ext fn calls
78
+ cohorts: List[List[exprs.Expr]] = []
79
+ current_ext_function: Optional[func.BatchedFunction] = None
80
+ for e in all_exprs:
81
+ if not self._is_ext_call(e):
82
+ continue
83
+ if current_ext_function is None or current_ext_function != e.fn:
84
+ # create a new cohort
85
+ cohorts.append([])
86
+ current_ext_function = e.fn
87
+ cohorts[-1].append(e)
88
+
89
+ # expand the cohorts to include all exprs that are in the same evaluation context as the external calls;
90
+ # cohorts are evaluated in order, so we can exclude the target slots from preceding cohorts and input slots
91
+ exclude = set([e.slot_idx for e in self.input_exprs])
92
+ all_target_slot_idxs = set([e.slot_idx for e in self.target_exprs])
93
+ target_slot_idxs: List[List[int]] = [] # the ones materialized by each cohort
94
+ for i in range(len(cohorts)):
95
+ cohorts[i] = self.row_builder.get_dependencies(
96
+ cohorts[i], exclude=[self.row_builder.unique_exprs[slot_idx] for slot_idx in exclude])
97
+ target_slot_idxs.append(
98
+ [e.slot_idx for e in cohorts[i] if e.slot_idx in all_target_slot_idxs])
99
+ exclude.update(target_slot_idxs[-1])
100
+
101
+ all_cohort_slot_idxs = set([e.slot_idx for cohort in cohorts for e in cohort])
102
+ remaining_slot_idxs = set(all_target_slot_idxs) - all_cohort_slot_idxs
103
+ if len(remaining_slot_idxs) > 0:
104
+ cohorts.append(self.row_builder.get_dependencies(
105
+ [self.row_builder.unique_exprs[slot_idx] for slot_idx in remaining_slot_idxs],
106
+ exclude=[self.row_builder.unique_exprs[slot_idx] for slot_idx in exclude]))
107
+ target_slot_idxs.append(list(remaining_slot_idxs))
108
+ # we need to have captured all target slots at this point
109
+ assert all_target_slot_idxs == set().union(*target_slot_idxs)
110
+
111
+ for i in range(len(cohorts)):
112
+ cohort = cohorts[i]
113
+ # segment the cohort into sublists that contain either a single ext. function call or no ext. function calls
114
+ # (i.e., only computed cols)
115
+ assert len(cohort) > 0
116
+ # create the first segment here, so we can avoid checking for an empty list in the loop
117
+ segments = [[cohort[0]]]
118
+ is_ext_segment = self._is_ext_call(cohort[0])
119
+ ext_fn: Optional[func.BatchedFunction] = self._get_batched_fn(cohort[0])
120
+ for e in cohort[1:]:
121
+ if self._is_ext_call(e):
122
+ segments.append([e])
123
+ is_ext_segment = True
124
+ ext_fn = self._get_batched_fn(e)
125
+ else:
126
+ if is_ext_segment:
127
+ # start a new segment
128
+ segments.append([])
129
+ is_ext_segment = False
130
+ segments[-1].append(e)
131
+
132
+ # we create the EvalCtxs manually because create_eval_ctx() would repeat the dependencies of each segment
133
+ segment_ctxs = [
134
+ exprs.RowBuilder.EvalCtx(
135
+ slot_idxs=[e.slot_idx for e in s], exprs=s, target_slot_idxs=[], target_exprs=[])
136
+ for s in segments
137
+ ]
138
+ cohort_info = self.Cohort(cohort, ext_fn, segment_ctxs, target_slot_idxs[i])
139
+ self.cohorts.append(cohort_info)
140
+
141
+ def _exec_cohort(self, cohort: Cohort, rows: DataRowBatch) -> None:
142
+ """Compute the cohort for the entire input batch by dividing it up into sub-batches"""
143
+ batch_start_idx = 0 # start row of the current sub-batch
144
+ # for multi-resolution models, we re-assess the correct ext fn batch size for each input batch
145
+ ext_batch_size = cohort.ext_function.get_batch_size() if cohort.ext_function is not None else None
146
+ if ext_batch_size is not None:
147
+ cohort.batch_size = ext_batch_size
148
+
149
+ while batch_start_idx < len(rows):
150
+ num_batch_rows = min(cohort.batch_size, len(rows) - batch_start_idx)
151
+ for segment_ctx in cohort.segment_ctxs:
152
+ if not self._is_ext_call(segment_ctx.exprs[0]):
153
+ # compute batch row-wise
154
+ for row_idx in range(batch_start_idx, batch_start_idx + num_batch_rows):
155
+ self.row_builder.eval(
156
+ rows[row_idx], segment_ctx, self.ctx.profile, ignore_errors=self.ctx.ignore_errors)
157
+ else:
158
+ fn_call = segment_ctx.exprs[0]
159
+ # make a batched external function call
160
+ arg_batches = [[] for _ in range(len(fn_call.args))]
161
+ kwarg_batches = {k: [] for k in fn_call.kwargs.keys()}
162
+
163
+ valid_batch_idxs: List[int] = [] # rows with exceptions are not valid
164
+ for row_idx in range(batch_start_idx, batch_start_idx + num_batch_rows):
165
+ row = rows[row_idx]
166
+ if row.has_exc(fn_call.slot_idx):
167
+ # one of our inputs had an exception, skip this row
168
+ continue
169
+ valid_batch_idxs.append(row_idx)
170
+ args, kwargs = fn_call._make_args(row)
171
+ [arg_batches[i].append(args[i]) for i in range(len(args))]
172
+ [kwarg_batches[k].append(kwargs[k]) for k in kwargs.keys()]
173
+ num_valid_batch_rows = len(valid_batch_idxs)
174
+
175
+ if ext_batch_size is None:
176
+ # we need to choose a batch size based on the args
177
+ sample_args = [arg_batches[i][0] for i in range(len(arg_batches))]
178
+ ext_batch_size = fn_call.fn.get_batch_size(*sample_args)
179
+
180
+ num_remaining_batch_rows = num_valid_batch_rows
181
+ while num_remaining_batch_rows > 0:
182
+ # we make ext. fn calls in batches of ext_batch_size
183
+ if ext_batch_size is None:
184
+ pass
185
+ num_ext_batch_rows = min(ext_batch_size, num_remaining_batch_rows)
186
+ ext_batch_offset = num_valid_batch_rows - num_remaining_batch_rows # offset into args, not rows
187
+ call_args = [
188
+ arg_batches[i][ext_batch_offset:ext_batch_offset + num_ext_batch_rows]
189
+ for i in range(len(arg_batches))
190
+ ]
191
+ call_kwargs = {
192
+ k: kwarg_batches[k][ext_batch_offset:ext_batch_offset + num_ext_batch_rows]
193
+ for k in kwarg_batches.keys()
194
+ }
195
+ start_ts = time.perf_counter()
196
+ result_batch = fn_call.fn.invoke(call_args, call_kwargs)
197
+ self.ctx.profile.eval_time[fn_call.slot_idx] += time.perf_counter() - start_ts
198
+ self.ctx.profile.eval_count[fn_call.slot_idx] += num_ext_batch_rows
199
+
200
+ # move the result into the row batch
201
+ for result_idx in range(len(result_batch)):
202
+ row_idx = valid_batch_idxs[ext_batch_offset + result_idx]
203
+ row = rows[row_idx]
204
+ row[fn_call.slot_idx] = result_batch[result_idx]
205
+
206
+ num_remaining_batch_rows -= num_ext_batch_rows
207
+
208
+ # switch to the ext fn batch size
209
+ cohort.batch_size = ext_batch_size
210
+
211
+ # make sure images for stored cols have been saved to files before moving on to the next batch
212
+ rows.flush_imgs(
213
+ slice(batch_start_idx, batch_start_idx + num_batch_rows), self.stored_img_cols, self.flushed_img_slots)
214
+ if self.pbar is not None:
215
+ self.pbar.update(num_batch_rows * len(cohort.target_slot_idxs))
216
+ batch_start_idx += num_batch_rows
217
+
@@ -0,0 +1,69 @@
1
+ from typing import List, Dict, Any, Optional
2
+ import urllib
3
+ import logging
4
+ import os
5
+
6
+ from .data_row_batch import DataRowBatch
7
+ from .exec_node import ExecNode
8
+ import pixeltable.catalog as catalog
9
+ import pixeltable.exprs as exprs
10
+ import pixeltable.env as env
11
+ from pixeltable.utils.media_store import MediaStore
12
+
13
+
14
+ _logger = logging.getLogger('pixeltable')
15
+
16
+ class InMemoryDataNode(ExecNode):
17
+ """Outputs in-memory data as a row batch of a particular table"""
18
+ def __init__(
19
+ self, tbl: catalog.TableVersionPath, rows: List[Dict[str, Any]],
20
+ row_builder: exprs.RowBuilder, start_row_id: int,
21
+ ):
22
+ super().__init__(row_builder, [], [], None)
23
+ assert tbl.is_insertable()
24
+ self.tbl = tbl
25
+ self.input_rows = rows
26
+ self.start_row_id = start_row_id
27
+ self.has_returned_data = False
28
+ self.output_rows: Optional[DataRowBatch] = None
29
+
30
+ def _open(self) -> None:
31
+ """Create row batch and populate with self.input_rows"""
32
+ column_info = {info.col.name: info for info in self.row_builder.output_slot_idxs()}
33
+ # stored columns that are not computed
34
+ inserted_column_names = set([
35
+ info.col.name for info in self.row_builder.output_slot_idxs()
36
+ if info.col.is_stored and not info.col.is_computed
37
+ ])
38
+
39
+ self.output_rows = DataRowBatch(self.tbl, self.row_builder, len(self.input_rows))
40
+ for row_idx, input_row in enumerate(self.input_rows):
41
+ # populate the output row with the values provided in the input row
42
+ for col_name, val in input_row.items():
43
+ col_info = column_info.get(col_name)
44
+ assert col_info is not None
45
+
46
+ if col_info.col.col_type.is_image_type() and isinstance(val, bytes):
47
+ # this is a literal image, ie, a sequence of bytes; we save this as a media file and store the path
48
+ path = str(MediaStore.prepare_media_path(self.tbl.id, col_info.col.id, self.tbl.version))
49
+ open(path, 'wb').write(val)
50
+ val = path
51
+ self.output_rows[row_idx][col_info.slot_idx] = val
52
+
53
+ # set the remaining stored non-computed columns to null
54
+ null_col_names = inserted_column_names - set(input_row.keys())
55
+ for col_name in null_col_names:
56
+ col_info = column_info.get(col_name)
57
+ assert col_info is not None
58
+ self.output_rows[row_idx][col_info.slot_idx] = None
59
+
60
+ self.output_rows.set_row_ids([self.start_row_id + i for i in range(len(self.output_rows))])
61
+ self.ctx.num_rows = len(self.output_rows)
62
+
63
+ def __next__(self) -> DataRowBatch:
64
+ if self.has_returned_data:
65
+ raise StopIteration
66
+ self.has_returned_data = True
67
+ _logger.debug(f'InMemoryDataNode: created row batch with {len(self.output_rows)} output_rows')
68
+ return self.output_rows
69
+