pixeltable 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (40) hide show
  1. pixeltable/__init__.py +2 -2
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/column.py +8 -22
  4. pixeltable/catalog/insertable_table.py +26 -8
  5. pixeltable/catalog/table.py +179 -83
  6. pixeltable/catalog/table_version.py +13 -39
  7. pixeltable/catalog/table_version_path.py +2 -2
  8. pixeltable/catalog/view.py +2 -2
  9. pixeltable/dataframe.py +20 -28
  10. pixeltable/env.py +2 -0
  11. pixeltable/exec/cache_prefetch_node.py +189 -43
  12. pixeltable/exec/data_row_batch.py +3 -3
  13. pixeltable/exec/exec_context.py +2 -2
  14. pixeltable/exec/exec_node.py +2 -2
  15. pixeltable/exec/expr_eval_node.py +8 -8
  16. pixeltable/exprs/arithmetic_expr.py +9 -4
  17. pixeltable/exprs/column_ref.py +4 -0
  18. pixeltable/exprs/comparison.py +5 -0
  19. pixeltable/exprs/json_path.py +1 -1
  20. pixeltable/func/aggregate_function.py +8 -8
  21. pixeltable/func/expr_template_function.py +6 -5
  22. pixeltable/func/udf.py +6 -11
  23. pixeltable/functions/huggingface.py +145 -25
  24. pixeltable/functions/llama_cpp.py +3 -2
  25. pixeltable/functions/mistralai.py +1 -1
  26. pixeltable/functions/openai.py +1 -1
  27. pixeltable/functions/together.py +1 -1
  28. pixeltable/functions/util.py +5 -2
  29. pixeltable/globals.py +55 -6
  30. pixeltable/plan.py +1 -1
  31. pixeltable/tool/create_test_db_dump.py +1 -1
  32. pixeltable/type_system.py +83 -35
  33. pixeltable/utils/coco.py +5 -5
  34. pixeltable/utils/formatter.py +3 -3
  35. pixeltable/utils/s3.py +6 -3
  36. {pixeltable-0.2.22.dist-info → pixeltable-0.2.24.dist-info}/METADATA +119 -46
  37. {pixeltable-0.2.22.dist-info → pixeltable-0.2.24.dist-info}/RECORD +40 -40
  38. {pixeltable-0.2.22.dist-info → pixeltable-0.2.24.dist-info}/LICENSE +0 -0
  39. {pixeltable-0.2.22.dist-info → pixeltable-0.2.24.dist-info}/WHEEL +0 -0
  40. {pixeltable-0.2.22.dist-info → pixeltable-0.2.24.dist-info}/entry_points.txt +0 -0
@@ -193,8 +193,6 @@ class TableVersion:
193
193
  col.id = pos
194
194
  col.schema_version_add = 0
195
195
  cols_by_name[col.name] = col
196
- if col.value_expr is None and col.compute_func is not None:
197
- cls._create_value_expr(col, base_path)
198
196
  if col.is_computed:
199
197
  col.check_value_expr()
200
198
 
@@ -494,37 +492,35 @@ class TableVersion:
494
492
  self._update_md(time.time(), conn, preceding_schema_version=preceding_schema_version)
495
493
  _logger.info(f'Dropped index {idx_md.name} on table {self.name}')
496
494
 
497
- def add_column(self, col: Column, print_stats: bool, on_error: Literal['abort', 'ignore']) -> UpdateStatus:
495
+ def add_columns(self, cols: Iterable[Column], print_stats: bool, on_error: Literal['abort', 'ignore']) -> UpdateStatus:
498
496
  """Adds a column to the table.
499
497
  """
500
498
  assert not self.is_snapshot
501
- assert is_valid_identifier(col.name)
502
- assert col.stored is not None
503
- assert col.name not in self.cols_by_name
504
- col.tbl = self
505
- col.id = self.next_col_id
506
- self.next_col_id += 1
507
-
508
- if col.compute_func is not None:
509
- # create value_expr from compute_func
510
- self._create_value_expr(col, self.path)
499
+ assert all(is_valid_identifier(col.name) for col in cols)
500
+ assert all(col.stored is not None for col in cols)
501
+ assert all(col.name not in self.cols_by_name for col in cols)
502
+ for col in cols:
503
+ col.tbl = self
504
+ col.id = self.next_col_id
505
+ self.next_col_id += 1
511
506
 
512
507
  # we're creating a new schema version
513
508
  self.version += 1
514
509
  preceding_schema_version = self.schema_version
515
510
  self.schema_version = self.version
516
511
  with Env.get().engine.begin() as conn:
517
- status = self._add_columns([col], conn, print_stats=print_stats, on_error=on_error)
518
- _ = self._add_default_index(col, conn)
512
+ status = self._add_columns(cols, conn, print_stats=print_stats, on_error=on_error)
513
+ for col in cols:
514
+ _ = self._add_default_index(col, conn)
519
515
  self._update_md(time.time(), conn, preceding_schema_version=preceding_schema_version)
520
- _logger.info(f'Added column {col.name} to table {self.name}, new version: {self.version}')
516
+ _logger.info(f'Added columns {[col.name for col in cols]} to table {self.name}, new version: {self.version}')
521
517
 
522
518
  msg = (
523
519
  f'Added {status.num_rows} column value{"" if status.num_rows == 1 else "s"} '
524
520
  f'with {status.num_excs} error{"" if status.num_excs == 1 else "s"}.'
525
521
  )
526
522
  print(msg)
527
- _logger.info(f'Column {col.name}: {msg}')
523
+ _logger.info(f'Columns {[col.name for col in cols]}: {msg}')
528
524
  return status
529
525
 
530
526
  def _add_columns(
@@ -1140,28 +1136,6 @@ class TableVersion:
1140
1136
  names = [c.name for c in self.cols_by_name.values() if c.is_computed]
1141
1137
  return names
1142
1138
 
1143
- @classmethod
1144
- def _create_value_expr(cls, col: Column, path: pxt.catalog.TableVersionPath) -> None:
1145
- """
1146
- Create col.value_expr, given col.compute_func.
1147
- Interprets compute_func's parameters to be references to columns and construct ColumnRefs as args.
1148
- Does not update Column.dependent_cols.
1149
- """
1150
- assert col.value_expr is None
1151
- assert col.compute_func is not None
1152
- from pixeltable import exprs
1153
- params = inspect.signature(col.compute_func).parameters
1154
- args: list[exprs.ColumnRef] = []
1155
- for param_name in params:
1156
- param = path.get_column(param_name)
1157
- if param is None:
1158
- raise excs.Error(
1159
- f'Column {col.name}: Callable parameter refers to an unknown column: {param_name}')
1160
- args.append(exprs.ColumnRef(param))
1161
- fn = func.make_function(
1162
- col.compute_func, return_type=col.col_type, param_types=[arg.col_type for arg in args])
1163
- col.set_value_expr(fn(*args))
1164
-
1165
1139
  def _record_refd_columns(self, col: Column) -> None:
1166
1140
  """Update Column.dependent_cols for all cols referenced in col.value_expr.
1167
1141
  """
@@ -81,13 +81,13 @@ class TableVersionPath:
81
81
  return None
82
82
  return self.base.find_tbl_version(id)
83
83
 
84
- def __getattr__(self, col_name: str) -> exprs.ColumnRef:
84
+ def get_column_ref(self, col_name: str) -> exprs.ColumnRef:
85
85
  """Return a ColumnRef for the given column name."""
86
86
  from pixeltable.exprs import ColumnRef
87
87
  if col_name not in self.tbl_version.cols_by_name:
88
88
  if self.base is None:
89
89
  raise AttributeError(f'Column {col_name} unknown')
90
- return getattr(self.base, col_name)
90
+ return self.base.get_column_ref(col_name)
91
91
  col = self.tbl_version.cols_by_name[col_name]
92
92
  return ColumnRef(col)
93
93
 
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import inspect
4
4
  import logging
5
- from typing import TYPE_CHECKING, Any, Iterable, Optional
5
+ from typing import TYPE_CHECKING, Any, Iterable, Literal, Optional
6
6
  from uuid import UUID
7
7
 
8
8
  import sqlalchemy.orm as orm
@@ -216,7 +216,7 @@ class View(Table):
216
216
 
217
217
  def insert(
218
218
  self, rows: Optional[Iterable[dict[str, Any]]] = None, /, *, print_stats: bool = False,
219
- fail_on_exception: bool = True, **kwargs: Any
219
+ on_error: Literal['abort', 'ignore'] = 'abort', **kwargs: Any
220
220
  ) -> UpdateStatus:
221
221
  raise excs.Error(f'{self._display_name()} {self._name!r}: cannot insert into view')
222
222
 
pixeltable/dataframe.py CHANGED
@@ -8,7 +8,7 @@ import logging
8
8
  import mimetypes
9
9
  import traceback
10
10
  from pathlib import Path
11
- from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, Iterator, List, Optional, Sequence, Set, Tuple, Union
11
+ from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, Optional, Sequence, Union
12
12
 
13
13
  import pandas as pd
14
14
  import pandas.io.formats.style
@@ -34,14 +34,6 @@ __all__ = ['DataFrame']
34
34
  _logger = logging.getLogger('pixeltable')
35
35
 
36
36
 
37
- def _create_source_tag(file_path: str) -> str:
38
- src_url = get_file_uri(Env.get().http_address, file_path)
39
- mime = mimetypes.guess_type(src_url)[0]
40
- # if mime is None, the attribute string would not be valid html.
41
- mime_attr = f'type="{mime}"' if mime is not None else ''
42
- return f'<source src="{src_url}" {mime_attr} />'
43
-
44
-
45
37
  class DataFrameResultSet:
46
38
  def __init__(self, rows: list[list[Any]], schema: dict[str, ColumnType]):
47
39
  self._rows = rows
@@ -77,7 +69,7 @@ class DataFrameResultSet:
77
69
  def to_pandas(self) -> pd.DataFrame:
78
70
  return pd.DataFrame.from_records(self._rows, columns=self._col_names)
79
71
 
80
- def _row_to_dict(self, row_idx: int) -> Dict[str, Any]:
72
+ def _row_to_dict(self, row_idx: int) -> dict[str, Any]:
81
73
  return {self._col_names[i]: self._rows[row_idx][i] for i in range(len(self._col_names))}
82
74
 
83
75
  def __getitem__(self, index: Any) -> Any:
@@ -111,22 +103,22 @@ class DataFrameResultSet:
111
103
  # def __init__(self, tbl: catalog.TableVersion):
112
104
  # self.tbl = tbl
113
105
  # # output of the SQL scan stage
114
- # self.sql_scan_output_exprs: List[exprs.Expr] = []
106
+ # self.sql_scan_output_exprs: list[exprs.Expr] = []
115
107
  # # output of the agg stage
116
- # self.agg_output_exprs: List[exprs.Expr] = []
108
+ # self.agg_output_exprs: list[exprs.Expr] = []
117
109
  # # Where clause of the Select stmt of the SQL scan stage
118
110
  # self.sql_where_clause: Optional[sql.ClauseElement] = None
119
111
  # # filter predicate applied to input rows of the SQL scan stage
120
112
  # self.filter: Optional[exprs.Predicate] = None
121
113
  # self.similarity_clause: Optional[exprs.ImageSimilarityPredicate] = None
122
- # self.agg_fn_calls: List[exprs.FunctionCall] = [] # derived from unique_exprs
114
+ # self.agg_fn_calls: list[exprs.FunctionCall] = [] # derived from unique_exprs
123
115
  # self.has_frame_col: bool = False # True if we're referencing the frame col
124
116
  #
125
117
  # self.evaluator: Optional[exprs.Evaluator] = None
126
- # self.sql_scan_eval_ctx: List[exprs.Expr] = [] # needed to materialize output of SQL scan stage
127
- # self.agg_eval_ctx: List[exprs.Expr] = [] # needed to materialize output of agg stage
128
- # self.filter_eval_ctx: List[exprs.Expr] = []
129
- # self.group_by_eval_ctx: List[exprs.Expr] = []
118
+ # self.sql_scan_eval_ctx: list[exprs.Expr] = [] # needed to materialize output of SQL scan stage
119
+ # self.agg_eval_ctx: list[exprs.Expr] = [] # needed to materialize output of agg stage
120
+ # self.filter_eval_ctx: list[exprs.Expr] = []
121
+ # self.group_by_eval_ctx: list[exprs.Expr] = []
130
122
  #
131
123
  # def finalize_exec(self) -> None:
132
124
  # """
@@ -142,11 +134,11 @@ class DataFrame:
142
134
  def __init__(
143
135
  self,
144
136
  tbl: catalog.TableVersionPath,
145
- select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]] = None,
137
+ select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]] = None,
146
138
  where_clause: Optional[exprs.Expr] = None,
147
- group_by_clause: Optional[List[exprs.Expr]] = None,
139
+ group_by_clause: Optional[list[exprs.Expr]] = None,
148
140
  grouping_tbl: Optional[catalog.TableVersion] = None,
149
- order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None, # List[(expr, asc)]
141
+ order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None, # list[(expr, asc)]
150
142
  limit: Optional[int] = None,
151
143
  ):
152
144
  self.tbl = tbl
@@ -174,7 +166,7 @@ class DataFrame:
174
166
  @classmethod
175
167
  def _select_list_check_rep(
176
168
  cls,
177
- select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]],
169
+ select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]],
178
170
  ) -> None:
179
171
  """Validate basic select list types."""
180
172
  if select_list is None: # basic check for valid select list
@@ -411,8 +403,8 @@ class DataFrame:
411
403
 
412
404
  def _description(self) -> pd.DataFrame:
413
405
  """see DataFrame.describe()"""
414
- heading_vals: List[str] = []
415
- info_vals: List[str] = []
406
+ heading_vals: list[str] = []
407
+ info_vals: list[str] = []
416
408
  if self.select_list is not None:
417
409
  assert len(self.select_list) > 0
418
410
  heading_vals.append('Select')
@@ -497,7 +489,7 @@ class DataFrame:
497
489
 
498
490
  # check user provided names do not conflict among themselves
499
491
  # or with auto-generated ones
500
- seen: Set[str] = set()
492
+ seen: set[str] = set()
501
493
  _, names = DataFrame._normalize_select_list(self.tbl, select_list)
502
494
  for name in names:
503
495
  if name in seen:
@@ -540,7 +532,7 @@ class DataFrame:
540
532
  if self.group_by_clause is not None:
541
533
  raise excs.Error(f'Group-by already specified')
542
534
  grouping_tbl: Optional[catalog.TableVersion] = None
543
- group_by_clause: Optional[List[exprs.Expr]] = None
535
+ group_by_clause: Optional[list[exprs.Expr]] = None
544
536
  for item in grouping_items:
545
537
  if isinstance(item, catalog.Table):
546
538
  if len(grouping_items) > 1:
@@ -618,7 +610,7 @@ class DataFrame:
618
610
  def __getitem__(self, index: Union[exprs.Expr, Sequence[exprs.Expr]]) -> DataFrame:
619
611
  """
620
612
  Allowed:
621
- - [List[Expr]]/[Tuple[Expr]]: setting the select list
613
+ - [list[Expr]]/[tuple[Expr]]: setting the select list
622
614
  - [Expr]: setting a single-col select list
623
615
  """
624
616
  if isinstance(index, exprs.Expr):
@@ -627,7 +619,7 @@ class DataFrame:
627
619
  return self.select(*index)
628
620
  raise TypeError(f'Invalid index type: {type(index)}')
629
621
 
630
- def as_dict(self) -> Dict[str, Any]:
622
+ def as_dict(self) -> dict[str, Any]:
631
623
  """
632
624
  Returns:
633
625
  Dictionary representing this dataframe.
@@ -649,7 +641,7 @@ class DataFrame:
649
641
  return d
650
642
 
651
643
  @classmethod
652
- def from_dict(cls, d: Dict[str, Any]) -> 'DataFrame':
644
+ def from_dict(cls, d: dict[str, Any]) -> 'DataFrame':
653
645
  tbl = catalog.TableVersionPath.from_dict(d['tbl'])
654
646
  select_list = [(exprs.Expr.from_dict(e), name) for e, name in d['select_list']] \
655
647
  if d['select_list'] is not None else None
pixeltable/env.py CHANGED
@@ -506,11 +506,13 @@ class Env:
506
506
  self.__register_package('openpyxl')
507
507
  self.__register_package('pyarrow')
508
508
  self.__register_package('replicate')
509
+ self.__register_package('sentencepiece')
509
510
  self.__register_package('sentence_transformers', library_name='sentence-transformers')
510
511
  self.__register_package('spacy')
511
512
  self.__register_package('tiktoken')
512
513
  self.__register_package('together')
513
514
  self.__register_package('torch')
515
+ self.__register_package('torchaudio')
514
516
  self.__register_package('torchvision')
515
517
  self.__register_package('transformers')
516
518
  self.__register_package('whisper', library_name='openai-whisper')
@@ -1,87 +1,226 @@
1
1
  from __future__ import annotations
2
2
 
3
- import concurrent.futures
3
+ import dataclasses
4
+ import itertools
4
5
  import logging
5
6
  import threading
6
7
  import urllib.parse
7
8
  import urllib.request
8
- from collections import defaultdict
9
+ from collections import deque
10
+ from concurrent import futures
9
11
  from pathlib import Path
10
- from typing import List, Optional, Any, Tuple, Dict
12
+ from typing import Optional, Any, Iterator
11
13
  from uuid import UUID
12
14
 
13
15
  import pixeltable.env as env
14
16
  import pixeltable.exceptions as excs
15
17
  import pixeltable.exprs as exprs
18
+ from pixeltable import catalog
16
19
  from pixeltable.utils.filecache import FileCache
20
+
17
21
  from .data_row_batch import DataRowBatch
18
22
  from .exec_node import ExecNode
19
23
 
20
24
  _logger = logging.getLogger('pixeltable')
21
25
 
26
+
22
27
  class CachePrefetchNode(ExecNode):
23
28
  """Brings files with external URLs into the cache
24
29
 
25
30
  TODO:
26
- - maintain a queue of row batches, in order to overlap download and evaluation
27
31
  - adapting the number of download threads at runtime to maximize throughput
28
32
  """
29
- def __init__(self, tbl_id: UUID, file_col_info: List[exprs.ColumnSlotIdx], input: ExecNode):
30
- # []: we don't have anything to evaluate
33
+ BATCH_SIZE = 16
34
+ NUM_EXECUTOR_THREADS = 16
35
+
36
+ retain_input_order: bool # if True, return rows in the exact order they were received
37
+ file_col_info: list[exprs.ColumnSlotIdx]
38
+ boto_client: Optional[Any]
39
+ boto_client_lock: threading.Lock
40
+
41
+ # execution state
42
+ batch_tbl_version: Optional[catalog.TableVersion] # needed to construct output batches
43
+ num_returned_rows: int
44
+
45
+ # ready_rows: rows that are ready to be returned, ordered by row idx;
46
+ # the implied row idx of ready_rows[0] is num_returned_rows
47
+ ready_rows: deque[Optional[exprs.DataRow]]
48
+
49
+ in_flight_rows: dict[int, CachePrefetchNode.RowState] # rows with in-flight urls; id(row) -> RowState
50
+ in_flight_requests: dict[futures.Future, str] # in-flight requests for urls; future -> URL
51
+ in_flight_urls: dict[str, list[tuple[exprs.DataRow, exprs.ColumnSlotIdx]]] # URL -> [(row, info)]
52
+ input_finished: bool
53
+ row_idx: Iterator[Optional[int]]
54
+
55
+ @dataclasses.dataclass
56
+ class RowState:
57
+ row: exprs.DataRow
58
+ idx: Optional[int] # position in input stream; None if we don't retain input order
59
+ num_missing: int # number of missing URLs in this row
60
+
61
+ def __init__(
62
+ self, tbl_id: UUID, file_col_info: list[exprs.ColumnSlotIdx], input: ExecNode,
63
+ retain_input_order: bool = True):
64
+ # input_/output_exprs=[]: we don't have anything to evaluate
31
65
  super().__init__(input.row_builder, [], [], input)
32
- self.tbl_id = tbl_id
66
+ self.retain_input_order = retain_input_order
33
67
  self.file_col_info = file_col_info
34
68
 
35
69
  # clients for specific services are constructed as needed, because it's time-consuming
36
- self.boto_client: Optional[Any] = None
70
+ self.boto_client = None
37
71
  self.boto_client_lock = threading.Lock()
38
72
 
39
- def __next__(self) -> DataRowBatch:
40
- input_batch = next(self.input)
73
+ self.batch_tbl_version = None
74
+ self.num_returned_rows = 0
75
+ self.ready_rows = deque()
76
+ self.in_flight_rows = {}
77
+ self.in_flight_requests = {}
78
+ self.in_flight_urls = {}
79
+ self.input_finished = False
80
+ self.row_idx = itertools.count() if retain_input_order else itertools.repeat(None)
81
+
82
+ def __iter__(self) -> Iterator[DataRowBatch]:
83
+ input_iter = iter(self.input)
84
+ with futures.ThreadPoolExecutor(max_workers=self.NUM_EXECUTOR_THREADS) as executor:
85
+ # we create enough in-flight requests to fill the first batch
86
+ while not self.input_finished and self.__num_pending_rows() < self.BATCH_SIZE:
87
+ self.__submit_input_batch(input_iter, executor)
88
+
89
+ while True:
90
+ # try to assemble a full batch of output rows
91
+ if not self.__has_ready_batch() and len(self.in_flight_requests) > 0:
92
+ self.__wait_for_requests()
93
+
94
+ # try to create enough in-flight requests to fill the next batch
95
+ while not self.input_finished and self.__num_pending_rows() < self.BATCH_SIZE:
96
+ self.__submit_input_batch(input_iter, executor)
97
+
98
+ if len(self.ready_rows) > 0:
99
+ # create DataRowBatch from the first BATCH_SIZE ready rows
100
+ batch = DataRowBatch(self.batch_tbl_version, self.row_builder)
101
+ rows = [self.ready_rows.popleft() for _ in range(min(self.BATCH_SIZE, len(self.ready_rows)))]
102
+ for row in rows:
103
+ assert row is not None
104
+ batch.add_row(row)
105
+ self.num_returned_rows += len(rows)
106
+ _logger.debug(f'returning {len(rows)} rows')
107
+ yield batch
108
+
109
+ if self.input_finished and self.__num_pending_rows() == 0:
110
+ return
111
+
112
+ def __num_pending_rows(self) -> int:
113
+ return len(self.in_flight_rows) + len(self.ready_rows)
114
+
115
+ def __has_ready_batch(self) -> bool:
116
+ """True if there are >= BATCH_SIZES entries in ready_rows and the first BATCH_SIZE ones are all non-None"""
117
+ return (
118
+ sum(int(row is not None) for row in itertools.islice(self.ready_rows, self.BATCH_SIZE)) == self.BATCH_SIZE
119
+ )
120
+
121
+ def __ready_prefix_len(self) -> int:
122
+ """Length of the non-None prefix of ready_rows (= what we can return right now)"""
123
+ return sum(1 for _ in itertools.takewhile(lambda x: x is not None, self.ready_rows))
124
+
125
+ def __add_ready_row(self, row: exprs.DataRow, row_idx: Optional[int]) -> None:
126
+ if row_idx is None:
127
+ self.ready_rows.append(row)
128
+ else:
129
+ # extend ready_rows to accommodate row_idx
130
+ idx = row_idx - self.num_returned_rows
131
+ if idx >= len(self.ready_rows):
132
+ self.ready_rows.extend([None] * (idx - len(self.ready_rows) + 1))
133
+ self.ready_rows[idx] = row
134
+
135
+ def __wait_for_requests(self) -> None:
136
+ """Wait for in-flight requests to complete until we have a full batch of rows"""
137
+ file_cache = FileCache.get()
138
+ _logger.debug(f'waiting for requests; ready_batch_size={self.__ready_prefix_len()}')
139
+ while not self.__has_ready_batch() and len(self.in_flight_requests) > 0:
140
+ done, _ = futures.wait(self.in_flight_requests, return_when=futures.FIRST_COMPLETED)
141
+ for f in done:
142
+ url = self.in_flight_requests.pop(f)
143
+ tmp_path, exc = f.result()
144
+ local_path: Optional[Path] = None
145
+ if tmp_path is not None:
146
+ # register the file with the cache for the first column in which it's missing
147
+ assert url in self.in_flight_urls
148
+ _, info = self.in_flight_urls[url][0]
149
+ local_path = file_cache.add(info.col.tbl.id, info.col.id, url, tmp_path)
150
+ _logger.debug(f'cached {url} as {local_path}')
151
+
152
+ # add the local path/exception to the slots that reference the url
153
+ for row, info in self.in_flight_urls.pop(url):
154
+ if exc is not None:
155
+ self.row_builder.set_exc(row, info.slot_idx, exc)
156
+ else:
157
+ assert local_path is not None
158
+ row.set_file_path(info.slot_idx, str(local_path))
159
+ state = self.in_flight_rows[id(row)]
160
+ state.num_missing -= 1
161
+ if state.num_missing == 0:
162
+ del self.in_flight_rows[id(row)]
163
+ self.__add_ready_row(row, state.idx)
164
+ _logger.debug(f'row {state.idx} is ready (ready_batch_size={self.__ready_prefix_len()})')
165
+
166
+ def __submit_input_batch(self, input: Iterator[DataRowBatch], executor: futures.ThreadPoolExecutor) -> None:
167
+ assert not self.input_finished
168
+ input_batch = next(input, None)
169
+ if input_batch is None:
170
+ self.input_finished = True
171
+ return
172
+ if self.batch_tbl_version is None:
173
+ self.batch_tbl_version = input_batch.tbl
41
174
 
42
- # collect external URLs that aren't already cached, and set DataRow.file_paths for those that are
43
175
  file_cache = FileCache.get()
44
- cache_misses: List[Tuple[exprs.DataRow, exprs.ColumnSlotIdx]] = []
45
- missing_url_rows: Dict[str, List[exprs.DataRow]] = defaultdict(list) # URL -> rows in which it's missing
176
+
177
+ # URLs from this input batch that aren't already in the file cache;
178
+ # we use a list to make sure we submit urls in the order in which they appear in the output, which minimizes
179
+ # the time it takes to get the next batch together
180
+ cache_misses: list[str] = []
181
+
182
+ url_pos: dict[str, int] = {} # url -> row_idx; used for logging
46
183
  for row in input_batch:
184
+ # identify missing local files in input batch, or fill in their paths if they're already cached
185
+ num_missing = 0
186
+ row_idx = next(self.row_idx)
187
+
47
188
  for info in self.file_col_info:
48
189
  url = row.file_urls[info.slot_idx]
49
190
  if url is None or row.file_paths[info.slot_idx] is not None:
50
191
  # nothing to do
51
192
  continue
52
- if url in missing_url_rows:
53
- missing_url_rows[url].append(row)
193
+ locations = self.in_flight_urls.get(url)
194
+ if locations is not None:
195
+ # we've already seen this
196
+ locations.append((row, info))
197
+ num_missing += 1
54
198
  continue
199
+
55
200
  local_path = file_cache.lookup(url)
56
201
  if local_path is None:
57
- cache_misses.append((row, info))
58
- missing_url_rows[url].append(row)
202
+ cache_misses.append(url)
203
+ self.in_flight_urls[url] = [(row, info)]
204
+ num_missing += 1
205
+ if url not in url_pos:
206
+ url_pos[url] = row_idx
59
207
  else:
60
208
  row.set_file_path(info.slot_idx, str(local_path))
61
209
 
62
- # download the cache misses in parallel
63
- # TODO: set max_workers to maximize throughput
64
- futures: Dict[concurrent.futures.Future, Tuple[exprs.DataRow, exprs.ColumnSlotIdx]] = {}
65
- with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
66
- for row, info in cache_misses:
67
- futures[executor.submit(self._fetch_url, row, info.slot_idx)] = (row, info)
68
- for future in concurrent.futures.as_completed(futures):
69
- # TODO: does this need to deal with recoverable errors (such as retry after throttling)?
70
- tmp_path = future.result()
71
- if tmp_path is None:
72
- continue
73
- row, info = futures[future]
74
- url = row.file_urls[info.slot_idx]
75
- local_path = file_cache.add(self.tbl_id, info.col.id, url, tmp_path)
76
- _logger.debug(f'PrefetchNode: cached {url} as {local_path}')
77
- for row in missing_url_rows[url]:
78
- row.set_file_path(info.slot_idx, str(local_path))
210
+ if num_missing > 0:
211
+ self.in_flight_rows[id(row)] = self.RowState(row, row_idx, num_missing)
212
+ else:
213
+ self.__add_ready_row(row, row_idx)
79
214
 
80
- return input_batch
215
+ _logger.debug(f'submitting {len(cache_misses)} urls')
216
+ for url in cache_misses:
217
+ f = executor.submit(self.__fetch_url, url)
218
+ _logger.debug(f'submitted {url} for idx {url_pos[url]}')
219
+ self.in_flight_requests[f] = url
81
220
 
82
- def _fetch_url(self, row: exprs.DataRow, slot_idx: int) -> Optional[Path]:
221
+ def __fetch_url(self, url: str) -> tuple[Optional[Path], Optional[Exception]]:
83
222
  """Fetches a remote URL into Env.tmp_dir and returns its path"""
84
- url = row.file_urls[slot_idx]
223
+ _logger.debug(f'fetching url={url} thread_name={threading.current_thread().name}')
85
224
  parsed = urllib.parse.urlparse(url)
86
225
  # Use len(parsed.scheme) > 1 here to ensure we're not being passed
87
226
  # a Windows filename
@@ -93,24 +232,31 @@ class CachePrefetchNode(ExecNode):
93
232
  extension = p.suffix
94
233
  tmp_path = env.Env.get().create_tmp_path(extension=extension)
95
234
  try:
235
+ _logger.debug(f'Downloading {url} to {tmp_path}')
96
236
  if parsed.scheme == 's3':
97
237
  from pixeltable.utils.s3 import get_client
98
238
  with self.boto_client_lock:
99
239
  if self.boto_client is None:
100
- self.boto_client = get_client()
101
- self.boto_client.download_file(parsed.netloc, parsed.path.lstrip('/'), str(tmp_path))
240
+ config = {
241
+ 'max_pool_connections': self.NUM_EXECUTOR_THREADS + 4, # +4: leave some headroom
242
+ 'connect_timeout': 5,
243
+ 'read_timeout': 30,
244
+ 'retries': {'max_attempts': 3, 'mode': 'adaptive'},
245
+ }
246
+ self.boto_client = get_client(**config)
247
+ self.boto_client.download_file(parsed.netloc, parsed.path.lstrip('/'), str(tmp_path))
102
248
  elif parsed.scheme == 'http' or parsed.scheme == 'https':
103
249
  with urllib.request.urlopen(url) as resp, open(tmp_path, 'wb') as f:
104
250
  data = resp.read()
105
251
  f.write(data)
106
252
  else:
107
253
  assert False, f'Unsupported URL scheme: {parsed.scheme}'
108
- return tmp_path
254
+ _logger.debug(f'Downloaded {url} to {tmp_path}')
255
+ return tmp_path, None
109
256
  except Exception as e:
110
257
  # we want to add the file url to the exception message
111
258
  exc = excs.Error(f'Failed to download {url}: {e}')
112
- self.row_builder.set_exc(row, slot_idx, exc)
259
+ _logger.debug(f'Failed to download {url}: {e}', exc_info=e)
113
260
  if not self.ctx.ignore_errors:
114
261
  raise exc from None # suppress original exception
115
- return None
116
-
262
+ return None, exc
@@ -1,5 +1,5 @@
1
1
  from __future__ import annotations
2
- from typing import List, Iterator, Optional
2
+ from typing import Iterator, Optional
3
3
  import logging
4
4
 
5
5
  import pixeltable.exprs as exprs
@@ -53,8 +53,8 @@ class DataRowBatch:
53
53
  return self.rows[index]
54
54
 
55
55
  def flush_imgs(
56
- self, idx_range: Optional[slice] = None, stored_img_info: Optional[List[exprs.ColumnSlotIdx]] = None,
57
- flushed_slot_idxs: Optional[List[int]] = None
56
+ self, idx_range: Optional[slice] = None, stored_img_info: Optional[list[exprs.ColumnSlotIdx]] = None,
57
+ flushed_slot_idxs: Optional[list[int]] = None
58
58
  ) -> None:
59
59
  """Flushes images in the given range of rows."""
60
60
  assert self.tbl is not None
@@ -1,4 +1,4 @@
1
- from typing import Optional, List
1
+ from typing import Optional
2
2
 
3
3
  import sqlalchemy as sql
4
4
 
@@ -8,7 +8,7 @@ class ExecContext:
8
8
  """Class for execution runtime constants"""
9
9
  def __init__(
10
10
  self, row_builder: exprs.RowBuilder, *, show_pbar: bool = False, batch_size: int = 0,
11
- pk_clause: Optional[List[sql.ClauseElement]] = None, num_computed_exprs: int = 0,
11
+ pk_clause: Optional[list[sql.ClauseElement]] = None, num_computed_exprs: int = 0,
12
12
  ignore_errors: bool = False
13
13
  ):
14
14
  self.show_pbar = show_pbar
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import abc
4
- from typing import TYPE_CHECKING, Iterable, Iterator, List, Optional
4
+ from typing import TYPE_CHECKING, Iterable, Iterator, Optional
5
5
 
6
6
  import pixeltable.exprs as exprs
7
7
 
@@ -43,7 +43,7 @@ class ExecNode(abc.ABC):
43
43
  if self.input is not None:
44
44
  self.input.set_ctx(ctx)
45
45
 
46
- def set_stored_img_cols(self, stored_img_cols: List[exprs.ColumnSlotIdx]) -> None:
46
+ def set_stored_img_cols(self, stored_img_cols: list[exprs.ColumnSlotIdx]) -> None:
47
47
  self.stored_img_cols = stored_img_cols
48
48
  # propagate batch size to the source
49
49
  if self.input is not None: