keras-hub-nightly 0.19.0.dev202502060348__py3-none-any.whl → 0.19.0.dev202502080344__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. keras_hub/api/__init__.py +0 -1
  2. keras_hub/api/layers/__init__.py +3 -1
  3. keras_hub/api/models/__init__.py +10 -4
  4. keras_hub/src/{models/retinanet → layers/modeling}/anchor_generator.py +11 -18
  5. keras_hub/src/{models/retinanet → layers/modeling}/box_matcher.py +17 -4
  6. keras_hub/src/{models/retinanet → layers/modeling}/non_max_supression.py +84 -32
  7. keras_hub/src/layers/preprocessing/image_converter.py +25 -3
  8. keras_hub/src/models/{image_object_detector.py → object_detector.py} +12 -7
  9. keras_hub/src/models/{image_object_detector_preprocessor.py → object_detector_preprocessor.py} +29 -13
  10. keras_hub/src/models/retinanet/retinanet_image_converter.py +8 -40
  11. keras_hub/src/models/retinanet/retinanet_label_encoder.py +18 -16
  12. keras_hub/src/models/retinanet/retinanet_object_detector.py +28 -28
  13. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +3 -3
  14. keras_hub/src/utils/tensor_utils.py +13 -0
  15. keras_hub/src/version_utils.py +1 -1
  16. {keras_hub_nightly-0.19.0.dev202502060348.dist-info → keras_hub_nightly-0.19.0.dev202502080344.dist-info}/METADATA +1 -1
  17. {keras_hub_nightly-0.19.0.dev202502060348.dist-info → keras_hub_nightly-0.19.0.dev202502080344.dist-info}/RECORD +19 -28
  18. keras_hub/api/bounding_box/__init__.py +0 -23
  19. keras_hub/src/bounding_box/__init__.py +0 -2
  20. keras_hub/src/bounding_box/converters.py +0 -606
  21. keras_hub/src/bounding_box/formats.py +0 -149
  22. keras_hub/src/bounding_box/iou.py +0 -251
  23. keras_hub/src/bounding_box/to_dense.py +0 -81
  24. keras_hub/src/bounding_box/to_ragged.py +0 -86
  25. keras_hub/src/bounding_box/utils.py +0 -181
  26. keras_hub/src/bounding_box/validate_format.py +0 -85
  27. {keras_hub_nightly-0.19.0.dev202502060348.dist-info → keras_hub_nightly-0.19.0.dev202502080344.dist-info}/WHEEL +0 -0
  28. {keras_hub_nightly-0.19.0.dev202502060348.dist-info → keras_hub_nightly-0.19.0.dev202502080344.dist-info}/top_level.txt +0 -0
@@ -1,606 +0,0 @@
1
- """Converter functions for working with bounding box formats."""
2
-
3
- import keras
4
- from keras import ops
5
-
6
- from keras_hub.src.api_export import keras_hub_export
7
-
8
- try:
9
- import tensorflow as tf
10
- except ImportError:
11
- tf = None
12
-
13
-
14
- # Internal exception to propagate the fact images was not passed to a converter
15
- # that needs it.
16
- class RequiresImagesException(Exception):
17
- pass
18
-
19
-
20
- ALL_AXES = 4
21
-
22
-
23
- def encode_box_to_deltas(
24
- anchors,
25
- boxes,
26
- anchor_format,
27
- box_format,
28
- encoding_format="center_yxhw",
29
- variance=None,
30
- image_shape=None,
31
- ):
32
- """Encodes bounding boxes relative to anchors as deltas.
33
-
34
- This function calculates the deltas that represent the difference between
35
- bounding boxes and provided anchors. Deltas encode the offsets and scaling
36
- factors to apply to anchors to obtain the target boxes.
37
-
38
- Boxes and anchors are first converted to the specified `encoding_format`
39
- (defaulting to `center_yxhw`) for consistent delta representation.
40
-
41
- Args:
42
- anchors: `Tensors`. Anchor boxes with shape of `(N, 4)` where N is the
43
- number of anchors.
44
- boxes: `Tensors` Bounding boxes to encode. Boxes can be of shape
45
- `(B, N, 4)` or `(N, 4)`.
46
- anchor_format: str. The format of the input `anchors`
47
- (e.g., "xyxy", "xywh", etc.).
48
- box_format: str. The format of the input `boxes`
49
- (e.g., "xyxy", "xywh", etc.).
50
- encoding_format: str. The intermediate format to which boxes and anchors
51
- are converted before delta calculation. Defaults to "center_yxhw".
52
- variance: `List[float]`. A 4-element array/tensor representing variance
53
- factors to scale the box deltas. If provided, the calculated deltas
54
- are divided by the variance. Defaults to None.
55
- image_shape: `Tuple[int]`. The shape of the image (height, width, 3).
56
- When using relative bounding box format for `box_format` the
57
- `image_shape` is used for normalization.
58
- Returns:
59
- Encoded box deltas. The return type matches the `encode_format`.
60
-
61
- Raises:
62
- ValueError: If `variance` is not None and its length is not 4.
63
- ValueError: If `encoding_format` is not `"center_xywh"` or
64
- `"center_yxhw"`.
65
-
66
- """
67
- if variance is not None:
68
- variance = ops.convert_to_tensor(variance, "float32")
69
- var_len = variance.shape[-1]
70
-
71
- if var_len != 4:
72
- raise ValueError(f"`variance` must be length 4, got {variance}")
73
-
74
- if encoding_format not in ["center_xywh", "center_yxhw"]:
75
- raise ValueError(
76
- "`encoding_format` should be one of 'center_xywh' or "
77
- f"'center_yxhw', got {encoding_format}"
78
- )
79
-
80
- encoded_anchors = convert_format(
81
- anchors,
82
- source=anchor_format,
83
- target=encoding_format,
84
- image_shape=image_shape,
85
- )
86
- boxes = convert_format(
87
- boxes,
88
- source=box_format,
89
- target=encoding_format,
90
- image_shape=image_shape,
91
- )
92
- anchor_dimensions = ops.maximum(
93
- encoded_anchors[..., 2:], keras.backend.epsilon()
94
- )
95
- box_dimensions = ops.maximum(boxes[..., 2:], keras.backend.epsilon())
96
- # anchors be unbatched, boxes can either be batched or unbatched.
97
- boxes_delta = ops.concatenate(
98
- [
99
- (boxes[..., :2] - encoded_anchors[..., :2]) / anchor_dimensions,
100
- ops.log(box_dimensions / anchor_dimensions),
101
- ],
102
- axis=-1,
103
- )
104
- if variance is not None:
105
- boxes_delta /= variance
106
- return boxes_delta
107
-
108
-
109
- def decode_deltas_to_boxes(
110
- anchors,
111
- boxes_delta,
112
- anchor_format,
113
- box_format,
114
- encoded_format="center_yxhw",
115
- variance=None,
116
- image_shape=None,
117
- ):
118
- """Converts bounding boxes from delta format to the specified `box_format`.
119
-
120
- This function decodes bounding box deltas relative to anchors to obtain the
121
- final bounding box coordinates. The boxes are encoded in a specific
122
- `encoded_format` (center_yxhw by default) during the decoding process.
123
- This allows flexibility in how the deltas are applied to the anchors.
124
-
125
- Args:
126
- anchors: Can be `Tensors` or `Dict[Tensors]` where keys are level
127
- indices and values are corresponding anchor boxes.
128
- The shape of the array/tensor should be `(N, 4)` where N is the
129
- number of anchors.
130
- boxes_delta Can be `Tensors` or `Dict[Tensors]` Bounding box deltas
131
- must have the same type and structure as `anchors`. The
132
- shape of the array/tensor can be `(N, 4)` or `(B, N, 4)` where N is
133
- the number of boxes.
134
- anchor_format: str. The format of the input `anchors`.
135
- (e.g., `"xyxy"`, `"xywh"`, etc.)
136
- box_format: str. The desired format for the output boxes.
137
- (e.g., `"xyxy"`, `"xywh"`, etc.)
138
- encoded_format: str. Raw output format from regression head. Defaults
139
- to `"center_yxhw"`.
140
- variance: `List[floats]`. A 4-element array/tensor representing
141
- variance factors to scale the box deltas. If provided, the deltas
142
- are multiplied by the variance before being applied to the anchors.
143
- Defaults to None.
144
- image_shape: The shape of the image (height, width). This is needed
145
- if normalization to image size is required when converting between
146
- formats. Defaults to None.
147
-
148
- Returns:
149
- Decoded box coordinates. The return type matches the `box_format`.
150
-
151
- Raises:
152
- ValueError: If `variance` is not None and its length is not 4.
153
- ValueError: If `encoded_format` is not `"center_xywh"` or
154
- `"center_yxhw"`.
155
-
156
- """
157
- if variance is not None:
158
- variance = ops.convert_to_tensor(variance, "float32")
159
- var_len = variance.shape[-1]
160
-
161
- if var_len != 4:
162
- raise ValueError(f"`variance` must be length 4, got {variance}")
163
-
164
- if encoded_format not in ["center_xywh", "center_yxhw"]:
165
- raise ValueError(
166
- f"`encoded_format` should be 'center_xywh' or 'center_yxhw', "
167
- f"but got '{encoded_format}'."
168
- )
169
-
170
- def decode_single_level(anchor, box_delta):
171
- encoded_anchor = convert_format(
172
- anchor,
173
- source=anchor_format,
174
- target=encoded_format,
175
- image_shape=image_shape,
176
- )
177
- if variance is not None:
178
- box_delta = box_delta * variance
179
- # anchors be unbatched, boxes can either be batched or unbatched.
180
- box = ops.concatenate(
181
- [
182
- box_delta[..., :2] * encoded_anchor[..., 2:]
183
- + encoded_anchor[..., :2],
184
- ops.exp(box_delta[..., 2:]) * encoded_anchor[..., 2:],
185
- ],
186
- axis=-1,
187
- )
188
- box = convert_format(
189
- box,
190
- source=encoded_format,
191
- target=box_format,
192
- image_shape=image_shape,
193
- )
194
- return box
195
-
196
- if isinstance(anchors, dict) and isinstance(boxes_delta, dict):
197
- boxes = {}
198
- for lvl, anchor in anchors.items():
199
- boxes[lvl] = decode_single_level(anchor, boxes_delta[lvl])
200
- return boxes
201
- else:
202
- return decode_single_level(anchors, boxes_delta)
203
-
204
-
205
- def _center_yxhw_to_xyxy(boxes, images=None, image_shape=None):
206
- y, x, height, width = ops.split(boxes, ALL_AXES, axis=-1)
207
- return ops.concatenate(
208
- [x - width / 2.0, y - height / 2.0, x + width / 2.0, y + height / 2.0],
209
- axis=-1,
210
- )
211
-
212
-
213
- def _center_xywh_to_xyxy(boxes, images=None, image_shape=None):
214
- x, y, width, height = ops.split(boxes, ALL_AXES, axis=-1)
215
- return ops.concatenate(
216
- [x - width / 2.0, y - height / 2.0, x + width / 2.0, y + height / 2.0],
217
- axis=-1,
218
- )
219
-
220
-
221
- def _xywh_to_xyxy(boxes, images=None, image_shape=None):
222
- x, y, width, height = ops.split(boxes, ALL_AXES, axis=-1)
223
- return ops.concatenate([x, y, x + width, y + height], axis=-1)
224
-
225
-
226
- def _xyxy_to_center_yxhw(boxes, images=None, image_shape=None):
227
- left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1)
228
- return ops.concatenate(
229
- [
230
- (top + bottom) / 2.0,
231
- (left + right) / 2.0,
232
- bottom - top,
233
- right - left,
234
- ],
235
- axis=-1,
236
- )
237
-
238
-
239
- def _rel_xywh_to_xyxy(boxes, images=None, image_shape=None):
240
- image_height, image_width = _image_shape(images, image_shape, boxes)
241
- x, y, width, height = ops.split(boxes, ALL_AXES, axis=-1)
242
- return ops.concatenate(
243
- [
244
- image_width * x,
245
- image_height * y,
246
- image_width * (x + width),
247
- image_height * (y + height),
248
- ],
249
- axis=-1,
250
- )
251
-
252
-
253
- def _xyxy_no_op(boxes, images=None, image_shape=None):
254
- return boxes
255
-
256
-
257
- def _xyxy_to_xywh(boxes, images=None, image_shape=None):
258
- left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1)
259
- return ops.concatenate(
260
- [left, top, right - left, bottom - top],
261
- axis=-1,
262
- )
263
-
264
-
265
- def _xyxy_to_rel_xywh(boxes, images=None, image_shape=None):
266
- image_height, image_width = _image_shape(images, image_shape, boxes)
267
- left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1)
268
- left, right = (
269
- left / image_width,
270
- right / image_width,
271
- )
272
- top, bottom = top / image_height, bottom / image_height
273
- return ops.concatenate(
274
- [left, top, right - left, bottom - top],
275
- axis=-1,
276
- )
277
-
278
-
279
- def _xyxy_to_center_xywh(boxes, images=None, image_shape=None):
280
- left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1)
281
- return ops.concatenate(
282
- [
283
- (left + right) / 2.0,
284
- (top + bottom) / 2.0,
285
- right - left,
286
- bottom - top,
287
- ],
288
- axis=-1,
289
- )
290
-
291
-
292
- def _rel_xyxy_to_xyxy(boxes, images=None, image_shape=None):
293
- image_height, image_width = _image_shape(images, image_shape, boxes)
294
- left, top, right, bottom = ops.split(
295
- boxes,
296
- ALL_AXES,
297
- axis=-1,
298
- )
299
- left, right = left * image_width, right * image_width
300
- top, bottom = top * image_height, bottom * image_height
301
- return ops.concatenate(
302
- [left, top, right, bottom],
303
- axis=-1,
304
- )
305
-
306
-
307
- def _xyxy_to_rel_xyxy(boxes, images=None, image_shape=None):
308
- image_height, image_width = _image_shape(images, image_shape, boxes)
309
- left, top, right, bottom = ops.split(
310
- boxes,
311
- ALL_AXES,
312
- axis=-1,
313
- )
314
- left, right = left / image_width, right / image_width
315
- top, bottom = top / image_height, bottom / image_height
316
- return ops.concatenate(
317
- [left, top, right, bottom],
318
- axis=-1,
319
- )
320
-
321
-
322
- def _yxyx_to_xyxy(boxes, images=None, image_shape=None):
323
- y1, x1, y2, x2 = ops.split(boxes, ALL_AXES, axis=-1)
324
- return ops.concatenate([x1, y1, x2, y2], axis=-1)
325
-
326
-
327
- def _rel_yxyx_to_xyxy(boxes, images=None, image_shape=None):
328
- image_height, image_width = _image_shape(images, image_shape, boxes)
329
- top, left, bottom, right = ops.split(
330
- boxes,
331
- ALL_AXES,
332
- axis=-1,
333
- )
334
- left, right = left * image_width, right * image_width
335
- top, bottom = top * image_height, bottom * image_height
336
- return ops.concatenate(
337
- [left, top, right, bottom],
338
- axis=-1,
339
- )
340
-
341
-
342
- def _xyxy_to_yxyx(boxes, images=None, image_shape=None):
343
- x1, y1, x2, y2 = ops.split(boxes, ALL_AXES, axis=-1)
344
- return ops.concatenate([y1, x1, y2, x2], axis=-1)
345
-
346
-
347
- def _xyxy_to_rel_yxyx(boxes, images=None, image_shape=None):
348
- image_height, image_width = _image_shape(images, image_shape, boxes)
349
- left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1)
350
- left, right = left / image_width, right / image_width
351
- top, bottom = top / image_height, bottom / image_height
352
- return ops.concatenate(
353
- [top, left, bottom, right],
354
- axis=-1,
355
- )
356
-
357
-
358
- TO_XYXY_CONVERTERS = {
359
- "xywh": _xywh_to_xyxy,
360
- "center_xywh": _center_xywh_to_xyxy,
361
- "center_yxhw": _center_yxhw_to_xyxy,
362
- "rel_xywh": _rel_xywh_to_xyxy,
363
- "xyxy": _xyxy_no_op,
364
- "rel_xyxy": _rel_xyxy_to_xyxy,
365
- "yxyx": _yxyx_to_xyxy,
366
- "rel_yxyx": _rel_yxyx_to_xyxy,
367
- }
368
-
369
- FROM_XYXY_CONVERTERS = {
370
- "xywh": _xyxy_to_xywh,
371
- "center_xywh": _xyxy_to_center_xywh,
372
- "center_yxhw": _xyxy_to_center_yxhw,
373
- "rel_xywh": _xyxy_to_rel_xywh,
374
- "xyxy": _xyxy_no_op,
375
- "rel_xyxy": _xyxy_to_rel_xyxy,
376
- "yxyx": _xyxy_to_yxyx,
377
- "rel_yxyx": _xyxy_to_rel_yxyx,
378
- }
379
-
380
-
381
- @keras_hub_export("keras_hub.bounding_box.convert_format")
382
- def convert_format(
383
- boxes, source, target, images=None, image_shape=None, dtype="float32"
384
- ):
385
- f"""Converts bounding_boxes from one format to another.
386
-
387
- Supported formats are:
388
- - `"xyxy"`, also known as `corners` format. In this format the first four
389
- axes represent `[left, top, right, bottom]` in that order.
390
- - `"rel_xyxy"`. In this format, the axes are the same as `"xyxy"` but the x
391
- coordinates are normalized using the image width, and the y axes the
392
- image height. All values in `rel_xyxy` are in the range `(0, 1)`.
393
- - `"xywh"`. In this format the first four axes represent
394
- `[left, top, width, height]`.
395
- - `"rel_xywh". In this format the first four axes represent
396
- [left, top, width, height], just like `"xywh"`. Unlike `"xywh"`, the
397
- values are in the range (0, 1) instead of absolute pixel values.
398
- - `"center_xyWH"`. In this format the first two coordinates represent the x
399
- and y coordinates of the center of the bounding box, while the last two
400
- represent the width and height of the bounding box.
401
- - `"center_yxHW"`. In this format the first two coordinates represent the y
402
- and x coordinates of the center of the bounding box, while the last two
403
- represent the height and width of the bounding box.
404
- - `"yxyx"`. In this format the first four axes represent
405
- [top, left, bottom, right] in that order.
406
- - `"rel_yxyx"`. In this format, the axes are the same as `"yxyx"` but the x
407
- coordinates are normalized using the image width, and the y axes the
408
- image height. All values in `rel_yxyx` are in the range (0, 1).
409
- Formats are case insensitive. It is recommended that you capitalize width
410
- and height to maximize the visual difference between `"xyWH"` and `"xyxy"`.
411
-
412
- Relative formats, abbreviated `rel`, make use of the shapes of the `images`
413
- passed. In these formats, the coordinates, widths, and heights are all
414
- specified as percentages of the host image. `images` may be a ragged
415
- Tensor. Note that using a ragged Tensor for images may cause a substantial
416
- performance loss, as each image will need to be processed separately due to
417
- the mismatching image shapes.
418
-
419
- Example:
420
-
421
- ```python
422
- boxes = load_coco_dataset()
423
- boxes_in_xywh = keras_hub.bounding_box.convert_format(
424
- boxes,
425
- source='xyxy',
426
- target='xyWH'
427
- )
428
- ```
429
-
430
- Args:
431
- boxes: tensor representing bounding boxes in the format specified in
432
- the `source` parameter. `boxes` can optionally have extra
433
- dimensions stacked on the final axis to store metadata. boxes
434
- should be a 3D tensor, with the shape `[batch_size, num_boxes, 4]`.
435
- Alternatively, boxes can be a dictionary with key 'boxes' containing
436
- a tensor matching the aforementioned spec.
437
- source:One of {" ".join([f'"{f}"' for f in TO_XYXY_CONVERTERS.keys()])}.
438
- Used to specify the original format of the `boxes` parameter.
439
- target:One of {" ".join([f'"{f}"' for f in TO_XYXY_CONVERTERS.keys()])}.
440
- Used to specify the destination format of the `boxes` parameter.
441
- images: (Optional) a batch of images aligned with `boxes` on the first
442
- axis. Should be at least 3 dimensions, with the first 3 dimensions
443
- representing: `[batch_size, height, width]`. Used in some
444
- converters to compute relative pixel values of the bounding box
445
- dimensions. Required when transforming from a rel format to a
446
- non-rel format.
447
- dtype: the data type to use when transforming the boxes, defaults to
448
- `"float32"`.
449
- """
450
- if isinstance(boxes, dict):
451
- converted_boxes = boxes.copy()
452
- converted_boxes["boxes"] = convert_format(
453
- boxes["boxes"],
454
- source=source,
455
- target=target,
456
- images=images,
457
- image_shape=image_shape,
458
- dtype=dtype,
459
- )
460
- return converted_boxes
461
-
462
- if boxes.shape[-1] is not None and boxes.shape[-1] != 4:
463
- raise ValueError(
464
- "Expected `boxes` to be a Tensor with a final dimension of "
465
- f"`4`. Instead, got `boxes.shape={boxes.shape}`."
466
- )
467
- if images is not None and image_shape is not None:
468
- raise ValueError(
469
- "convert_format() expects either `images` or `image_shape`, but "
470
- f"not both. Received images={images} image_shape={image_shape}"
471
- )
472
-
473
- _validate_image_shape(image_shape)
474
-
475
- source = source.lower()
476
- target = target.lower()
477
- if source not in TO_XYXY_CONVERTERS:
478
- raise ValueError(
479
- "`convert_format()` received an unsupported format for the "
480
- "argument `source`. `source` should be one of "
481
- f"{TO_XYXY_CONVERTERS.keys()}. Got source={source}"
482
- )
483
- if target not in FROM_XYXY_CONVERTERS:
484
- raise ValueError(
485
- "`convert_format()` received an unsupported format for the "
486
- "argument `target`. `target` should be one of "
487
- f"{FROM_XYXY_CONVERTERS.keys()}. Got target={target}"
488
- )
489
-
490
- boxes = ops.cast(boxes, dtype)
491
- if source == target:
492
- return boxes
493
-
494
- # rel->rel conversions should not require images
495
- if source.startswith("rel") and target.startswith("rel"):
496
- source = source.replace("rel_", "", 1)
497
- target = target.replace("rel_", "", 1)
498
-
499
- boxes, images, squeeze = _format_inputs(boxes, images)
500
- to_xyxy_fn = TO_XYXY_CONVERTERS[source]
501
- from_xyxy_fn = FROM_XYXY_CONVERTERS[target]
502
-
503
- try:
504
- in_xyxy = to_xyxy_fn(boxes, images=images, image_shape=image_shape)
505
- result = from_xyxy_fn(in_xyxy, images=images, image_shape=image_shape)
506
- except RequiresImagesException:
507
- raise ValueError(
508
- "convert_format() must receive `images` or `image_shape` when "
509
- "transforming between relative and absolute formats."
510
- f"convert_format() received source=`{format}`, target=`{format}, "
511
- f"but images={images} and image_shape={image_shape}."
512
- )
513
-
514
- return _format_outputs(result, squeeze)
515
-
516
-
517
- def _format_inputs(boxes, images):
518
- boxes_rank = len(boxes.shape)
519
- if boxes_rank > 3:
520
- raise ValueError(
521
- "Expected len(boxes.shape)=2, or len(boxes.shape)=3, got "
522
- f"len(boxes.shape)={boxes_rank}"
523
- )
524
- boxes_includes_batch = boxes_rank == 3
525
- # Determine if images needs an expand_dims() call
526
- if images is not None:
527
- images_rank = len(images.shape)
528
- if images_rank > 4:
529
- raise ValueError(
530
- "Expected len(images.shape)=2, or len(images.shape)=3, got "
531
- f"len(images.shape)={images_rank}"
532
- )
533
- images_include_batch = images_rank == 4
534
- if boxes_includes_batch != images_include_batch:
535
- raise ValueError(
536
- "convert_format() expects both boxes and images to be batched, "
537
- "or both boxes and images to be unbatched. Received "
538
- f"len(boxes.shape)={boxes_rank}, "
539
- f"len(images.shape)={images_rank}. Expected either "
540
- "len(boxes.shape)=2 AND len(images.shape)=3, or "
541
- "len(boxes.shape)=3 AND len(images.shape)=4."
542
- )
543
- if not images_include_batch:
544
- images = ops.expand_dims(images, axis=0)
545
-
546
- if not boxes_includes_batch:
547
- return ops.expand_dims(boxes, axis=0), images, True
548
- return boxes, images, False
549
-
550
-
551
- def _validate_image_shape(image_shape):
552
- # Escape early if image_shape is None and skip validation.
553
- if image_shape is None:
554
- return
555
- # tuple/list
556
- if isinstance(image_shape, (tuple, list)):
557
- if len(image_shape) != 3:
558
- raise ValueError(
559
- "image_shape should be of length 3, but got "
560
- f"image_shape={image_shape}"
561
- )
562
- return
563
-
564
- # tensor
565
- if ops.is_tensor(image_shape):
566
- if len(image_shape.shape) > 1:
567
- raise ValueError(
568
- "image_shape.shape should be (3), but got "
569
- f"image_shape.shape={image_shape.shape}"
570
- )
571
- if image_shape.shape[0] != 3:
572
- raise ValueError(
573
- "image_shape.shape should be (3), but got "
574
- f"image_shape.shape={image_shape.shape}"
575
- )
576
- return
577
-
578
- # Warn about failure cases
579
- raise ValueError(
580
- "Expected image_shape to be either a tuple, list, Tensor. "
581
- f"Received image_shape={image_shape}"
582
- )
583
-
584
-
585
- def _format_outputs(boxes, squeeze):
586
- if squeeze:
587
- return ops.squeeze(boxes, axis=0)
588
- return boxes
589
-
590
-
591
- def _image_shape(images, image_shape, boxes):
592
- if images is None and image_shape is None:
593
- raise RequiresImagesException()
594
-
595
- if image_shape is None:
596
- if not isinstance(images, tf.RaggedTensor):
597
- image_shape = ops.shape(images)
598
- height, width = image_shape[1], image_shape[2]
599
- else:
600
- height = ops.reshape(images.row_lengths(), (-1, 1))
601
- width = ops.reshape(ops.max(images.row_lengths(axis=2), 1), (-1, 1))
602
- height = ops.expand_dims(height, axis=-1)
603
- width = ops.expand_dims(width, axis=-1)
604
- else:
605
- height, width = image_shape[0], image_shape[1]
606
- return ops.cast(height, boxes.dtype), ops.cast(width, boxes.dtype)