tf-models-nightly 2.17.0.dev20240312__py2.py3-none-any.whl → 2.17.0.dev20240314__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,6 +16,7 @@
16
16
 
17
17
  import math
18
18
  from typing import Optional, Sequence, Tuple, Union
19
+
19
20
  from six.moves import range
20
21
  import tensorflow as tf, tf_keras
21
22
 
@@ -64,8 +65,9 @@ def clip_or_pad_to_fixed_size(input_tensor, size, constant_values=0):
64
65
  padding_shape.append(tf.shape(input_tensor)[i])
65
66
 
66
67
  # Pads input tensor to the fixed first dimension.
67
- paddings = tf.cast(constant_values * tf.ones(padding_shape),
68
- input_tensor.dtype)
68
+ paddings = tf.cast(
69
+ constant_values * tf.ones(padding_shape), input_tensor.dtype
70
+ )
69
71
  padded_tensor = tf.concat([input_tensor, paddings], axis=0)
70
72
  output_shape = input_shape
71
73
  output_shape[0] = size
@@ -73,9 +75,11 @@ def clip_or_pad_to_fixed_size(input_tensor, size, constant_values=0):
73
75
  return padded_tensor
74
76
 
75
77
 
76
- def normalize_image(image: tf.Tensor,
77
- offset: Sequence[float] = MEAN_NORM,
78
- scale: Sequence[float] = STDDEV_NORM) -> tf.Tensor:
78
+ def normalize_image(
79
+ image: tf.Tensor,
80
+ offset: Sequence[float] = MEAN_NORM,
81
+ scale: Sequence[float] = STDDEV_NORM,
82
+ ) -> tf.Tensor:
79
83
  """Normalizes the image to zero mean and unit variance.
80
84
 
81
85
  If the input image dtype is float, it is expected to either have values in
@@ -96,9 +100,11 @@ def normalize_image(image: tf.Tensor,
96
100
  return normalize_scaled_float_image(image, offset, scale)
97
101
 
98
102
 
99
- def normalize_scaled_float_image(image: tf.Tensor,
100
- offset: Sequence[float] = MEAN_NORM,
101
- scale: Sequence[float] = STDDEV_NORM):
103
+ def normalize_scaled_float_image(
104
+ image: tf.Tensor,
105
+ offset: Sequence[float] = MEAN_NORM,
106
+ scale: Sequence[float] = STDDEV_NORM,
107
+ ):
102
108
  """Normalizes a scaled float image to zero mean and unit variance.
103
109
 
104
110
  It assumes the input image is float dtype with values in [0, 1) if offset is
@@ -142,24 +148,28 @@ def compute_padded_size(desired_size, stride):
142
148
  [height, width] of the padded output image size.
143
149
  """
144
150
  if isinstance(desired_size, list) or isinstance(desired_size, tuple):
145
- padded_size = [int(math.ceil(d * 1.0 / stride) * stride)
146
- for d in desired_size]
151
+ padded_size = [
152
+ int(math.ceil(d * 1.0 / stride) * stride) for d in desired_size
153
+ ]
147
154
  else:
148
155
  padded_size = tf.cast(
149
- tf.math.ceil(
150
- tf.cast(desired_size, dtype=tf.float32) / stride) * stride,
151
- tf.int32)
156
+ tf.math.ceil(tf.cast(desired_size, dtype=tf.float32) / stride) * stride,
157
+ tf.int32,
158
+ )
152
159
  return padded_size
153
160
 
154
161
 
155
- def resize_and_crop_image(image,
156
- desired_size,
157
- padded_size,
158
- aug_scale_min=1.0,
159
- aug_scale_max=1.0,
160
- seed=1,
161
- method=tf.image.ResizeMethod.BILINEAR,
162
- keep_aspect_ratio=True):
162
+ def resize_and_crop_image(
163
+ image,
164
+ desired_size,
165
+ padded_size,
166
+ aug_scale_min=1.0,
167
+ aug_scale_max=1.0,
168
+ seed=1,
169
+ method=tf.image.ResizeMethod.BILINEAR,
170
+ keep_aspect_ratio=True,
171
+ centered_crop=False,
172
+ ):
163
173
  """Resizes the input image to output size (RetinaNet style).
164
174
 
165
175
  Resize and pad images given the desired output size of the image and
@@ -186,6 +196,9 @@ def resize_and_crop_image(image,
186
196
  seed: seed for random scale jittering.
187
197
  method: function to resize input image to scaled image.
188
198
  keep_aspect_ratio: whether or not to keep the aspect ratio when resizing.
199
+ centered_crop: If `centered_crop` is set to True, then resized crop (if
200
+ smaller than padded size) is place in the center of the image. Default
201
+ behaviour is to place it at left top corner.
189
202
 
190
203
  Returns:
191
204
  output_image: `Tensor` of shape [height, width, 3] where [height, width]
@@ -210,14 +223,16 @@ def resize_and_crop_image(image,
210
223
 
211
224
  if random_jittering:
212
225
  random_scale = tf.random.uniform(
213
- [], aug_scale_min, aug_scale_max, seed=seed)
226
+ [], aug_scale_min, aug_scale_max, seed=seed
227
+ )
214
228
  scaled_size = tf.round(random_scale * tf.cast(desired_size, tf.float32))
215
229
  else:
216
230
  scaled_size = tf.cast(desired_size, tf.float32)
217
231
 
218
232
  if keep_aspect_ratio:
219
233
  scale = tf.minimum(
220
- scaled_size[0] / image_size[0], scaled_size[1] / image_size[1])
234
+ scaled_size[0] / image_size[0], scaled_size[1] / image_size[1]
235
+ )
221
236
  scaled_size = tf.round(image_size * scale)
222
237
 
223
238
  # Computes 2D image_scale.
@@ -228,41 +243,66 @@ def resize_and_crop_image(image,
228
243
  if random_jittering:
229
244
  max_offset = scaled_size - tf.cast(desired_size, tf.float32)
230
245
  max_offset = tf.where(
231
- tf.less(max_offset, 0), tf.zeros_like(max_offset), max_offset)
232
- offset = max_offset * tf.random.uniform([2,], 0, 1, seed=seed)
246
+ tf.less(max_offset, 0), tf.zeros_like(max_offset), max_offset
247
+ )
248
+ offset = max_offset * tf.random.uniform(
249
+ [
250
+ 2,
251
+ ],
252
+ 0,
253
+ 1,
254
+ seed=seed,
255
+ )
233
256
  offset = tf.cast(offset, tf.int32)
234
257
  else:
235
258
  offset = tf.zeros((2,), tf.int32)
236
259
 
237
260
  scaled_image = tf.image.resize(
238
- image, tf.cast(scaled_size, tf.int32), method=method)
261
+ image, tf.cast(scaled_size, tf.int32), method=method
262
+ )
239
263
 
240
264
  if random_jittering:
241
265
  scaled_image = scaled_image[
242
- offset[0]:offset[0] + desired_size[0],
243
- offset[1]:offset[1] + desired_size[1], :]
266
+ offset[0] : offset[0] + desired_size[0],
267
+ offset[1] : offset[1] + desired_size[1],
268
+ :,
269
+ ]
244
270
 
245
271
  output_image = scaled_image
246
272
  if padded_size is not None:
247
- output_image = tf.image.pad_to_bounding_box(
248
- scaled_image, 0, 0, padded_size[0], padded_size[1])
273
+ if centered_crop:
274
+ scaled_image_size = tf.cast(tf.shape(scaled_image)[0:2], tf.int32)
275
+ output_image = tf.image.pad_to_bounding_box(
276
+ scaled_image,
277
+ tf.maximum((padded_size[0] - scaled_image_size[0]) // 2, 0),
278
+ tf.maximum((padded_size[1] - scaled_image_size[1]) // 2, 0),
279
+ padded_size[0],
280
+ padded_size[1],
281
+ )
282
+ else:
283
+ output_image = tf.image.pad_to_bounding_box(
284
+ scaled_image, 0, 0, padded_size[0], padded_size[1]
285
+ )
249
286
 
250
287
  image_info = tf.stack([
251
288
  image_size,
252
289
  tf.cast(desired_size, dtype=tf.float32),
253
290
  image_scale,
254
- tf.cast(offset, tf.float32)])
291
+ tf.cast(offset, tf.float32),
292
+ ])
255
293
  return output_image, image_info
256
294
 
257
295
 
258
- def resize_and_crop_image_v2(image,
259
- short_side,
260
- long_side,
261
- padded_size,
262
- aug_scale_min=1.0,
263
- aug_scale_max=1.0,
264
- seed=1,
265
- method=tf.image.ResizeMethod.BILINEAR):
296
+ def resize_and_crop_image_v2(
297
+ image,
298
+ short_side,
299
+ long_side,
300
+ padded_size,
301
+ aug_scale_min=1.0,
302
+ aug_scale_max=1.0,
303
+ seed=1,
304
+ method=tf.image.ResizeMethod.BILINEAR,
305
+ ):
266
306
  """Resizes the input image to output size (Faster R-CNN style).
267
307
 
268
308
  Resize and pad images given the specified short / long side length and the
@@ -306,17 +346,21 @@ def resize_and_crop_image_v2(image,
306
346
  with tf.name_scope('resize_and_crop_image_v2'):
307
347
  image_size = tf.cast(tf.shape(image)[0:2], tf.float32)
308
348
 
309
- scale_using_short_side = (
310
- short_side / tf.math.minimum(image_size[0], image_size[1]))
311
- scale_using_long_side = (
312
- long_side / tf.math.maximum(image_size[0], image_size[1]))
349
+ scale_using_short_side = short_side / tf.math.minimum(
350
+ image_size[0], image_size[1]
351
+ )
352
+ scale_using_long_side = long_side / tf.math.maximum(
353
+ image_size[0], image_size[1]
354
+ )
313
355
 
314
356
  scaled_size = tf.math.round(image_size * scale_using_short_side)
315
357
  scaled_size = tf.where(
316
358
  tf.math.greater(
317
- tf.math.maximum(scaled_size[0], scaled_size[1]), long_side),
359
+ tf.math.maximum(scaled_size[0], scaled_size[1]), long_side
360
+ ),
318
361
  tf.math.round(image_size * scale_using_long_side),
319
- scaled_size)
362
+ scaled_size,
363
+ )
320
364
  desired_size = scaled_size
321
365
 
322
366
  random_jittering = (
@@ -328,7 +372,8 @@ def resize_and_crop_image_v2(image,
328
372
 
329
373
  if random_jittering:
330
374
  random_scale = tf.random.uniform(
331
- [], aug_scale_min, aug_scale_max, seed=seed)
375
+ [], aug_scale_min, aug_scale_max, seed=seed
376
+ )
332
377
  scaled_size = tf.math.round(random_scale * scaled_size)
333
378
 
334
379
  # Computes 2D image_scale.
@@ -339,28 +384,41 @@ def resize_and_crop_image_v2(image,
339
384
  if random_jittering:
340
385
  max_offset = scaled_size - desired_size
341
386
  max_offset = tf.where(
342
- tf.math.less(max_offset, 0), tf.zeros_like(max_offset), max_offset)
343
- offset = max_offset * tf.random.uniform([2,], 0, 1, seed=seed)
387
+ tf.math.less(max_offset, 0), tf.zeros_like(max_offset), max_offset
388
+ )
389
+ offset = max_offset * tf.random.uniform(
390
+ [
391
+ 2,
392
+ ],
393
+ 0,
394
+ 1,
395
+ seed=seed,
396
+ )
344
397
  offset = tf.cast(offset, tf.int32)
345
398
  else:
346
399
  offset = tf.zeros((2,), tf.int32)
347
400
 
348
401
  scaled_image = tf.image.resize(
349
- image, tf.cast(scaled_size, tf.int32), method=method)
402
+ image, tf.cast(scaled_size, tf.int32), method=method
403
+ )
350
404
 
351
405
  if random_jittering:
352
406
  scaled_image = scaled_image[
353
- offset[0]:offset[0] + desired_size[0],
354
- offset[1]:offset[1] + desired_size[1], :]
407
+ offset[0] : offset[0] + desired_size[0],
408
+ offset[1] : offset[1] + desired_size[1],
409
+ :,
410
+ ]
355
411
 
356
412
  output_image = tf.image.pad_to_bounding_box(
357
- scaled_image, 0, 0, padded_size[0], padded_size[1])
413
+ scaled_image, 0, 0, padded_size[0], padded_size[1]
414
+ )
358
415
 
359
416
  image_info = tf.stack([
360
417
  image_size,
361
418
  tf.cast(desired_size, dtype=tf.float32),
362
419
  image_scale,
363
- tf.cast(offset, tf.float32)])
420
+ tf.cast(offset, tf.float32),
421
+ ])
364
422
  return output_image, image_info
365
423
 
366
424
 
@@ -368,13 +426,14 @@ def resize_image(
368
426
  image: tf.Tensor,
369
427
  size: Union[Tuple[int, int], int],
370
428
  max_size: Optional[int] = None,
371
- method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR):
429
+ method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR,
430
+ ):
372
431
  """Resize image with size and max_size.
373
432
 
374
433
  Args:
375
434
  image: the image to be resized.
376
- size: if list to tuple, resize to it. If scalar, we keep the same
377
- aspect ratio and resize the short side to the value.
435
+ size: if list to tuple, resize to it. If scalar, we keep the same aspect
436
+ ratio and resize the short side to the value.
378
437
  max_size: only used when size is a scalar. When the larger side is larger
379
438
  than max_size after resized with size we used max_size to keep the aspect
380
439
  ratio instead.
@@ -400,7 +459,8 @@ def resize_image(
400
459
  if max_original_size / min_original_size * size > max_size:
401
460
  size = tf.cast(
402
461
  tf.math.floor(max_size * min_original_size / max_original_size),
403
- dtype=tf.int32)
462
+ dtype=tf.int32,
463
+ )
404
464
  else:
405
465
  size = tf.cast(size, tf.int32)
406
466
 
@@ -412,15 +472,23 @@ def resize_image(
412
472
  if w < h:
413
473
  ow = size
414
474
  oh = tf.cast(
415
- (tf.cast(size, dtype=tf.float32) * tf.cast(h, dtype=tf.float32) /
416
- tf.cast(w, dtype=tf.float32)),
417
- dtype=tf.int32)
475
+ (
476
+ tf.cast(size, dtype=tf.float32)
477
+ * tf.cast(h, dtype=tf.float32)
478
+ / tf.cast(w, dtype=tf.float32)
479
+ ),
480
+ dtype=tf.int32,
481
+ )
418
482
  else:
419
483
  oh = size
420
484
  ow = tf.cast(
421
- (tf.cast(size, dtype=tf.float32) * tf.cast(w, dtype=tf.float32) /
422
- tf.cast(h, dtype=tf.float32)),
423
- dtype=tf.int32)
485
+ (
486
+ tf.cast(size, dtype=tf.float32)
487
+ * tf.cast(w, dtype=tf.float32)
488
+ / tf.cast(h, dtype=tf.float32)
489
+ ),
490
+ dtype=tf.int32,
491
+ )
424
492
 
425
493
  return tf.stack([oh, ow])
426
494
 
@@ -433,19 +501,21 @@ def resize_image(
433
501
  orignal_size = tf.shape(image)[0:2]
434
502
  size = get_size(orignal_size, size, max_size)
435
503
  rescaled_image = tf.image.resize(
436
- image, tf.cast(size, tf.int32), method=method)
504
+ image, tf.cast(size, tf.int32), method=method
505
+ )
437
506
  image_scale = size / orignal_size
438
507
  image_info = tf.stack([
439
508
  tf.cast(orignal_size, dtype=tf.float32),
440
509
  tf.cast(size, dtype=tf.float32),
441
510
  tf.cast(image_scale, tf.float32),
442
- tf.constant([0.0, 0.0], dtype=tf.float32)
511
+ tf.constant([0.0, 0.0], dtype=tf.float32),
443
512
  ])
444
513
  return rescaled_image, image_info
445
514
 
446
515
 
447
516
  def center_crop_image(
448
- image, center_crop_fraction: float = CENTER_CROP_FRACTION):
517
+ image, center_crop_fraction: float = CENTER_CROP_FRACTION
518
+ ):
449
519
  """Center crop a square shape slice from the input image.
450
520
 
451
521
  It crops a square shape slice from the image. The side of the actual crop
@@ -465,13 +535,16 @@ def center_crop_image(
465
535
  """
466
536
  with tf.name_scope('center_crop_image'):
467
537
  image_size = tf.cast(tf.shape(image)[:2], dtype=tf.float32)
468
- crop_size = (
469
- center_crop_fraction * tf.math.minimum(image_size[0], image_size[1]))
538
+ crop_size = center_crop_fraction * tf.math.minimum(
539
+ image_size[0], image_size[1]
540
+ )
470
541
  crop_offset = tf.cast((image_size - crop_size) / 2.0, dtype=tf.int32)
471
542
  crop_size = tf.cast(crop_size, dtype=tf.int32)
472
543
  cropped_image = image[
473
- crop_offset[0]:crop_offset[0] + crop_size,
474
- crop_offset[1]:crop_offset[1] + crop_size, :]
544
+ crop_offset[0] : crop_offset[0] + crop_size,
545
+ crop_offset[1] : crop_offset[1] + crop_size,
546
+ :,
547
+ ]
475
548
  return cropped_image
476
549
 
477
550
 
@@ -508,25 +581,29 @@ def center_crop_image_v2(
508
581
  crop_offset = tf.cast((image_shape - crop_size) / 2.0, dtype=tf.int32)
509
582
  crop_size = tf.cast(crop_size, dtype=tf.int32)
510
583
  crop_window = tf.stack(
511
- [crop_offset[0], crop_offset[1], crop_size, crop_size])
584
+ [crop_offset[0], crop_offset[1], crop_size, crop_size]
585
+ )
512
586
  cropped_image = tf.image.decode_and_crop_jpeg(
513
- image_bytes, crop_window, channels=3)
587
+ image_bytes, crop_window, channels=3
588
+ )
514
589
  return cropped_image
515
590
 
516
591
 
517
- def random_crop_image(image,
518
- aspect_ratio_range=(3. / 4., 4. / 3.),
519
- area_range=(0.08, 1.0),
520
- max_attempts=10,
521
- seed=1):
592
+ def random_crop_image(
593
+ image,
594
+ aspect_ratio_range=(3.0 / 4.0, 4.0 / 3.0),
595
+ area_range=(0.08, 1.0),
596
+ max_attempts=10,
597
+ seed=1,
598
+ ):
522
599
  """Randomly crop an arbitrary shaped slice from the input image.
523
600
 
524
601
  Args:
525
602
  image: a Tensor of shape [height, width, 3] representing the input image.
526
603
  aspect_ratio_range: a list of floats. The cropped area of the image must
527
604
  have an aspect ratio = width / height within this range.
528
- area_range: a list of floats. The cropped reas of the image must contain
529
- a fraction of the input image within this range.
605
+ area_range: a list of floats. The cropped reas of the image must contain a
606
+ fraction of the input image within this range.
530
607
  max_attempts: the number of attempts at generating a cropped region of the
531
608
  image of the specified constraints. After max_attempts failures, return
532
609
  the entire image.
@@ -544,17 +621,20 @@ def random_crop_image(image,
544
621
  min_object_covered=area_range[0],
545
622
  aspect_ratio_range=aspect_ratio_range,
546
623
  area_range=area_range,
547
- max_attempts=max_attempts)
624
+ max_attempts=max_attempts,
625
+ )
548
626
  cropped_image = tf.slice(image, crop_offset, crop_size)
549
627
  return cropped_image
550
628
 
551
629
 
552
- def random_crop_image_v2(image_bytes,
553
- image_shape,
554
- aspect_ratio_range=(3. / 4., 4. / 3.),
555
- area_range=(0.08, 1.0),
556
- max_attempts=10,
557
- seed=1):
630
+ def random_crop_image_v2(
631
+ image_bytes,
632
+ image_shape,
633
+ aspect_ratio_range=(3.0 / 4.0, 4.0 / 3.0),
634
+ area_range=(0.08, 1.0),
635
+ max_attempts=10,
636
+ seed=1,
637
+ ):
558
638
  """Randomly crop an arbitrary shaped slice from the input image.
559
639
 
560
640
  This is a faster version of `random_crop_image` which takes the original
@@ -566,8 +646,8 @@ def random_crop_image_v2(image_bytes,
566
646
  image_shape: a Tensor specifying the shape of the raw image.
567
647
  aspect_ratio_range: a list of floats. The cropped area of the image must
568
648
  have an aspect ratio = width / height within this range.
569
- area_range: a list of floats. The cropped reas of the image must contain
570
- a fraction of the input image within this range.
649
+ area_range: a list of floats. The cropped reas of the image must contain a
650
+ fraction of the input image within this range.
571
651
  max_attempts: the number of attempts at generating a cropped region of the
572
652
  image of the specified constraints. After max_attempts failures, return
573
653
  the entire image.
@@ -585,19 +665,18 @@ def random_crop_image_v2(image_bytes,
585
665
  min_object_covered=area_range[0],
586
666
  aspect_ratio_range=aspect_ratio_range,
587
667
  area_range=area_range,
588
- max_attempts=max_attempts)
668
+ max_attempts=max_attempts,
669
+ )
589
670
  offset_y, offset_x, _ = tf.unstack(crop_offset)
590
671
  crop_height, crop_width, _ = tf.unstack(crop_size)
591
672
  crop_window = tf.stack([offset_y, offset_x, crop_height, crop_width])
592
673
  cropped_image = tf.image.decode_and_crop_jpeg(
593
- image_bytes, crop_window, channels=3)
674
+ image_bytes, crop_window, channels=3
675
+ )
594
676
  return cropped_image
595
677
 
596
678
 
597
- def resize_and_crop_boxes(boxes,
598
- image_scale,
599
- output_size,
600
- offset):
679
+ def resize_and_crop_boxes(boxes, image_scale, output_size, offset):
601
680
  """Resizes boxes to output size with scale and offset.
602
681
 
603
682
  Args:
@@ -621,7 +700,9 @@ def resize_and_crop_boxes(boxes,
621
700
  return boxes
622
701
 
623
702
 
624
- def resize_and_crop_masks(masks, image_scale, output_size, offset):
703
+ def resize_and_crop_masks(
704
+ masks, image_scale, output_size, offset, centered_crop: bool = False
705
+ ):
625
706
  """Resizes boxes to output size with scale and offset.
626
707
 
627
708
  Args:
@@ -632,6 +713,9 @@ def resize_and_crop_masks(masks, image_scale, output_size, offset):
632
713
  output image size.
633
714
  offset: 2D `Tensor` representing top-left corner [y0, x0] to crop scaled
634
715
  boxes.
716
+ centered_crop: If `centered_crop` is set to True, then resized crop (if
717
+ smaller than padded size) is place in the center of the image. Default
718
+ behaviour is to place it at left top corner.
635
719
 
636
720
  Returns:
637
721
  masks: `Tensor` of shape [N, H, W, C] representing the scaled masks.
@@ -640,24 +724,43 @@ def resize_and_crop_masks(masks, image_scale, output_size, offset):
640
724
  mask_size = tf.cast(tf.shape(masks)[1:3], tf.float32)
641
725
  num_channels = tf.shape(masks)[3]
642
726
  # Pad masks to avoid empty mask annotations.
643
- masks = tf.concat([
644
- tf.zeros([1, mask_size[0], mask_size[1], num_channels],
645
- dtype=masks.dtype), masks
646
- ],
647
- axis=0)
727
+ masks = tf.concat(
728
+ [
729
+ tf.zeros(
730
+ [1, mask_size[0], mask_size[1], num_channels], dtype=masks.dtype
731
+ ),
732
+ masks,
733
+ ],
734
+ axis=0,
735
+ )
648
736
 
649
737
  scaled_size = tf.cast(image_scale * mask_size, tf.int32)
650
738
  scaled_masks = tf.image.resize(
651
- masks, scaled_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
739
+ masks, scaled_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
740
+ )
741
+
652
742
  offset = tf.cast(offset, tf.int32)
653
743
  scaled_masks = scaled_masks[
654
744
  :,
655
- offset[0]:offset[0] + output_size[0],
656
- offset[1]:offset[1] + output_size[1],
657
- :]
745
+ offset[0] : offset[0] + output_size[0],
746
+ offset[1] : offset[1] + output_size[1],
747
+ :,
748
+ ]
749
+
750
+ if centered_crop:
751
+ scaled_mask_size = tf.cast(tf.shape(scaled_masks)[1:3], tf.int32)
752
+ output_masks = tf.image.pad_to_bounding_box(
753
+ scaled_masks,
754
+ tf.maximum((output_size[0] - scaled_mask_size[0]) // 2, 0),
755
+ tf.maximum((output_size[1] - scaled_mask_size[1]) // 2, 0),
756
+ output_size[0],
757
+ output_size[1],
758
+ )
759
+ else:
760
+ output_masks = tf.image.pad_to_bounding_box(
761
+ scaled_masks, 0, 0, output_size[0], output_size[1]
762
+ )
658
763
 
659
- output_masks = tf.image.pad_to_bounding_box(
660
- scaled_masks, 0, 0, output_size[0], output_size[1])
661
764
  # Remove padding.
662
765
  output_masks = output_masks[1::]
663
766
  return output_masks
@@ -704,21 +807,20 @@ def random_horizontal_flip(
704
807
  do_flip = tf.less(tf.random.uniform([], seed=seed), prob)
705
808
 
706
809
  image = tf.cond(
707
- do_flip,
708
- lambda: horizontal_flip_image(image),
709
- lambda: image)
810
+ do_flip, lambda: horizontal_flip_image(image), lambda: image
811
+ )
710
812
 
711
813
  if normalized_boxes is not None:
712
814
  normalized_boxes = tf.cond(
713
815
  do_flip,
714
816
  lambda: horizontal_flip_boxes(normalized_boxes),
715
- lambda: normalized_boxes)
817
+ lambda: normalized_boxes,
818
+ )
716
819
 
717
820
  if masks is not None:
718
821
  masks = tf.cond(
719
- do_flip,
720
- lambda: horizontal_flip_masks(masks),
721
- lambda: masks)
822
+ do_flip, lambda: horizontal_flip_masks(masks), lambda: masks
823
+ )
722
824
 
723
825
  return image, normalized_boxes, masks
724
826
 
@@ -728,9 +830,10 @@ def random_horizontal_flip_with_roi(
728
830
  boxes: Optional[tf.Tensor] = None,
729
831
  masks: Optional[tf.Tensor] = None,
730
832
  roi_boxes: Optional[tf.Tensor] = None,
731
- seed: int = 1
732
- ) -> Tuple[tf.Tensor, Optional[tf.Tensor], Optional[tf.Tensor],
733
- Optional[tf.Tensor]]:
833
+ seed: int = 1,
834
+ ) -> Tuple[
835
+ tf.Tensor, Optional[tf.Tensor], Optional[tf.Tensor], Optional[tf.Tensor]
836
+ ]:
734
837
  """Randomly flips input image and bounding boxes horizontally.
735
838
 
736
839
  Extends preprocess_ops.random_horizontal_flip to also flip roi_boxes used
@@ -752,20 +855,24 @@ def random_horizontal_flip_with_roi(
752
855
  with tf.name_scope('random_horizontal_flip'):
753
856
  do_flip = tf.greater(tf.random.uniform([], seed=seed), 0.5)
754
857
 
755
- image = tf.cond(do_flip, lambda: horizontal_flip_image(image),
756
- lambda: image)
858
+ image = tf.cond(
859
+ do_flip, lambda: horizontal_flip_image(image), lambda: image
860
+ )
757
861
 
758
862
  if boxes is not None:
759
- boxes = tf.cond(do_flip, lambda: horizontal_flip_boxes(boxes),
760
- lambda: boxes)
863
+ boxes = tf.cond(
864
+ do_flip, lambda: horizontal_flip_boxes(boxes), lambda: boxes
865
+ )
761
866
 
762
867
  if masks is not None:
763
- masks = tf.cond(do_flip, lambda: horizontal_flip_masks(masks),
764
- lambda: masks)
868
+ masks = tf.cond(
869
+ do_flip, lambda: horizontal_flip_masks(masks), lambda: masks
870
+ )
765
871
 
766
872
  if roi_boxes is not None:
767
- roi_boxes = tf.cond(do_flip, lambda: horizontal_flip_boxes(roi_boxes),
768
- lambda: roi_boxes)
873
+ roi_boxes = tf.cond(
874
+ do_flip, lambda: horizontal_flip_boxes(roi_boxes), lambda: roi_boxes
875
+ )
769
876
 
770
877
  return image, boxes, masks, roi_boxes
771
878
 
@@ -778,30 +885,33 @@ def random_vertical_flip(
778
885
  do_flip = tf.less(tf.random.uniform([], seed=seed), prob)
779
886
 
780
887
  image = tf.cond(
781
- do_flip,
782
- lambda: tf.image.flip_up_down(image),
783
- lambda: image)
888
+ do_flip, lambda: tf.image.flip_up_down(image), lambda: image
889
+ )
784
890
 
785
891
  if normalized_boxes is not None:
786
892
  normalized_boxes = tf.cond(
787
893
  do_flip,
788
894
  lambda: vertical_flip_boxes(normalized_boxes),
789
- lambda: normalized_boxes)
895
+ lambda: normalized_boxes,
896
+ )
790
897
 
791
898
  if masks is not None:
792
899
  masks = tf.cond(
793
900
  do_flip,
794
901
  lambda: tf.image.flip_up_down(masks[..., None])[..., 0],
795
- lambda: masks)
902
+ lambda: masks,
903
+ )
796
904
 
797
905
  return image, normalized_boxes, masks
798
906
 
799
907
 
800
- def color_jitter(image: tf.Tensor,
801
- brightness: Optional[float] = 0.,
802
- contrast: Optional[float] = 0.,
803
- saturation: Optional[float] = 0.,
804
- seed: Optional[int] = None) -> tf.Tensor:
908
+ def color_jitter(
909
+ image: tf.Tensor,
910
+ brightness: Optional[float] = 0.0,
911
+ contrast: Optional[float] = 0.0,
912
+ saturation: Optional[float] = 0.0,
913
+ seed: Optional[int] = None,
914
+ ) -> tf.Tensor:
805
915
  """Applies color jitter to an image, similarly to torchvision`s ColorJitter.
806
916
 
807
917
  Args:
@@ -823,9 +933,9 @@ def color_jitter(image: tf.Tensor,
823
933
  return image
824
934
 
825
935
 
826
- def random_brightness(image: tf.Tensor,
827
- brightness: float = 0.,
828
- seed: Optional[int] = None) -> tf.Tensor:
936
+ def random_brightness(
937
+ image: tf.Tensor, brightness: float = 0.0, seed: Optional[int] = None
938
+ ) -> tf.Tensor:
829
939
  """Jitters brightness of an image.
830
940
 
831
941
  Args:
@@ -838,17 +948,15 @@ def random_brightness(image: tf.Tensor,
838
948
  tf.Tensor: The augmented `image` of type uint8.
839
949
  """
840
950
  assert brightness >= 0, '`brightness` must be positive'
841
- brightness = tf.random.uniform([],
842
- max(0, 1 - brightness),
843
- 1 + brightness,
844
- seed=seed,
845
- dtype=tf.float32)
951
+ brightness = tf.random.uniform(
952
+ [], max(0, 1 - brightness), 1 + brightness, seed=seed, dtype=tf.float32
953
+ )
846
954
  return augment.brightness(image, brightness)
847
955
 
848
956
 
849
- def random_contrast(image: tf.Tensor,
850
- contrast: float = 0.,
851
- seed: Optional[int] = None) -> tf.Tensor:
957
+ def random_contrast(
958
+ image: tf.Tensor, contrast: float = 0.0, seed: Optional[int] = None
959
+ ) -> tf.Tensor:
852
960
  """Jitters contrast of an image, similarly to torchvision`s ColorJitter.
853
961
 
854
962
  Args:
@@ -860,17 +968,15 @@ def random_contrast(image: tf.Tensor,
860
968
  tf.Tensor: The augmented `image` of type uint8.
861
969
  """
862
970
  assert contrast >= 0, '`contrast` must be positive'
863
- contrast = tf.random.uniform([],
864
- max(0, 1 - contrast),
865
- 1 + contrast,
866
- seed=seed,
867
- dtype=tf.float32)
971
+ contrast = tf.random.uniform(
972
+ [], max(0, 1 - contrast), 1 + contrast, seed=seed, dtype=tf.float32
973
+ )
868
974
  return augment.contrast(image, contrast)
869
975
 
870
976
 
871
- def random_saturation(image: tf.Tensor,
872
- saturation: float = 0.,
873
- seed: Optional[int] = None) -> tf.Tensor:
977
+ def random_saturation(
978
+ image: tf.Tensor, saturation: float = 0.0, seed: Optional[int] = None
979
+ ) -> tf.Tensor:
874
980
  """Jitters saturation of an image, similarly to torchvision`s ColorJitter.
875
981
 
876
982
  Args:
@@ -883,24 +989,29 @@ def random_saturation(image: tf.Tensor,
883
989
  tf.Tensor: The augmented `image` of type uint8.
884
990
  """
885
991
  assert saturation >= 0, '`saturation` must be positive'
886
- saturation = tf.random.uniform([],
887
- max(0, 1 - saturation),
888
- 1 + saturation,
889
- seed=seed,
890
- dtype=tf.float32)
992
+ saturation = tf.random.uniform(
993
+ [], max(0, 1 - saturation), 1 + saturation, seed=seed, dtype=tf.float32
994
+ )
891
995
  return _saturation(image, saturation)
892
996
 
893
997
 
894
- def _saturation(image: tf.Tensor,
895
- saturation: Optional[float] = 0.) -> tf.Tensor:
998
+ def _saturation(
999
+ image: tf.Tensor, saturation: Optional[float] = 0.0
1000
+ ) -> tf.Tensor:
896
1001
  return augment.blend(
897
- tf.repeat(tf.image.rgb_to_grayscale(image), 3, axis=-1), image,
898
- saturation)
1002
+ tf.repeat(tf.image.rgb_to_grayscale(image), 3, axis=-1), image, saturation
1003
+ )
899
1004
 
900
1005
 
901
- def random_crop_image_with_boxes_and_labels(img, boxes, labels, min_scale,
902
- aspect_ratio_range,
903
- min_overlap_params, max_retry):
1006
+ def random_crop_image_with_boxes_and_labels(
1007
+ img,
1008
+ boxes,
1009
+ labels,
1010
+ min_scale,
1011
+ aspect_ratio_range,
1012
+ min_overlap_params,
1013
+ max_retry,
1014
+ ):
904
1015
  """Crops a random slice from the input image.
905
1016
 
906
1017
  The function will correspondingly recompute the bounding boxes and filter out
@@ -946,8 +1057,13 @@ def random_crop_image_with_boxes_and_labels(img, boxes, labels, min_scale,
946
1057
 
947
1058
  minval, maxval, step, offset = min_overlap_params
948
1059
 
949
- min_overlap = tf.math.floordiv(
950
- tf.random.uniform([], minval=minval, maxval=maxval), step) * step - offset
1060
+ min_overlap = (
1061
+ tf.math.floordiv(
1062
+ tf.random.uniform([], minval=minval, maxval=maxval), step
1063
+ )
1064
+ * step
1065
+ - offset
1066
+ )
951
1067
 
952
1068
  min_overlap = tf.clip_by_value(min_overlap, 0.0, 1.1)
953
1069
 
@@ -961,9 +1077,11 @@ def random_crop_image_with_boxes_and_labels(img, boxes, labels, min_scale,
961
1077
  scale_h = tf.random.uniform([], min_scale, 1.0)
962
1078
  scale_w = tf.random.uniform([], min_scale, 1.0)
963
1079
  new_h = tf.cast(
964
- scale_h * tf.cast(original_h, dtype=tf.float32), dtype=tf.int32)
1080
+ scale_h * tf.cast(original_h, dtype=tf.float32), dtype=tf.int32
1081
+ )
965
1082
  new_w = tf.cast(
966
- scale_w * tf.cast(original_w, dtype=tf.float32), dtype=tf.int32)
1083
+ scale_w * tf.cast(original_w, dtype=tf.float32), dtype=tf.int32
1084
+ )
967
1085
 
968
1086
  # Aspect ratio has to be in the prespecified range
969
1087
  aspect_ratio = new_h / new_w
@@ -975,18 +1093,18 @@ def random_crop_image_with_boxes_and_labels(img, boxes, labels, min_scale,
975
1093
  top = tf.random.uniform([], 0, original_h - new_h, dtype=tf.int32)
976
1094
  bottom = top + new_h
977
1095
 
978
- normalized_left = tf.cast(
979
- left, dtype=tf.float32) / tf.cast(
980
- original_w, dtype=tf.float32)
981
- normalized_right = tf.cast(
982
- right, dtype=tf.float32) / tf.cast(
983
- original_w, dtype=tf.float32)
984
- normalized_top = tf.cast(
985
- top, dtype=tf.float32) / tf.cast(
986
- original_h, dtype=tf.float32)
987
- normalized_bottom = tf.cast(
988
- bottom, dtype=tf.float32) / tf.cast(
989
- original_h, dtype=tf.float32)
1096
+ normalized_left = tf.cast(left, dtype=tf.float32) / tf.cast(
1097
+ original_w, dtype=tf.float32
1098
+ )
1099
+ normalized_right = tf.cast(right, dtype=tf.float32) / tf.cast(
1100
+ original_w, dtype=tf.float32
1101
+ )
1102
+ normalized_top = tf.cast(top, dtype=tf.float32) / tf.cast(
1103
+ original_h, dtype=tf.float32
1104
+ )
1105
+ normalized_bottom = tf.cast(bottom, dtype=tf.float32) / tf.cast(
1106
+ original_h, dtype=tf.float32
1107
+ )
990
1108
 
991
1109
  cropped_box = tf.expand_dims(
992
1110
  tf.stack([
@@ -995,10 +1113,11 @@ def random_crop_image_with_boxes_and_labels(img, boxes, labels, min_scale,
995
1113
  normalized_bottom,
996
1114
  normalized_right,
997
1115
  ]),
998
- axis=0)
1116
+ axis=0,
1117
+ )
999
1118
  iou = box_ops.bbox_overlap(
1000
- tf.expand_dims(cropped_box, axis=0),
1001
- tf.expand_dims(boxes, axis=0)) # (1, 1, n_ground_truth)
1119
+ tf.expand_dims(cropped_box, axis=0), tf.expand_dims(boxes, axis=0)
1120
+ ) # (1, 1, n_ground_truth)
1002
1121
  iou = tf.squeeze(iou, axis=[0, 1])
1003
1122
 
1004
1123
  # If not a single bounding box has a Jaccard overlap of greater than
@@ -1008,10 +1127,15 @@ def random_crop_image_with_boxes_and_labels(img, boxes, labels, min_scale,
1008
1127
 
1009
1128
  centroids = box_ops.yxyx_to_cycxhw(boxes)
1010
1129
  mask = tf.math.logical_and(
1011
- tf.math.logical_and(centroids[:, 0] > normalized_top,
1012
- centroids[:, 0] < normalized_bottom),
1013
- tf.math.logical_and(centroids[:, 1] > normalized_left,
1014
- centroids[:, 1] < normalized_right))
1130
+ tf.math.logical_and(
1131
+ centroids[:, 0] > normalized_top,
1132
+ centroids[:, 0] < normalized_bottom,
1133
+ ),
1134
+ tf.math.logical_and(
1135
+ centroids[:, 1] > normalized_left,
1136
+ centroids[:, 1] < normalized_right,
1137
+ ),
1138
+ )
1015
1139
  # If not a single bounding box has its center in the crop, try again.
1016
1140
  if tf.reduce_sum(tf.cast(mask, dtype=tf.int32)) > 0:
1017
1141
  indices = tf.squeeze(tf.where(mask), axis=1)
@@ -1019,15 +1143,22 @@ def random_crop_image_with_boxes_and_labels(img, boxes, labels, min_scale,
1019
1143
  filtered_boxes = tf.gather(boxes, indices)
1020
1144
 
1021
1145
  boxes = tf.clip_by_value(
1022
- (filtered_boxes[..., :] * tf.cast(
1023
- tf.stack([original_h, original_w, original_h, original_w]),
1024
- dtype=tf.float32) -
1025
- tf.cast(tf.stack([top, left, top, left]), dtype=tf.float32)) /
1026
- tf.cast(tf.stack([new_h, new_w, new_h, new_w]), dtype=tf.float32),
1027
- 0.0, 1.0)
1028
-
1029
- img = tf.image.crop_to_bounding_box(img, top, left, bottom - top,
1030
- right - left)
1146
+ (
1147
+ filtered_boxes[..., :]
1148
+ * tf.cast(
1149
+ tf.stack([original_h, original_w, original_h, original_w]),
1150
+ dtype=tf.float32,
1151
+ )
1152
+ - tf.cast(tf.stack([top, left, top, left]), dtype=tf.float32)
1153
+ )
1154
+ / tf.cast(tf.stack([new_h, new_w, new_h, new_w]), dtype=tf.float32),
1155
+ 0.0,
1156
+ 1.0,
1157
+ )
1158
+
1159
+ img = tf.image.crop_to_bounding_box(
1160
+ img, top, left, bottom - top, right - left
1161
+ )
1031
1162
 
1032
1163
  labels = tf.gather(labels, indices)
1033
1164
  break
@@ -1035,14 +1166,16 @@ def random_crop_image_with_boxes_and_labels(img, boxes, labels, min_scale,
1035
1166
  return img, boxes, labels
1036
1167
 
1037
1168
 
1038
- def random_crop(image,
1039
- boxes,
1040
- labels,
1041
- min_scale=0.3,
1042
- aspect_ratio_range=(0.5, 2.0),
1043
- min_overlap_params=(0.0, 1.4, 0.2, 0.1),
1044
- max_retry=50,
1045
- seed=None):
1169
+ def random_crop(
1170
+ image,
1171
+ boxes,
1172
+ labels,
1173
+ min_scale=0.3,
1174
+ aspect_ratio_range=(0.5, 2.0),
1175
+ min_overlap_params=(0.0, 1.4, 0.2, 0.1),
1176
+ max_retry=50,
1177
+ seed=None,
1178
+ ):
1046
1179
  """Randomly crop the image and boxes, filtering labels.
1047
1180
 
1048
1181
  Args:
@@ -1069,11 +1202,15 @@ def random_crop(image,
1069
1202
  with tf.name_scope('random_crop'):
1070
1203
  do_crop = tf.greater(tf.random.uniform([], seed=seed), 0.5)
1071
1204
  if do_crop:
1072
- return random_crop_image_with_boxes_and_labels(image, boxes, labels,
1073
- min_scale,
1074
- aspect_ratio_range,
1075
- min_overlap_params,
1076
- max_retry)
1205
+ return random_crop_image_with_boxes_and_labels(
1206
+ image,
1207
+ boxes,
1208
+ labels,
1209
+ min_scale,
1210
+ aspect_ratio_range,
1211
+ min_overlap_params,
1212
+ max_retry,
1213
+ )
1077
1214
  else:
1078
1215
  return image, boxes, labels
1079
1216