tf-models-nightly 2.11.0.dev20230321__py2.py3-none-any.whl → 2.11.0.dev20230323__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -45,15 +45,10 @@ class YT8MTask(base_task.Task):
45
45
  l2_weight_decay = self.task_config.losses.l2_weight_decay
46
46
  # Model configuration.
47
47
  model_config = self.task_config.model
48
- norm_activation_config = model_config.norm_activation
49
48
  model = DbofModel(
50
49
  params=model_config,
51
50
  input_specs=input_specs,
52
51
  num_classes=train_cfg.num_classes,
53
- activation=norm_activation_config.activation,
54
- use_sync_bn=norm_activation_config.use_sync_bn,
55
- norm_momentum=norm_activation_config.norm_momentum,
56
- norm_epsilon=norm_activation_config.norm_epsilon,
57
52
  l2_weight_decay=l2_weight_decay)
58
53
 
59
54
  non_trainable_batch_norm_variables = []
@@ -66,18 +61,32 @@ class YT8MTask(base_task.Task):
66
61
  non_trainable_extra_variables.append(var)
67
62
 
68
63
  logging.info(
69
- 'Trainable model variables:\n%s', '\n'.join(
70
- [f'{var.name}\t{var.shape}' for var in model.trainable_variables]))
64
+ 'Trainable model variables:\n%s',
65
+ '\n'.join(
66
+ [f'{var.name}\t{var.shape}' for var in model.trainable_variables]
67
+ ),
68
+ )
71
69
  logging.info(
72
- 'Non-trainable batch norm variables (get updated in training mode):\n%s',
73
- '\n'.join([
74
- f'{var.name}\t{var.shape}'
75
- for var in non_trainable_batch_norm_variables
76
- ]))
70
+ (
71
+ 'Non-trainable batch norm variables (get updated in training'
72
+ ' mode):\n%s'
73
+ ),
74
+ '\n'.join(
75
+ [
76
+ f'{var.name}\t{var.shape}'
77
+ for var in non_trainable_batch_norm_variables
78
+ ]
79
+ ),
80
+ )
77
81
  logging.info(
78
- 'Non-trainable frozen model variables:\n%s', '\n'.join([
79
- f'{var.name}\t{var.shape}' for var in non_trainable_extra_variables
80
- ]))
82
+ 'Non-trainable frozen model variables:\n%s',
83
+ '\n'.join(
84
+ [
85
+ f'{var.name}\t{var.shape}'
86
+ for var in non_trainable_extra_variables
87
+ ]
88
+ ),
89
+ )
81
90
  return model
82
91
 
83
92
  def build_inputs(self, params: yt8m_cfg.DataConfig, input_context=None):
@@ -173,7 +182,10 @@ class YT8MTask(base_task.Task):
173
182
  for name in metric_names:
174
183
  metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
175
184
 
176
- if self.task_config.evaluation.average_precision is not None and not training:
185
+ if (
186
+ self.task_config.evaluation.average_precision is not None
187
+ and not training
188
+ ):
177
189
  # Cannot run in train step.
178
190
  num_classes = self.task_config.validation_data.num_classes
179
191
  top_k = self.task_config.evaluation.average_precision.top_k
@@ -183,14 +195,16 @@ class YT8MTask(base_task.Task):
183
195
 
184
196
  return metrics
185
197
 
186
- def process_metrics(self,
187
- metrics: List[tf.keras.metrics.Metric],
188
- labels: tf.Tensor,
189
- outputs: tf.Tensor,
190
- model_losses: Optional[Dict[str, tf.Tensor]] = None,
191
- label_weights: Optional[tf.Tensor] = None,
192
- training: bool = True,
193
- **kwargs) -> Dict[str, Tuple[tf.Tensor, ...]]:
198
+ def process_metrics(
199
+ self,
200
+ metrics: List[tf.keras.metrics.Metric],
201
+ labels: tf.Tensor,
202
+ outputs: tf.Tensor,
203
+ model_losses: Optional[Dict[str, tf.Tensor]] = None,
204
+ label_weights: Optional[tf.Tensor] = None,
205
+ training: bool = True,
206
+ **kwargs,
207
+ ) -> Dict[str, Tuple[tf.Tensor, ...]]:
194
208
  """Updates metrics.
195
209
 
196
210
  Args:
@@ -210,7 +224,10 @@ class YT8MTask(base_task.Task):
210
224
  model_losses = {}
211
225
 
212
226
  logs = {}
213
- if self.task_config.evaluation.average_precision is not None and not training:
227
+ if (
228
+ self.task_config.evaluation.average_precision is not None
229
+ and not training
230
+ ):
214
231
  logs.update({self.avg_prec_metric.name: (labels, outputs)})
215
232
 
216
233
  for m in metrics:
@@ -211,19 +211,18 @@ class Parser(parser.Parser):
211
211
  image = preprocess_ops.normalize_image(image)
212
212
 
213
213
  # Flips image randomly during training.
214
- if self._aug_rand_hflip:
215
- if self._include_mask:
216
- image, boxes, masks = preprocess_ops.random_horizontal_flip(
217
- image, boxes, masks)
218
- else:
219
- image, boxes, _ = preprocess_ops.random_horizontal_flip(
220
- image, boxes)
221
- if self._aug_rand_vflip:
222
- if self._include_mask:
223
- image, boxes, masks = preprocess_ops.random_vertical_flip(
224
- image, boxes, masks)
225
- else:
226
- image, boxes, _ = preprocess_ops.random_vertical_flip(image, boxes)
214
+ image, boxes, masks = preprocess_ops.random_horizontal_flip(
215
+ image,
216
+ boxes,
217
+ masks=None if not self._include_mask else masks,
218
+ prob=tf.where(self._aug_rand_hflip, 0.5, 0.0),
219
+ )
220
+ image, boxes, masks = preprocess_ops.random_vertical_flip(
221
+ image,
222
+ boxes,
223
+ masks=None if not self._include_mask else masks,
224
+ prob=tf.where(self._aug_rand_vflip, 0.5, 0.0),
225
+ )
227
226
 
228
227
  # Converts boxes from normalized coordinates to pixel coordinates.
229
228
  # Now the coordinates of boxes are w.r.t. the original image.
@@ -35,4 +35,5 @@ TFDS_ID_TO_DECODER_MAP = {
35
35
  'cifar10': ClassificationDecorder,
36
36
  'cifar100': ClassificationDecorder,
37
37
  'imagenet2012': ClassificationDecorder,
38
+ 'imagenet2012_fewshot/10shot': ClassificationDecorder,
38
39
  }
@@ -465,213 +465,6 @@ def _count_detection_type(
465
465
  return count
466
466
 
467
467
 
468
- def _compute_fp_tp_gt_count(
469
- y_true: Dict[str, tf.Tensor],
470
- y_pred: Dict[str, tf.Tensor],
471
- num_classes: int,
472
- mask_output_boundary: Tuple[int, int] = (640, 640),
473
- iou_thresholds: Tuple[float, ...] = (0.5,),
474
- matching_algorithm: Optional[MatchingAlgorithm] = None,
475
- num_confidence_bins: int = 1000,
476
- use_masks: bool = False,
477
- ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
478
- """Computes the true and false positives."""
479
-
480
- if matching_algorithm is None:
481
- matching_algorithm = COCOMatchingAlgorithm(iou_thresholds)
482
-
483
- # (batch_size, num_detections, 4) in absolute coordinates.
484
- detection_boxes = tf.cast(y_pred['detection_boxes'], tf.float32)
485
- # (batch_size, num_detections)
486
- detection_classes = tf.cast(y_pred['detection_classes'], tf.int32)
487
- # (batch_size, num_detections)
488
- detection_scores = tf.cast(y_pred['detection_scores'], tf.float32)
489
- # (batch_size, num_gts, 4) in absolute coordinates.
490
- gt_boxes = tf.cast(y_true['boxes'], tf.float32)
491
- # (batch_size, num_gts)
492
- gt_classes = tf.cast(y_true['classes'], tf.int32)
493
- # (batch_size, num_gts)
494
- if 'is_crowds' in y_true:
495
- gt_is_crowd = tf.cast(y_true['is_crowds'], tf.bool)
496
- else:
497
- gt_is_crowd = tf.zeros_like(gt_classes, dtype=tf.bool)
498
-
499
- image_scale = tf.tile(y_true['image_info'][:, 2:3, :], multiples=[1, 1, 2])
500
- detection_boxes = detection_boxes / tf.cast(
501
- image_scale, dtype=detection_boxes.dtype
502
- )
503
-
504
- # Step 1: Computes IoUs between the detections and the non-crowd ground
505
- # truths and IoAs between the detections and the crowd ground truths.
506
- if not use_masks:
507
- # (batch_size, num_detections, num_gts)
508
- detection_to_gt_ious = box_ops.bbox_overlap(detection_boxes, gt_boxes)
509
- detection_to_gt_ioas = box_ops.bbox_intersection_over_area(
510
- detection_boxes, gt_boxes
511
- )
512
- else:
513
- # (batch_size, num_detections, mask_height, mask_width)
514
- detection_masks = tf.cast(y_pred['detection_masks'], tf.float32)
515
- # (batch_size, num_gts, gt_mask_height, gt_mask_width)
516
- gt_masks = tf.cast(y_true['masks'], tf.float32)
517
-
518
- num_detections = detection_boxes.get_shape()[1]
519
- # (batch_size, num_detections + num_gts, 4)
520
- all_boxes = _shift_and_rescale_boxes(
521
- tf.concat([detection_boxes, gt_boxes], axis=1),
522
- mask_output_boundary,
523
- )
524
- detection_boxes = all_boxes[:, :num_detections, :]
525
- gt_boxes = all_boxes[:, num_detections:, :]
526
- # (batch_size, num_detections, num_gts)
527
- detection_to_gt_ious, detection_to_gt_ioas = (
528
- mask_ops.instance_masks_overlap(
529
- detection_boxes,
530
- detection_masks,
531
- gt_boxes,
532
- gt_masks,
533
- output_size=mask_output_boundary,
534
- )
535
- )
536
-
537
- # (batch_size, num_detections, num_gts)
538
- detection_to_gt_ious = tf.where(
539
- gt_is_crowd[:, tf.newaxis, :], 0.0, detection_to_gt_ious
540
- )
541
- detection_to_crowd_ioas = tf.where(
542
- gt_is_crowd[:, tf.newaxis, :], detection_to_gt_ioas, 0.0
543
- )
544
-
545
- # Step 2: counts true positives grouped by IoU thresholds, classes and
546
- # confidence bins.
547
-
548
- # (batch_size, num_detections, num_iou_thresholds)
549
- detection_is_tp, _ = matching_algorithm(
550
- detection_to_gt_ious, detection_classes, detection_scores, gt_classes
551
- )
552
- # (batch_size * num_detections,)
553
- flattened_binned_confidence = tf.reshape(
554
- tf.cast(detection_scores * num_confidence_bins, tf.int32), [-1]
555
- )
556
- # (batch_size * num_detections, num_confidence_bins + 1)
557
- flattened_binned_confidence_one_hot = tf.one_hot(
558
- flattened_binned_confidence, num_confidence_bins + 1, axis=1
559
- )
560
- # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
561
- tp_count = _count_detection_type(
562
- detection_is_tp,
563
- detection_classes,
564
- flattened_binned_confidence_one_hot,
565
- num_classes,
566
- )
567
-
568
- # Step 3: Counts false positives grouped by IoU thresholds, classes and
569
- # confidence bins.
570
- # False positive: detection is not true positive (see above) and not part of
571
- # the crowd ground truth with the same class.
572
-
573
- # (batch_size, num_detections, num_gts, num_iou_thresholds)
574
- detection_matches_crowd = (
575
- (detection_to_crowd_ioas[..., tf.newaxis] > iou_thresholds)
576
- & (
577
- detection_classes[:, :, tf.newaxis, tf.newaxis]
578
- == gt_classes[:, tf.newaxis, :, tf.newaxis]
579
- )
580
- & (detection_classes[:, :, tf.newaxis, tf.newaxis] > 0)
581
- )
582
- # (batch_size, num_detections, num_iou_thresholds)
583
- detection_matches_any_crowd = tf.reduce_any(
584
- detection_matches_crowd & ~detection_is_tp[:, :, tf.newaxis, :], axis=2
585
- )
586
- detection_is_fp = ~detection_is_tp & ~detection_matches_any_crowd
587
- # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
588
- fp_count = _count_detection_type(
589
- detection_is_fp,
590
- detection_classes,
591
- flattened_binned_confidence_one_hot,
592
- num_classes,
593
- )
594
-
595
- # Step 4: Counts non-crowd groundtruths grouped by classes.
596
- # (num_classes, )
597
- gt_count = tf.reduce_sum(
598
- tf.one_hot(
599
- tf.where(gt_is_crowd, -1, gt_classes), num_classes, axis=-1
600
- ),
601
- axis=[0, 1],
602
- )
603
- # Clears the count of class 0 (background).
604
- gt_count *= 1.0 - tf.eye(1, num_classes, dtype=gt_count.dtype)[0]
605
-
606
- return tp_count, fp_count, gt_count
607
-
608
-
609
- def _compute_metrics(
610
- tp_count: tf.Tensor,
611
- fp_count: tf.Tensor,
612
- gt_count: tf.Tensor,
613
- confidence_thresholds: Tuple[float, ...] = (),
614
- num_confidence_bins: int = 1000,
615
- average_precision_algorithms: Optional[
616
- Dict[str, AveragePrecision]] = None,
617
- ) -> Dict[str, tf.Tensor]:
618
- """Returns the metrics values as a dict."""
619
-
620
- if average_precision_algorithms is None:
621
- average_precision_algorithms = {'ap': COCOAveragePrecision()}
622
-
623
- result = {
624
- # (num_classes,)
625
- 'valid_classes': gt_count != 0,
626
- }
627
-
628
- # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
629
- tp_count_cum_by_confidence = tf.math.cumsum(
630
- tp_count, axis=-1, reverse=True
631
- )
632
- # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
633
- fp_count_cum_by_confidence = tf.math.cumsum(
634
- fp_count, axis=-1, reverse=True
635
- )
636
-
637
- # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
638
- precisions = tf.math.divide_no_nan(
639
- tp_count_cum_by_confidence,
640
- tp_count_cum_by_confidence + fp_count_cum_by_confidence,
641
- )
642
- # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
643
- recalls = tf.math.divide_no_nan(
644
- tp_count_cum_by_confidence, gt_count[..., tf.newaxis]
645
- )
646
-
647
- if confidence_thresholds:
648
- # If confidence_thresholds is set, reports precision and recall at each
649
- # confidence threshold.
650
- confidence_thresholds = tf.cast(
651
- tf.constant(confidence_thresholds, dtype=tf.float32)
652
- * num_confidence_bins,
653
- dtype=tf.int32,
654
- )
655
- # (num_confidence_thresholds, num_iou_thresholds, num_classes)
656
- result['precisions'] = tf.gather(
657
- tf.transpose(precisions, [2, 0, 1]), confidence_thresholds
658
- )
659
- result['recalls'] = tf.gather(
660
- tf.transpose(recalls, [2, 0, 1]), confidence_thresholds
661
- )
662
-
663
- precisions = tf.reverse(precisions, axis=[-1])
664
- recalls = tf.reverse(recalls, axis=[-1])
665
- result.update(
666
- {
667
- # (num_iou_thresholds, num_classes)
668
- key: ap_algorithm(precisions, recalls)
669
- for key, ap_algorithm in average_precision_algorithms.items()
670
- }
671
- )
672
- return result
673
-
674
-
675
468
  class InstanceMetrics(tf.keras.metrics.Metric):
676
469
  """Reports the metrics of instance detection & segmentation."""
677
470
 
@@ -780,22 +573,138 @@ class InstanceMetrics(tf.keras.metrics.Metric):
780
573
 
781
574
  def reset_state(self):
782
575
  """Resets all of the metric state variables."""
783
- for v in self.variables:
784
- tf.keras.backend.set_value(v, np.zeros(v.shape))
576
+ self.tp_count.assign(tf.zeros_like(self.tp_count))
577
+ self.fp_count.assign(tf.zeros_like(self.fp_count))
578
+ self.gt_count.assign(tf.zeros_like(self.gt_count))
785
579
 
786
580
  def update_state(
787
581
  self, y_true: Dict[str, tf.Tensor], y_pred: Dict[str, tf.Tensor]
788
582
  ):
583
+ # (batch_size, num_detections, 4) in absolute coordinates.
584
+ detection_boxes = tf.cast(y_pred['detection_boxes'], tf.float32)
585
+ # (batch_size, num_detections)
586
+ detection_classes = tf.cast(y_pred['detection_classes'], tf.int32)
587
+ # (batch_size, num_detections)
588
+ detection_scores = tf.cast(y_pred['detection_scores'], tf.float32)
589
+ # (batch_size, num_gts, 4) in absolute coordinates.
590
+ gt_boxes = tf.cast(y_true['boxes'], tf.float32)
591
+ # (batch_size, num_gts)
592
+ gt_classes = tf.cast(y_true['classes'], tf.int32)
593
+ # (batch_size, num_gts)
594
+ if 'is_crowds' in y_true:
595
+ gt_is_crowd = tf.cast(y_true['is_crowds'], tf.bool)
596
+ else:
597
+ gt_is_crowd = tf.zeros_like(gt_classes, dtype=tf.bool)
789
598
 
790
- tp_count, fp_count, gt_count = _compute_fp_tp_gt_count(
791
- y_true=y_true,
792
- y_pred=y_pred,
793
- num_classes=self._num_classes,
794
- mask_output_boundary=self._mask_output_boundary,
795
- iou_thresholds=self._iou_thresholds,
796
- matching_algorithm=self._matching_algorithm,
797
- num_confidence_bins=self._num_confidence_bins,
798
- use_masks=self._use_masks)
599
+ image_scale = tf.tile(y_true['image_info'][:, 2:3, :], multiples=[1, 1, 2])
600
+ detection_boxes = detection_boxes / tf.cast(
601
+ image_scale, dtype=detection_boxes.dtype
602
+ )
603
+
604
+ # Step 1: Computes IoUs between the detections and the non-crowd ground
605
+ # truths and IoAs between the detections and the crowd ground truths.
606
+ if not self._use_masks:
607
+ # (batch_size, num_detections, num_gts)
608
+ detection_to_gt_ious = box_ops.bbox_overlap(detection_boxes, gt_boxes)
609
+ detection_to_gt_ioas = box_ops.bbox_intersection_over_area(
610
+ detection_boxes, gt_boxes
611
+ )
612
+ else:
613
+ # Use outer boxes to generate the masks if available.
614
+ if 'detection_outer_boxes' in y_pred:
615
+ detection_boxes = tf.cast(y_pred['detection_outer_boxes'], tf.float32)
616
+
617
+ # (batch_size, num_detections, mask_height, mask_width)
618
+ detection_masks = tf.cast(y_pred['detection_masks'], tf.float32)
619
+ # (batch_size, num_gts, gt_mask_height, gt_mask_width)
620
+ gt_masks = tf.cast(y_true['masks'], tf.float32)
621
+
622
+ num_detections = detection_boxes.get_shape()[1]
623
+ # (batch_size, num_detections + num_gts, 4)
624
+ all_boxes = _shift_and_rescale_boxes(
625
+ tf.concat([detection_boxes, gt_boxes], axis=1),
626
+ self._mask_output_boundary,
627
+ )
628
+ detection_boxes = all_boxes[:, :num_detections, :]
629
+ gt_boxes = all_boxes[:, num_detections:, :]
630
+ # (batch_size, num_detections, num_gts)
631
+ detection_to_gt_ious, detection_to_gt_ioas = (
632
+ mask_ops.instance_masks_overlap(
633
+ detection_boxes,
634
+ detection_masks,
635
+ gt_boxes,
636
+ gt_masks,
637
+ output_size=self._mask_output_boundary,
638
+ )
639
+ )
640
+ # (batch_size, num_detections, num_gts)
641
+ detection_to_gt_ious = tf.where(
642
+ gt_is_crowd[:, tf.newaxis, :], 0.0, detection_to_gt_ious
643
+ )
644
+ detection_to_crowd_ioas = tf.where(
645
+ gt_is_crowd[:, tf.newaxis, :], detection_to_gt_ioas, 0.0
646
+ )
647
+
648
+ # Step 2: counts true positives grouped by IoU thresholds, classes and
649
+ # confidence bins.
650
+
651
+ # (batch_size, num_detections, num_iou_thresholds)
652
+ detection_is_tp, _ = self._matching_algorithm(
653
+ detection_to_gt_ious, detection_classes, detection_scores, gt_classes
654
+ )
655
+ # (batch_size * num_detections,)
656
+ flattened_binned_confidence = tf.reshape(
657
+ tf.cast(detection_scores * self._num_confidence_bins, tf.int32), [-1]
658
+ )
659
+ # (batch_size * num_detections, num_confidence_bins + 1)
660
+ flattened_binned_confidence_one_hot = tf.one_hot(
661
+ flattened_binned_confidence, self._num_confidence_bins + 1, axis=1
662
+ )
663
+ # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
664
+ tp_count = _count_detection_type(
665
+ detection_is_tp,
666
+ detection_classes,
667
+ flattened_binned_confidence_one_hot,
668
+ self._num_classes,
669
+ )
670
+
671
+ # Step 3: Counts false positives grouped by IoU thresholds, classes and
672
+ # confidence bins.
673
+ # False positive: detection is not true positive (see above) and not part of
674
+ # the crowd ground truth with the same class.
675
+
676
+ # (batch_size, num_detections, num_gts, num_iou_thresholds)
677
+ detection_matches_crowd = (
678
+ (detection_to_crowd_ioas[..., tf.newaxis] > self._iou_thresholds)
679
+ & (
680
+ detection_classes[:, :, tf.newaxis, tf.newaxis]
681
+ == gt_classes[:, tf.newaxis, :, tf.newaxis]
682
+ )
683
+ & (detection_classes[:, :, tf.newaxis, tf.newaxis] > 0)
684
+ )
685
+ # (batch_size, num_detections, num_iou_thresholds)
686
+ detection_matches_any_crowd = tf.reduce_any(
687
+ detection_matches_crowd & ~detection_is_tp[:, :, tf.newaxis, :], axis=2
688
+ )
689
+ detection_is_fp = ~detection_is_tp & ~detection_matches_any_crowd
690
+ # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
691
+ fp_count = _count_detection_type(
692
+ detection_is_fp,
693
+ detection_classes,
694
+ flattened_binned_confidence_one_hot,
695
+ self._num_classes,
696
+ )
697
+
698
+ # Step 4: Counts non-crowd groundtruths grouped by classes.
699
+ # (num_classes, )
700
+ gt_count = tf.reduce_sum(
701
+ tf.one_hot(
702
+ tf.where(gt_is_crowd, -1, gt_classes), self._num_classes, axis=-1
703
+ ),
704
+ axis=[0, 1],
705
+ )
706
+ # Clears the count of class 0 (background).
707
+ gt_count *= 1.0 - tf.eye(1, self._num_classes, dtype=gt_count.dtype)[0]
799
708
 
800
709
  # Accumulates the variables.
801
710
  self.fp_count.assign_add(tf.cast(fp_count, self.fp_count.dtype))
@@ -818,13 +727,55 @@ class InstanceMetrics(tf.keras.metrics.Metric):
818
727
  'valid_classes': a bool tensor in shape (num_classes,). If False, there
819
728
  is no instance of the class in the ground truth.
820
729
  """
821
- result = _compute_metrics(
822
- fp_count=self.fp_count,
823
- tp_count=self.tp_count,
824
- gt_count=self.gt_count,
825
- confidence_thresholds=self._confidence_thresholds,
826
- num_confidence_bins=self._num_confidence_bins,
827
- average_precision_algorithms=self._average_precision_algorithms)
730
+ result = {
731
+ # (num_classes,)
732
+ 'valid_classes': self.gt_count != 0,
733
+ }
734
+
735
+ # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
736
+ tp_count_cum_by_confidence = tf.math.cumsum(
737
+ self.tp_count, axis=-1, reverse=True
738
+ )
739
+ # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
740
+ fp_count_cum_by_confidence = tf.math.cumsum(
741
+ self.fp_count, axis=-1, reverse=True
742
+ )
743
+
744
+ # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
745
+ precisions = tf.math.divide_no_nan(
746
+ tp_count_cum_by_confidence,
747
+ tp_count_cum_by_confidence + fp_count_cum_by_confidence,
748
+ )
749
+ # (num_iou_thresholds, num_classes, num_confidence_bins + 1)
750
+ recalls = tf.math.divide_no_nan(
751
+ tp_count_cum_by_confidence, self.gt_count[..., tf.newaxis]
752
+ )
753
+
754
+ if self._confidence_thresholds:
755
+ # If confidence_thresholds is set, reports precision and recall at each
756
+ # confidence threshold.
757
+ confidence_thresholds = tf.cast(
758
+ tf.constant(self._confidence_thresholds, dtype=tf.float32)
759
+ * self._num_confidence_bins,
760
+ dtype=tf.int32,
761
+ )
762
+ # (num_confidence_thresholds, num_iou_thresholds, num_classes)
763
+ result['precisions'] = tf.gather(
764
+ tf.transpose(precisions, [2, 0, 1]), confidence_thresholds
765
+ )
766
+ result['recalls'] = tf.gather(
767
+ tf.transpose(recalls, [2, 0, 1]), confidence_thresholds
768
+ )
769
+
770
+ precisions = tf.reverse(precisions, axis=[-1])
771
+ recalls = tf.reverse(recalls, axis=[-1])
772
+ result.update(
773
+ {
774
+ # (num_iou_thresholds, num_classes)
775
+ key: ap_algorithm(precisions, recalls)
776
+ for key, ap_algorithm in self._average_precision_algorithms.items()
777
+ }
778
+ )
828
779
  return result
829
780
 
830
781
  def get_average_precision_metrics_keys(self):