pixeltable 0.2.15__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (52) hide show
  1. pixeltable/__version__.py +2 -2
  2. pixeltable/catalog/column.py +3 -0
  3. pixeltable/catalog/dir.py +1 -1
  4. pixeltable/catalog/globals.py +15 -6
  5. pixeltable/catalog/insertable_table.py +23 -8
  6. pixeltable/catalog/named_function.py +1 -1
  7. pixeltable/catalog/path_dict.py +4 -4
  8. pixeltable/catalog/schema_object.py +30 -18
  9. pixeltable/catalog/table.py +84 -99
  10. pixeltable/catalog/table_version.py +35 -24
  11. pixeltable/catalog/table_version_path.py +2 -2
  12. pixeltable/catalog/view.py +15 -8
  13. pixeltable/dataframe.py +56 -56
  14. pixeltable/env.py +6 -5
  15. pixeltable/exec/__init__.py +3 -3
  16. pixeltable/exec/aggregation_node.py +3 -3
  17. pixeltable/exec/expr_eval_node.py +3 -3
  18. pixeltable/exec/in_memory_data_node.py +4 -4
  19. pixeltable/exec/sql_node.py +4 -1
  20. pixeltable/exprs/array_slice.py +3 -4
  21. pixeltable/exprs/column_ref.py +20 -4
  22. pixeltable/exprs/comparison.py +11 -6
  23. pixeltable/exprs/data_row.py +3 -0
  24. pixeltable/exprs/expr.py +51 -23
  25. pixeltable/exprs/function_call.py +8 -1
  26. pixeltable/exprs/inline_array.py +2 -2
  27. pixeltable/exprs/json_path.py +36 -20
  28. pixeltable/exprs/row_builder.py +4 -4
  29. pixeltable/exprs/rowid_ref.py +1 -1
  30. pixeltable/functions/__init__.py +1 -2
  31. pixeltable/functions/audio.py +32 -0
  32. pixeltable/functions/huggingface.py +4 -4
  33. pixeltable/functions/image.py +1 -1
  34. pixeltable/functions/video.py +5 -1
  35. pixeltable/functions/vision.py +2 -6
  36. pixeltable/globals.py +57 -28
  37. pixeltable/io/external_store.py +4 -4
  38. pixeltable/io/globals.py +12 -13
  39. pixeltable/io/label_studio.py +6 -6
  40. pixeltable/io/pandas.py +27 -12
  41. pixeltable/io/parquet.py +14 -14
  42. pixeltable/iterators/document.py +7 -7
  43. pixeltable/plan.py +58 -29
  44. pixeltable/store.py +32 -31
  45. pixeltable/tool/create_test_db_dump.py +12 -6
  46. pixeltable/type_system.py +89 -97
  47. pixeltable/utils/pytorch.py +12 -10
  48. {pixeltable-0.2.15.dist-info → pixeltable-0.2.16.dist-info}/METADATA +10 -10
  49. {pixeltable-0.2.15.dist-info → pixeltable-0.2.16.dist-info}/RECORD +52 -51
  50. {pixeltable-0.2.15.dist-info → pixeltable-0.2.16.dist-info}/LICENSE +0 -0
  51. {pixeltable-0.2.15.dist-info → pixeltable-0.2.16.dist-info}/WHEEL +0 -0
  52. {pixeltable-0.2.15.dist-info → pixeltable-0.2.16.dist-info}/entry_points.txt +0 -0
pixeltable/exprs/expr.py CHANGED
@@ -7,7 +7,7 @@ import inspect
7
7
  import json
8
8
  import sys
9
9
  import typing
10
- from typing import Union, Optional, List, Callable, Any, Dict, Tuple, Set, Generator, Type
10
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union, overload
11
11
  from uuid import UUID
12
12
 
13
13
  import sqlalchemy as sql
@@ -17,8 +17,9 @@ import pixeltable.catalog as catalog
17
17
  import pixeltable.exceptions as excs
18
18
  import pixeltable.func as func
19
19
  import pixeltable.type_system as ts
20
+
20
21
  from .data_row import DataRow
21
- from .globals import ComparisonOperator, LogicalOperator, LiteralPythonTypes, ArithmeticOperator
22
+ from .globals import ArithmeticOperator, ComparisonOperator, LiteralPythonTypes, LogicalOperator
22
23
 
23
24
 
24
25
  class ExprScope:
@@ -191,8 +192,8 @@ class Expr(abc.ABC):
191
192
  Recursively replace ColRefs to unstored computed columns with their value exprs.
192
193
  Also replaces references to stored computed columns in resolve_cols.
193
194
  """
194
- from .expr_set import ExprSet
195
195
  from .column_ref import ColumnRef
196
+ from .expr_set import ExprSet
196
197
  if resolve_cols is None:
197
198
  resolve_cols = set()
198
199
  result = self
@@ -240,28 +241,64 @@ class Expr(abc.ABC):
240
241
  return str(self)
241
242
 
242
243
  @classmethod
243
- def print_list(cls, expr_list: List[Expr]) -> str:
244
+ def print_list(cls, expr_list: list[Any]) -> str:
244
245
  if len(expr_list) == 1:
245
246
  return str(expr_list[0])
246
- return f'({", ".join([str(e) for e in expr_list])})'
247
+ return f'({", ".join(str(e) for e in expr_list)})'
248
+
249
+ # `subexprs` has two forms: one that takes an explicit subclass of `Expr` as an argument and returns only
250
+ # instances of that subclass; and another that returns all subexpressions that match the given filter.
251
+ # In order for type checking to behave correctly on both forms, we provide two overloaded signatures.
252
+
253
+ T = TypeVar('T', bound='Expr')
254
+
255
+ @overload
256
+ def subexprs(
257
+ self, *, filter: Optional[Callable[[Expr], bool]] = None, traverse_matches: bool = True
258
+ ) -> Iterator[Expr]: ...
247
259
 
260
+ @overload
248
261
  def subexprs(
249
- self, expr_class: Optional[Type[Expr]] = None, filter: Optional[Callable[[Expr], bool]] = None,
250
- traverse_matches: bool = True
251
- ) -> Generator[Expr, None, None]:
262
+ self, expr_class: type[T], filter: Optional[Callable[[Expr], bool]] = None,
263
+ traverse_matches: bool = True
264
+ ) -> Iterator[T]: ...
265
+
266
+ def subexprs(
267
+ self, expr_class: Optional[type[T]] = None, filter: Optional[Callable[[Expr], bool]] = None,
268
+ traverse_matches: bool = True
269
+ ) -> Iterator[T]:
252
270
  """
253
271
  Iterate over all subexprs, including self.
254
272
  """
255
- assert expr_class is None or filter is None # at most one of them
256
- if expr_class is not None:
257
- filter = lambda e: isinstance(e, expr_class)
258
273
  is_match = filter is None or filter(self)
274
+ if expr_class is not None:
275
+ is_match = is_match and isinstance(self, expr_class)
259
276
  if not is_match or traverse_matches:
260
277
  for c in self.components:
261
- yield from c.subexprs(filter=filter, traverse_matches=traverse_matches)
278
+ yield from c.subexprs(expr_class=expr_class, filter=filter, traverse_matches=traverse_matches)
262
279
  if is_match:
263
280
  yield self
264
281
 
282
+ @overload
283
+ def list_subexprs(
284
+ expr_list: list[Expr], *, filter: Optional[Callable[[Expr], bool]] = None, traverse_matches: bool = True
285
+ ) -> Iterator[Expr]: ...
286
+
287
+ @overload
288
+ def list_subexprs(
289
+ expr_list: list[Expr], expr_class: type[T], filter: Optional[Callable[[Expr], bool]] = None,
290
+ traverse_matches: bool = True
291
+ ) -> Iterator[T]: ...
292
+
293
+ @classmethod
294
+ def list_subexprs(
295
+ cls, expr_list: list[Expr], expr_class: Optional[type[T]] = None,
296
+ filter: Optional[Callable[[Expr], bool]] = None, traverse_matches: bool = True
297
+ ) -> Iterator[T]:
298
+ """Produce subexprs for all exprs in list. Can contain duplicates."""
299
+ for e in expr_list:
300
+ yield from e.subexprs(expr_class=expr_class, filter=filter, traverse_matches=traverse_matches)
301
+
265
302
  def _contains(self, cls: Optional[Type[Expr]] = None, filter: Optional[Callable[[Expr], bool]] = None) -> bool:
266
303
  """
267
304
  Returns True if any subexpr is an instance of cls.
@@ -275,15 +312,6 @@ class Expr(abc.ABC):
275
312
  except StopIteration:
276
313
  return False
277
314
 
278
- @classmethod
279
- def list_subexprs(
280
- cls, expr_list: List[Expr], expr_class: Optional[Type[Expr]] = None,
281
- filter: Optional[Callable[[Expr], bool]] = None, traverse_matches: bool = True
282
- ) -> Generator[Expr, None, None]:
283
- """Produce subexprs for all exprs in list. Can contain duplicates."""
284
- for e in expr_list:
285
- yield from e.subexprs(expr_class=expr_class, filter=filter, traverse_matches=traverse_matches)
286
-
287
315
  def tbl_ids(self) -> Set[UUID]:
288
316
  """Returns table ids referenced by this expr."""
289
317
  from .column_ref import ColumnRef
@@ -334,10 +362,10 @@ class Expr(abc.ABC):
334
362
  return None
335
363
 
336
364
  @abc.abstractmethod
337
- def sql_expr(self) -> Optional[sql.ClauseElement]:
365
+ def sql_expr(self) -> Optional[sql.ColumnElement]:
338
366
  """
339
367
  If this expr can be materialized directly in SQL:
340
- - returns a ClauseElement
368
+ - returns a ColumnElement
341
369
  - eval() will not be called (exception: Literal)
342
370
  Otherwise
343
371
  - returns None
@@ -204,7 +204,14 @@ class FunctionCall(Expr):
204
204
  # Check that the argument is consistent with the expected parameter type, with the allowance that
205
205
  # non-nullable parameters can still accept nullable arguments (since function calls with Nones
206
206
  # assigned to non-nullable parameters will always return None)
207
- if not param.col_type.is_supertype_of(arg.col_type, ignore_nullable=True):
207
+ if not (
208
+ param.col_type.is_supertype_of(arg.col_type, ignore_nullable=True)
209
+ # TODO: this is a hack to allow JSON columns to be passed to functions that accept scalar
210
+ # types. It's necessary to avoid littering notebooks with `apply(str)` calls or equivalent.
211
+ # (Previously, this wasn't necessary because `is_supertype_of()` was improperly implemented.)
212
+ # We need to think through the right way to handle this scenario.
213
+ or (arg.col_type.is_json_type() and param.col_type.is_scalar_type())
214
+ ):
208
215
  raise excs.Error(
209
216
  f'Parameter {param_name}: argument type {arg.col_type} does not match parameter type '
210
217
  f'{param.col_type}')
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import copy
4
- from typing import Optional, List, Any, Dict, Tuple
4
+ from typing import Iterable, Optional, List, Any, Dict, Tuple
5
5
 
6
6
  import numpy as np
7
7
  import sqlalchemy as sql
@@ -24,7 +24,7 @@ class InlineArray(Expr):
24
24
 
25
25
  elements: List[Tuple[Optional[int], Any]]
26
26
 
27
- def __init__(self, elements: Tuple, force_json: bool = False):
27
+ def __init__(self, elements: Iterable, force_json: bool = False):
28
28
  # we need to call this in order to populate self.components
29
29
  super().__init__(ts.ArrayType((len(elements),), ts.IntType()))
30
30
 
@@ -1,22 +1,29 @@
1
1
  from __future__ import annotations
2
- from typing import Optional, List, Any, Dict, Tuple, Union
2
+
3
+ from typing import Any, Optional, Union
3
4
 
4
5
  import jmespath
5
6
  import sqlalchemy as sql
6
7
 
7
- from .globals import print_slice
8
- from .expr import Expr
9
- from .json_mapper import JsonMapper
10
- from .data_row import DataRow
11
- from .row_builder import RowBuilder
12
8
  import pixeltable
13
- import pixeltable.exceptions as excs
14
9
  import pixeltable.catalog as catalog
10
+ import pixeltable.exceptions as excs
15
11
  import pixeltable.type_system as ts
16
12
 
13
+ from .data_row import DataRow
14
+ from .expr import Expr
15
+ from .globals import print_slice
16
+ from .json_mapper import JsonMapper
17
+ from .row_builder import RowBuilder
18
+
17
19
 
18
20
  class JsonPath(Expr):
19
- def __init__(self, anchor: Optional['pixeltable.exprs.ColumnRef'], path_elements: Optional[List[str]] = None, scope_idx: int = 0):
21
+ def __init__(
22
+ self,
23
+ anchor: Optional['pixeltable.exprs.ColumnRef'],
24
+ path_elements: Optional[list[Union[str, int, slice]]] = None,
25
+ scope_idx: int = 0
26
+ ) -> None:
20
27
  """
21
28
  anchor can be None, in which case this is a relative JsonPath and the anchor is set later via set_anchor().
22
29
  scope_idx: for relative paths, index of referenced JsonMapper
@@ -27,7 +34,7 @@ class JsonPath(Expr):
27
34
  super().__init__(ts.JsonType())
28
35
  if anchor is not None:
29
36
  self.components = [anchor]
30
- self.path_elements: List[Union[str, int]] = path_elements
37
+ self.path_elements: list[Union[str, int, slice]] = path_elements
31
38
  self.compiled_path = jmespath.compile(self._json_path()) if len(path_elements) > 0 else None
32
39
  self.scope_idx = scope_idx
33
40
  # NOTE: the _create_id() result will change if set_anchor() gets called;
@@ -39,16 +46,26 @@ class JsonPath(Expr):
39
46
  return (f'{str(self._anchor) if self._anchor is not None else "R"}'
40
47
  f'{"." if isinstance(self.path_elements[0], str) else ""}{self._json_path()}')
41
48
 
42
- def _as_dict(self) -> Dict:
43
- return {'path_elements': self.path_elements, 'scope_idx': self.scope_idx, **super()._as_dict()}
49
+ def _as_dict(self) -> dict:
50
+ path_elements = [
51
+ [el.start, el.stop, el.step] if isinstance(el, slice)
52
+ else el
53
+ for el in self.path_elements
54
+ ]
55
+ return {'path_elements': path_elements, 'scope_idx': self.scope_idx, **super()._as_dict()}
44
56
 
45
57
  @classmethod
46
- def _from_dict(cls, d: Dict, components: List[Expr]) -> Expr:
58
+ def _from_dict(cls, d: dict, components: list[Expr]) -> Expr:
47
59
  assert 'path_elements' in d
48
60
  assert 'scope_idx' in d
49
61
  assert len(components) <= 1
50
62
  anchor = components[0] if len(components) == 1 else None
51
- return cls(anchor, d['path_elements'], d['scope_idx'])
63
+ path_elements = [
64
+ slice(el[0], el[1], el[2]) if isinstance(el, list)
65
+ else el
66
+ for el in d['path_elements']
67
+ ]
68
+ return cls(anchor, path_elements, d['scope_idx'])
52
69
 
53
70
  @property
54
71
  def _anchor(self) -> Optional[Expr]:
@@ -85,8 +102,7 @@ class JsonPath(Expr):
85
102
  if isinstance(index, str):
86
103
  if index != '*':
87
104
  raise excs.Error(f'Invalid json list index: {index}')
88
- else:
89
- if not isinstance(index, slice) and not isinstance(index, int):
105
+ elif not isinstance(index, (int, slice)):
90
106
  raise excs.Error(f'Invalid json list index: {index}')
91
107
  return JsonPath(self._anchor, self.path_elements + [index])
92
108
 
@@ -99,7 +115,7 @@ class JsonPath(Expr):
99
115
  def default_column_name(self) -> Optional[str]:
100
116
  anchor_name = self._anchor.default_column_name() if self._anchor is not None else ''
101
117
  ret_name = f'{anchor_name}.{self._json_path()}'
102
-
118
+
103
119
  def cleanup_char(s : str) -> str:
104
120
  if s == '.':
105
121
  return '_'
@@ -109,19 +125,19 @@ class JsonPath(Expr):
109
125
  return s
110
126
  else:
111
127
  return ''
112
-
128
+
113
129
  clean_name = ''.join(map(cleanup_char, ret_name))
114
130
  clean_name = clean_name.lstrip('_') # remove leading underscore
115
131
  if clean_name == '':
116
132
  clean_name = None
117
-
133
+
118
134
  assert clean_name is None or catalog.is_valid_identifier(clean_name)
119
135
  return clean_name
120
136
 
121
137
  def _equals(self, other: JsonPath) -> bool:
122
138
  return self.path_elements == other.path_elements
123
139
 
124
- def _id_attrs(self) -> List[Tuple[str, Any]]:
140
+ def _id_attrs(self) -> list[tuple[str, Any]]:
125
141
  return super()._id_attrs() + [('path_elements', self.path_elements)]
126
142
 
127
143
  def sql_expr(self) -> Optional[sql.ClauseElement]:
@@ -137,7 +153,7 @@ class JsonPath(Expr):
137
153
 
138
154
  def _json_path(self) -> str:
139
155
  assert len(self.path_elements) > 0
140
- result: List[str] = []
156
+ result: list[str] = []
141
157
  for element in self.path_elements:
142
158
  if element == '*':
143
159
  result.append('[*]')
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import sys
4
4
  import time
5
5
  from dataclasses import dataclass
6
- from typing import Optional, List, Any, Dict, Tuple, Set
6
+ from typing import Optional, List, Any, Dict, Sequence, Tuple, Set
7
7
 
8
8
  import sqlalchemy as sql
9
9
 
@@ -57,7 +57,7 @@ class RowBuilder:
57
57
  target_exprs: List[Expr] # exprs corresponding to target_slot_idxs
58
58
 
59
59
  def __init__(
60
- self, output_exprs: List[Expr], columns: List[catalog.Column], input_exprs: List[Expr]
60
+ self, output_exprs: Sequence[Expr], columns: Sequence[catalog.Column], input_exprs: Sequence[Expr]
61
61
  ):
62
62
  """
63
63
  Args:
@@ -227,10 +227,10 @@ class RowBuilder:
227
227
  # merge dependencies and convert to list
228
228
  return sorted(set().union(*[dependencies[i] for i in target_slot_idxs]))
229
229
 
230
- def substitute_exprs(self, expr_list: List[Expr], remove_duplicates: bool = True) -> None:
230
+ def substitute_exprs(self, expr_list: list, remove_duplicates: bool = True) -> None:
231
231
  """Substitutes exprs with their executable counterparts from unique_exprs and optionally removes duplicates"""
232
232
  i = 0
233
- unique_ids: Set[i] = set() # slot idxs within expr_list
233
+ unique_ids: set[int] = set() # slot idxs within expr_list
234
234
  while i < len(expr_list):
235
235
  unique_expr = self.unique_exprs[expr_list[i]]
236
236
  if unique_expr.slot_idx in unique_ids and remove_duplicates:
@@ -56,7 +56,7 @@ class RowidRef(Expr):
56
56
  # check if this is the pos column of a component view
57
57
  tbl = self.tbl if self.tbl is not None else catalog.Catalog.get().tbl_versions[(self.tbl_id, None)]
58
58
  if tbl.is_component_view() and self.rowid_component_idx == tbl.store_tbl.pos_col_idx:
59
- return catalog.globals.POS_COLUMN_NAME
59
+ return catalog.globals._POS_COLUMN_NAME
60
60
  return ''
61
61
 
62
62
  def set_tbl(self, tbl: catalog.TableVersionPath) -> None:
@@ -1,8 +1,7 @@
1
- from . import fireworks, huggingface, image, openai, string, together, video, timestamp, json, vision
1
+ from . import audio, fireworks, huggingface, image, json, openai, string, timestamp, together, video, vision
2
2
  from .globals import *
3
3
  from pixeltable.utils.code import local_public_names
4
4
 
5
-
6
5
  __all__ = local_public_names(__name__, exclude=['globals']) + local_public_names(globals.__name__)
7
6
 
8
7
 
@@ -0,0 +1,32 @@
1
+ """
2
+ Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs) for `AudioType`.
3
+
4
+ Example:
5
+ ```python
6
+ import pixeltable as pxt
7
+ import pixeltable.functions as pxtf
8
+
9
+ t = pxt.get_table(...)
10
+ t.select(pxtf.audio.get_metadata()).collect()
11
+ ```
12
+ """
13
+
14
+ import pixeltable.func as func
15
+ import pixeltable.type_system as ts
16
+ from pixeltable.utils.code import local_public_names
17
+
18
+
19
+ @func.udf(return_type=ts.JsonType(nullable=False), param_types=[ts.AudioType(nullable=False)], is_method=True)
20
+ def get_metadata(audio: str) -> dict:
21
+ """
22
+ Gets various metadata associated with an audio file and returns it as a dictionary.
23
+ """
24
+ import pixeltable.functions as pxtf
25
+ return pxtf.video._get_metadata(audio)
26
+
27
+
28
+ __all__ = local_public_names(__name__)
29
+
30
+
31
+ def __dir__():
32
+ return __all__
@@ -50,7 +50,7 @@ def sentence_transformer(
50
50
  >>> tbl['result'] = sentence_transformer(tbl.sentence, model_id='all-mpnet-base-v2')
51
51
  """
52
52
  env.Env.get().require_package('sentence_transformers')
53
- from sentence_transformers import SentenceTransformer
53
+ from sentence_transformers import SentenceTransformer # type: ignore
54
54
 
55
55
  model = _lookup_model(model_id, SentenceTransformer)
56
56
 
@@ -154,7 +154,7 @@ def clip_text(text: Batch[str], *, model_id: str) -> Batch[np.ndarray]:
154
154
  env.Env.get().require_package('transformers')
155
155
  device = resolve_torch_device('auto')
156
156
  import torch
157
- from transformers import CLIPModel, CLIPProcessor
157
+ from transformers import CLIPModel, CLIPProcessor # type: ignore
158
158
 
159
159
  model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
160
160
  processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
@@ -317,9 +317,9 @@ def _lookup_model(model_id: str, create: Callable[[str], T], device: Optional[st
317
317
  key = (model_id, create, device) # For safety, include the `create` callable in the cache key
318
318
  if key not in _model_cache:
319
319
  model = create(model_id)
320
- if device is not None:
321
- model.to(device)
322
320
  if isinstance(model, nn.Module):
321
+ if device is not None:
322
+ model.to(device)
323
323
  model.eval()
324
324
  _model_cache[key] = model
325
325
  return _model_cache[key]
@@ -144,7 +144,7 @@ def resize(self: PIL.Image.Image, size: tuple[int, int]) -> PIL.Image.Image:
144
144
  Equivalent to
145
145
  [`PIL.Image.Image.resize()`](https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.resize)
146
146
  """
147
- return self.resize(size)
147
+ return self.resize(tuple(size))
148
148
 
149
149
 
150
150
  @resize.conditional_return_type
@@ -133,7 +133,11 @@ def get_metadata(video: str) -> dict:
133
133
  """
134
134
  Gets various metadata associated with a video file and returns it as a dictionary.
135
135
  """
136
- with av.open(video) as container:
136
+ return _get_metadata(video)
137
+
138
+
139
+ def _get_metadata(path: str) -> dict:
140
+ with av.open(path) as container:
137
141
  assert isinstance(container, av.container.InputContainer)
138
142
  streams_info = [__get_stream_metadata(stream) for stream in container.streams]
139
143
  result = {
@@ -274,7 +274,6 @@ def draw_bounding_boxes(
274
274
  boxes: list[list[int]],
275
275
  labels: Optional[list[Any]] = None,
276
276
  color: Optional[str] = None,
277
- label_colors: Optional[dict[Union[str, int], str]] = None,
278
277
  box_colors: Optional[list[str]] = None,
279
278
  fill: bool = False,
280
279
  width: int = 1,
@@ -297,7 +296,6 @@ def draw_bounding_boxes(
297
296
  boxes: List of bounding boxes, each represented as [xmin, ymin, xmax, ymax].
298
297
  labels: List of labels for each bounding box.
299
298
  color: Single color to be used for all bounding boxes and labels.
300
- label_colors: Dictionary mapping labels to colors.
301
299
  box_colors: List of colors, one per bounding box.
302
300
  fill: Whether to fill the bounding boxes with color.
303
301
  width: Width of the bounding box borders.
@@ -310,9 +308,9 @@ def draw_bounding_boxes(
310
308
  Returns:
311
309
  The image with bounding boxes drawn on it.
312
310
  """
313
- color_params = sum([color is not None, label_colors is not None, box_colors is not None])
311
+ color_params = sum([color is not None, box_colors is not None])
314
312
  if color_params > 1:
315
- raise ValueError("Only one of 'color', 'label_colors', or 'box_colors' can be set")
313
+ raise ValueError("Only one of 'color' or 'box_colors' can be set")
316
314
 
317
315
  # ensure the number of labels matches the number of boxes
318
316
  num_boxes = len(boxes)
@@ -328,8 +326,6 @@ def draw_bounding_boxes(
328
326
  else:
329
327
  if color is not None:
330
328
  box_colors = [color] * num_boxes
331
- elif label_colors is not None:
332
- box_colors = [label_colors.get(label, DEFAULT_COLOR) for label in labels]
333
329
  else:
334
330
  label_colors = _create_label_colors(labels)
335
331
  box_colors = [label_colors[label] for label in labels]