pixeltable 0.4.0rc1__py3-none-any.whl → 0.4.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

pixeltable/plan.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import dataclasses
4
4
  import enum
5
5
  from textwrap import dedent
6
- from typing import Any, Iterable, Literal, Optional, Sequence
6
+ from typing import Any, Iterable, Literal, NamedTuple, Optional, Sequence
7
7
  from uuid import UUID
8
8
 
9
9
  import sqlalchemy as sql
@@ -12,6 +12,7 @@ import pixeltable as pxt
12
12
  from pixeltable import catalog, exceptions as excs, exec, exprs
13
13
  from pixeltable.catalog import Column, TableVersionHandle
14
14
  from pixeltable.exec.sql_node import OrderByClause, OrderByItem, combine_order_by_clauses, print_order_by_clause
15
+ from pixeltable.utils.sample import sample_key
15
16
 
16
17
 
17
18
  def _is_agg_fn_call(e: exprs.Expr) -> bool:
@@ -75,6 +76,98 @@ class FromClause:
75
76
  tbls: list[catalog.TableVersionPath]
76
77
  join_clauses: list[JoinClause] = dataclasses.field(default_factory=list)
77
78
 
79
+ @property
80
+ def _first_tbl(self) -> catalog.TableVersionPath:
81
+ assert len(self.tbls) == 1
82
+ return self.tbls[0]
83
+
84
+
85
+ @dataclasses.dataclass
86
+ class SampleClause:
87
+ """Defines a sampling clause for a table."""
88
+
89
+ version: Optional[int]
90
+ n: Optional[int]
91
+ n_per_stratum: Optional[int]
92
+ fraction: Optional[float]
93
+ seed: Optional[int]
94
+ stratify_exprs: Optional[list[exprs.Expr]]
95
+
96
+ # This seed value is used if one is not supplied
97
+ DEFAULT_SEED = 0
98
+
99
+ # The version of the hashing algorithm used for ordering and fractional sampling.
100
+ CURRENT_VERSION = 1
101
+
102
+ def __post_init__(self) -> None:
103
+ """If no version was provided, provide the default version"""
104
+ if self.version is None:
105
+ self.version = self.CURRENT_VERSION
106
+ if self.seed is None:
107
+ self.seed = self.DEFAULT_SEED
108
+
109
+ @property
110
+ def is_stratified(self) -> bool:
111
+ """Check if the sampling is stratified"""
112
+ return self.stratify_exprs is not None and len(self.stratify_exprs) > 0
113
+
114
+ @property
115
+ def is_repeatable(self) -> bool:
116
+ """Return true if the same rows will continue to be sampled if source rows are added or deleted."""
117
+ return not self.is_stratified and self.fraction is not None
118
+
119
+ def display_str(self, inline: bool = False) -> str:
120
+ return str(self)
121
+
122
+ def as_dict(self) -> dict:
123
+ """Return a dictionary representation of the object"""
124
+ d = dataclasses.asdict(self)
125
+ d['_classname'] = self.__class__.__name__
126
+ if self.is_stratified:
127
+ d['stratify_exprs'] = [e.as_dict() for e in self.stratify_exprs]
128
+ return d
129
+
130
+ @classmethod
131
+ def from_dict(cls, d: dict) -> SampleClause:
132
+ """Create a SampleClause from a dictionary representation"""
133
+ d_cleaned = {key: value for key, value in d.items() if key != '_classname'}
134
+ s = cls(**d_cleaned)
135
+ if s.is_stratified:
136
+ s.stratify_exprs = [exprs.Expr.from_dict(e) for e in d_cleaned.get('stratify_exprs', [])]
137
+ return s
138
+
139
+ def __repr__(self) -> str:
140
+ s = ','.join(e.display_str(inline=True) for e in self.stratify_exprs)
141
+ return (
142
+ f'sample_{self.version}(n={self.n}, n_per_stratum={self.n_per_stratum}, '
143
+ f'fraction={self.fraction}, seed={self.seed}, [{s}])'
144
+ )
145
+
146
+ @classmethod
147
+ def fraction_to_md5_hex(cls, fraction: float) -> str:
148
+ """Return the string representation of an approximation (to ~1e-9) of a fraction of the total space
149
+ of md5 hash values.
150
+ This is used for fractional sampling.
151
+ """
152
+ # Maximum count for the upper 32 bits of MD5: 2^32
153
+ max_md5_value = (2**32) - 1
154
+
155
+ # Calculate the fraction of this value
156
+ threshold_int = max_md5_value * int(1_000_000_000 * fraction) // 1_000_000_000
157
+
158
+ # Convert to hexadecimal string with padding
159
+ return format(threshold_int, '08x') + 'ffffffffffffffffffffffff'
160
+
161
+
162
+ class SamplingClauses(NamedTuple):
163
+ """Clauses provided when rewriting a SampleClause"""
164
+
165
+ where: exprs.Expr
166
+ group_by_clause: Optional[list[exprs.Expr]]
167
+ order_by_clause: Optional[list[tuple[exprs.Expr, bool]]]
168
+ limit: Optional[exprs.Expr]
169
+ sample_clause: Optional[SampleClause]
170
+
78
171
 
79
172
  class Analyzer:
80
173
  """
@@ -260,7 +353,7 @@ class Planner:
260
353
  # TODO: create an exec.CountNode and change this to create_count_plan()
261
354
  @classmethod
262
355
  def create_count_stmt(cls, tbl: catalog.TableVersionPath, where_clause: Optional[exprs.Expr] = None) -> sql.Select:
263
- stmt = sql.select(sql.func.count())
356
+ stmt = sql.select(sql.func.count().label('all_count'))
264
357
  refd_tbl_ids: set[UUID] = set()
265
358
  if where_clause is not None:
266
359
  analyzer = cls.analyze(tbl, where_clause)
@@ -322,6 +415,13 @@ class Planner:
322
415
  )
323
416
  return plan
324
417
 
418
+ @classmethod
419
+ def rowid_columns(cls, target: TableVersionHandle, num_rowid_cols: Optional[int] = None) -> list[exprs.Expr]:
420
+ """Return list of RowidRef for the given number of associated rowids"""
421
+ if num_rowid_cols is None:
422
+ num_rowid_cols = target.get().num_rowid_columns()
423
+ return [exprs.RowidRef(target, i) for i in range(num_rowid_cols)]
424
+
325
425
  @classmethod
326
426
  def create_df_insert_plan(
327
427
  cls, tbl: catalog.TableVersion, df: 'pxt.DataFrame', ignore_errors: bool
@@ -591,7 +691,24 @@ class Planner:
591
691
  # 2. for component views: iterator args
592
692
  iterator_args = [target.iterator_args] if target.iterator_args is not None else []
593
693
 
594
- row_builder = exprs.RowBuilder(iterator_args, stored_cols, [])
694
+ # If this contains a sample specification, modify / create where, group_by, order_by, and limit clauses
695
+ from_clause = FromClause(tbls=[view.base])
696
+ where, group_by_clause, order_by_clause, limit, sample_clause = cls.create_sample_clauses(
697
+ from_clause, target.sample_clause, target.predicate, None, [], None
698
+ )
699
+
700
+ # if we're propagating an insert, we only want to see those base rows that were created for the current version
701
+ base_analyzer = Analyzer(
702
+ from_clause,
703
+ iterator_args,
704
+ where_clause=where,
705
+ group_by_clause=group_by_clause,
706
+ order_by_clause=order_by_clause,
707
+ )
708
+ row_builder = exprs.RowBuilder(base_analyzer.all_exprs, stored_cols, [])
709
+
710
+ if target.sample_clause is not None and base_analyzer.filter is not None:
711
+ raise excs.Error(f'Filter {base_analyzer.filter} not expressible in SQL')
595
712
 
596
713
  # execution plan:
597
714
  # 1. materialize exprs computed from the base that are needed for stored view columns
@@ -603,13 +720,22 @@ class Planner:
603
720
  for e in row_builder.default_eval_ctx.target_exprs
604
721
  if e.is_bound_by([view]) and not e.is_bound_by([view.base])
605
722
  ]
606
- # if we're propagating an insert, we only want to see those base rows that were created for the current version
607
- base_analyzer = Analyzer(FromClause(tbls=[view.base]), base_output_exprs, where_clause=target.predicate)
723
+
724
+ # Create a new analyzer reflecting exactly what is required from the base table
725
+ base_analyzer = Analyzer(
726
+ from_clause,
727
+ base_output_exprs,
728
+ where_clause=where,
729
+ group_by_clause=group_by_clause,
730
+ order_by_clause=order_by_clause,
731
+ )
608
732
  base_eval_ctx = row_builder.create_eval_ctx(base_analyzer.all_exprs)
609
733
  plan = cls._create_query_plan(
610
734
  row_builder=row_builder,
611
735
  analyzer=base_analyzer,
612
736
  eval_ctx=base_eval_ctx,
737
+ limit=limit,
738
+ sample_clause=sample_clause,
613
739
  with_pk=True,
614
740
  exact_version_only=view.get_bases() if propagates_insert else [],
615
741
  )
@@ -692,6 +818,62 @@ class Planner:
692
818
  prefetch_node = exec.CachePrefetchNode(tbl_id, file_col_info, input_node)
693
819
  return prefetch_node
694
820
 
821
+ @classmethod
822
+ def create_sample_clauses(
823
+ cls,
824
+ from_clause: FromClause,
825
+ sample_clause: SampleClause,
826
+ where_clause: Optional[exprs.Expr],
827
+ group_by_clause: Optional[list[exprs.Expr]],
828
+ order_by_clause: Optional[list[tuple[exprs.Expr, bool]]],
829
+ limit: Optional[exprs.Expr],
830
+ ) -> SamplingClauses:
831
+ """tuple[
832
+ exprs.Expr,
833
+ Optional[list[exprs.Expr]],
834
+ Optional[list[tuple[exprs.Expr, bool]]],
835
+ Optional[exprs.Expr],
836
+ Optional[SampleClause],
837
+ ]:"""
838
+ """Construct clauses required for sampling under various conditions.
839
+ If there is no sampling, then return the original clauses.
840
+ If the sample is stratified, then return only the group by clause. The rest of the
841
+ mechanism for stratified sampling is provided by the SampleSqlNode.
842
+ If the sample is non-stratified, then rewrite the query to accommodate the supplied where clause,
843
+ and provide the other clauses required for sampling
844
+ """
845
+
846
+ # If no sample clause, return the original clauses
847
+ if sample_clause is None:
848
+ return SamplingClauses(where_clause, group_by_clause, order_by_clause, limit, None)
849
+
850
+ # If the sample clause is stratified, create a group by clause
851
+ if sample_clause.is_stratified:
852
+ group_by = sample_clause.stratify_exprs
853
+ # Note that limit is not possible here
854
+ return SamplingClauses(where_clause, group_by, order_by_clause, None, sample_clause)
855
+
856
+ else:
857
+ # If non-stratified sampling, construct a where clause, order_by, and limit clauses
858
+ # Construct an expression for sorting rows and limiting row counts
859
+ s_key = sample_key(
860
+ exprs.Literal(sample_clause.seed), *cls.rowid_columns(from_clause._first_tbl.tbl_version)
861
+ )
862
+
863
+ # Construct a suitable where clause
864
+ where = where_clause
865
+ if sample_clause.fraction is not None:
866
+ fraction_md5_hex = exprs.Expr.from_object(
867
+ sample_clause.fraction_to_md5_hex(float(sample_clause.fraction))
868
+ )
869
+ f_where = s_key < fraction_md5_hex
870
+ where = where & f_where if where is not None else f_where
871
+
872
+ order_by: list[tuple[exprs.Expr, bool]] = [(s_key, True)]
873
+ limit = exprs.Literal(sample_clause.n)
874
+ # Note that group_by is not possible here
875
+ return SamplingClauses(where, None, order_by, limit, None)
876
+
695
877
  @classmethod
696
878
  def create_query_plan(
697
879
  cls,
@@ -701,6 +883,7 @@ class Planner:
701
883
  group_by_clause: Optional[list[exprs.Expr]] = None,
702
884
  order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None,
703
885
  limit: Optional[exprs.Expr] = None,
886
+ sample_clause: Optional[SampleClause] = None,
704
887
  ignore_errors: bool = False,
705
888
  exact_version_only: Optional[list[catalog.TableVersionHandle]] = None,
706
889
  ) -> exec.ExecNode:
@@ -714,14 +897,22 @@ class Planner:
714
897
  order_by_clause = []
715
898
  if exact_version_only is None:
716
899
  exact_version_only = []
900
+
901
+ # Modify clauses to include sample clause
902
+ where, group_by_clause, order_by_clause, limit, sample = cls.create_sample_clauses(
903
+ from_clause, sample_clause, where_clause, group_by_clause, order_by_clause, limit
904
+ )
905
+
717
906
  analyzer = Analyzer(
718
907
  from_clause,
719
908
  select_list,
720
- where_clause=where_clause,
909
+ where_clause=where,
721
910
  group_by_clause=group_by_clause,
722
911
  order_by_clause=order_by_clause,
723
912
  )
724
913
  row_builder = exprs.RowBuilder(analyzer.all_exprs, [], [])
914
+ if sample_clause is not None and analyzer.filter is not None:
915
+ raise excs.Error(f'Filter {analyzer.filter} not expressible in SQL')
725
916
 
726
917
  analyzer.finalize(row_builder)
727
918
  # select_list: we need to materialize everything that's been collected
@@ -732,6 +923,7 @@ class Planner:
732
923
  analyzer=analyzer,
733
924
  eval_ctx=eval_ctx,
734
925
  limit=limit,
926
+ sample_clause=sample,
735
927
  with_pk=True,
736
928
  exact_version_only=exact_version_only,
737
929
  )
@@ -747,6 +939,7 @@ class Planner:
747
939
  analyzer: Analyzer,
748
940
  eval_ctx: exprs.RowBuilder.EvalCtx,
749
941
  limit: Optional[exprs.Expr] = None,
942
+ sample_clause: Optional[SampleClause] = None,
750
943
  with_pk: bool = False,
751
944
  exact_version_only: Optional[list[catalog.TableVersionHandle]] = None,
752
945
  ) -> exec.ExecNode:
@@ -857,12 +1050,26 @@ class Planner:
857
1050
  sql_elements.contains_all(analyzer.select_list)
858
1051
  and sql_elements.contains_all(analyzer.grouping_exprs)
859
1052
  and isinstance(plan, exec.SqlNode)
860
- and plan.to_cte() is not None
1053
+ and plan.to_cte(keep_pk=(sample_clause is not None)) is not None
861
1054
  ):
862
- plan = exec.SqlAggregationNode(
863
- row_builder, input=plan, select_list=analyzer.select_list, group_by_items=analyzer.group_by_clause
864
- )
1055
+ if sample_clause is not None:
1056
+ plan = exec.SqlSampleNode(
1057
+ row_builder,
1058
+ input=plan,
1059
+ select_list=analyzer.select_list,
1060
+ stratify_exprs=analyzer.group_by_clause,
1061
+ sample_clause=sample_clause,
1062
+ )
1063
+ else:
1064
+ plan = exec.SqlAggregationNode(
1065
+ row_builder,
1066
+ input=plan,
1067
+ select_list=analyzer.select_list,
1068
+ group_by_items=analyzer.group_by_clause,
1069
+ )
865
1070
  else:
1071
+ if sample_clause is not None:
1072
+ raise excs.Error('Sample clause not supported with Python aggregation')
866
1073
  input_sql_node = plan.get_node(exec.SqlNode)
867
1074
  assert combined_ordering is not None
868
1075
  input_sql_node.set_order_by(combined_ordering)
@@ -1,4 +1,7 @@
1
+ import base64
1
2
  import datetime
3
+ import io
4
+ import itertools
2
5
  import json
3
6
  import logging
4
7
  import tarfile
@@ -10,15 +13,18 @@ from typing import Any, Iterator, Optional
10
13
  from uuid import UUID
11
14
 
12
15
  import more_itertools
16
+ import numpy as np
17
+ import PIL.Image
13
18
  import pyarrow as pa
14
19
  import pyarrow.parquet as pq
15
20
  import sqlalchemy as sql
16
21
 
17
22
  import pixeltable as pxt
18
- from pixeltable import catalog, exceptions as excs, metadata
23
+ from pixeltable import catalog, exceptions as excs, metadata, type_system as ts
19
24
  from pixeltable.env import Env
20
25
  from pixeltable.metadata import schema
21
26
  from pixeltable.utils import sha256sum
27
+ from pixeltable.utils.formatter import Formatter
22
28
  from pixeltable.utils.media_store import MediaStore
23
29
 
24
30
  _logger = logging.getLogger('pixeltable')
@@ -46,6 +52,10 @@ class TablePackager:
46
52
  media_files: dict[Path, str] # Mapping from local media file paths to their tarball names
47
53
  md: dict[str, Any]
48
54
 
55
+ bundle_path: Path
56
+ preview_header: dict[str, str]
57
+ preview: list[list[Any]]
58
+
49
59
  def __init__(self, table: catalog.Table, additional_md: Optional[dict[str, Any]] = None) -> None:
50
60
  self.table = table
51
61
  self.tmp_dir = Path(Env.get().create_tmp_path())
@@ -67,7 +77,8 @@ class TablePackager:
67
77
  Export the table to a tarball containing Parquet tables and media files.
68
78
  """
69
79
  assert not self.tmp_dir.exists() # Packaging can only be done once per TablePackager instance
70
- _logger.info(f"Packaging table '{self.table._path()}' and its ancestors in: {self.tmp_dir}")
80
+
81
+ _logger.info(f'Packaging table {self.table._path()!r} and its ancestors in: {self.tmp_dir}')
71
82
  self.tmp_dir.mkdir()
72
83
  with open(self.tmp_dir / 'metadata.json', 'w', encoding='utf8') as fp:
73
84
  json.dump(self.md, fp)
@@ -75,12 +86,20 @@ class TablePackager:
75
86
  self.tables_dir.mkdir()
76
87
  with catalog.Catalog.get().begin_xact(for_write=False):
77
88
  for tv in self.table._tbl_version_path.get_tbl_versions():
78
- _logger.info(f"Exporting table '{tv.get().versioned_name}'.")
89
+ _logger.info(f'Exporting table {tv.get().versioned_name!r}.')
79
90
  self.__export_table(tv.get())
91
+
80
92
  _logger.info('Building archive.')
81
- bundle_path = self.__build_tarball()
82
- _logger.info(f'Packaging complete: {bundle_path}')
83
- return bundle_path
93
+ self.bundle_path = self.__build_tarball()
94
+
95
+ _logger.info('Extracting preview data.')
96
+ self.md['count'] = self.table.count()
97
+ preview_header, preview = self.__extract_preview_data()
98
+ self.md['preview_header'] = preview_header
99
+ self.md['preview'] = preview
100
+
101
+ _logger.info(f'Packaging complete: {self.bundle_path}')
102
+ return self.bundle_path
84
103
 
85
104
  def __export_table(self, tv: catalog.TableVersion) -> None:
86
105
  """
@@ -207,6 +226,96 @@ class TablePackager:
207
226
  tf.add(src_file, arcname=f'media/{dest_name}')
208
227
  return bundle_path
209
228
 
229
+ def __extract_preview_data(self) -> tuple[dict[str, str], list[list[Any]]]:
230
+ """
231
+ Extract a preview of the table data for display in the UI.
232
+
233
+ In order to bound the size of the output data, all "unbounded" data types are resized:
234
+ - Strings are abbreviated as per Formatter.abbreviate()
235
+ - Arrays and JSON are shortened and formatted as strings
236
+ - Images are resized to thumbnail size as a base64-encoded webp
237
+ - Videos are replaced by their first frame and resized as above
238
+ - Documents are replaced by a thumbnail as a base64-encoded webp
239
+ """
240
+ # First 8 columns
241
+ preview_cols = dict(itertools.islice(self.table._schema.items(), 0, 8))
242
+ select_list = [self.table[col_name] for col_name in preview_cols]
243
+ # First 5 rows
244
+ rows = list(self.table.select(*select_list).head(n=5))
245
+
246
+ preview_header = {col_name: str(col_type._type) for col_name, col_type in preview_cols.items()}
247
+ preview = [
248
+ [self.__encode_preview_data(val, col_type)]
249
+ for row in rows
250
+ for val, col_type in zip(row.values(), preview_cols.values())
251
+ ]
252
+
253
+ return preview_header, preview
254
+
255
+ def __encode_preview_data(self, val: Any, col_type: ts.ColumnType) -> Any:
256
+ if val is None:
257
+ return None
258
+
259
+ match col_type._type:
260
+ case ts.ColumnType.Type.STRING:
261
+ assert isinstance(val, str)
262
+ return Formatter.abbreviate(val)
263
+
264
+ case ts.ColumnType.Type.INT | ts.ColumnType.Type.FLOAT | ts.ColumnType.Type.BOOL:
265
+ return val
266
+
267
+ case ts.ColumnType.Type.TIMESTAMP | ts.ColumnType.Type.DATE:
268
+ return str(val)
269
+
270
+ case ts.ColumnType.Type.ARRAY:
271
+ assert isinstance(val, np.ndarray)
272
+ return Formatter.format_array(val)
273
+
274
+ case ts.ColumnType.Type.JSON:
275
+ # We need to escape the JSON string server-side for security reasons.
276
+ # Therefore we don't escape it here, in order to avoid double-escaping.
277
+ return Formatter.format_json(val, escape_strings=False)
278
+
279
+ case ts.ColumnType.Type.IMAGE:
280
+ # Rescale the image to minimize data transfer size
281
+ assert isinstance(val, PIL.Image.Image)
282
+ return self.__encode_image(val)
283
+
284
+ case ts.ColumnType.Type.VIDEO:
285
+ assert isinstance(val, str)
286
+ return self.__encode_video(val)
287
+
288
+ case ts.ColumnType.Type.AUDIO:
289
+ return None
290
+
291
+ case ts.ColumnType.Type.DOCUMENT:
292
+ assert isinstance(val, str)
293
+ return self.__encode_document(val)
294
+
295
+ case _:
296
+ raise AssertionError(f'Unrecognized column type: {col_type._type}')
297
+
298
+ def __encode_image(self, img: PIL.Image.Image) -> str:
299
+ # Heuristic for thumbnail sizing:
300
+ # Standardize on a width of 240 pixels (to most efficiently utilize the columnar display).
301
+ # But, if the aspect ratio is below 2:3, bound the height at 360 pixels (to avoid unboundedly tall thumbnails
302
+ # in the case of highly oblong images).
303
+ if img.height > img.width * 1.5:
304
+ scaled_img = img.resize((img.width * 360 // img.height, 360))
305
+ else:
306
+ scaled_img = img.resize((240, img.height * 240 // img.width))
307
+ with io.BytesIO() as buffer:
308
+ scaled_img.save(buffer, 'webp')
309
+ return base64.b64encode(buffer.getvalue()).decode()
310
+
311
+ def __encode_video(self, video_path: str) -> Optional[str]:
312
+ thumb = Formatter.extract_first_video_frame(video_path)
313
+ return self.__encode_image(thumb) if thumb is not None else None
314
+
315
+ def __encode_document(self, doc_path: str) -> Optional[str]:
316
+ thumb = Formatter.make_document_thumbnail(doc_path)
317
+ return self.__encode_image(thumb) if thumb is not None else None
318
+
210
319
 
211
320
  class TableRestorer:
212
321
  """
@@ -63,10 +63,10 @@ class Formatter:
63
63
  """
64
64
  Escapes special characters in `val`, and abbreviates `val` if its length exceeds `_STRING_MAX_LEN`.
65
65
  """
66
- return cls.__escape(cls.__abbreviate(val, cls.__STRING_MAX_LEN))
66
+ return cls.__escape(cls.abbreviate(val))
67
67
 
68
68
  @classmethod
69
- def __abbreviate(cls, val: str, max_len: int) -> str:
69
+ def abbreviate(cls, val: str, max_len: int = __STRING_MAX_LEN) -> str:
70
70
  if len(val) > max_len:
71
71
  edgeitems = (max_len - len(cls.__STRING_SEP)) // 2
72
72
  return f'{val[:edgeitems]}{cls.__STRING_SEP}{val[-edgeitems:]}'
@@ -94,41 +94,45 @@ class Formatter:
94
94
  )
95
95
 
96
96
  @classmethod
97
- def format_json(cls, val: Any) -> str:
97
+ def format_json(cls, val: Any, escape_strings: bool = True) -> str:
98
98
  if isinstance(val, str):
99
99
  # JSON-like formatting will be applied to strings that appear nested within a list or dict
100
100
  # (quote the string; escape any quotes inside the string; shorter abbreviations).
101
101
  # However, if the string appears in top-level position (i.e., the entire JSON value is a
102
102
  # string), then we format it like an ordinary string.
103
- return cls.format_string(val)
103
+ return cls.format_string(val) if escape_strings else cls.abbreviate(val)
104
104
  # In all other cases, dump the JSON struct recursively.
105
- return cls.__format_json_rec(val)
105
+ return cls.__format_json_rec(val, escape_strings)
106
106
 
107
107
  @classmethod
108
- def __format_json_rec(cls, val: Any) -> str:
108
+ def __format_json_rec(cls, val: Any, escape_strings: bool) -> str:
109
109
  if isinstance(val, str):
110
- return cls.__escape(json.dumps(cls.__abbreviate(val, cls.__NESTED_STRING_MAX_LEN)))
110
+ formatted = json.dumps(cls.abbreviate(val, cls.__NESTED_STRING_MAX_LEN))
111
+ return cls.__escape(formatted) if escape_strings else formatted
111
112
  if isinstance(val, float):
112
113
  return cls.format_float(val)
113
114
  if isinstance(val, np.ndarray):
114
115
  return cls.format_array(val)
115
116
  if isinstance(val, list):
116
117
  if len(val) < cls.__LIST_THRESHOLD:
117
- components = [cls.__format_json_rec(x) for x in val]
118
+ components = [cls.__format_json_rec(x, escape_strings) for x in val]
118
119
  else:
119
- components = [cls.__format_json_rec(x) for x in val[: cls.__LIST_EDGEITEMS]]
120
+ components = [cls.__format_json_rec(x, escape_strings) for x in val[: cls.__LIST_EDGEITEMS]]
120
121
  components.append('...')
121
- components.extend(cls.__format_json_rec(x) for x in val[-cls.__LIST_EDGEITEMS :])
122
+ components.extend(cls.__format_json_rec(x, escape_strings) for x in val[-cls.__LIST_EDGEITEMS :])
122
123
  return '[' + ', '.join(components) + ']'
123
124
  if isinstance(val, dict):
124
- kv_pairs = (f'{cls.__format_json_rec(k)}: {cls.__format_json_rec(v)}' for k, v in val.items())
125
+ kv_pairs = (
126
+ f'{cls.__format_json_rec(k, escape_strings)}: {cls.__format_json_rec(v, escape_strings)}'
127
+ for k, v in val.items()
128
+ )
125
129
  return '{' + ', '.join(kv_pairs) + '}'
126
130
 
127
131
  # Everything else
128
132
  try:
129
133
  return json.dumps(val)
130
134
  except TypeError: # Not JSON serializable
131
- return str(val)
135
+ return cls.__escape(str(val))
132
136
 
133
137
  def format_img(self, img: Image.Image) -> str:
134
138
  """
@@ -152,22 +156,19 @@ class Formatter:
152
156
  """
153
157
 
154
158
  def format_video(self, file_path: str) -> str:
155
- thumb_tag = ''
156
159
  # Attempt to extract the first frame of the video to use as a thumbnail,
157
160
  # so that the notebook can be exported as HTML and viewed in contexts where
158
161
  # the video itself is not accessible.
159
162
  # TODO(aaron-siegel): If the video is backed by a concrete external URL,
160
163
  # should we link to that instead?
161
- with av.open(file_path) as container:
162
- try:
163
- thumb = next(container.decode(video=0)).to_image()
164
- assert isinstance(thumb, Image.Image)
165
- with io.BytesIO() as buffer:
166
- thumb.save(buffer, 'jpeg')
167
- thumb_base64 = base64.b64encode(buffer.getvalue()).decode()
168
- thumb_tag = f'poster="data:image/jpeg;base64,{thumb_base64}"'
169
- except Exception:
170
- pass
164
+ thumb = self.extract_first_video_frame(file_path)
165
+ if thumb is None:
166
+ thumb_tag = ''
167
+ else:
168
+ with io.BytesIO() as buffer:
169
+ thumb.save(buffer, 'jpeg')
170
+ thumb_base64 = base64.b64encode(buffer.getvalue()).decode()
171
+ thumb_tag = f'poster="data:image/jpeg;base64,{thumb_base64}"'
171
172
  if self.__num_rows > 1:
172
173
  width = 320
173
174
  elif self.__num_cols > 1:
@@ -182,6 +183,16 @@ class Formatter:
182
183
  </div>
183
184
  """
184
185
 
186
+ @classmethod
187
+ def extract_first_video_frame(cls, file_path: str) -> Optional[Image.Image]:
188
+ with av.open(file_path) as container:
189
+ try:
190
+ img = next(container.decode(video=0)).to_image()
191
+ assert isinstance(img, Image.Image)
192
+ return img
193
+ except Exception:
194
+ return None
195
+
185
196
  def format_audio(self, file_path: str) -> str:
186
197
  return f"""
187
198
  <div class="pxt_audio">
@@ -191,29 +202,18 @@ class Formatter:
191
202
  </div>
192
203
  """
193
204
 
194
- def format_document(self, file_path: str) -> str:
195
- max_width = max_height = 320
205
+ def format_document(self, file_path: str, max_width: int = 320, max_height: int = 320) -> str:
196
206
  # by default, file path will be shown as a link
197
207
  inner_element = file_path
198
208
  inner_element = html.escape(inner_element)
199
- # try generating a thumbnail for different types and use that if successful
200
- if file_path.lower().endswith('.pdf'):
201
- try:
202
- import fitz # type: ignore[import-untyped]
203
209
 
204
- doc = fitz.open(file_path)
205
- p = doc.get_page_pixmap(0)
206
- while p.width > max_width or p.height > max_height:
207
- # shrink(1) will halve each dimension
208
- p.shrink(1)
209
- data = p.tobytes(output='jpeg')
210
- thumb_base64 = base64.b64encode(data).decode()
211
- img_src = f'data:image/jpeg;base64,{thumb_base64}'
212
- inner_element = f"""
213
- <img style="object-fit: contain; border: 1px solid black;" src="{img_src}" />
214
- """
215
- except Exception:
216
- logging.warning(f'Failed to produce PDF thumbnail {file_path}. Make sure you have PyMuPDF installed.')
210
+ thumb = self.make_document_thumbnail(file_path, max_width, max_height)
211
+ if thumb is not None:
212
+ with io.BytesIO() as buffer:
213
+ thumb.save(buffer, 'webp')
214
+ thumb_base64 = base64.b64encode(buffer.getvalue()).decode()
215
+ thumb_tag = f'data:image/webp;base64,{thumb_base64}'
216
+ inner_element = f'<img style="object-fit: contain; border: 1px solid black;" src="{thumb_tag}" />'
217
217
 
218
218
  return f"""
219
219
  <div class="pxt_document" style="width:{max_width}px;">
@@ -223,6 +223,28 @@ class Formatter:
223
223
  </div>
224
224
  """
225
225
 
226
+ @classmethod
227
+ def make_document_thumbnail(
228
+ cls, file_path: str, max_width: int = 320, max_height: int = 320
229
+ ) -> Optional[Image.Image]:
230
+ """
231
+ Returns a thumbnail image of a document.
232
+ """
233
+ if file_path.lower().endswith('.pdf'):
234
+ try:
235
+ import fitz # type: ignore[import-untyped]
236
+
237
+ doc = fitz.open(file_path)
238
+ pixmap = doc.get_page_pixmap(0)
239
+ while pixmap.width > max_width or pixmap.height > max_height:
240
+ # shrink(1) will halve each dimension
241
+ pixmap.shrink(1)
242
+ return pixmap.pil_image()
243
+ except Exception:
244
+ logging.warning(f'Failed to produce PDF thumbnail {file_path}. Make sure you have PyMuPDF installed.')
245
+
246
+ return None
247
+
226
248
  @classmethod
227
249
  def __create_source_tag(cls, http_address: str, file_path: str) -> str:
228
250
  src_url = get_file_uri(http_address, file_path)