pixeltable 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (119) hide show
  1. pixeltable/__init__.py +53 -0
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/__init__.py +13 -0
  4. pixeltable/catalog/catalog.py +159 -0
  5. pixeltable/catalog/column.py +181 -0
  6. pixeltable/catalog/dir.py +32 -0
  7. pixeltable/catalog/globals.py +33 -0
  8. pixeltable/catalog/insertable_table.py +192 -0
  9. pixeltable/catalog/named_function.py +36 -0
  10. pixeltable/catalog/path.py +58 -0
  11. pixeltable/catalog/path_dict.py +139 -0
  12. pixeltable/catalog/schema_object.py +39 -0
  13. pixeltable/catalog/table.py +695 -0
  14. pixeltable/catalog/table_version.py +1026 -0
  15. pixeltable/catalog/table_version_path.py +133 -0
  16. pixeltable/catalog/view.py +203 -0
  17. pixeltable/dataframe.py +749 -0
  18. pixeltable/env.py +466 -0
  19. pixeltable/exceptions.py +17 -0
  20. pixeltable/exec/__init__.py +10 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +116 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +94 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +73 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +226 -0
  31. pixeltable/exprs/__init__.py +25 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +114 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +199 -0
  39. pixeltable/exprs/expr.py +594 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +382 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +96 -0
  44. pixeltable/exprs/in_predicate.py +96 -0
  45. pixeltable/exprs/inline_array.py +109 -0
  46. pixeltable/exprs/inline_dict.py +103 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +66 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +329 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/similarity_expr.py +65 -0
  56. pixeltable/exprs/type_cast.py +53 -0
  57. pixeltable/exprs/variable.py +45 -0
  58. pixeltable/ext/__init__.py +5 -0
  59. pixeltable/ext/functions/yolox.py +92 -0
  60. pixeltable/func/__init__.py +7 -0
  61. pixeltable/func/aggregate_function.py +197 -0
  62. pixeltable/func/callable_function.py +113 -0
  63. pixeltable/func/expr_template_function.py +99 -0
  64. pixeltable/func/function.py +141 -0
  65. pixeltable/func/function_registry.py +227 -0
  66. pixeltable/func/globals.py +46 -0
  67. pixeltable/func/nos_function.py +202 -0
  68. pixeltable/func/signature.py +162 -0
  69. pixeltable/func/udf.py +164 -0
  70. pixeltable/functions/__init__.py +95 -0
  71. pixeltable/functions/eval.py +215 -0
  72. pixeltable/functions/fireworks.py +34 -0
  73. pixeltable/functions/huggingface.py +167 -0
  74. pixeltable/functions/image.py +16 -0
  75. pixeltable/functions/openai.py +289 -0
  76. pixeltable/functions/pil/image.py +147 -0
  77. pixeltable/functions/string.py +13 -0
  78. pixeltable/functions/together.py +143 -0
  79. pixeltable/functions/util.py +52 -0
  80. pixeltable/functions/video.py +62 -0
  81. pixeltable/globals.py +425 -0
  82. pixeltable/index/__init__.py +2 -0
  83. pixeltable/index/base.py +51 -0
  84. pixeltable/index/embedding_index.py +168 -0
  85. pixeltable/io/__init__.py +3 -0
  86. pixeltable/io/hf_datasets.py +188 -0
  87. pixeltable/io/pandas.py +148 -0
  88. pixeltable/io/parquet.py +192 -0
  89. pixeltable/iterators/__init__.py +3 -0
  90. pixeltable/iterators/base.py +52 -0
  91. pixeltable/iterators/document.py +432 -0
  92. pixeltable/iterators/video.py +88 -0
  93. pixeltable/metadata/__init__.py +58 -0
  94. pixeltable/metadata/converters/convert_10.py +18 -0
  95. pixeltable/metadata/converters/convert_12.py +3 -0
  96. pixeltable/metadata/converters/convert_13.py +41 -0
  97. pixeltable/metadata/schema.py +234 -0
  98. pixeltable/plan.py +620 -0
  99. pixeltable/store.py +424 -0
  100. pixeltable/tool/create_test_db_dump.py +184 -0
  101. pixeltable/tool/create_test_video.py +81 -0
  102. pixeltable/type_system.py +846 -0
  103. pixeltable/utils/__init__.py +17 -0
  104. pixeltable/utils/arrow.py +98 -0
  105. pixeltable/utils/clip.py +18 -0
  106. pixeltable/utils/coco.py +136 -0
  107. pixeltable/utils/documents.py +69 -0
  108. pixeltable/utils/filecache.py +195 -0
  109. pixeltable/utils/help.py +11 -0
  110. pixeltable/utils/http_server.py +70 -0
  111. pixeltable/utils/media_store.py +76 -0
  112. pixeltable/utils/pytorch.py +91 -0
  113. pixeltable/utils/s3.py +13 -0
  114. pixeltable/utils/sql.py +17 -0
  115. pixeltable/utils/transactional_directory.py +35 -0
  116. pixeltable-0.0.0.dist-info/LICENSE +18 -0
  117. pixeltable-0.0.0.dist-info/METADATA +131 -0
  118. pixeltable-0.0.0.dist-info/RECORD +119 -0
  119. pixeltable-0.0.0.dist-info/WHEEL +4 -0
pixeltable/env.py ADDED
@@ -0,0 +1,466 @@
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import glob
5
+ import http.server
6
+ import importlib
7
+ import importlib.util
8
+ import logging
9
+ import os
10
+ import socketserver
11
+ import sys
12
+ import threading
13
+ import uuid
14
+ import warnings
15
+ from pathlib import Path
16
+ from typing import Callable, Optional, Dict, Any, List
17
+
18
+ import pgserver
19
+ import sqlalchemy as sql
20
+ import yaml
21
+ from sqlalchemy_utils.functions import database_exists, create_database, drop_database
22
+ from tqdm import TqdmWarning
23
+
24
+ import pixeltable.exceptions as excs
25
+ from pixeltable import metadata
26
+ from pixeltable.utils.http_server import make_server
27
+
28
+
29
+ class Env:
30
+ """
31
+ Store for runtime globals.
32
+ """
33
+
34
+ _instance: Optional[Env] = None
35
+ _log_fmt_str = '%(asctime)s %(levelname)s %(name)s %(filename)s:%(lineno)d: %(message)s'
36
+
37
+ @classmethod
38
+ def get(cls) -> Env:
39
+ if cls._instance is None:
40
+ env = Env()
41
+ env._set_up()
42
+ env._upgrade_metadata()
43
+ cls._instance = env
44
+ return cls._instance
45
+
46
+ def __init__(self):
47
+ self._home: Optional[Path] = None
48
+ self._media_dir: Optional[Path] = None # computed media files
49
+ self._file_cache_dir: Optional[Path] = None # cached media files with external URL
50
+ self._dataset_cache_dir: Optional[Path] = None # cached datasets (eg, pytorch or COCO)
51
+ self._log_dir: Optional[Path] = None # log files
52
+ self._tmp_dir: Optional[Path] = None # any tmp files
53
+ self._sa_engine: Optional[sql.engine.base.Engine] = None
54
+ self._pgdata_dir: Optional[Path] = None
55
+ self._db_name: Optional[str] = None
56
+ self._db_server: Optional[pgserver.PostgresServer] = None
57
+ self._db_url: Optional[str] = None
58
+
59
+ # info about installed packages that are utilized by some parts of the code;
60
+ # package name -> version; version == []: package is installed, but we haven't determined the version yet
61
+ self._installed_packages: Dict[str, Optional[List[int]]] = {}
62
+ self._nos_client: Optional[Any] = None
63
+ self._spacy_nlp: Optional[Any] = None # spacy.Language
64
+ self._httpd: Optional[http.server.ThreadingHTTPServer] = None
65
+ self._http_address: Optional[str] = None
66
+
67
+ self._registered_clients: dict[str, Any] = {}
68
+
69
+ # logging-related state
70
+ self._logger = logging.getLogger('pixeltable')
71
+ self._logger.setLevel(logging.DEBUG) # allow everything to pass, we filter in _log_filter()
72
+ self._logger.propagate = False
73
+ self._logger.addFilter(self._log_filter)
74
+ self._default_log_level = logging.INFO
75
+ self._logfilename: Optional[str] = None
76
+ self._log_to_stdout = False
77
+ self._module_log_level: Dict[str, int] = {} # module name -> log level
78
+
79
+ # config
80
+ self._config_file: Optional[Path] = None
81
+ self._config: Optional[Dict[str, Any]] = None
82
+
83
+ # create logging handler to also log to stdout
84
+ self._stdout_handler = logging.StreamHandler(stream=sys.stdout)
85
+ self._stdout_handler.setFormatter(logging.Formatter(self._log_fmt_str))
86
+ self._initialized = False
87
+
88
+ @property
89
+ def config(self):
90
+ return self._config
91
+
92
+ @property
93
+ def db_url(self) -> str:
94
+ assert self._db_url is not None
95
+ return self._db_url
96
+
97
+ @property
98
+ def http_address(self) -> str:
99
+ assert self._http_address is not None
100
+ return self._http_address
101
+
102
+ def configure_logging(
103
+ self,
104
+ *,
105
+ to_stdout: Optional[bool] = None,
106
+ level: Optional[int] = None,
107
+ add: Optional[str] = None,
108
+ remove: Optional[str] = None,
109
+ ) -> None:
110
+ """Configure logging.
111
+
112
+ Args:
113
+ to_stdout: if True, also log to stdout
114
+ level: default log level
115
+ add: comma-separated list of 'module name:log level' pairs; ex.: add='video:10'
116
+ remove: comma-separated list of module names
117
+ """
118
+ if to_stdout is not None:
119
+ self.log_to_stdout(to_stdout)
120
+ if level is not None:
121
+ self.set_log_level(level)
122
+ if add is not None:
123
+ for module, level in [t.split(':') for t in add.split(',')]:
124
+ self.set_module_log_level(module, int(level))
125
+ if remove is not None:
126
+ for module in remove.split(','):
127
+ self.set_module_log_level(module, None)
128
+ if to_stdout is None and level is None and add is None and remove is None:
129
+ self.print_log_config()
130
+
131
+ def print_log_config(self) -> None:
132
+ print(f'logging to {self._logfilename}')
133
+ print(f'{"" if self._log_to_stdout else "not "}logging to stdout')
134
+ print(f'default log level: {logging.getLevelName(self._default_log_level)}')
135
+ print(
136
+ f'module log levels: '
137
+ f'{",".join([name + ":" + logging.getLevelName(val) for name, val in self._module_log_level.items()])}'
138
+ )
139
+
140
+ def log_to_stdout(self, enable: bool = True) -> None:
141
+ self._log_to_stdout = enable
142
+ if enable:
143
+ self._logger.addHandler(self._stdout_handler)
144
+ else:
145
+ self._logger.removeHandler(self._stdout_handler)
146
+
147
+ def set_log_level(self, level: int) -> None:
148
+ self._default_log_level = level
149
+
150
+ def set_module_log_level(self, module: str, level: Optional[int]) -> None:
151
+ if level is None:
152
+ self._module_log_level.pop(module, None)
153
+ else:
154
+ self._module_log_level[module] = level
155
+
156
+ def is_installed_package(self, package_name: str) -> bool:
157
+ return self._installed_packages[package_name] is not None
158
+
159
+ def _log_filter(self, record: logging.LogRecord) -> bool:
160
+ if record.name == 'pixeltable':
161
+ # accept log messages from a configured pixeltable module (at any level of the module hierarchy)
162
+ path_parts = list(Path(record.pathname).parts)
163
+ path_parts.reverse()
164
+ max_idx = path_parts.index('pixeltable')
165
+ for module_name in path_parts[:max_idx]:
166
+ if module_name in self._module_log_level and record.levelno >= self._module_log_level[module_name]:
167
+ return True
168
+ if record.levelno >= self._default_log_level:
169
+ return True
170
+ else:
171
+ return False
172
+
173
+ def _set_up(self, echo: bool = False, reinit_db: bool = False) -> None:
174
+ if self._initialized:
175
+ return
176
+
177
+ # Disable spurious warnings
178
+ warnings.simplefilter('ignore', category=TqdmWarning)
179
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
180
+
181
+ self._initialized = True
182
+ home = Path(os.environ.get('PIXELTABLE_HOME', str(Path.home() / '.pixeltable')))
183
+ assert self._home is None or self._home == home
184
+ self._home = home
185
+ self._config_file = Path(os.environ.get('PIXELTABLE_CONFIG', str(self._home / 'config.yaml')))
186
+ self._media_dir = self._home / 'media'
187
+ self._file_cache_dir = self._home / 'file_cache'
188
+ self._dataset_cache_dir = self._home / 'dataset_cache'
189
+ self._log_dir = self._home / 'logs'
190
+ self._tmp_dir = self._home / 'tmp'
191
+
192
+ # Read in the config
193
+ if os.path.isfile(self._config_file):
194
+ with open(self._config_file, 'r') as stream:
195
+ try:
196
+ self._config = yaml.safe_load(stream)
197
+ except yaml.YAMLError as exc:
198
+ self._logger.error(f'Could not read config file: {self._config_file}')
199
+ self._config = {}
200
+ else:
201
+ self._config = {}
202
+
203
+ if self._home.exists() and not self._home.is_dir():
204
+ raise RuntimeError(f'{self._home} is not a directory')
205
+
206
+ if not self._home.exists():
207
+ # we don't have our logger set up yet, so print to stdout
208
+ print(f'Creating a Pixeltable instance at: {self._home}')
209
+ self._home.mkdir()
210
+ # TODO (aaron-siegel) This is the existing behavior, but it seems scary. If something happens to
211
+ # self._home, it will cause the DB to be destroyed even if pgdata is in an alternate location.
212
+ # PROPOSAL: require `reinit_db` to be set explicitly to destroy the DB.
213
+ reinit_db = True
214
+
215
+ if not self._media_dir.exists():
216
+ self._media_dir.mkdir()
217
+ if not self._file_cache_dir.exists():
218
+ self._file_cache_dir.mkdir()
219
+ if not self._dataset_cache_dir.exists():
220
+ self._dataset_cache_dir.mkdir()
221
+ if not self._log_dir.exists():
222
+ self._log_dir.mkdir()
223
+ if not self._tmp_dir.exists():
224
+ self._tmp_dir.mkdir()
225
+
226
+ # configure _logger to log to a file
227
+ self._logfilename = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + '.log'
228
+ fh = logging.FileHandler(self._log_dir / self._logfilename, mode='w')
229
+ fh.setFormatter(logging.Formatter(self._log_fmt_str))
230
+ self._logger.addHandler(fh)
231
+
232
+ # configure sqlalchemy logging
233
+ sql_logger = logging.getLogger('sqlalchemy.engine')
234
+ sql_logger.setLevel(logging.INFO)
235
+ sql_logger.addHandler(fh)
236
+ sql_logger.propagate = False
237
+
238
+ # configure pyav logging
239
+ av_logfilename = self._logfilename.replace('.log', '_av.log')
240
+ av_fh = logging.FileHandler(self._log_dir / av_logfilename, mode='w')
241
+ av_fh.setFormatter(logging.Formatter(self._log_fmt_str))
242
+ av_logger = logging.getLogger('libav')
243
+ av_logger.addHandler(av_fh)
244
+ av_logger.propagate = False
245
+
246
+ # configure web-server logging
247
+ http_logfilename = self._logfilename.replace('.log', '_http.log')
248
+ http_fh = logging.FileHandler(self._log_dir / http_logfilename, mode='w')
249
+ http_fh.setFormatter(logging.Formatter(self._log_fmt_str))
250
+ http_logger = logging.getLogger('pixeltable.http.server')
251
+ http_logger.addHandler(http_fh)
252
+ http_logger.propagate = False
253
+
254
+ # empty tmp dir
255
+ for path in glob.glob(f'{self._tmp_dir}/*'):
256
+ os.remove(path)
257
+
258
+ self._db_name = os.environ.get('PIXELTABLE_DB', 'pixeltable')
259
+ self._pgdata_dir = Path(os.environ.get('PIXELTABLE_PGDATA', str(self._home / 'pgdata')))
260
+
261
+ # in pgserver.get_server(): cleanup_mode=None will leave db on for debugging purposes
262
+ self._db_server = pgserver.get_server(self._pgdata_dir, cleanup_mode=None)
263
+ self._db_url = self._db_server.get_uri(database=self._db_name)
264
+
265
+ if reinit_db:
266
+ if database_exists(self.db_url):
267
+ drop_database(self.db_url)
268
+
269
+ if not database_exists(self.db_url):
270
+ self._logger.info(f'creating database at {self.db_url}')
271
+ create_database(self.db_url)
272
+ self._sa_engine = sql.create_engine(self.db_url, echo=echo, future=True)
273
+ from pixeltable.metadata import schema
274
+
275
+ schema.Base.metadata.create_all(self._sa_engine)
276
+ metadata.create_system_info(self._sa_engine)
277
+ # enable pgvector
278
+ with self._sa_engine.begin() as conn:
279
+ conn.execute(sql.text('CREATE EXTENSION vector'))
280
+ else:
281
+ self._logger.info(f'found database {self.db_url}')
282
+ if self._sa_engine is None:
283
+ self._sa_engine = sql.create_engine(self.db_url, echo=echo, future=True)
284
+
285
+ print(f'Connected to Pixeltable database at: {self.db_url}')
286
+
287
+ # we now have a home directory and db; start other services
288
+ self._set_up_runtime()
289
+ self.log_to_stdout(False)
290
+
291
+ def _upgrade_metadata(self) -> None:
292
+ metadata.upgrade_md(self._sa_engine)
293
+
294
+ def _create_nos_client(self) -> None:
295
+ import nos
296
+
297
+ self._logger.info('connecting to NOS')
298
+ nos.init(logging_level=logging.DEBUG)
299
+ self._nos_client = nos.client.InferenceClient()
300
+ self._logger.info('waiting for NOS')
301
+ self._nos_client.WaitForServer()
302
+
303
+ # now that we have a client, we can create the module
304
+ import importlib
305
+
306
+ try:
307
+ importlib.import_module('pixeltable.functions.nos')
308
+ # it's already been created
309
+ return
310
+ except ImportError:
311
+ pass
312
+ from pixeltable.functions.util import create_nos_modules
313
+
314
+ _ = create_nos_modules()
315
+
316
+ def get_client(self, name: str, init: Callable, environ: Optional[str] = None) -> Any:
317
+ """
318
+ Gets the client with the specified name, using `init` to construct one if necessary.
319
+
320
+ - name: The name of the client
321
+ - init: A `Callable` with signature `fn(api_key: str) -> Any` that constructs a client object
322
+ - environ: The name of the environment variable to use for the API key, if no API key is found in config
323
+ (defaults to f'{name.upper()}_API_KEY')
324
+ """
325
+ if name in self._registered_clients:
326
+ return self._registered_clients[name]
327
+
328
+ if environ is None:
329
+ environ = f'{name.upper()}_API_KEY'
330
+
331
+ if name in self._config and 'api_key' in self._config[name]:
332
+ api_key = self._config[name]['api_key']
333
+ else:
334
+ api_key = os.environ.get(environ)
335
+ if api_key is None or api_key == '':
336
+ raise excs.Error(f'`{name}` client not initialized (no API key configured).')
337
+
338
+ client = init(api_key)
339
+ self._registered_clients[name] = client
340
+ self._logger.info(f'Initialized `{name}` client.')
341
+ return client
342
+
343
+ def _start_web_server(self) -> None:
344
+ """
345
+ The http server root is the file system root.
346
+ eg: /home/media/foo.mp4 is located at http://127.0.0.1:{port}/home/media/foo.mp4
347
+ in windows, the server will translate paths like http://127.0.0.1:{port}/c:/media/foo.mp4
348
+ This arrangement enables serving media hosted within _home,
349
+ as well as external media inserted into pixeltable or produced by pixeltable.
350
+ The port is chosen dynamically to prevent conflicts.
351
+ """
352
+ # Port 0 means OS picks one for us.
353
+ self._httpd = make_server('127.0.0.1', 0)
354
+ port = self._httpd.server_address[1]
355
+ self._http_address = f'http://127.0.0.1:{port}'
356
+
357
+ def run_server():
358
+ logging.log(logging.INFO, f'running web server at {self._http_address}')
359
+ self._httpd.serve_forever()
360
+
361
+ # Run the server in a separate thread
362
+ thread = threading.Thread(target=run_server, daemon=True)
363
+ thread.start()
364
+
365
+ def _set_up_runtime(self) -> None:
366
+ """Check for and start runtime services"""
367
+ self._start_web_server()
368
+ self._check_installed_packages()
369
+
370
+ def _check_installed_packages(self) -> None:
371
+ def check(package: str) -> None:
372
+ if importlib.util.find_spec(package) is not None:
373
+ self._installed_packages[package] = []
374
+ else:
375
+ self._installed_packages[package] = None
376
+
377
+ check('datasets')
378
+ check('torch')
379
+ check('torchvision')
380
+ check('transformers')
381
+ check('sentence_transformers')
382
+ check('yolox')
383
+ check('boto3')
384
+ check('fitz') # pymupdf
385
+ check('pyarrow')
386
+ check('spacy') # TODO: deal with en-core-web-sm
387
+ if self.is_installed_package('spacy'):
388
+ import spacy
389
+
390
+ self._spacy_nlp = spacy.load('en_core_web_sm')
391
+ check('tiktoken')
392
+ check('openai')
393
+ check('together')
394
+ check('fireworks')
395
+ check('nos')
396
+ if self.is_installed_package('nos'):
397
+ self._create_nos_client()
398
+ check('openpyxl')
399
+
400
+ def require_package(self, package: str, min_version: Optional[List[int]] = None) -> None:
401
+ assert package in self._installed_packages
402
+ if self._installed_packages[package] is None:
403
+ raise excs.Error(f'Package {package} is not installed')
404
+ if min_version is None:
405
+ return
406
+
407
+ # check whether we have a version >= the required one
408
+ if self._installed_packages[package] == []:
409
+ m = importlib.import_module(package)
410
+ module_version = [int(x) for x in m.__version__.split('.')]
411
+ self._installed_packages[package] = module_version
412
+ installed_version = self._installed_packages[package]
413
+ if len(min_version) < len(installed_version):
414
+ normalized_min_version = min_version + [0] * (len(installed_version) - len(min_version))
415
+ if any([a < b for a, b in zip(installed_version, normalized_min_version)]):
416
+ raise excs.Error(
417
+ (
418
+ f'The installed version of package {package} is {".".join([str[v] for v in installed_version])}, '
419
+ f'but version >={".".join([str[v] for v in min_version])} is required'
420
+ )
421
+ )
422
+
423
+ def num_tmp_files(self) -> int:
424
+ return len(glob.glob(f'{self._tmp_dir}/*'))
425
+
426
+ def create_tmp_path(self, extension: str = '') -> Path:
427
+ return self._tmp_dir / f'{uuid.uuid4()}{extension}'
428
+
429
+ @property
430
+ def home(self) -> Path:
431
+ assert self._home is not None
432
+ return self._home
433
+
434
+ @property
435
+ def media_dir(self) -> Path:
436
+ assert self._media_dir is not None
437
+ return self._media_dir
438
+
439
+ @property
440
+ def file_cache_dir(self) -> Path:
441
+ assert self._file_cache_dir is not None
442
+ return self._file_cache_dir
443
+
444
+ @property
445
+ def dataset_cache_dir(self) -> Path:
446
+ assert self._dataset_cache_dir is not None
447
+ return self._dataset_cache_dir
448
+
449
+ @property
450
+ def tmp_dir(self) -> Path:
451
+ assert self._tmp_dir is not None
452
+ return self._tmp_dir
453
+
454
+ @property
455
+ def engine(self) -> sql.engine.base.Engine:
456
+ assert self._sa_engine is not None
457
+ return self._sa_engine
458
+
459
+ @property
460
+ def nos_client(self) -> Any:
461
+ return self._nos_client
462
+
463
+ @property
464
+ def spacy_nlp(self) -> Any:
465
+ assert self._spacy_nlp is not None
466
+ return self._spacy_nlp
@@ -0,0 +1,17 @@
1
+ from typing import List, Any
2
+ from types import TracebackType
3
+ from dataclasses import dataclass
4
+
5
+
6
+ class Error(Exception):
7
+ pass
8
+
9
+
10
+ @dataclass
11
+ class ExprEvalError(Exception):
12
+ expr: Any # exprs.Expr, but we're not importing pixeltable.exprs to avoid circular imports
13
+ expr_msg: str
14
+ exc: Exception
15
+ exc_tb: TracebackType
16
+ input_vals: List[Any]
17
+ row_num: int
@@ -0,0 +1,10 @@
1
+ from .aggregation_node import AggregationNode
2
+ from .cache_prefetch_node import CachePrefetchNode
3
+ from .component_iteration_node import ComponentIterationNode
4
+ from .exec_context import ExecContext
5
+ from .exec_node import ExecNode
6
+ from .expr_eval_node import ExprEvalNode
7
+ from .in_memory_data_node import InMemoryDataNode
8
+ from .sql_scan_node import SqlScanNode
9
+ from .media_validation_node import MediaValidationNode
10
+ from .data_row_batch import DataRowBatch
@@ -0,0 +1,78 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import sys
5
+ from typing import List, Optional, Any
6
+
7
+ import pixeltable.catalog as catalog
8
+ import pixeltable.exceptions as excs
9
+ import pixeltable.exprs as exprs
10
+ from .data_row_batch import DataRowBatch
11
+ from .exec_node import ExecNode
12
+
13
+ _logger = logging.getLogger('pixeltable')
14
+
15
+ class AggregationNode(ExecNode):
16
+ def __init__(
17
+ self, tbl: catalog.TableVersion, row_builder: exprs.RowBuilder, group_by: List[exprs.Expr],
18
+ agg_fn_calls: List[exprs.FunctionCall], input_exprs: List[exprs.Expr], input: ExecNode
19
+ ):
20
+ super().__init__(row_builder, group_by + agg_fn_calls, input_exprs, input)
21
+ self.input = input
22
+ self.group_by = group_by
23
+ self.input_exprs = input_exprs
24
+ self.agg_fn_calls = agg_fn_calls
25
+ self.agg_fn_eval_ctx = row_builder.create_eval_ctx(agg_fn_calls, exclude=input_exprs)
26
+ self.output_batch = DataRowBatch(tbl, row_builder, 0)
27
+
28
+ def _reset_agg_state(self, row_num: int) -> None:
29
+ for fn_call in self.agg_fn_calls:
30
+ try:
31
+ fn_call.reset_agg()
32
+ except Exception as e:
33
+ _, _, exc_tb = sys.exc_info()
34
+ expr_msg = f'init() function of the aggregate {fn_call}'
35
+ raise excs.ExprEvalError(fn_call, expr_msg, e, exc_tb, [], row_num)
36
+
37
+ def _update_agg_state(self, row: exprs.DataRow, row_num: int) -> None:
38
+ for fn_call in self.agg_fn_calls:
39
+ try:
40
+ fn_call.update(row)
41
+ except Exception as e:
42
+ _, _, exc_tb = sys.exc_info()
43
+ expr_msg = f'update() function of the aggregate {fn_call}'
44
+ input_vals = [row[d.slot_idx] for d in fn_call.dependencies()]
45
+ raise excs.ExprEvalError(fn_call, expr_msg, e, exc_tb, input_vals, row_num)
46
+
47
+ def __next__(self) -> DataRowBatch:
48
+ if self.output_batch is None:
49
+ raise StopIteration
50
+
51
+ prev_row: Optional[exprs.DataRow] = None
52
+ current_group: Optional[List[Any]] = None # the values of the group-by exprs
53
+ num_input_rows = 0
54
+ for row_batch in self.input:
55
+ num_input_rows += len(row_batch)
56
+ for row in row_batch:
57
+ group = [row[e.slot_idx] for e in self.group_by]
58
+ if current_group is None:
59
+ current_group = group
60
+ self._reset_agg_state(0)
61
+ if group != current_group:
62
+ # we're entering a new group, emit a row for the previous one
63
+ self.row_builder.eval(prev_row, self.agg_fn_eval_ctx, profile=self.ctx.profile)
64
+ self.output_batch.add_row(prev_row)
65
+ current_group = group
66
+ self._reset_agg_state(0)
67
+ self._update_agg_state(row, 0)
68
+ prev_row = row
69
+ # emit the last group
70
+ self.row_builder.eval(prev_row, self.agg_fn_eval_ctx, profile=self.ctx.profile)
71
+ self.output_batch.add_row(prev_row)
72
+
73
+ result = self.output_batch
74
+ result.flush_imgs(None, self.stored_img_cols, self.flushed_img_slots)
75
+ self.output_batch = None
76
+ _logger.debug(f'AggregateNode: consumed {num_input_rows} rows, returning {len(result.rows)} rows')
77
+ return result
78
+
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ import concurrent.futures
4
+ import logging
5
+ import threading
6
+ import urllib.parse
7
+ import urllib.request
8
+ from collections import defaultdict
9
+ from pathlib import Path
10
+ from typing import List, Optional, Any, Tuple, Dict
11
+ from uuid import UUID
12
+
13
+ import pixeltable.env as env
14
+ import pixeltable.exceptions as excs
15
+ import pixeltable.exprs as exprs
16
+ from pixeltable.utils.filecache import FileCache
17
+ from .data_row_batch import DataRowBatch
18
+ from .exec_node import ExecNode
19
+
20
+ _logger = logging.getLogger('pixeltable')
21
+
22
+ class CachePrefetchNode(ExecNode):
23
+ """Brings files with external URLs into the cache
24
+
25
+ TODO:
26
+ - maintain a queue of row batches, in order to overlap download and evaluation
27
+ - adapting the number of download threads at runtime to maximize throughput
28
+ """
29
+ def __init__(self, tbl_id: UUID, file_col_info: List[exprs.ColumnSlotIdx], input: ExecNode):
30
+ # []: we don't have anything to evaluate
31
+ super().__init__(input.row_builder, [], [], input)
32
+ self.tbl_id = tbl_id
33
+ self.file_col_info = file_col_info
34
+
35
+ # clients for specific services are constructed as needed, because it's time-consuming
36
+ self.boto_client: Optional[Any] = None
37
+ self.boto_client_lock = threading.Lock()
38
+
39
+ def __next__(self) -> DataRowBatch:
40
+ input_batch = next(self.input)
41
+
42
+ # collect external URLs that aren't already cached, and set DataRow.file_paths for those that are
43
+ file_cache = FileCache.get()
44
+ cache_misses: List[Tuple[exprs.DataRow, exprs.ColumnSlotIdx]] = []
45
+ missing_url_rows: Dict[str, List[exprs.DataRow]] = defaultdict(list) # URL -> rows in which it's missing
46
+ for row in input_batch:
47
+ for info in self.file_col_info:
48
+ url = row.file_urls[info.slot_idx]
49
+ if url is None or row.file_paths[info.slot_idx] is not None:
50
+ # nothing to do
51
+ continue
52
+ if url in missing_url_rows:
53
+ missing_url_rows[url].append(row)
54
+ continue
55
+ local_path = file_cache.lookup(url)
56
+ if local_path is None:
57
+ cache_misses.append((row, info))
58
+ missing_url_rows[url].append(row)
59
+ else:
60
+ row.set_file_path(info.slot_idx, str(local_path))
61
+
62
+ # download the cache misses in parallel
63
+ # TODO: set max_workers to maximize throughput
64
+ futures: Dict[concurrent.futures.Future, Tuple[exprs.DataRow, exprs.ColumnSlotIdx]] = {}
65
+ with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
66
+ for row, info in cache_misses:
67
+ futures[executor.submit(self._fetch_url, row, info.slot_idx)] = (row, info)
68
+ for future in concurrent.futures.as_completed(futures):
69
+ # TODO: does this need to deal with recoverable errors (such as retry after throttling)?
70
+ tmp_path = future.result()
71
+ if tmp_path is None:
72
+ continue
73
+ row, info = futures[future]
74
+ url = row.file_urls[info.slot_idx]
75
+ local_path = file_cache.add(self.tbl_id, info.col.id, url, tmp_path)
76
+ _logger.debug(f'PrefetchNode: cached {url} as {local_path}')
77
+ for row in missing_url_rows[url]:
78
+ row.set_file_path(info.slot_idx, str(local_path))
79
+
80
+ return input_batch
81
+
82
+ def _fetch_url(self, row: exprs.DataRow, slot_idx: int) -> Optional[str]:
83
+ """Fetches a remote URL into Env.tmp_dir and returns its path"""
84
+ url = row.file_urls[slot_idx]
85
+ parsed = urllib.parse.urlparse(url)
86
+ # Use len(parsed.scheme) > 1 here to ensure we're not being passed
87
+ # a Windows filename
88
+ assert len(parsed.scheme) > 1 and parsed.scheme != 'file'
89
+ # preserve the file extension, if there is one
90
+ extension = ''
91
+ if parsed.path != '':
92
+ p = Path(urllib.parse.unquote(urllib.request.url2pathname(parsed.path)))
93
+ extension = p.suffix
94
+ tmp_path = env.Env.get().create_tmp_path(extension=extension)
95
+ try:
96
+ if parsed.scheme == 's3':
97
+ from pixeltable.utils.s3 import get_client
98
+ with self.boto_client_lock:
99
+ if self.boto_client is None:
100
+ self.boto_client = get_client()
101
+ self.boto_client.download_file(parsed.netloc, parsed.path.lstrip('/'), str(tmp_path))
102
+ elif parsed.scheme == 'http' or parsed.scheme == 'https':
103
+ with urllib.request.urlopen(url) as resp, open(tmp_path, 'wb') as f:
104
+ data = resp.read()
105
+ f.write(data)
106
+ else:
107
+ assert False, f'Unsupported URL scheme: {parsed.scheme}'
108
+ return tmp_path
109
+ except Exception as e:
110
+ # we want to add the file url to the exception message
111
+ exc = excs.Error(f'Failed to download {url}: {e}')
112
+ self.row_builder.set_exc(row, slot_idx, exc)
113
+ if not self.ctx.ignore_errors:
114
+ raise exc from None # suppress original exception
115
+ return None
116
+