tf-models-nightly 2.11.0.dev20230321__py2.py3-none-any.whl → 2.11.0.dev20230323__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/modeling/hyperparams/base_config.py +20 -2
- official/modeling/hyperparams/base_config_test.py +29 -0
- official/projects/yt8m/configs/yt8m.py +1 -4
- official/projects/yt8m/modeling/nn_layers.py +167 -26
- official/projects/yt8m/modeling/yt8m_model.py +44 -182
- official/projects/yt8m/modeling/yt8m_model_utils.py +5 -6
- official/projects/yt8m/tasks/yt8m_task.py +42 -25
- official/vision/dataloaders/maskrcnn_input.py +12 -13
- official/vision/dataloaders/tfds_classification_decoders.py +1 -0
- official/vision/evaluation/instance_metrics.py +176 -225
- official/vision/ops/augment.py +45 -33
- official/vision/ops/augment_test.py +9 -0
- official/vision/ops/preprocess_ops.py +20 -6
- official/vision/serving/export_tflite_lib.py +20 -8
- {tf_models_nightly-2.11.0.dev20230321.dist-info → tf_models_nightly-2.11.0.dev20230323.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.11.0.dev20230321.dist-info → tf_models_nightly-2.11.0.dev20230323.dist-info}/RECORD +20 -20
- {tf_models_nightly-2.11.0.dev20230321.dist-info → tf_models_nightly-2.11.0.dev20230323.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.11.0.dev20230321.dist-info → tf_models_nightly-2.11.0.dev20230323.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.11.0.dev20230321.dist-info → tf_models_nightly-2.11.0.dev20230323.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.11.0.dev20230321.dist-info → tf_models_nightly-2.11.0.dev20230323.dist-info}/top_level.txt +0 -0
@@ -45,15 +45,10 @@ class YT8MTask(base_task.Task):
|
|
45
45
|
l2_weight_decay = self.task_config.losses.l2_weight_decay
|
46
46
|
# Model configuration.
|
47
47
|
model_config = self.task_config.model
|
48
|
-
norm_activation_config = model_config.norm_activation
|
49
48
|
model = DbofModel(
|
50
49
|
params=model_config,
|
51
50
|
input_specs=input_specs,
|
52
51
|
num_classes=train_cfg.num_classes,
|
53
|
-
activation=norm_activation_config.activation,
|
54
|
-
use_sync_bn=norm_activation_config.use_sync_bn,
|
55
|
-
norm_momentum=norm_activation_config.norm_momentum,
|
56
|
-
norm_epsilon=norm_activation_config.norm_epsilon,
|
57
52
|
l2_weight_decay=l2_weight_decay)
|
58
53
|
|
59
54
|
non_trainable_batch_norm_variables = []
|
@@ -66,18 +61,32 @@ class YT8MTask(base_task.Task):
|
|
66
61
|
non_trainable_extra_variables.append(var)
|
67
62
|
|
68
63
|
logging.info(
|
69
|
-
'Trainable model variables:\n%s',
|
70
|
-
|
64
|
+
'Trainable model variables:\n%s',
|
65
|
+
'\n'.join(
|
66
|
+
[f'{var.name}\t{var.shape}' for var in model.trainable_variables]
|
67
|
+
),
|
68
|
+
)
|
71
69
|
logging.info(
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
70
|
+
(
|
71
|
+
'Non-trainable batch norm variables (get updated in training'
|
72
|
+
' mode):\n%s'
|
73
|
+
),
|
74
|
+
'\n'.join(
|
75
|
+
[
|
76
|
+
f'{var.name}\t{var.shape}'
|
77
|
+
for var in non_trainable_batch_norm_variables
|
78
|
+
]
|
79
|
+
),
|
80
|
+
)
|
77
81
|
logging.info(
|
78
|
-
'Non-trainable frozen model variables:\n%s',
|
79
|
-
|
80
|
-
|
82
|
+
'Non-trainable frozen model variables:\n%s',
|
83
|
+
'\n'.join(
|
84
|
+
[
|
85
|
+
f'{var.name}\t{var.shape}'
|
86
|
+
for var in non_trainable_extra_variables
|
87
|
+
]
|
88
|
+
),
|
89
|
+
)
|
81
90
|
return model
|
82
91
|
|
83
92
|
def build_inputs(self, params: yt8m_cfg.DataConfig, input_context=None):
|
@@ -173,7 +182,10 @@ class YT8MTask(base_task.Task):
|
|
173
182
|
for name in metric_names:
|
174
183
|
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
|
175
184
|
|
176
|
-
if
|
185
|
+
if (
|
186
|
+
self.task_config.evaluation.average_precision is not None
|
187
|
+
and not training
|
188
|
+
):
|
177
189
|
# Cannot run in train step.
|
178
190
|
num_classes = self.task_config.validation_data.num_classes
|
179
191
|
top_k = self.task_config.evaluation.average_precision.top_k
|
@@ -183,14 +195,16 @@ class YT8MTask(base_task.Task):
|
|
183
195
|
|
184
196
|
return metrics
|
185
197
|
|
186
|
-
def process_metrics(
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
198
|
+
def process_metrics(
|
199
|
+
self,
|
200
|
+
metrics: List[tf.keras.metrics.Metric],
|
201
|
+
labels: tf.Tensor,
|
202
|
+
outputs: tf.Tensor,
|
203
|
+
model_losses: Optional[Dict[str, tf.Tensor]] = None,
|
204
|
+
label_weights: Optional[tf.Tensor] = None,
|
205
|
+
training: bool = True,
|
206
|
+
**kwargs,
|
207
|
+
) -> Dict[str, Tuple[tf.Tensor, ...]]:
|
194
208
|
"""Updates metrics.
|
195
209
|
|
196
210
|
Args:
|
@@ -210,7 +224,10 @@ class YT8MTask(base_task.Task):
|
|
210
224
|
model_losses = {}
|
211
225
|
|
212
226
|
logs = {}
|
213
|
-
if
|
227
|
+
if (
|
228
|
+
self.task_config.evaluation.average_precision is not None
|
229
|
+
and not training
|
230
|
+
):
|
214
231
|
logs.update({self.avg_prec_metric.name: (labels, outputs)})
|
215
232
|
|
216
233
|
for m in metrics:
|
@@ -211,19 +211,18 @@ class Parser(parser.Parser):
|
|
211
211
|
image = preprocess_ops.normalize_image(image)
|
212
212
|
|
213
213
|
# Flips image randomly during training.
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
image, boxes, _ = preprocess_ops.random_vertical_flip(image, boxes)
|
214
|
+
image, boxes, masks = preprocess_ops.random_horizontal_flip(
|
215
|
+
image,
|
216
|
+
boxes,
|
217
|
+
masks=None if not self._include_mask else masks,
|
218
|
+
prob=tf.where(self._aug_rand_hflip, 0.5, 0.0),
|
219
|
+
)
|
220
|
+
image, boxes, masks = preprocess_ops.random_vertical_flip(
|
221
|
+
image,
|
222
|
+
boxes,
|
223
|
+
masks=None if not self._include_mask else masks,
|
224
|
+
prob=tf.where(self._aug_rand_vflip, 0.5, 0.0),
|
225
|
+
)
|
227
226
|
|
228
227
|
# Converts boxes from normalized coordinates to pixel coordinates.
|
229
228
|
# Now the coordinates of boxes are w.r.t. the original image.
|
@@ -465,213 +465,6 @@ def _count_detection_type(
|
|
465
465
|
return count
|
466
466
|
|
467
467
|
|
468
|
-
def _compute_fp_tp_gt_count(
|
469
|
-
y_true: Dict[str, tf.Tensor],
|
470
|
-
y_pred: Dict[str, tf.Tensor],
|
471
|
-
num_classes: int,
|
472
|
-
mask_output_boundary: Tuple[int, int] = (640, 640),
|
473
|
-
iou_thresholds: Tuple[float, ...] = (0.5,),
|
474
|
-
matching_algorithm: Optional[MatchingAlgorithm] = None,
|
475
|
-
num_confidence_bins: int = 1000,
|
476
|
-
use_masks: bool = False,
|
477
|
-
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
478
|
-
"""Computes the true and false positives."""
|
479
|
-
|
480
|
-
if matching_algorithm is None:
|
481
|
-
matching_algorithm = COCOMatchingAlgorithm(iou_thresholds)
|
482
|
-
|
483
|
-
# (batch_size, num_detections, 4) in absolute coordinates.
|
484
|
-
detection_boxes = tf.cast(y_pred['detection_boxes'], tf.float32)
|
485
|
-
# (batch_size, num_detections)
|
486
|
-
detection_classes = tf.cast(y_pred['detection_classes'], tf.int32)
|
487
|
-
# (batch_size, num_detections)
|
488
|
-
detection_scores = tf.cast(y_pred['detection_scores'], tf.float32)
|
489
|
-
# (batch_size, num_gts, 4) in absolute coordinates.
|
490
|
-
gt_boxes = tf.cast(y_true['boxes'], tf.float32)
|
491
|
-
# (batch_size, num_gts)
|
492
|
-
gt_classes = tf.cast(y_true['classes'], tf.int32)
|
493
|
-
# (batch_size, num_gts)
|
494
|
-
if 'is_crowds' in y_true:
|
495
|
-
gt_is_crowd = tf.cast(y_true['is_crowds'], tf.bool)
|
496
|
-
else:
|
497
|
-
gt_is_crowd = tf.zeros_like(gt_classes, dtype=tf.bool)
|
498
|
-
|
499
|
-
image_scale = tf.tile(y_true['image_info'][:, 2:3, :], multiples=[1, 1, 2])
|
500
|
-
detection_boxes = detection_boxes / tf.cast(
|
501
|
-
image_scale, dtype=detection_boxes.dtype
|
502
|
-
)
|
503
|
-
|
504
|
-
# Step 1: Computes IoUs between the detections and the non-crowd ground
|
505
|
-
# truths and IoAs between the detections and the crowd ground truths.
|
506
|
-
if not use_masks:
|
507
|
-
# (batch_size, num_detections, num_gts)
|
508
|
-
detection_to_gt_ious = box_ops.bbox_overlap(detection_boxes, gt_boxes)
|
509
|
-
detection_to_gt_ioas = box_ops.bbox_intersection_over_area(
|
510
|
-
detection_boxes, gt_boxes
|
511
|
-
)
|
512
|
-
else:
|
513
|
-
# (batch_size, num_detections, mask_height, mask_width)
|
514
|
-
detection_masks = tf.cast(y_pred['detection_masks'], tf.float32)
|
515
|
-
# (batch_size, num_gts, gt_mask_height, gt_mask_width)
|
516
|
-
gt_masks = tf.cast(y_true['masks'], tf.float32)
|
517
|
-
|
518
|
-
num_detections = detection_boxes.get_shape()[1]
|
519
|
-
# (batch_size, num_detections + num_gts, 4)
|
520
|
-
all_boxes = _shift_and_rescale_boxes(
|
521
|
-
tf.concat([detection_boxes, gt_boxes], axis=1),
|
522
|
-
mask_output_boundary,
|
523
|
-
)
|
524
|
-
detection_boxes = all_boxes[:, :num_detections, :]
|
525
|
-
gt_boxes = all_boxes[:, num_detections:, :]
|
526
|
-
# (batch_size, num_detections, num_gts)
|
527
|
-
detection_to_gt_ious, detection_to_gt_ioas = (
|
528
|
-
mask_ops.instance_masks_overlap(
|
529
|
-
detection_boxes,
|
530
|
-
detection_masks,
|
531
|
-
gt_boxes,
|
532
|
-
gt_masks,
|
533
|
-
output_size=mask_output_boundary,
|
534
|
-
)
|
535
|
-
)
|
536
|
-
|
537
|
-
# (batch_size, num_detections, num_gts)
|
538
|
-
detection_to_gt_ious = tf.where(
|
539
|
-
gt_is_crowd[:, tf.newaxis, :], 0.0, detection_to_gt_ious
|
540
|
-
)
|
541
|
-
detection_to_crowd_ioas = tf.where(
|
542
|
-
gt_is_crowd[:, tf.newaxis, :], detection_to_gt_ioas, 0.0
|
543
|
-
)
|
544
|
-
|
545
|
-
# Step 2: counts true positives grouped by IoU thresholds, classes and
|
546
|
-
# confidence bins.
|
547
|
-
|
548
|
-
# (batch_size, num_detections, num_iou_thresholds)
|
549
|
-
detection_is_tp, _ = matching_algorithm(
|
550
|
-
detection_to_gt_ious, detection_classes, detection_scores, gt_classes
|
551
|
-
)
|
552
|
-
# (batch_size * num_detections,)
|
553
|
-
flattened_binned_confidence = tf.reshape(
|
554
|
-
tf.cast(detection_scores * num_confidence_bins, tf.int32), [-1]
|
555
|
-
)
|
556
|
-
# (batch_size * num_detections, num_confidence_bins + 1)
|
557
|
-
flattened_binned_confidence_one_hot = tf.one_hot(
|
558
|
-
flattened_binned_confidence, num_confidence_bins + 1, axis=1
|
559
|
-
)
|
560
|
-
# (num_iou_thresholds, num_classes, num_confidence_bins + 1)
|
561
|
-
tp_count = _count_detection_type(
|
562
|
-
detection_is_tp,
|
563
|
-
detection_classes,
|
564
|
-
flattened_binned_confidence_one_hot,
|
565
|
-
num_classes,
|
566
|
-
)
|
567
|
-
|
568
|
-
# Step 3: Counts false positives grouped by IoU thresholds, classes and
|
569
|
-
# confidence bins.
|
570
|
-
# False positive: detection is not true positive (see above) and not part of
|
571
|
-
# the crowd ground truth with the same class.
|
572
|
-
|
573
|
-
# (batch_size, num_detections, num_gts, num_iou_thresholds)
|
574
|
-
detection_matches_crowd = (
|
575
|
-
(detection_to_crowd_ioas[..., tf.newaxis] > iou_thresholds)
|
576
|
-
& (
|
577
|
-
detection_classes[:, :, tf.newaxis, tf.newaxis]
|
578
|
-
== gt_classes[:, tf.newaxis, :, tf.newaxis]
|
579
|
-
)
|
580
|
-
& (detection_classes[:, :, tf.newaxis, tf.newaxis] > 0)
|
581
|
-
)
|
582
|
-
# (batch_size, num_detections, num_iou_thresholds)
|
583
|
-
detection_matches_any_crowd = tf.reduce_any(
|
584
|
-
detection_matches_crowd & ~detection_is_tp[:, :, tf.newaxis, :], axis=2
|
585
|
-
)
|
586
|
-
detection_is_fp = ~detection_is_tp & ~detection_matches_any_crowd
|
587
|
-
# (num_iou_thresholds, num_classes, num_confidence_bins + 1)
|
588
|
-
fp_count = _count_detection_type(
|
589
|
-
detection_is_fp,
|
590
|
-
detection_classes,
|
591
|
-
flattened_binned_confidence_one_hot,
|
592
|
-
num_classes,
|
593
|
-
)
|
594
|
-
|
595
|
-
# Step 4: Counts non-crowd groundtruths grouped by classes.
|
596
|
-
# (num_classes, )
|
597
|
-
gt_count = tf.reduce_sum(
|
598
|
-
tf.one_hot(
|
599
|
-
tf.where(gt_is_crowd, -1, gt_classes), num_classes, axis=-1
|
600
|
-
),
|
601
|
-
axis=[0, 1],
|
602
|
-
)
|
603
|
-
# Clears the count of class 0 (background).
|
604
|
-
gt_count *= 1.0 - tf.eye(1, num_classes, dtype=gt_count.dtype)[0]
|
605
|
-
|
606
|
-
return tp_count, fp_count, gt_count
|
607
|
-
|
608
|
-
|
609
|
-
def _compute_metrics(
|
610
|
-
tp_count: tf.Tensor,
|
611
|
-
fp_count: tf.Tensor,
|
612
|
-
gt_count: tf.Tensor,
|
613
|
-
confidence_thresholds: Tuple[float, ...] = (),
|
614
|
-
num_confidence_bins: int = 1000,
|
615
|
-
average_precision_algorithms: Optional[
|
616
|
-
Dict[str, AveragePrecision]] = None,
|
617
|
-
) -> Dict[str, tf.Tensor]:
|
618
|
-
"""Returns the metrics values as a dict."""
|
619
|
-
|
620
|
-
if average_precision_algorithms is None:
|
621
|
-
average_precision_algorithms = {'ap': COCOAveragePrecision()}
|
622
|
-
|
623
|
-
result = {
|
624
|
-
# (num_classes,)
|
625
|
-
'valid_classes': gt_count != 0,
|
626
|
-
}
|
627
|
-
|
628
|
-
# (num_iou_thresholds, num_classes, num_confidence_bins + 1)
|
629
|
-
tp_count_cum_by_confidence = tf.math.cumsum(
|
630
|
-
tp_count, axis=-1, reverse=True
|
631
|
-
)
|
632
|
-
# (num_iou_thresholds, num_classes, num_confidence_bins + 1)
|
633
|
-
fp_count_cum_by_confidence = tf.math.cumsum(
|
634
|
-
fp_count, axis=-1, reverse=True
|
635
|
-
)
|
636
|
-
|
637
|
-
# (num_iou_thresholds, num_classes, num_confidence_bins + 1)
|
638
|
-
precisions = tf.math.divide_no_nan(
|
639
|
-
tp_count_cum_by_confidence,
|
640
|
-
tp_count_cum_by_confidence + fp_count_cum_by_confidence,
|
641
|
-
)
|
642
|
-
# (num_iou_thresholds, num_classes, num_confidence_bins + 1)
|
643
|
-
recalls = tf.math.divide_no_nan(
|
644
|
-
tp_count_cum_by_confidence, gt_count[..., tf.newaxis]
|
645
|
-
)
|
646
|
-
|
647
|
-
if confidence_thresholds:
|
648
|
-
# If confidence_thresholds is set, reports precision and recall at each
|
649
|
-
# confidence threshold.
|
650
|
-
confidence_thresholds = tf.cast(
|
651
|
-
tf.constant(confidence_thresholds, dtype=tf.float32)
|
652
|
-
* num_confidence_bins,
|
653
|
-
dtype=tf.int32,
|
654
|
-
)
|
655
|
-
# (num_confidence_thresholds, num_iou_thresholds, num_classes)
|
656
|
-
result['precisions'] = tf.gather(
|
657
|
-
tf.transpose(precisions, [2, 0, 1]), confidence_thresholds
|
658
|
-
)
|
659
|
-
result['recalls'] = tf.gather(
|
660
|
-
tf.transpose(recalls, [2, 0, 1]), confidence_thresholds
|
661
|
-
)
|
662
|
-
|
663
|
-
precisions = tf.reverse(precisions, axis=[-1])
|
664
|
-
recalls = tf.reverse(recalls, axis=[-1])
|
665
|
-
result.update(
|
666
|
-
{
|
667
|
-
# (num_iou_thresholds, num_classes)
|
668
|
-
key: ap_algorithm(precisions, recalls)
|
669
|
-
for key, ap_algorithm in average_precision_algorithms.items()
|
670
|
-
}
|
671
|
-
)
|
672
|
-
return result
|
673
|
-
|
674
|
-
|
675
468
|
class InstanceMetrics(tf.keras.metrics.Metric):
|
676
469
|
"""Reports the metrics of instance detection & segmentation."""
|
677
470
|
|
@@ -780,22 +573,138 @@ class InstanceMetrics(tf.keras.metrics.Metric):
|
|
780
573
|
|
781
574
|
def reset_state(self):
|
782
575
|
"""Resets all of the metric state variables."""
|
783
|
-
|
784
|
-
|
576
|
+
self.tp_count.assign(tf.zeros_like(self.tp_count))
|
577
|
+
self.fp_count.assign(tf.zeros_like(self.fp_count))
|
578
|
+
self.gt_count.assign(tf.zeros_like(self.gt_count))
|
785
579
|
|
786
580
|
def update_state(
|
787
581
|
self, y_true: Dict[str, tf.Tensor], y_pred: Dict[str, tf.Tensor]
|
788
582
|
):
|
583
|
+
# (batch_size, num_detections, 4) in absolute coordinates.
|
584
|
+
detection_boxes = tf.cast(y_pred['detection_boxes'], tf.float32)
|
585
|
+
# (batch_size, num_detections)
|
586
|
+
detection_classes = tf.cast(y_pred['detection_classes'], tf.int32)
|
587
|
+
# (batch_size, num_detections)
|
588
|
+
detection_scores = tf.cast(y_pred['detection_scores'], tf.float32)
|
589
|
+
# (batch_size, num_gts, 4) in absolute coordinates.
|
590
|
+
gt_boxes = tf.cast(y_true['boxes'], tf.float32)
|
591
|
+
# (batch_size, num_gts)
|
592
|
+
gt_classes = tf.cast(y_true['classes'], tf.int32)
|
593
|
+
# (batch_size, num_gts)
|
594
|
+
if 'is_crowds' in y_true:
|
595
|
+
gt_is_crowd = tf.cast(y_true['is_crowds'], tf.bool)
|
596
|
+
else:
|
597
|
+
gt_is_crowd = tf.zeros_like(gt_classes, dtype=tf.bool)
|
789
598
|
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
599
|
+
image_scale = tf.tile(y_true['image_info'][:, 2:3, :], multiples=[1, 1, 2])
|
600
|
+
detection_boxes = detection_boxes / tf.cast(
|
601
|
+
image_scale, dtype=detection_boxes.dtype
|
602
|
+
)
|
603
|
+
|
604
|
+
# Step 1: Computes IoUs between the detections and the non-crowd ground
|
605
|
+
# truths and IoAs between the detections and the crowd ground truths.
|
606
|
+
if not self._use_masks:
|
607
|
+
# (batch_size, num_detections, num_gts)
|
608
|
+
detection_to_gt_ious = box_ops.bbox_overlap(detection_boxes, gt_boxes)
|
609
|
+
detection_to_gt_ioas = box_ops.bbox_intersection_over_area(
|
610
|
+
detection_boxes, gt_boxes
|
611
|
+
)
|
612
|
+
else:
|
613
|
+
# Use outer boxes to generate the masks if available.
|
614
|
+
if 'detection_outer_boxes' in y_pred:
|
615
|
+
detection_boxes = tf.cast(y_pred['detection_outer_boxes'], tf.float32)
|
616
|
+
|
617
|
+
# (batch_size, num_detections, mask_height, mask_width)
|
618
|
+
detection_masks = tf.cast(y_pred['detection_masks'], tf.float32)
|
619
|
+
# (batch_size, num_gts, gt_mask_height, gt_mask_width)
|
620
|
+
gt_masks = tf.cast(y_true['masks'], tf.float32)
|
621
|
+
|
622
|
+
num_detections = detection_boxes.get_shape()[1]
|
623
|
+
# (batch_size, num_detections + num_gts, 4)
|
624
|
+
all_boxes = _shift_and_rescale_boxes(
|
625
|
+
tf.concat([detection_boxes, gt_boxes], axis=1),
|
626
|
+
self._mask_output_boundary,
|
627
|
+
)
|
628
|
+
detection_boxes = all_boxes[:, :num_detections, :]
|
629
|
+
gt_boxes = all_boxes[:, num_detections:, :]
|
630
|
+
# (batch_size, num_detections, num_gts)
|
631
|
+
detection_to_gt_ious, detection_to_gt_ioas = (
|
632
|
+
mask_ops.instance_masks_overlap(
|
633
|
+
detection_boxes,
|
634
|
+
detection_masks,
|
635
|
+
gt_boxes,
|
636
|
+
gt_masks,
|
637
|
+
output_size=self._mask_output_boundary,
|
638
|
+
)
|
639
|
+
)
|
640
|
+
# (batch_size, num_detections, num_gts)
|
641
|
+
detection_to_gt_ious = tf.where(
|
642
|
+
gt_is_crowd[:, tf.newaxis, :], 0.0, detection_to_gt_ious
|
643
|
+
)
|
644
|
+
detection_to_crowd_ioas = tf.where(
|
645
|
+
gt_is_crowd[:, tf.newaxis, :], detection_to_gt_ioas, 0.0
|
646
|
+
)
|
647
|
+
|
648
|
+
# Step 2: counts true positives grouped by IoU thresholds, classes and
|
649
|
+
# confidence bins.
|
650
|
+
|
651
|
+
# (batch_size, num_detections, num_iou_thresholds)
|
652
|
+
detection_is_tp, _ = self._matching_algorithm(
|
653
|
+
detection_to_gt_ious, detection_classes, detection_scores, gt_classes
|
654
|
+
)
|
655
|
+
# (batch_size * num_detections,)
|
656
|
+
flattened_binned_confidence = tf.reshape(
|
657
|
+
tf.cast(detection_scores * self._num_confidence_bins, tf.int32), [-1]
|
658
|
+
)
|
659
|
+
# (batch_size * num_detections, num_confidence_bins + 1)
|
660
|
+
flattened_binned_confidence_one_hot = tf.one_hot(
|
661
|
+
flattened_binned_confidence, self._num_confidence_bins + 1, axis=1
|
662
|
+
)
|
663
|
+
# (num_iou_thresholds, num_classes, num_confidence_bins + 1)
|
664
|
+
tp_count = _count_detection_type(
|
665
|
+
detection_is_tp,
|
666
|
+
detection_classes,
|
667
|
+
flattened_binned_confidence_one_hot,
|
668
|
+
self._num_classes,
|
669
|
+
)
|
670
|
+
|
671
|
+
# Step 3: Counts false positives grouped by IoU thresholds, classes and
|
672
|
+
# confidence bins.
|
673
|
+
# False positive: detection is not true positive (see above) and not part of
|
674
|
+
# the crowd ground truth with the same class.
|
675
|
+
|
676
|
+
# (batch_size, num_detections, num_gts, num_iou_thresholds)
|
677
|
+
detection_matches_crowd = (
|
678
|
+
(detection_to_crowd_ioas[..., tf.newaxis] > self._iou_thresholds)
|
679
|
+
& (
|
680
|
+
detection_classes[:, :, tf.newaxis, tf.newaxis]
|
681
|
+
== gt_classes[:, tf.newaxis, :, tf.newaxis]
|
682
|
+
)
|
683
|
+
& (detection_classes[:, :, tf.newaxis, tf.newaxis] > 0)
|
684
|
+
)
|
685
|
+
# (batch_size, num_detections, num_iou_thresholds)
|
686
|
+
detection_matches_any_crowd = tf.reduce_any(
|
687
|
+
detection_matches_crowd & ~detection_is_tp[:, :, tf.newaxis, :], axis=2
|
688
|
+
)
|
689
|
+
detection_is_fp = ~detection_is_tp & ~detection_matches_any_crowd
|
690
|
+
# (num_iou_thresholds, num_classes, num_confidence_bins + 1)
|
691
|
+
fp_count = _count_detection_type(
|
692
|
+
detection_is_fp,
|
693
|
+
detection_classes,
|
694
|
+
flattened_binned_confidence_one_hot,
|
695
|
+
self._num_classes,
|
696
|
+
)
|
697
|
+
|
698
|
+
# Step 4: Counts non-crowd groundtruths grouped by classes.
|
699
|
+
# (num_classes, )
|
700
|
+
gt_count = tf.reduce_sum(
|
701
|
+
tf.one_hot(
|
702
|
+
tf.where(gt_is_crowd, -1, gt_classes), self._num_classes, axis=-1
|
703
|
+
),
|
704
|
+
axis=[0, 1],
|
705
|
+
)
|
706
|
+
# Clears the count of class 0 (background).
|
707
|
+
gt_count *= 1.0 - tf.eye(1, self._num_classes, dtype=gt_count.dtype)[0]
|
799
708
|
|
800
709
|
# Accumulates the variables.
|
801
710
|
self.fp_count.assign_add(tf.cast(fp_count, self.fp_count.dtype))
|
@@ -818,13 +727,55 @@ class InstanceMetrics(tf.keras.metrics.Metric):
|
|
818
727
|
'valid_classes': a bool tensor in shape (num_classes,). If False, there
|
819
728
|
is no instance of the class in the ground truth.
|
820
729
|
"""
|
821
|
-
result =
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
730
|
+
result = {
|
731
|
+
# (num_classes,)
|
732
|
+
'valid_classes': self.gt_count != 0,
|
733
|
+
}
|
734
|
+
|
735
|
+
# (num_iou_thresholds, num_classes, num_confidence_bins + 1)
|
736
|
+
tp_count_cum_by_confidence = tf.math.cumsum(
|
737
|
+
self.tp_count, axis=-1, reverse=True
|
738
|
+
)
|
739
|
+
# (num_iou_thresholds, num_classes, num_confidence_bins + 1)
|
740
|
+
fp_count_cum_by_confidence = tf.math.cumsum(
|
741
|
+
self.fp_count, axis=-1, reverse=True
|
742
|
+
)
|
743
|
+
|
744
|
+
# (num_iou_thresholds, num_classes, num_confidence_bins + 1)
|
745
|
+
precisions = tf.math.divide_no_nan(
|
746
|
+
tp_count_cum_by_confidence,
|
747
|
+
tp_count_cum_by_confidence + fp_count_cum_by_confidence,
|
748
|
+
)
|
749
|
+
# (num_iou_thresholds, num_classes, num_confidence_bins + 1)
|
750
|
+
recalls = tf.math.divide_no_nan(
|
751
|
+
tp_count_cum_by_confidence, self.gt_count[..., tf.newaxis]
|
752
|
+
)
|
753
|
+
|
754
|
+
if self._confidence_thresholds:
|
755
|
+
# If confidence_thresholds is set, reports precision and recall at each
|
756
|
+
# confidence threshold.
|
757
|
+
confidence_thresholds = tf.cast(
|
758
|
+
tf.constant(self._confidence_thresholds, dtype=tf.float32)
|
759
|
+
* self._num_confidence_bins,
|
760
|
+
dtype=tf.int32,
|
761
|
+
)
|
762
|
+
# (num_confidence_thresholds, num_iou_thresholds, num_classes)
|
763
|
+
result['precisions'] = tf.gather(
|
764
|
+
tf.transpose(precisions, [2, 0, 1]), confidence_thresholds
|
765
|
+
)
|
766
|
+
result['recalls'] = tf.gather(
|
767
|
+
tf.transpose(recalls, [2, 0, 1]), confidence_thresholds
|
768
|
+
)
|
769
|
+
|
770
|
+
precisions = tf.reverse(precisions, axis=[-1])
|
771
|
+
recalls = tf.reverse(recalls, axis=[-1])
|
772
|
+
result.update(
|
773
|
+
{
|
774
|
+
# (num_iou_thresholds, num_classes)
|
775
|
+
key: ap_algorithm(precisions, recalls)
|
776
|
+
for key, ap_algorithm in self._average_precision_algorithms.items()
|
777
|
+
}
|
778
|
+
)
|
828
779
|
return result
|
829
780
|
|
830
781
|
def get_average_precision_metrics_keys(self):
|