tf-models-nightly 2.18.0.dev20240820__py2.py3-none-any.whl → 2.18.0.dev20240821__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,641 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Panoptic MaskRCNN task definition."""
16
+ from typing import Any, Dict, List, Mapping, Optional, Tuple
17
+
18
+ import tensorflow as tf, tf_keras
19
+
20
+ from official.common import dataset_fn
21
+ from official.core import task_factory
22
+ from official.projects.centernet.ops import loss_ops
23
+ from official.projects.maskconver.configs import maskconver as exp_cfg
24
+ from official.projects.maskconver.dataloaders import maskconver_segmentation_input
25
+ from official.projects.maskconver.dataloaders import panoptic_maskrcnn_input
26
+ from official.projects.maskconver.losses import maskconver_losses
27
+ from official.projects.maskconver.modeling import factory
28
+ from official.projects.volumetric_models.losses import segmentation_losses as volumeteric_segmentation_losses
29
+ from official.vision.dataloaders import input_reader_factory
30
+ from official.vision.dataloaders import segmentation_input
31
+ from official.vision.evaluation import panoptic_quality_evaluator
32
+ from official.vision.evaluation import segmentation_metrics
33
+ from official.vision.tasks import maskrcnn
34
+ from official.vision.tasks import semantic_segmentation
35
+
36
+
37
+ @task_factory.register_task_cls(exp_cfg.MaskConverTask)
38
+ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
39
+
40
+ """A single-replica view of training procedure.
41
+
42
+ Panoptic Mask R-CNN task provides artifacts for training/evalution procedures,
43
+ including loading/iterating over Datasets, initializing the model, calculating
44
+ the loss, post-processing, and customized metrics with reduction.
45
+ """
46
+
47
+ def build_model(self) -> tf_keras.Model:
48
+ """Build Panoptic Mask R-CNN model."""
49
+
50
+ tf_keras.utils.set_random_seed(0)
51
+ tf.config.experimental.enable_op_determinism()
52
+ input_specs = tf_keras.layers.InputSpec(
53
+ shape=[None] + self.task_config.model.input_size)
54
+
55
+ l2_weight_decay = self.task_config.losses.l2_weight_decay
56
+ # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
57
+ # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
58
+ # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
59
+ l2_regularizer = (tf_keras.regularizers.l2(
60
+ l2_weight_decay / 2.0) if l2_weight_decay else None)
61
+
62
+ model = factory.build_maskconver_model(
63
+ input_specs=input_specs,
64
+ model_config=self.task_config.model,
65
+ l2_regularizer=l2_regularizer)
66
+ return model
67
+
68
+ def build_inputs(
69
+ self,
70
+ params: exp_cfg.DataConfig,
71
+ input_context: Optional[tf.distribute.InputContext] = None
72
+ ) -> tf.data.Dataset:
73
+ """Build input dataset."""
74
+ decoder_cfg = params.decoder.get()
75
+ if params.decoder.type == 'simple_decoder':
76
+ decoder = panoptic_maskrcnn_input.TfExampleDecoder(
77
+ regenerate_source_id=decoder_cfg.regenerate_source_id,
78
+ mask_binarize_threshold=decoder_cfg.mask_binarize_threshold,
79
+ include_panoptic_masks=decoder_cfg.include_panoptic_masks,
80
+ panoptic_category_mask_key=decoder_cfg.panoptic_category_mask_key,
81
+ panoptic_instance_mask_key=decoder_cfg.panoptic_instance_mask_key)
82
+ else:
83
+ raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))
84
+
85
+ parser = panoptic_maskrcnn_input.Parser(
86
+ output_size=self.task_config.model.input_size[:2],
87
+ min_level=self.task_config.model.min_level,
88
+ max_level=self.task_config.model.max_level,
89
+ num_scales=self.task_config.model.anchor.num_scales,
90
+ aspect_ratios=self.task_config.model.anchor.aspect_ratios,
91
+ anchor_size=self.task_config.model.anchor.anchor_size,
92
+ dtype=params.dtype,
93
+ rpn_match_threshold=params.parser.rpn_match_threshold,
94
+ rpn_unmatched_threshold=params.parser.rpn_unmatched_threshold,
95
+ rpn_batch_size_per_im=params.parser.rpn_batch_size_per_im,
96
+ rpn_fg_fraction=params.parser.rpn_fg_fraction,
97
+ aug_rand_hflip=params.parser.aug_rand_hflip,
98
+ aug_scale_min=params.parser.aug_scale_min,
99
+ aug_scale_max=params.parser.aug_scale_max,
100
+ skip_crowd_during_training=params.parser.skip_crowd_during_training,
101
+ max_num_instances=self.task_config.model.num_instances,
102
+ mask_crop_size=params.parser.mask_crop_size,
103
+ segmentation_resize_eval_groundtruth=params.parser
104
+ .segmentation_resize_eval_groundtruth,
105
+ segmentation_groundtruth_padded_size=params.parser
106
+ .segmentation_groundtruth_padded_size,
107
+ segmentation_ignore_label=params.parser.segmentation_ignore_label,
108
+ panoptic_ignore_label=params.parser.panoptic_ignore_label,
109
+ include_panoptic_masks=params.parser.include_panoptic_masks,
110
+ num_panoptic_categories=self.task_config.model.num_classes,
111
+ num_thing_categories=self.task_config.model.num_thing_classes,
112
+ level=self.task_config.model.level,
113
+ gaussian_iou=params.parser.gaussaian_iou,
114
+ aug_type=params.parser.aug_type,
115
+ max_num_stuff_centers=params.parser.max_num_stuff_centers)
116
+
117
+ reader = input_reader_factory.input_reader_generator(
118
+ params,
119
+ dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
120
+ decoder_fn=decoder.decode,
121
+ parser_fn=parser.parse_fn(params.is_training))
122
+ dataset = reader.read(input_context=input_context)
123
+
124
+ return dataset
125
+
126
+ def build_losses(self,
127
+ outputs: Mapping[str, Any],
128
+ labels: Mapping[str, Any],
129
+ aux_losses: Optional[Any] = None,
130
+ step=None) -> Dict[str, tf.Tensor]:
131
+ """Build Panoptic Mask R-CNN losses."""
132
+ loss_params = self._task_config.losses
133
+
134
+ batch_size = tf.cast(tf.shape(labels['num_instances'])[0], tf.float32)
135
+ center_loss_fn = maskconver_losses.PenaltyReducedLogisticFocalLoss(
136
+ alpha=loss_params.alpha, beta=loss_params.beta)
137
+ mask_loss_fn = maskconver_losses.PenaltyReducedLogisticFocalLoss()
138
+
139
+ # Calculate center heatmap loss
140
+ # TODO(arashwan): add valid weights.
141
+ # output_unpad_image_shapes = labels['image_info'][:, 0, :]
142
+ # valid_anchor_weights = loss_ops.get_valid_anchor_weights_in_flattened_image( # pylint: disable=line-too-long
143
+ # output_unpad_image_shapes, h, w)
144
+ # valid_anchor_weights = tf.expand_dims(valid_anchor_weights, 2)
145
+
146
+ true_flattened_ct_heatmap = loss_ops.flatten_spatial_dimensions(
147
+ labels['panoptic_heatmaps'])
148
+ true_flattened_ct_heatmap = tf.cast(true_flattened_ct_heatmap, tf.float32)
149
+
150
+ pred_flattened_ct_heatmap = loss_ops.flatten_spatial_dimensions(
151
+ outputs['class_heatmaps'])
152
+ pred_flattened_ct_heatmap = tf.cast(pred_flattened_ct_heatmap, tf.float32)
153
+ center_padding_mask = 1 - labels['panoptic_padding_mask'][:, :, :, None]
154
+ center_padding_mask = tf.image.resize(
155
+ center_padding_mask, tf.shape(
156
+ labels['panoptic_heatmaps'])[1:3], method='nearest')
157
+ center_padding_mask = tf.maximum(center_padding_mask, 0.0)
158
+ center_padding_mask = center_padding_mask * tf.ones_like(labels['panoptic_heatmaps'])
159
+ weights_flattened_mask = loss_ops.flatten_spatial_dimensions(
160
+ center_padding_mask)
161
+ center_loss = center_loss_fn(
162
+ target_tensor=true_flattened_ct_heatmap,
163
+ prediction_tensor=pred_flattened_ct_heatmap,
164
+ weights=weights_flattened_mask)
165
+
166
+ center_loss = tf.reduce_sum(
167
+ center_loss / (labels['num_instances'][:, None, None] + 1.0)) / batch_size
168
+
169
+ gt_masks = labels['panoptic_masks']
170
+ gt_mask_weights = labels['panoptic_mask_weights'][:, None, None, :] * tf.ones_like(gt_masks)
171
+ panoptic_padding_mask = labels['panoptic_padding_mask'][:, :, :, None] * tf.ones_like(gt_masks)
172
+
173
+ true_flattened_masks = loss_ops.flatten_spatial_dimensions(
174
+ gt_masks)
175
+ true_flattened_ct_heatmap = tf.cast(true_flattened_ct_heatmap, tf.float32)
176
+ predicted_masks = tf.cast(outputs['mask_proposal_logits'], tf.float32)
177
+ predicted_masks = tf.image.resize(
178
+ predicted_masks, tf.shape(gt_masks)[1:3], method='bilinear')
179
+ pred_flattened_masks = loss_ops.flatten_spatial_dimensions(predicted_masks)
180
+ mask_loss = tf.cast(0.0, tf.float32)
181
+ mask_loss_fn = tf_keras.losses.BinaryCrossentropy(
182
+ from_logits=True,
183
+ label_smoothing=0.0,
184
+ axis=-1,
185
+ reduction=tf_keras.losses.Reduction.NONE,
186
+ name='binary_crossentropy')
187
+ mask_weights = tf.reshape(
188
+ tf.cast(true_flattened_masks >= 0, tf.float32),
189
+ [-1, 1]) * tf.reshape(gt_mask_weights, [-1, 1]) * tf.reshape(
190
+ (1 - panoptic_padding_mask), [-1, 1])
191
+ mask_loss = mask_loss_fn(
192
+ tf.reshape(gt_masks, [-1, 1]),
193
+ tf.reshape(pred_flattened_masks, [-1, 1]),
194
+ sample_weight=mask_weights)
195
+ mask_loss = tf.reduce_sum(mask_loss) / (tf.reduce_sum(mask_weights) + 1.0)
196
+
197
+ # Dice loss
198
+ _, h, w, _ = gt_masks.get_shape().as_list()
199
+ masked_predictions = tf.sigmoid(predicted_masks) * gt_mask_weights * (1 - panoptic_padding_mask)
200
+ masked_gt_masks = gt_masks * gt_mask_weights * (1 - panoptic_padding_mask)
201
+
202
+ masked_predictions = tf.transpose(masked_predictions, [0, 3, 1, 2])
203
+ masked_predictions = tf.reshape(masked_predictions, [-1, h, w, 1])
204
+ masked_gt_masks = tf.transpose(masked_gt_masks, [0, 3, 1, 2])
205
+ masked_gt_masks = tf.reshape(masked_gt_masks, [-1, h, w, 1])
206
+
207
+ dice_loss_fn = volumeteric_segmentation_losses.SegmentationLossDiceScore(
208
+ metric_type='adaptive', axis=(2, 3))
209
+ dice_loss = dice_loss_fn(logits=masked_predictions, labels=masked_gt_masks)
210
+
211
+ total_loss = (center_loss + loss_params.mask_weight * mask_loss + loss_params.mask_weight * dice_loss)
212
+
213
+ if aux_losses:
214
+ total_loss += tf.add_n(aux_losses)
215
+
216
+ total_loss = loss_params.loss_weight * total_loss
217
+
218
+ losses = {'total_loss': total_loss,
219
+ 'mask_loss': mask_loss,
220
+ 'center_loss': center_loss,
221
+ 'dice_loss': dice_loss}
222
+ return losses
223
+
224
+ def build_metrics(self, training: bool = True) -> List[
225
+ tf_keras.metrics.Metric]:
226
+ """Build detection metrics."""
227
+ metrics = []
228
+ if training:
229
+ metric_names = [
230
+ 'total_loss',
231
+ 'center_loss',
232
+ 'mask_loss',
233
+ 'dice_loss',
234
+ ]
235
+ for name in metric_names:
236
+ metrics.append(tf_keras.metrics.Mean(name, dtype=tf.float32))
237
+
238
+ else:
239
+ pq_config = self.task_config.panoptic_quality_evaluator
240
+ self.panoptic_quality_metric = (
241
+ panoptic_quality_evaluator.PanopticQualityEvaluator(
242
+ num_categories=pq_config.num_categories,
243
+ ignored_label=pq_config.ignored_label,
244
+ max_instances_per_category=pq_config.max_instances_per_category,
245
+ offset=pq_config.offset,
246
+ is_thing=pq_config.is_thing,
247
+ rescale_predictions=pq_config.rescale_predictions))
248
+ return metrics
249
+
250
+ def train_step(self,
251
+ inputs: Tuple[Any, Any],
252
+ model: tf_keras.Model,
253
+ optimizer: tf_keras.optimizers.Optimizer,
254
+ metrics: Optional[List[Any]] = None) -> Dict[str, Any]:
255
+ """Does forward and backward.
256
+
257
+ Args:
258
+ inputs: a dictionary of input tensors.
259
+ model: the model, forward pass definition.
260
+ optimizer: the optimizer for this training step.
261
+ metrics: a nested structure of metrics objects.
262
+
263
+ Returns:
264
+ A dictionary of logs.
265
+ """
266
+ images, labels = inputs
267
+ num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
268
+
269
+ with tf.GradientTape() as tape:
270
+ outputs = model(
271
+ images,
272
+ box_indices=labels['panoptic_box_indices'],
273
+ classes=labels['panoptic_classes'],
274
+ training=True)
275
+ outputs = tf.nest.map_structure(
276
+ lambda x: tf.cast(x, tf.float32), outputs)
277
+
278
+ # Computes per-replica loss.
279
+ losses = self.build_losses(
280
+ outputs=outputs,
281
+ labels=labels,
282
+ aux_losses=model.losses,
283
+ step=optimizer.iterations)
284
+ scaled_loss = losses['total_loss'] / num_replicas
285
+
286
+ # For mixed_precision policy, when LossScaleOptimizer is used, loss is
287
+ # scaled for numerical stability.
288
+ if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
289
+ scaled_loss = optimizer.get_scaled_loss(scaled_loss)
290
+
291
+ tvars = model.trainable_variables
292
+ grads = tape.gradient(scaled_loss, tvars)
293
+ # Scales back gradient when LossScaleOptimizer is used.
294
+ if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
295
+ grads = optimizer.get_unscaled_gradients(grads)
296
+ optimizer.apply_gradients(list(zip(grads, tvars)))
297
+
298
+ logs = {self.loss: losses['total_loss']}
299
+
300
+ if metrics:
301
+ for m in metrics:
302
+ m.update_state(losses[m.name])
303
+
304
+ return logs
305
+
306
+ def validation_step(self,
307
+ inputs: Tuple[Any, Any],
308
+ model: tf_keras.Model,
309
+ metrics: Optional[List[Any]] = None) -> Dict[str, Any]:
310
+ """Validatation step.
311
+
312
+ Args:
313
+ inputs: a dictionary of input tensors.
314
+ model: the keras.Model.
315
+ metrics: a nested structure of metrics objects.
316
+
317
+ Returns:
318
+ A dictionary of logs.
319
+ """
320
+ images, labels = inputs
321
+
322
+ outputs = model(
323
+ images,
324
+ image_info=labels['image_info'],
325
+ training=False)
326
+
327
+ logs = {self.loss: 0}
328
+
329
+ pq_metric_labels = {
330
+ 'category_mask':
331
+ labels['groundtruths']['gt_panoptic_category_mask'],
332
+ 'instance_mask':
333
+ labels['groundtruths']['gt_panoptic_instance_mask'],
334
+ 'image_info': labels['image_info']
335
+ }
336
+ logs.update({
337
+ self.panoptic_quality_metric.name:
338
+ (pq_metric_labels, outputs['panoptic_outputs'])})
339
+ return logs
340
+
341
+ def aggregate_logs(self, state=None, step_outputs=None):
342
+ if state is None:
343
+ self.panoptic_quality_metric.reset_states()
344
+ state = [self.panoptic_quality_metric]
345
+
346
+ self.panoptic_quality_metric.update_state(
347
+ step_outputs[self.panoptic_quality_metric.name][0],
348
+ step_outputs[self.panoptic_quality_metric.name][1])
349
+
350
+ return state
351
+
352
+ def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
353
+ result = {}
354
+
355
+ report_per_class_metrics = (
356
+ self.task_config.panoptic_quality_evaluator.report_per_class_metrics)
357
+ panoptic_quality_results = self.panoptic_quality_metric.result()
358
+ for k, value in panoptic_quality_results.items():
359
+ if k.endswith('per_class'):
360
+ if report_per_class_metrics:
361
+ for i, per_class_value in enumerate(value):
362
+ metric_key = 'panoptic_quality/{}/class_{}'.format(k, i)
363
+ result[metric_key] = per_class_value
364
+ else:
365
+ continue
366
+ else:
367
+ result['panoptic_quality/{}'.format(k)] = value
368
+ return result
369
+
370
+
371
+ @task_factory.register_task_cls(exp_cfg.MaskConverSegTask)
372
+ class MaskConverSegmentation(semantic_segmentation.SemanticSegmentationTask):
373
+
374
+ """A single-replica view of training procedure.
375
+
376
+ MaskConver task provides artifacts for training/evalution procedures,
377
+ including loading/iterating over Datasets, initializing the model, calculating
378
+ the loss, post-processing, and customized metrics with reduction.
379
+ """
380
+
381
+ def build_model(self) -> tf_keras.Model:
382
+ """Build maskconver model."""
383
+
384
+ tf_keras.utils.set_random_seed(0)
385
+ tf.config.experimental.enable_op_determinism()
386
+ input_specs = tf_keras.layers.InputSpec(
387
+ shape=[None] + self.task_config.model.input_size)
388
+
389
+ l2_weight_decay = self.task_config.losses.l2_weight_decay
390
+ # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
391
+ # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
392
+ # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
393
+ l2_regularizer = (tf_keras.regularizers.l2(
394
+ l2_weight_decay / 2.0) if l2_weight_decay else None)
395
+
396
+ model = factory.build_maskconver_model(
397
+ input_specs=input_specs,
398
+ model_config=self.task_config.model,
399
+ l2_regularizer=l2_regularizer,
400
+ segmentation_inference=True)
401
+ return model
402
+
403
+ def build_inputs(self,
404
+ params: exp_cfg.DataConfig,
405
+ input_context: Optional[tf.distribute.InputContext] = None):
406
+ """Builds classification input."""
407
+
408
+ ignore_label = self.task_config.losses.ignore_label
409
+
410
+ decoder = segmentation_input.Decoder()
411
+
412
+ parser = maskconver_segmentation_input.Parser(
413
+ output_size=params.output_size,
414
+ num_classes=self.task_config.model.num_classes,
415
+ crop_size=params.crop_size,
416
+ ignore_label=ignore_label,
417
+ resize_eval_groundtruth=params.resize_eval_groundtruth,
418
+ groundtruth_padded_size=params.groundtruth_padded_size,
419
+ aug_scale_min=params.aug_scale_min,
420
+ aug_scale_max=params.aug_scale_max,
421
+ aug_rand_hflip=params.aug_rand_hflip,
422
+ preserve_aspect_ratio=params.preserve_aspect_ratio,
423
+ level=self.task_config.model.level,
424
+ aug_type=params.aug_type,
425
+ max_num_stuff_centers=params.max_num_stuff_centers,
426
+ dtype=params.dtype)
427
+
428
+ reader = input_reader_factory.input_reader_generator(
429
+ params,
430
+ dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
431
+ decoder_fn=decoder.decode,
432
+ parser_fn=parser.parse_fn(params.is_training))
433
+
434
+ dataset = reader.read(input_context=input_context)
435
+
436
+ return dataset
437
+
438
+ def build_losses(self,
439
+ outputs: Mapping[str, Any],
440
+ labels: Mapping[str, Any],
441
+ aux_losses: Optional[Any] = None,
442
+ step=None) -> Dict[str, tf.Tensor]:
443
+ """Build Panoptic Mask R-CNN losses."""
444
+ loss_params = self._task_config.losses
445
+
446
+ # b, h, w, c = outputs['class_heatmaps'].get_shape().as_list()
447
+ batch_size = tf.cast(tf.shape(labels['num_instances'])[0], tf.float32)
448
+ center_loss_fn = maskconver_losses.PenaltyReducedLogisticFocalLoss(
449
+ alpha=loss_params.alpha, beta=loss_params.beta)
450
+ mask_loss_fn = maskconver_losses.PenaltyReducedLogisticFocalLoss()
451
+
452
+ true_flattened_ct_heatmap = loss_ops.flatten_spatial_dimensions(
453
+ labels['seg_ct_heatmaps'])
454
+ true_flattened_ct_heatmap = tf.cast(true_flattened_ct_heatmap, tf.float32)
455
+
456
+ pred_flattened_ct_heatmap = loss_ops.flatten_spatial_dimensions(
457
+ outputs['class_heatmaps'])
458
+ pred_flattened_ct_heatmap = tf.cast(pred_flattened_ct_heatmap, tf.float32)
459
+ center_valid_mask = labels['seg_valid_mask'][:, :, :, None]
460
+ center_valid_mask = tf.image.resize(
461
+ center_valid_mask, tf.shape(
462
+ labels['seg_ct_heatmaps'])[1:3], method='nearest')
463
+ center_valid_mask = tf.maximum(center_valid_mask, 0.0)
464
+ center_valid_mask = center_valid_mask * tf.ones_like(
465
+ labels['seg_ct_heatmaps'])
466
+ weights_flattened_mask = loss_ops.flatten_spatial_dimensions(
467
+ center_valid_mask)
468
+ center_loss = center_loss_fn(
469
+ target_tensor=true_flattened_ct_heatmap,
470
+ prediction_tensor=pred_flattened_ct_heatmap,
471
+ weights=weights_flattened_mask)
472
+
473
+ center_loss = tf.reduce_sum(
474
+ center_loss /
475
+ (labels['num_instances'][:, None, None] + 1.0)) / batch_size
476
+
477
+ gt_masks = labels['seg_masks']
478
+ gt_mask_weights = labels['seg_mask_weights'][:, None,
479
+ None, :] * tf.ones_like(
480
+ gt_masks)
481
+ valid_mask = labels['seg_valid_mask'][:, :, :,
482
+ None] * tf.ones_like(gt_masks)
483
+
484
+ true_flattened_masks = loss_ops.flatten_spatial_dimensions(gt_masks)
485
+ true_flattened_ct_heatmap = tf.cast(true_flattened_ct_heatmap, tf.float32)
486
+ predicted_masks = tf.cast(outputs['mask_proposal_logits'], tf.float32)
487
+ predicted_masks = tf.image.resize(
488
+ predicted_masks, tf.shape(gt_masks)[1:3], method='bilinear')
489
+ pred_flattened_masks = loss_ops.flatten_spatial_dimensions(predicted_masks)
490
+ mask_loss = tf.cast(0.0, tf.float32)
491
+
492
+ mask_loss_fn = tf_keras.losses.BinaryCrossentropy(
493
+ from_logits=True,
494
+ label_smoothing=0.0,
495
+ axis=-1,
496
+ reduction=tf_keras.losses.Reduction.NONE,
497
+ name='binary_crossentropy')
498
+ mask_weights = tf.reshape(
499
+ tf.cast(true_flattened_masks >= 0, tf.float32),
500
+ [-1, 1]) * tf.reshape(gt_mask_weights, [-1, 1]) * tf.reshape(
501
+ (valid_mask), [-1, 1])
502
+ mask_loss = mask_loss_fn(
503
+ tf.reshape(gt_masks, [-1, 1]),
504
+ tf.reshape(pred_flattened_masks, [-1, 1]),
505
+ sample_weight=mask_weights)
506
+
507
+ mask_loss = tf.reduce_sum(mask_loss) / (tf.reduce_sum(mask_weights) + 1.0)
508
+
509
+ total_loss = (center_loss + loss_params.mask_weight * mask_loss)
510
+
511
+ if aux_losses:
512
+ total_loss += tf.add_n(aux_losses)
513
+
514
+ total_loss = loss_params.loss_weight * total_loss
515
+
516
+ losses = {'total_loss': total_loss,
517
+ 'mask_loss': mask_loss,
518
+ 'center_loss': center_loss}
519
+ return losses
520
+
521
+ def build_metrics(self, training: bool = True) -> List[
522
+ tf_keras.metrics.Metric]:
523
+ """Build detection metrics."""
524
+ metrics = []
525
+ if training:
526
+ metric_names = [
527
+ 'total_loss',
528
+ 'center_loss',
529
+ 'mask_loss',
530
+ ]
531
+ for name in metric_names:
532
+ metrics.append(tf_keras.metrics.Mean(name, dtype=tf.float32))
533
+ else:
534
+ self.iou_metric = segmentation_metrics.PerClassIoU(
535
+ name='per_class_iou',
536
+ num_classes=self.task_config.model.num_classes,
537
+ rescale_predictions=False,
538
+ dtype=tf.float32)
539
+
540
+ return metrics
541
+
542
+ def train_step(self,
543
+ inputs: Tuple[Any, Any],
544
+ model: tf_keras.Model,
545
+ optimizer: tf_keras.optimizers.Optimizer,
546
+ metrics: Optional[List[Any]] = None) -> Dict[str, Any]:
547
+ """Does forward and backward.
548
+
549
+ Args:
550
+ inputs: a dictionary of input tensors.
551
+ model: the model, forward pass definition.
552
+ optimizer: the optimizer for this training step.
553
+ metrics: a nested structure of metrics objects.
554
+
555
+ Returns:
556
+ A dictionary of logs.
557
+ """
558
+ images, labels = inputs
559
+ num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
560
+
561
+ with tf.GradientTape() as tape:
562
+ outputs = model(
563
+ images,
564
+ box_indices=labels['seg_box_indices'],
565
+ classes=labels['seg_classes'],
566
+ training=True)
567
+ outputs = tf.nest.map_structure(
568
+ lambda x: tf.cast(x, tf.float32), outputs)
569
+
570
+ # Computes per-replica loss.
571
+ losses = self.build_losses(
572
+ outputs=outputs,
573
+ labels=labels,
574
+ aux_losses=model.losses,
575
+ step=optimizer.iterations)
576
+ scaled_loss = losses['total_loss'] / num_replicas
577
+
578
+ # For mixed_precision policy, when LossScaleOptimizer is used, loss is
579
+ # scaled for numerical stability.
580
+ if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
581
+ scaled_loss = optimizer.get_scaled_loss(scaled_loss)
582
+
583
+ tvars = model.trainable_variables
584
+ grads = tape.gradient(scaled_loss, tvars)
585
+ # Scales back gradient when LossScaleOptimizer is used.
586
+ if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
587
+ grads = optimizer.get_unscaled_gradients(grads)
588
+ optimizer.apply_gradients(list(zip(grads, tvars)))
589
+
590
+ logs = {self.loss: losses['total_loss']}
591
+
592
+ if metrics:
593
+ for m in metrics:
594
+ m.update_state(losses[m.name])
595
+
596
+ return logs
597
+
598
+ def validation_step(self,
599
+ inputs: Tuple[Any, Any],
600
+ model: tf_keras.Model,
601
+ metrics: Optional[List[Any]] = None):
602
+ """Validatation step.
603
+
604
+ Args:
605
+ inputs: a dictionary of input tensors.
606
+ model: the keras.Model.
607
+ metrics: a nested structure of metrics objects.
608
+
609
+ Returns:
610
+ A dictionary of logs.
611
+ """
612
+ features, labels = inputs
613
+
614
+ outputs = model(
615
+ features,
616
+ image_info=labels['image_info'],
617
+ training=False)
618
+ outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
619
+
620
+ logs = {self.loss: 0}
621
+ outputs = tf.one_hot(
622
+ tf.cast(outputs['panoptic_outputs']['category_mask'], tf.int32),
623
+ self.task_config.model.num_classes)
624
+
625
+ self.iou_metric.update_state(labels, tf.cast(outputs, tf.float32))
626
+ return logs
627
+
628
+ def aggregate_logs(self, state=None, step_outputs=None):
629
+ if state is None:
630
+ self.iou_metric.reset_states()
631
+ state = self.iou_metric
632
+ return state
633
+
634
+ def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
635
+ result = {}
636
+ ious = self.iou_metric.result()
637
+ for i, value in enumerate(ious.numpy()):
638
+ result.update({'iou/{}'.format(i): value})
639
+ # Computes mean IoU
640
+ result.update({'mean_iou': tf.reduce_mean(ious).numpy()})
641
+ return result