pixeltable 0.4.16__py3-none-any.whl → 0.4.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

@@ -2,10 +2,11 @@
2
2
  Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs) for `VideoType`.
3
3
  """
4
4
 
5
+ import glob
5
6
  import logging
6
7
  import pathlib
7
8
  import subprocess
8
- from typing import Literal, NoReturn
9
+ from typing import Any, Literal, NoReturn
9
10
 
10
11
  import av
11
12
  import av.stream
@@ -305,7 +306,14 @@ def _handle_ffmpeg_error(e: subprocess.CalledProcessError) -> NoReturn:
305
306
 
306
307
  @pxt.udf(is_method=True)
307
308
  def clip(
308
- video: pxt.Video, *, start_time: float, end_time: float | None = None, duration: float | None = None
309
+ video: pxt.Video,
310
+ *,
311
+ start_time: float,
312
+ end_time: float | None = None,
313
+ duration: float | None = None,
314
+ mode: Literal['fast', 'accurate'] = 'accurate',
315
+ video_encoder: str | None = None,
316
+ video_encoder_args: dict[str, Any] | None = None,
309
317
  ) -> pxt.Video | None:
310
318
  """
311
319
  Extract a clip from a video, specified by `start_time` and either `end_time` or `duration` (in seconds).
@@ -322,6 +330,14 @@ def clip(
322
330
  start_time: Start time in seconds
323
331
  end_time: End time in seconds
324
332
  duration: Duration of the clip in seconds
333
+ mode:
334
+
335
+ - `'fast'`: avoids re-encoding but starts the clip at the nearest keyframes and as a result, the clip
336
+ duration will be slightly longer than requested
337
+ - `'accurate'`: extracts a frame-accurate clip, but requires re-encoding
338
+ video_encoder: Video encoder to use. If not specified, uses the default encoder for the current platform.
339
+ Only available for `mode='accurate'`.
340
+ video_encoder_args: Additional arguments to pass to the video encoder. Only available for `mode='accurate'`.
325
341
 
326
342
  Returns:
327
343
  New video containing only the specified time range or None if start_time is beyond the end of the video.
@@ -335,6 +351,11 @@ def clip(
335
351
  raise pxt.Error(f'duration must be positive, got {duration}')
336
352
  if end_time is not None and duration is not None:
337
353
  raise pxt.Error('end_time and duration cannot both be specified')
354
+ if mode == 'fast':
355
+ if video_encoder is not None:
356
+ raise pxt.Error("video_encoder is not supported for mode='fast'")
357
+ if video_encoder_args is not None:
358
+ raise pxt.Error("video_encoder_args is not supported for mode='fast'")
338
359
 
339
360
  video_duration = av_utils.get_video_duration(video)
340
361
  if video_duration is not None and start_time > video_duration:
@@ -344,7 +365,15 @@ def clip(
344
365
 
345
366
  if end_time is not None:
346
367
  duration = end_time - start_time
347
- cmd = av_utils.ffmpeg_clip_cmd(str(video), output_path, start_time, duration)
368
+ cmd = av_utils.ffmpeg_clip_cmd(
369
+ str(video),
370
+ output_path,
371
+ start_time,
372
+ duration,
373
+ fast=(mode == 'fast'),
374
+ video_encoder=video_encoder,
375
+ video_encoder_args=video_encoder_args,
376
+ )
348
377
 
349
378
  try:
350
379
  result = subprocess.run(cmd, capture_output=True, text=True, check=True)
@@ -358,9 +387,17 @@ def clip(
358
387
 
359
388
 
360
389
  @pxt.udf(is_method=True)
361
- def segment_video(video: pxt.Video, *, duration: float) -> list[str]:
390
+ def segment_video(
391
+ video: pxt.Video,
392
+ *,
393
+ duration: float | None = None,
394
+ segment_times: list[float] | None = None,
395
+ mode: Literal['fast', 'accurate'] = 'accurate',
396
+ video_encoder: str | None = None,
397
+ video_encoder_args: dict[str, Any] | None = None,
398
+ ) -> list[str]:
362
399
  """
363
- Split a video into fixed-size segments.
400
+ Split a video into segments.
364
401
 
365
402
  __Requirements:__
366
403
 
@@ -368,7 +405,19 @@ def segment_video(video: pxt.Video, *, duration: float) -> list[str]:
368
405
 
369
406
  Args:
370
407
  video: Input video file to segment
371
- duration: Approximate duration of each segment (in seconds).
408
+ duration: Duration of each segment (in seconds). For `mode='fast'`, this is approximate;
409
+ for `mode='accurate'`, segments will have exact durations. Cannot be specified together with
410
+ `segment_times`.
411
+ segment_times: List of timestamps (in seconds) in video where segments should be split. Note that these are not
412
+ segment durations. If all segment times are less than the duration of the video, produces exactly
413
+ `len(segment_times) + 1` segments. Cannot be empty or be specified together with `duration`.
414
+ mode: Segmentation mode:
415
+
416
+ - `'fast'`: Quick segmentation using stream copy (splits only at keyframes, approximate durations)
417
+ - `'accurate'`: Precise segmentation with re-encoding (exact durations, slower)
418
+ video_encoder: Video encoder to use. If not specified, uses the default encoder for the current platform.
419
+ Only available for `mode='accurate'`.
420
+ video_encoder_args: Additional arguments to pass to the video encoder. Only available for `mode='accurate'`.
372
421
 
373
422
  Returns:
374
423
  List of file paths for the generated video segments.
@@ -377,45 +426,105 @@ def segment_video(video: pxt.Video, *, duration: float) -> list[str]:
377
426
  pxt.Error: If the video is missing timing information.
378
427
 
379
428
  Examples:
380
- Split a video at 1 minute intervals
429
+ Split a video at 1 minute intervals using fast mode:
381
430
 
382
- >>> tbl.select(segment_paths=tbl.video.segment_video(duration=60)).collect()
431
+ >>> tbl.select(segment_paths=tbl.video.segment_video(duration=60, mode='fast')).collect()
432
+
433
+ Split video into exact 10-second segments with default accurate mode, using the libx264 encoder with a CRF of 23
434
+ and slow preset (for smaller output files):
435
+
436
+ >>> tbl.select(
437
+ ... segment_paths=tbl.video.segment_video(
438
+ ... duration=10,
439
+ ... video_encoder='libx264',
440
+ ... video_encoder_args={'crf': 23, 'preset': 'slow'}
441
+ ... )
442
+ ... ).collect()
383
443
 
384
444
  Split video into two parts at the midpoint:
385
445
 
386
446
  >>> duration = tbl.video.get_duration()
387
- >>> tbl.select(segment_paths=tbl.video.segment_video(duration=duration / 2 + 1)).collect()
447
+ >>> tbl.select(segment_paths=tbl.video.segment_video(segment_times=[duration / 2])).collect()
388
448
  """
389
449
  Env.get().require_binary('ffmpeg')
390
- if duration <= 0:
450
+ if duration is not None and segment_times is not None:
451
+ raise pxt.Error('duration and segment_times cannot both be specified')
452
+ if duration is not None and duration <= 0:
391
453
  raise pxt.Error(f'duration must be positive, got {duration}')
454
+ if segment_times is not None and len(segment_times) == 0:
455
+ raise pxt.Error('segment_times cannot be empty')
456
+ if mode == 'fast':
457
+ if video_encoder is not None:
458
+ raise pxt.Error("video_encoder is not supported for mode='fast'")
459
+ if video_encoder_args is not None:
460
+ raise pxt.Error("video_encoder_args is not supported for mode='fast'")
392
461
 
393
462
  base_path = TempStore.create_path(extension='')
394
463
 
395
- # we extract consecutive clips instead of running ffmpeg -f segment, which is inexplicably much slower
396
- start_time = 0.0
397
- result: list[str] = []
398
- try:
399
- while True:
400
- segment_path = f'{base_path}_segment_{len(result)}.mp4'
401
- cmd = av_utils.ffmpeg_clip_cmd(str(video), segment_path, start_time, duration)
464
+ output_paths: list[str] = []
465
+ if mode == 'accurate':
466
+ # Use ffmpeg -f segment for accurate segmentation with re-encoding
467
+ output_pattern = f'{base_path}_segment_%04d.mp4'
468
+ cmd = av_utils.ffmpeg_segment_cmd(
469
+ str(video),
470
+ output_pattern,
471
+ segment_duration=duration,
472
+ segment_times=segment_times,
473
+ video_encoder=video_encoder,
474
+ video_encoder_args=video_encoder_args,
475
+ )
402
476
 
477
+ try:
403
478
  _ = subprocess.run(cmd, capture_output=True, text=True, check=True)
404
- segment_duration = av_utils.get_video_duration(segment_path)
405
- if segment_duration == 0.0:
406
- # we're done
407
- pathlib.Path(segment_path).unlink()
408
- return result
409
- result.append(segment_path)
410
- start_time += segment_duration # use the actual segment duration here, it won't match duration exactly
479
+ output_paths = sorted(glob.glob(f'{base_path}_segment_*.mp4'))
480
+ # TODO: is this actually an error?
481
+ # if len(output_paths) == 0:
482
+ # stderr_output = result.stderr.strip() if result.stderr is not None else ''
483
+ # raise pxt.Error(
484
+ # f'ffmpeg failed to create output files for commandline: {" ".join(cmd)}\n{stderr_output}'
485
+ # )
486
+ return output_paths
487
+
488
+ except subprocess.CalledProcessError as e:
489
+ _handle_ffmpeg_error(e)
411
490
 
412
- return result
413
-
414
- except subprocess.CalledProcessError as e:
415
- # clean up partial results
416
- for segment_path in result:
417
- pathlib.Path(segment_path).unlink()
418
- _handle_ffmpeg_error(e)
491
+ else:
492
+ # Fast mode: extract consecutive clips using stream copy (no re-encoding)
493
+ # This is faster but can only split at keyframes, leading to approximate durations
494
+ start_time = 0.0
495
+ segment_idx = 0
496
+ try:
497
+ while True:
498
+ target_duration: float | None
499
+ if duration is not None:
500
+ target_duration = duration
501
+ elif segment_idx < len(segment_times):
502
+ target_duration = segment_times[segment_idx] - start_time
503
+ else:
504
+ target_duration = None # the rest
505
+ segment_path = f'{base_path}_segment_{len(output_paths)}.mp4'
506
+ cmd = av_utils.ffmpeg_clip_cmd(str(video), segment_path, start_time, target_duration)
507
+
508
+ _ = subprocess.run(cmd, capture_output=True, text=True, check=True)
509
+ segment_duration = av_utils.get_video_duration(segment_path)
510
+ if segment_duration == 0.0:
511
+ # we're done
512
+ pathlib.Path(segment_path).unlink()
513
+ return output_paths
514
+ output_paths.append(segment_path)
515
+ start_time += segment_duration # use the actual segment duration here, it won't match duration exactly
516
+
517
+ segment_idx += 1
518
+ if segment_times is not None and segment_idx > len(segment_times):
519
+ break
520
+
521
+ return output_paths
522
+
523
+ except subprocess.CalledProcessError as e:
524
+ # clean up partial results
525
+ for segment_path in output_paths:
526
+ pathlib.Path(segment_path).unlink()
527
+ _handle_ffmpeg_error(e)
419
528
 
420
529
 
421
530
  @pxt.udf(is_method=True)
pixeltable/globals.py CHANGED
@@ -487,12 +487,28 @@ def get_table(path: str, if_not_exists: Literal['error', 'ignore'] = 'error') ->
487
487
  return tbl
488
488
 
489
489
 
490
- def move(path: str, new_path: str) -> None:
490
+ def move(
491
+ path: str,
492
+ new_path: str,
493
+ *,
494
+ if_exists: Literal['error', 'ignore'] = 'error',
495
+ if_not_exists: Literal['error', 'ignore'] = 'error',
496
+ ) -> None:
491
497
  """Move a schema object to a new directory and/or rename a schema object.
492
498
 
493
499
  Args:
494
500
  path: absolute path to the existing schema object.
495
501
  new_path: absolute new path for the schema object.
502
+ if_exists: Directive regarding how to handle if a schema object already exists at the new path.
503
+ Must be one of the following:
504
+
505
+ - `'error'`: raise an error
506
+ - `'ignore'`: do nothing and return
507
+ if_not_exists: Directive regarding how to handle if the source path does not exist.
508
+ Must be one of the following:
509
+
510
+ - `'error'`: raise an error
511
+ - `'ignore'`: do nothing and return
496
512
 
497
513
  Raises:
498
514
  Error: If path does not exist or new_path already exists.
@@ -506,13 +522,16 @@ def move(path: str, new_path: str) -> None:
506
522
 
507
523
  >>>> pxt.move('dir1.my_table', 'dir1.new_name')
508
524
  """
525
+ if_exists_ = catalog.IfExistsParam.validated(if_exists, 'if_exists')
526
+ if if_exists_ not in (catalog.IfExistsParam.ERROR, catalog.IfExistsParam.IGNORE):
527
+ raise excs.Error("`if_exists` must be one of 'error' or 'ignore'")
528
+ if_not_exists_ = catalog.IfNotExistsParam.validated(if_not_exists, 'if_not_exists')
509
529
  if path == new_path:
510
530
  raise excs.Error('move(): source and destination cannot be identical')
511
531
  path_obj, new_path_obj = catalog.Path.parse(path), catalog.Path.parse(new_path)
512
532
  if path_obj.is_ancestor(new_path_obj):
513
533
  raise excs.Error(f'move(): cannot move {path!r} into its own subdirectory')
514
- cat = Catalog.get()
515
- cat.move(path_obj, new_path_obj)
534
+ Catalog.get().move(path_obj, new_path_obj, if_exists_, if_not_exists_)
516
535
 
517
536
 
518
537
  def drop_table(
@@ -660,7 +679,7 @@ def _list_tables(dir_path: str = '', recursive: bool = True, allow_system_paths:
660
679
 
661
680
 
662
681
  def create_dir(
663
- path: str, if_exists: Literal['error', 'ignore', 'replace', 'replace_force'] = 'error', parents: bool = False
682
+ path: str, *, if_exists: Literal['error', 'ignore', 'replace', 'replace_force'] = 'error', parents: bool = False
664
683
  ) -> Optional[catalog.Dir]:
665
684
  """Create a directory.
666
685
 
pixeltable/io/globals.py CHANGED
@@ -152,7 +152,7 @@ def export_images_as_fo_dataset(
152
152
  (or expression) containing image data, along with optional additional columns containing labels. Currently, only
153
153
  classification and detection labels are supported.
154
154
 
155
- The [Working with Voxel51 in Pixeltable](https://docs.pixeltable.com/docs/working-with-voxel51) tutorial contains a
155
+ The [Working with Voxel51 in Pixeltable](https://docs.pixeltable.com/examples/vision/voxel51) tutorial contains a
156
156
  fully worked example showing how to export data from a Pixeltable table and load it into Voxel51.
157
157
 
158
158
  Images in the dataset that already exist on disk will be exported directly, in whatever format they
@@ -211,7 +211,7 @@ def export_images_as_fo_dataset(
211
211
  ... classifications=tbl.classifications
212
212
  ... )
213
213
 
214
- See the [Working with Voxel51 in Pixeltable](https://docs.pixeltable.com/docs/working-with-voxel51) tutorial
214
+ See the [Working with Voxel51 in Pixeltable](https://docs.pixeltable.com/examples/vision/voxel51) tutorial
215
215
  for a fully worked example.
216
216
  """
217
217
  Env.get().require_package('fiftyone')
pixeltable/io/parquet.py CHANGED
@@ -62,7 +62,7 @@ def export_parquet(
62
62
  with Catalog.get().begin_xact(for_write=False):
63
63
  for record_batch in to_record_batches(df, partition_size_bytes):
64
64
  output_path = temp_path / f'part-{batch_num:05d}.parquet'
65
- arrow_tbl = pa.Table.from_batches([record_batch]) # type: ignore
65
+ arrow_tbl = pa.Table.from_batches([record_batch])
66
66
  pa.parquet.write_table(arrow_tbl, str(output_path))
67
67
  batch_num += 1
68
68
 
@@ -528,7 +528,7 @@ class ParquetTableDataConduit(TableDataConduit):
528
528
  from pixeltable.utils.arrow import iter_tuples2
529
529
 
530
530
  try:
531
- for fragment in self.pq_ds.fragments: # type: ignore[attr-defined]
531
+ for fragment in self.pq_ds.fragments:
532
532
  for batch in fragment.to_batches():
533
533
  dict_batch = list(iter_tuples2(batch, self.source_column_map, self.pxt_schema))
534
534
  self.total_rows += len(dict_batch)
@@ -1,13 +1,17 @@
1
1
  import dataclasses
2
2
  import enum
3
+ import io
3
4
  import logging
4
- from typing import Any, ClassVar, Iterable, Iterator, Optional
5
+ from typing import Any, ClassVar, Iterable, Iterator, Literal
5
6
 
7
+ import fitz # type: ignore[import-untyped]
6
8
  import ftfy
9
+ import PIL.Image
10
+ from bs4.element import NavigableString, Tag
7
11
 
8
12
  from pixeltable.env import Env
9
13
  from pixeltable.exceptions import Error
10
- from pixeltable.type_system import ColumnType, DocumentType, IntType, JsonType, StringType
14
+ from pixeltable.type_system import ColumnType, DocumentType, ImageType, IntType, JsonType, StringType
11
15
  from pixeltable.utils.documents import get_document_handle
12
16
 
13
17
  from .base import ComponentIterator
@@ -15,6 +19,11 @@ from .base import ComponentIterator
15
19
  _logger = logging.getLogger('pixeltable')
16
20
 
17
21
 
22
+ class Element(enum.Enum):
23
+ TEXT = 1
24
+ IMAGE = 2
25
+
26
+
18
27
  class ChunkMetadata(enum.Enum):
19
28
  TITLE = 1
20
29
  HEADING = 2
@@ -37,27 +46,28 @@ class DocumentSectionMetadata:
37
46
  """Metadata for a subsection of a document (ie, a structural element like a heading or paragraph)"""
38
47
 
39
48
  # html and markdown metadata
40
- sourceline: Optional[int] = None
49
+ sourceline: int | None = None
41
50
  # the stack of headings up to the most recently observed one;
42
51
  # eg, if the most recent one was an h2, 'headings' would contain keys 1 and 2, but nothing below that
43
- heading: Optional[dict[str, str]] = None
52
+ heading: dict[str, str] | None = None
44
53
 
45
54
  # pdf-specific metadata
46
- page: Optional[int] = None
55
+ page: int | None = None
47
56
  # bounding box as an {x1, y1, x2, y2} dictionary
48
- bounding_box: Optional[dict[str, float]] = None
57
+ bounding_box: dict[str, float] | None = None
49
58
 
50
59
 
51
60
  @dataclasses.dataclass
52
61
  class DocumentSection:
53
62
  """A single document chunk, according to some of the splitting criteria"""
54
63
 
55
- text: Optional[str]
56
- metadata: Optional[DocumentSectionMetadata]
64
+ text: str | None = None
65
+ image: PIL.Image.Image | None = None
66
+ metadata: DocumentSectionMetadata | None = None
57
67
 
58
68
 
59
69
  def _parse_separators(separators: str) -> list[Separator]:
60
- ret = []
70
+ ret: list[Separator] = []
61
71
  for s in separators.split(','):
62
72
  clean_s = s.strip().upper()
63
73
  if not clean_s:
@@ -71,7 +81,7 @@ def _parse_separators(separators: str) -> list[Separator]:
71
81
 
72
82
 
73
83
  def _parse_metadata(metadata: str) -> list[ChunkMetadata]:
74
- ret = []
84
+ ret: list[ChunkMetadata] = []
75
85
  for m in metadata.split(','):
76
86
  clean_m = m.strip().upper()
77
87
  if not clean_m:
@@ -84,6 +94,18 @@ def _parse_metadata(metadata: str) -> list[ChunkMetadata]:
84
94
  return ret
85
95
 
86
96
 
97
+ def _parse_elements(elements: list[Literal['text', 'image']]) -> list[Element]:
98
+ result: list[Element] = []
99
+ for e in elements:
100
+ clean_e = e.strip().upper()
101
+ if clean_e not in Element.__members__:
102
+ raise Error(f'Invalid element: `{e}`. Valid elements are: {", ".join(Element.__members__).lower()}')
103
+ result.append(Element[clean_e])
104
+ if len(result) == 0:
105
+ raise Error('elements cannot be empty')
106
+ return result
107
+
108
+
87
109
  _HTML_HEADINGS = {'h1', 'h2', 'h3', 'h4', 'h5', 'h6'}
88
110
 
89
111
 
@@ -95,15 +117,22 @@ class DocumentSplitter(ComponentIterator):
95
117
 
96
118
  Chunked text will be cleaned with `ftfy.fix_text` to fix up common problems with unicode sequences.
97
119
 
120
+ How to init the `DocumentSplitter` class?
121
+
98
122
  Args:
99
123
  separators: separators to use to chunk the document. Options are:
100
124
  `'heading'`, `'paragraph'`, `'sentence'`, `'token_limit'`, `'char_limit'`, `'page'`.
101
125
  This may be a comma-separated string, e.g., `'heading,token_limit'`.
126
+ elements: list of elements to extract from the document. Options are:
127
+ `'text'`, `'image'`. Defaults to `['text']` if not specified. The `'image'` element is only supported
128
+ for the `'page'` separator on PDF documents.
102
129
  limit: the maximum number of tokens or characters in each chunk, if `'token_limit'`
103
130
  or `'char_limit'` is specified.
104
131
  metadata: additional metadata fields to include in the output. Options are:
105
132
  `'title'`, `'heading'` (HTML and Markdown), `'sourceline'` (HTML), `'page'` (PDF), `'bounding_box'`
106
133
  (PDF). The input may be a comma-separated string, e.g., `'title,heading,sourceline'`.
134
+ image_dpi: DPI to use when extracting images from PDFs. Defaults to 300.
135
+ image_format: format to use when extracting images from PDFs. Defaults to 'png'.
107
136
  """
108
137
 
109
138
  METADATA_COLUMN_TYPES: ClassVar[dict[ChunkMetadata, ColumnType]] = {
@@ -114,24 +143,41 @@ class DocumentSplitter(ComponentIterator):
114
143
  ChunkMetadata.BOUNDING_BOX: JsonType(nullable=True),
115
144
  }
116
145
 
146
+ _doc_handle: Any
147
+ _separators: list[Separator]
148
+ _elements: list[Element]
149
+ _metadata_fields: list[ChunkMetadata]
150
+ _doc_title: str
151
+ _limit: int
152
+ _skip_tags: list[str]
153
+ _overlap: int
154
+ _tiktoken_encoding: str | None
155
+ _tiktoken_target_model: str | None
156
+ _image_dpi: int
157
+ _image_format: str
158
+
159
+ _sections: Iterator[DocumentSection]
160
+
117
161
  def __init__(
118
162
  self,
119
163
  document: str,
120
164
  *,
121
165
  separators: str,
122
- limit: Optional[int] = None,
123
- overlap: Optional[int] = None,
166
+ elements: list[Literal['text', 'image']] | None = None,
167
+ limit: int | None = None,
168
+ overlap: int | None = None,
124
169
  metadata: str = '',
125
- html_skip_tags: Optional[list[str]] = None,
126
- tiktoken_encoding: Optional[str] = 'cl100k_base',
127
- tiktoken_target_model: Optional[str] = None,
170
+ html_skip_tags: list[str] | None = None,
171
+ tiktoken_encoding: str | None = 'cl100k_base',
172
+ tiktoken_target_model: str | None = None,
173
+ image_dpi: int = 300,
174
+ image_format: str = 'png',
128
175
  ):
129
176
  if html_skip_tags is None:
130
177
  html_skip_tags = ['nav']
131
178
  self._doc_handle = get_document_handle(document)
179
+ self._elements = _parse_elements(elements.copy()) if elements is not None else [Element.TEXT]
132
180
  assert self._doc_handle is not None
133
- # calling the output_schema method to validate the input arguments
134
- self.output_schema(separators=separators, metadata=metadata, limit=limit, overlap=overlap)
135
181
  self._separators = _parse_separators(separators)
136
182
  self._metadata_fields = _parse_metadata(metadata)
137
183
  if self._doc_handle.bs_doc is not None:
@@ -147,6 +193,8 @@ class DocumentSplitter(ComponentIterator):
147
193
  self._overlap = 0 if overlap is None else overlap
148
194
  self._tiktoken_encoding = tiktoken_encoding
149
195
  self._tiktoken_target_model = tiktoken_target_model
196
+ self._image_dpi = image_dpi
197
+ self._image_format = image_format
150
198
 
151
199
  # set up processing pipeline
152
200
  if self._doc_handle.format == DocumentType.DocumentFormat.HTML:
@@ -176,19 +224,28 @@ class DocumentSplitter(ComponentIterator):
176
224
  return {
177
225
  'document': DocumentType(nullable=False),
178
226
  'separators': StringType(nullable=False),
227
+ 'elements': JsonType(nullable=False),
179
228
  'metadata': StringType(nullable=False),
180
229
  'limit': IntType(nullable=True),
181
230
  'overlap': IntType(nullable=True),
182
231
  'skip_tags': StringType(nullable=True),
183
232
  'tiktoken_encoding': StringType(nullable=True),
184
233
  'tiktoken_target_model': StringType(nullable=True),
234
+ 'image_dpi': IntType(nullable=True),
235
+ 'image_format': StringType(nullable=True),
185
236
  }
186
237
 
187
238
  @classmethod
188
239
  def output_schema(cls, *args: Any, **kwargs: Any) -> tuple[dict[str, ColumnType], list[str]]:
189
- schema: dict[str, ColumnType] = {'text': StringType()}
190
- md_fields = _parse_metadata(kwargs['metadata']) if 'metadata' in kwargs else []
191
-
240
+ schema: dict[str, ColumnType] = {}
241
+ elements = _parse_elements(kwargs.get('elements', ['text']))
242
+ for element in elements:
243
+ if element == Element.TEXT:
244
+ schema['text'] = StringType(nullable=False)
245
+ elif element == Element.IMAGE:
246
+ schema['image'] = ImageType(nullable=False)
247
+
248
+ md_fields = _parse_metadata(kwargs.get('metadata', ''))
192
249
  for md_field in md_fields:
193
250
  schema[md_field.name.lower()] = cls.METADATA_COLUMN_TYPES[md_field]
194
251
 
@@ -198,6 +255,8 @@ class DocumentSplitter(ComponentIterator):
198
255
  limit = kwargs.get('limit')
199
256
  overlap = kwargs.get('overlap')
200
257
 
258
+ if Element.IMAGE in elements and separators != [Separator.PAGE]:
259
+ raise Error('Image elements are only supported for the "page" separator on PDF documents')
201
260
  if limit is not None or overlap is not None:
202
261
  if Separator.TOKEN_LIMIT not in separators and Separator.CHAR_LIMIT not in separators:
203
262
  raise Error('limit/overlap requires the "token_limit" or "char_limit" separator')
@@ -211,14 +270,25 @@ class DocumentSplitter(ComponentIterator):
211
270
  if kwargs.get('limit') is None:
212
271
  raise Error('limit is required with "token_limit"/"char_limit" separators')
213
272
 
273
+ if Separator.SENTENCE in separators:
274
+ _ = Env.get().spacy_nlp
275
+ if Separator.TOKEN_LIMIT in separators:
276
+ Env.get().require_package('tiktoken')
277
+
214
278
  return schema, []
215
279
 
216
280
  def __next__(self) -> dict[str, Any]:
217
281
  while True:
218
282
  section = next(self._sections)
219
- if section.text is None:
283
+ if section.text is None and section.image is None:
220
284
  continue
221
- result: dict[str, Any] = {'text': section.text}
285
+ result: dict[str, Any] = {}
286
+ for element in self._elements:
287
+ if element == Element.TEXT:
288
+ result['text'] = section.text
289
+ elif element == Element.IMAGE:
290
+ result['image'] = section.image
291
+
222
292
  for md_field in self._metadata_fields:
223
293
  if md_field == ChunkMetadata.TITLE:
224
294
  result[md_field.name.lower()] = self._doc_title
@@ -230,6 +300,7 @@ class DocumentSplitter(ComponentIterator):
230
300
  result[md_field.name.lower()] = section.metadata.page
231
301
  elif md_field == ChunkMetadata.BOUNDING_BOX:
232
302
  result[md_field.name.lower()] = section.metadata.bounding_box
303
+
233
304
  return result
234
305
 
235
306
  def _html_sections(self) -> Iterator[DocumentSection]:
@@ -265,7 +336,7 @@ class DocumentSplitter(ComponentIterator):
265
336
  yield DocumentSection(text=full_text, metadata=md)
266
337
  accumulated_text = []
267
338
 
268
- def process_element(el: bs4.element.Tag | bs4.NavigableString) -> Iterator[DocumentSection]:
339
+ def process_element(el: Tag | NavigableString) -> Iterator[DocumentSection]:
269
340
  # process the element and emit sections as necessary
270
341
  nonlocal accumulated_text, headings, sourceline, emit_on_heading, emit_on_paragraph
271
342
 
@@ -353,43 +424,41 @@ class DocumentSplitter(ComponentIterator):
353
424
  yield from emit()
354
425
 
355
426
  def _pdf_sections(self) -> Iterator[DocumentSection]:
356
- """Create DocumentSections reflecting the pdf-specific separators"""
357
- import fitz # type: ignore[import-untyped]
358
-
359
427
  doc: fitz.Document = self._doc_handle.pdf_doc
360
428
  assert doc is not None
361
429
 
362
430
  emit_on_paragraph = Separator.PARAGRAPH in self._separators or Separator.SENTENCE in self._separators
363
431
  emit_on_page = Separator.PAGE in self._separators or emit_on_paragraph
364
432
 
365
- accumulated_text = [] # invariant: all elements are ftfy clean and non-empty
433
+ accumulated_text: list[str] = []
366
434
 
367
- def _add_cleaned_text(raw_text: str) -> None:
368
- fixed = ftfy.fix_text(raw_text)
435
+ def _add_cleaned(raw: str) -> None:
436
+ fixed = ftfy.fix_text(raw)
369
437
  if fixed:
370
438
  accumulated_text.append(fixed)
371
439
 
372
440
  def _emit_text() -> str:
373
- full_text = ''.join(accumulated_text)
441
+ txt = ''.join(accumulated_text)
374
442
  accumulated_text.clear()
375
- return full_text
443
+ return txt
444
+
445
+ for page_idx, page in enumerate(doc.pages()):
446
+ img: PIL.Image.Image | None = None
447
+ if Element.IMAGE in self._elements:
448
+ pix = page.get_pixmap(dpi=self._image_dpi)
449
+ img = PIL.Image.open(io.BytesIO(pix.tobytes(self._image_format)))
376
450
 
377
- for page_number, page in enumerate(doc.pages()):
378
451
  for block in page.get_text('blocks'):
379
- # there is no concept of paragraph in pdf, block is the closest thing
380
- # we can get (eg a paragraph in text may cut across pages)
381
- # see pymupdf docs https://pymupdf.readthedocs.io/en/latest/app1.html
382
- # other libraries like pdfminer also lack an explicit paragraph concept
383
- x1, y1, x2, y2, text, _, _ = block
384
- _add_cleaned_text(text)
452
+ x1, y1, x2, y2, text, *_ = block
453
+ _add_cleaned(text)
385
454
  if accumulated_text and emit_on_paragraph:
386
455
  bbox = {'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2}
387
- metadata = DocumentSectionMetadata(page=page_number, bounding_box=bbox)
388
- yield DocumentSection(text=_emit_text(), metadata=metadata)
456
+ md = DocumentSectionMetadata(page=page_idx, bounding_box=bbox)
457
+ yield DocumentSection(text=_emit_text(), metadata=md)
389
458
 
390
459
  if accumulated_text and emit_on_page and not emit_on_paragraph:
391
- yield DocumentSection(text=_emit_text(), metadata=DocumentSectionMetadata(page=page_number))
392
- accumulated_text = []
460
+ md = DocumentSectionMetadata(page=page_idx)
461
+ yield DocumentSection(text=_emit_text(), image=img, metadata=md)
393
462
 
394
463
  if accumulated_text and not emit_on_page:
395
464
  yield DocumentSection(text=_emit_text(), metadata=DocumentSectionMetadata())