tf-models-nightly 2.18.0.dev20240820__py2.py3-none-any.whl → 2.18.0.dev20240822__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/legacy/transformer/transformer_main.py +0 -2
- official/projects/maskconver/__init__.py +14 -0
- official/projects/maskconver/configs/__init__.py +14 -0
- official/projects/maskconver/configs/backbones.py +43 -0
- official/projects/maskconver/configs/decoders.py +36 -0
- official/projects/maskconver/configs/maskconver.py +523 -0
- official/projects/maskconver/configs/multiscale_maskconver.py +215 -0
- official/projects/maskconver/tasks/__init__.py +14 -0
- official/projects/maskconver/tasks/maskconver.py +641 -0
- official/projects/maskconver/tasks/multiscale_maskconver.py +278 -0
- official/projects/maskconver/train.py +30 -0
- {tf_models_nightly-2.18.0.dev20240820.dist-info → tf_models_nightly-2.18.0.dev20240822.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.18.0.dev20240820.dist-info → tf_models_nightly-2.18.0.dev20240822.dist-info}/RECORD +17 -7
- {tf_models_nightly-2.18.0.dev20240820.dist-info → tf_models_nightly-2.18.0.dev20240822.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.18.0.dev20240820.dist-info → tf_models_nightly-2.18.0.dev20240822.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.18.0.dev20240820.dist-info → tf_models_nightly-2.18.0.dev20240822.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.18.0.dev20240820.dist-info → tf_models_nightly-2.18.0.dev20240822.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,641 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Panoptic MaskRCNN task definition."""
|
16
|
+
from typing import Any, Dict, List, Mapping, Optional, Tuple
|
17
|
+
|
18
|
+
import tensorflow as tf, tf_keras
|
19
|
+
|
20
|
+
from official.common import dataset_fn
|
21
|
+
from official.core import task_factory
|
22
|
+
from official.projects.centernet.ops import loss_ops
|
23
|
+
from official.projects.maskconver.configs import maskconver as exp_cfg
|
24
|
+
from official.projects.maskconver.dataloaders import maskconver_segmentation_input
|
25
|
+
from official.projects.maskconver.dataloaders import panoptic_maskrcnn_input
|
26
|
+
from official.projects.maskconver.losses import maskconver_losses
|
27
|
+
from official.projects.maskconver.modeling import factory
|
28
|
+
from official.projects.volumetric_models.losses import segmentation_losses as volumeteric_segmentation_losses
|
29
|
+
from official.vision.dataloaders import input_reader_factory
|
30
|
+
from official.vision.dataloaders import segmentation_input
|
31
|
+
from official.vision.evaluation import panoptic_quality_evaluator
|
32
|
+
from official.vision.evaluation import segmentation_metrics
|
33
|
+
from official.vision.tasks import maskrcnn
|
34
|
+
from official.vision.tasks import semantic_segmentation
|
35
|
+
|
36
|
+
|
37
|
+
@task_factory.register_task_cls(exp_cfg.MaskConverTask)
|
38
|
+
class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
|
39
|
+
|
40
|
+
"""A single-replica view of training procedure.
|
41
|
+
|
42
|
+
Panoptic Mask R-CNN task provides artifacts for training/evalution procedures,
|
43
|
+
including loading/iterating over Datasets, initializing the model, calculating
|
44
|
+
the loss, post-processing, and customized metrics with reduction.
|
45
|
+
"""
|
46
|
+
|
47
|
+
def build_model(self) -> tf_keras.Model:
|
48
|
+
"""Build Panoptic Mask R-CNN model."""
|
49
|
+
|
50
|
+
tf_keras.utils.set_random_seed(0)
|
51
|
+
tf.config.experimental.enable_op_determinism()
|
52
|
+
input_specs = tf_keras.layers.InputSpec(
|
53
|
+
shape=[None] + self.task_config.model.input_size)
|
54
|
+
|
55
|
+
l2_weight_decay = self.task_config.losses.l2_weight_decay
|
56
|
+
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
|
57
|
+
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
|
58
|
+
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
|
59
|
+
l2_regularizer = (tf_keras.regularizers.l2(
|
60
|
+
l2_weight_decay / 2.0) if l2_weight_decay else None)
|
61
|
+
|
62
|
+
model = factory.build_maskconver_model(
|
63
|
+
input_specs=input_specs,
|
64
|
+
model_config=self.task_config.model,
|
65
|
+
l2_regularizer=l2_regularizer)
|
66
|
+
return model
|
67
|
+
|
68
|
+
def build_inputs(
|
69
|
+
self,
|
70
|
+
params: exp_cfg.DataConfig,
|
71
|
+
input_context: Optional[tf.distribute.InputContext] = None
|
72
|
+
) -> tf.data.Dataset:
|
73
|
+
"""Build input dataset."""
|
74
|
+
decoder_cfg = params.decoder.get()
|
75
|
+
if params.decoder.type == 'simple_decoder':
|
76
|
+
decoder = panoptic_maskrcnn_input.TfExampleDecoder(
|
77
|
+
regenerate_source_id=decoder_cfg.regenerate_source_id,
|
78
|
+
mask_binarize_threshold=decoder_cfg.mask_binarize_threshold,
|
79
|
+
include_panoptic_masks=decoder_cfg.include_panoptic_masks,
|
80
|
+
panoptic_category_mask_key=decoder_cfg.panoptic_category_mask_key,
|
81
|
+
panoptic_instance_mask_key=decoder_cfg.panoptic_instance_mask_key)
|
82
|
+
else:
|
83
|
+
raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))
|
84
|
+
|
85
|
+
parser = panoptic_maskrcnn_input.Parser(
|
86
|
+
output_size=self.task_config.model.input_size[:2],
|
87
|
+
min_level=self.task_config.model.min_level,
|
88
|
+
max_level=self.task_config.model.max_level,
|
89
|
+
num_scales=self.task_config.model.anchor.num_scales,
|
90
|
+
aspect_ratios=self.task_config.model.anchor.aspect_ratios,
|
91
|
+
anchor_size=self.task_config.model.anchor.anchor_size,
|
92
|
+
dtype=params.dtype,
|
93
|
+
rpn_match_threshold=params.parser.rpn_match_threshold,
|
94
|
+
rpn_unmatched_threshold=params.parser.rpn_unmatched_threshold,
|
95
|
+
rpn_batch_size_per_im=params.parser.rpn_batch_size_per_im,
|
96
|
+
rpn_fg_fraction=params.parser.rpn_fg_fraction,
|
97
|
+
aug_rand_hflip=params.parser.aug_rand_hflip,
|
98
|
+
aug_scale_min=params.parser.aug_scale_min,
|
99
|
+
aug_scale_max=params.parser.aug_scale_max,
|
100
|
+
skip_crowd_during_training=params.parser.skip_crowd_during_training,
|
101
|
+
max_num_instances=self.task_config.model.num_instances,
|
102
|
+
mask_crop_size=params.parser.mask_crop_size,
|
103
|
+
segmentation_resize_eval_groundtruth=params.parser
|
104
|
+
.segmentation_resize_eval_groundtruth,
|
105
|
+
segmentation_groundtruth_padded_size=params.parser
|
106
|
+
.segmentation_groundtruth_padded_size,
|
107
|
+
segmentation_ignore_label=params.parser.segmentation_ignore_label,
|
108
|
+
panoptic_ignore_label=params.parser.panoptic_ignore_label,
|
109
|
+
include_panoptic_masks=params.parser.include_panoptic_masks,
|
110
|
+
num_panoptic_categories=self.task_config.model.num_classes,
|
111
|
+
num_thing_categories=self.task_config.model.num_thing_classes,
|
112
|
+
level=self.task_config.model.level,
|
113
|
+
gaussian_iou=params.parser.gaussaian_iou,
|
114
|
+
aug_type=params.parser.aug_type,
|
115
|
+
max_num_stuff_centers=params.parser.max_num_stuff_centers)
|
116
|
+
|
117
|
+
reader = input_reader_factory.input_reader_generator(
|
118
|
+
params,
|
119
|
+
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
|
120
|
+
decoder_fn=decoder.decode,
|
121
|
+
parser_fn=parser.parse_fn(params.is_training))
|
122
|
+
dataset = reader.read(input_context=input_context)
|
123
|
+
|
124
|
+
return dataset
|
125
|
+
|
126
|
+
def build_losses(self,
|
127
|
+
outputs: Mapping[str, Any],
|
128
|
+
labels: Mapping[str, Any],
|
129
|
+
aux_losses: Optional[Any] = None,
|
130
|
+
step=None) -> Dict[str, tf.Tensor]:
|
131
|
+
"""Build Panoptic Mask R-CNN losses."""
|
132
|
+
loss_params = self._task_config.losses
|
133
|
+
|
134
|
+
batch_size = tf.cast(tf.shape(labels['num_instances'])[0], tf.float32)
|
135
|
+
center_loss_fn = maskconver_losses.PenaltyReducedLogisticFocalLoss(
|
136
|
+
alpha=loss_params.alpha, beta=loss_params.beta)
|
137
|
+
mask_loss_fn = maskconver_losses.PenaltyReducedLogisticFocalLoss()
|
138
|
+
|
139
|
+
# Calculate center heatmap loss
|
140
|
+
# TODO(arashwan): add valid weights.
|
141
|
+
# output_unpad_image_shapes = labels['image_info'][:, 0, :]
|
142
|
+
# valid_anchor_weights = loss_ops.get_valid_anchor_weights_in_flattened_image( # pylint: disable=line-too-long
|
143
|
+
# output_unpad_image_shapes, h, w)
|
144
|
+
# valid_anchor_weights = tf.expand_dims(valid_anchor_weights, 2)
|
145
|
+
|
146
|
+
true_flattened_ct_heatmap = loss_ops.flatten_spatial_dimensions(
|
147
|
+
labels['panoptic_heatmaps'])
|
148
|
+
true_flattened_ct_heatmap = tf.cast(true_flattened_ct_heatmap, tf.float32)
|
149
|
+
|
150
|
+
pred_flattened_ct_heatmap = loss_ops.flatten_spatial_dimensions(
|
151
|
+
outputs['class_heatmaps'])
|
152
|
+
pred_flattened_ct_heatmap = tf.cast(pred_flattened_ct_heatmap, tf.float32)
|
153
|
+
center_padding_mask = 1 - labels['panoptic_padding_mask'][:, :, :, None]
|
154
|
+
center_padding_mask = tf.image.resize(
|
155
|
+
center_padding_mask, tf.shape(
|
156
|
+
labels['panoptic_heatmaps'])[1:3], method='nearest')
|
157
|
+
center_padding_mask = tf.maximum(center_padding_mask, 0.0)
|
158
|
+
center_padding_mask = center_padding_mask * tf.ones_like(labels['panoptic_heatmaps'])
|
159
|
+
weights_flattened_mask = loss_ops.flatten_spatial_dimensions(
|
160
|
+
center_padding_mask)
|
161
|
+
center_loss = center_loss_fn(
|
162
|
+
target_tensor=true_flattened_ct_heatmap,
|
163
|
+
prediction_tensor=pred_flattened_ct_heatmap,
|
164
|
+
weights=weights_flattened_mask)
|
165
|
+
|
166
|
+
center_loss = tf.reduce_sum(
|
167
|
+
center_loss / (labels['num_instances'][:, None, None] + 1.0)) / batch_size
|
168
|
+
|
169
|
+
gt_masks = labels['panoptic_masks']
|
170
|
+
gt_mask_weights = labels['panoptic_mask_weights'][:, None, None, :] * tf.ones_like(gt_masks)
|
171
|
+
panoptic_padding_mask = labels['panoptic_padding_mask'][:, :, :, None] * tf.ones_like(gt_masks)
|
172
|
+
|
173
|
+
true_flattened_masks = loss_ops.flatten_spatial_dimensions(
|
174
|
+
gt_masks)
|
175
|
+
true_flattened_ct_heatmap = tf.cast(true_flattened_ct_heatmap, tf.float32)
|
176
|
+
predicted_masks = tf.cast(outputs['mask_proposal_logits'], tf.float32)
|
177
|
+
predicted_masks = tf.image.resize(
|
178
|
+
predicted_masks, tf.shape(gt_masks)[1:3], method='bilinear')
|
179
|
+
pred_flattened_masks = loss_ops.flatten_spatial_dimensions(predicted_masks)
|
180
|
+
mask_loss = tf.cast(0.0, tf.float32)
|
181
|
+
mask_loss_fn = tf_keras.losses.BinaryCrossentropy(
|
182
|
+
from_logits=True,
|
183
|
+
label_smoothing=0.0,
|
184
|
+
axis=-1,
|
185
|
+
reduction=tf_keras.losses.Reduction.NONE,
|
186
|
+
name='binary_crossentropy')
|
187
|
+
mask_weights = tf.reshape(
|
188
|
+
tf.cast(true_flattened_masks >= 0, tf.float32),
|
189
|
+
[-1, 1]) * tf.reshape(gt_mask_weights, [-1, 1]) * tf.reshape(
|
190
|
+
(1 - panoptic_padding_mask), [-1, 1])
|
191
|
+
mask_loss = mask_loss_fn(
|
192
|
+
tf.reshape(gt_masks, [-1, 1]),
|
193
|
+
tf.reshape(pred_flattened_masks, [-1, 1]),
|
194
|
+
sample_weight=mask_weights)
|
195
|
+
mask_loss = tf.reduce_sum(mask_loss) / (tf.reduce_sum(mask_weights) + 1.0)
|
196
|
+
|
197
|
+
# Dice loss
|
198
|
+
_, h, w, _ = gt_masks.get_shape().as_list()
|
199
|
+
masked_predictions = tf.sigmoid(predicted_masks) * gt_mask_weights * (1 - panoptic_padding_mask)
|
200
|
+
masked_gt_masks = gt_masks * gt_mask_weights * (1 - panoptic_padding_mask)
|
201
|
+
|
202
|
+
masked_predictions = tf.transpose(masked_predictions, [0, 3, 1, 2])
|
203
|
+
masked_predictions = tf.reshape(masked_predictions, [-1, h, w, 1])
|
204
|
+
masked_gt_masks = tf.transpose(masked_gt_masks, [0, 3, 1, 2])
|
205
|
+
masked_gt_masks = tf.reshape(masked_gt_masks, [-1, h, w, 1])
|
206
|
+
|
207
|
+
dice_loss_fn = volumeteric_segmentation_losses.SegmentationLossDiceScore(
|
208
|
+
metric_type='adaptive', axis=(2, 3))
|
209
|
+
dice_loss = dice_loss_fn(logits=masked_predictions, labels=masked_gt_masks)
|
210
|
+
|
211
|
+
total_loss = (center_loss + loss_params.mask_weight * mask_loss + loss_params.mask_weight * dice_loss)
|
212
|
+
|
213
|
+
if aux_losses:
|
214
|
+
total_loss += tf.add_n(aux_losses)
|
215
|
+
|
216
|
+
total_loss = loss_params.loss_weight * total_loss
|
217
|
+
|
218
|
+
losses = {'total_loss': total_loss,
|
219
|
+
'mask_loss': mask_loss,
|
220
|
+
'center_loss': center_loss,
|
221
|
+
'dice_loss': dice_loss}
|
222
|
+
return losses
|
223
|
+
|
224
|
+
def build_metrics(self, training: bool = True) -> List[
|
225
|
+
tf_keras.metrics.Metric]:
|
226
|
+
"""Build detection metrics."""
|
227
|
+
metrics = []
|
228
|
+
if training:
|
229
|
+
metric_names = [
|
230
|
+
'total_loss',
|
231
|
+
'center_loss',
|
232
|
+
'mask_loss',
|
233
|
+
'dice_loss',
|
234
|
+
]
|
235
|
+
for name in metric_names:
|
236
|
+
metrics.append(tf_keras.metrics.Mean(name, dtype=tf.float32))
|
237
|
+
|
238
|
+
else:
|
239
|
+
pq_config = self.task_config.panoptic_quality_evaluator
|
240
|
+
self.panoptic_quality_metric = (
|
241
|
+
panoptic_quality_evaluator.PanopticQualityEvaluator(
|
242
|
+
num_categories=pq_config.num_categories,
|
243
|
+
ignored_label=pq_config.ignored_label,
|
244
|
+
max_instances_per_category=pq_config.max_instances_per_category,
|
245
|
+
offset=pq_config.offset,
|
246
|
+
is_thing=pq_config.is_thing,
|
247
|
+
rescale_predictions=pq_config.rescale_predictions))
|
248
|
+
return metrics
|
249
|
+
|
250
|
+
def train_step(self,
|
251
|
+
inputs: Tuple[Any, Any],
|
252
|
+
model: tf_keras.Model,
|
253
|
+
optimizer: tf_keras.optimizers.Optimizer,
|
254
|
+
metrics: Optional[List[Any]] = None) -> Dict[str, Any]:
|
255
|
+
"""Does forward and backward.
|
256
|
+
|
257
|
+
Args:
|
258
|
+
inputs: a dictionary of input tensors.
|
259
|
+
model: the model, forward pass definition.
|
260
|
+
optimizer: the optimizer for this training step.
|
261
|
+
metrics: a nested structure of metrics objects.
|
262
|
+
|
263
|
+
Returns:
|
264
|
+
A dictionary of logs.
|
265
|
+
"""
|
266
|
+
images, labels = inputs
|
267
|
+
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
|
268
|
+
|
269
|
+
with tf.GradientTape() as tape:
|
270
|
+
outputs = model(
|
271
|
+
images,
|
272
|
+
box_indices=labels['panoptic_box_indices'],
|
273
|
+
classes=labels['panoptic_classes'],
|
274
|
+
training=True)
|
275
|
+
outputs = tf.nest.map_structure(
|
276
|
+
lambda x: tf.cast(x, tf.float32), outputs)
|
277
|
+
|
278
|
+
# Computes per-replica loss.
|
279
|
+
losses = self.build_losses(
|
280
|
+
outputs=outputs,
|
281
|
+
labels=labels,
|
282
|
+
aux_losses=model.losses,
|
283
|
+
step=optimizer.iterations)
|
284
|
+
scaled_loss = losses['total_loss'] / num_replicas
|
285
|
+
|
286
|
+
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
|
287
|
+
# scaled for numerical stability.
|
288
|
+
if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
|
289
|
+
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
|
290
|
+
|
291
|
+
tvars = model.trainable_variables
|
292
|
+
grads = tape.gradient(scaled_loss, tvars)
|
293
|
+
# Scales back gradient when LossScaleOptimizer is used.
|
294
|
+
if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
|
295
|
+
grads = optimizer.get_unscaled_gradients(grads)
|
296
|
+
optimizer.apply_gradients(list(zip(grads, tvars)))
|
297
|
+
|
298
|
+
logs = {self.loss: losses['total_loss']}
|
299
|
+
|
300
|
+
if metrics:
|
301
|
+
for m in metrics:
|
302
|
+
m.update_state(losses[m.name])
|
303
|
+
|
304
|
+
return logs
|
305
|
+
|
306
|
+
def validation_step(self,
|
307
|
+
inputs: Tuple[Any, Any],
|
308
|
+
model: tf_keras.Model,
|
309
|
+
metrics: Optional[List[Any]] = None) -> Dict[str, Any]:
|
310
|
+
"""Validatation step.
|
311
|
+
|
312
|
+
Args:
|
313
|
+
inputs: a dictionary of input tensors.
|
314
|
+
model: the keras.Model.
|
315
|
+
metrics: a nested structure of metrics objects.
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
A dictionary of logs.
|
319
|
+
"""
|
320
|
+
images, labels = inputs
|
321
|
+
|
322
|
+
outputs = model(
|
323
|
+
images,
|
324
|
+
image_info=labels['image_info'],
|
325
|
+
training=False)
|
326
|
+
|
327
|
+
logs = {self.loss: 0}
|
328
|
+
|
329
|
+
pq_metric_labels = {
|
330
|
+
'category_mask':
|
331
|
+
labels['groundtruths']['gt_panoptic_category_mask'],
|
332
|
+
'instance_mask':
|
333
|
+
labels['groundtruths']['gt_panoptic_instance_mask'],
|
334
|
+
'image_info': labels['image_info']
|
335
|
+
}
|
336
|
+
logs.update({
|
337
|
+
self.panoptic_quality_metric.name:
|
338
|
+
(pq_metric_labels, outputs['panoptic_outputs'])})
|
339
|
+
return logs
|
340
|
+
|
341
|
+
def aggregate_logs(self, state=None, step_outputs=None):
|
342
|
+
if state is None:
|
343
|
+
self.panoptic_quality_metric.reset_states()
|
344
|
+
state = [self.panoptic_quality_metric]
|
345
|
+
|
346
|
+
self.panoptic_quality_metric.update_state(
|
347
|
+
step_outputs[self.panoptic_quality_metric.name][0],
|
348
|
+
step_outputs[self.panoptic_quality_metric.name][1])
|
349
|
+
|
350
|
+
return state
|
351
|
+
|
352
|
+
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
|
353
|
+
result = {}
|
354
|
+
|
355
|
+
report_per_class_metrics = (
|
356
|
+
self.task_config.panoptic_quality_evaluator.report_per_class_metrics)
|
357
|
+
panoptic_quality_results = self.panoptic_quality_metric.result()
|
358
|
+
for k, value in panoptic_quality_results.items():
|
359
|
+
if k.endswith('per_class'):
|
360
|
+
if report_per_class_metrics:
|
361
|
+
for i, per_class_value in enumerate(value):
|
362
|
+
metric_key = 'panoptic_quality/{}/class_{}'.format(k, i)
|
363
|
+
result[metric_key] = per_class_value
|
364
|
+
else:
|
365
|
+
continue
|
366
|
+
else:
|
367
|
+
result['panoptic_quality/{}'.format(k)] = value
|
368
|
+
return result
|
369
|
+
|
370
|
+
|
371
|
+
@task_factory.register_task_cls(exp_cfg.MaskConverSegTask)
|
372
|
+
class MaskConverSegmentation(semantic_segmentation.SemanticSegmentationTask):
|
373
|
+
|
374
|
+
"""A single-replica view of training procedure.
|
375
|
+
|
376
|
+
MaskConver task provides artifacts for training/evalution procedures,
|
377
|
+
including loading/iterating over Datasets, initializing the model, calculating
|
378
|
+
the loss, post-processing, and customized metrics with reduction.
|
379
|
+
"""
|
380
|
+
|
381
|
+
def build_model(self) -> tf_keras.Model:
|
382
|
+
"""Build maskconver model."""
|
383
|
+
|
384
|
+
tf_keras.utils.set_random_seed(0)
|
385
|
+
tf.config.experimental.enable_op_determinism()
|
386
|
+
input_specs = tf_keras.layers.InputSpec(
|
387
|
+
shape=[None] + self.task_config.model.input_size)
|
388
|
+
|
389
|
+
l2_weight_decay = self.task_config.losses.l2_weight_decay
|
390
|
+
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
|
391
|
+
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
|
392
|
+
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
|
393
|
+
l2_regularizer = (tf_keras.regularizers.l2(
|
394
|
+
l2_weight_decay / 2.0) if l2_weight_decay else None)
|
395
|
+
|
396
|
+
model = factory.build_maskconver_model(
|
397
|
+
input_specs=input_specs,
|
398
|
+
model_config=self.task_config.model,
|
399
|
+
l2_regularizer=l2_regularizer,
|
400
|
+
segmentation_inference=True)
|
401
|
+
return model
|
402
|
+
|
403
|
+
def build_inputs(self,
|
404
|
+
params: exp_cfg.DataConfig,
|
405
|
+
input_context: Optional[tf.distribute.InputContext] = None):
|
406
|
+
"""Builds classification input."""
|
407
|
+
|
408
|
+
ignore_label = self.task_config.losses.ignore_label
|
409
|
+
|
410
|
+
decoder = segmentation_input.Decoder()
|
411
|
+
|
412
|
+
parser = maskconver_segmentation_input.Parser(
|
413
|
+
output_size=params.output_size,
|
414
|
+
num_classes=self.task_config.model.num_classes,
|
415
|
+
crop_size=params.crop_size,
|
416
|
+
ignore_label=ignore_label,
|
417
|
+
resize_eval_groundtruth=params.resize_eval_groundtruth,
|
418
|
+
groundtruth_padded_size=params.groundtruth_padded_size,
|
419
|
+
aug_scale_min=params.aug_scale_min,
|
420
|
+
aug_scale_max=params.aug_scale_max,
|
421
|
+
aug_rand_hflip=params.aug_rand_hflip,
|
422
|
+
preserve_aspect_ratio=params.preserve_aspect_ratio,
|
423
|
+
level=self.task_config.model.level,
|
424
|
+
aug_type=params.aug_type,
|
425
|
+
max_num_stuff_centers=params.max_num_stuff_centers,
|
426
|
+
dtype=params.dtype)
|
427
|
+
|
428
|
+
reader = input_reader_factory.input_reader_generator(
|
429
|
+
params,
|
430
|
+
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
|
431
|
+
decoder_fn=decoder.decode,
|
432
|
+
parser_fn=parser.parse_fn(params.is_training))
|
433
|
+
|
434
|
+
dataset = reader.read(input_context=input_context)
|
435
|
+
|
436
|
+
return dataset
|
437
|
+
|
438
|
+
def build_losses(self,
|
439
|
+
outputs: Mapping[str, Any],
|
440
|
+
labels: Mapping[str, Any],
|
441
|
+
aux_losses: Optional[Any] = None,
|
442
|
+
step=None) -> Dict[str, tf.Tensor]:
|
443
|
+
"""Build Panoptic Mask R-CNN losses."""
|
444
|
+
loss_params = self._task_config.losses
|
445
|
+
|
446
|
+
# b, h, w, c = outputs['class_heatmaps'].get_shape().as_list()
|
447
|
+
batch_size = tf.cast(tf.shape(labels['num_instances'])[0], tf.float32)
|
448
|
+
center_loss_fn = maskconver_losses.PenaltyReducedLogisticFocalLoss(
|
449
|
+
alpha=loss_params.alpha, beta=loss_params.beta)
|
450
|
+
mask_loss_fn = maskconver_losses.PenaltyReducedLogisticFocalLoss()
|
451
|
+
|
452
|
+
true_flattened_ct_heatmap = loss_ops.flatten_spatial_dimensions(
|
453
|
+
labels['seg_ct_heatmaps'])
|
454
|
+
true_flattened_ct_heatmap = tf.cast(true_flattened_ct_heatmap, tf.float32)
|
455
|
+
|
456
|
+
pred_flattened_ct_heatmap = loss_ops.flatten_spatial_dimensions(
|
457
|
+
outputs['class_heatmaps'])
|
458
|
+
pred_flattened_ct_heatmap = tf.cast(pred_flattened_ct_heatmap, tf.float32)
|
459
|
+
center_valid_mask = labels['seg_valid_mask'][:, :, :, None]
|
460
|
+
center_valid_mask = tf.image.resize(
|
461
|
+
center_valid_mask, tf.shape(
|
462
|
+
labels['seg_ct_heatmaps'])[1:3], method='nearest')
|
463
|
+
center_valid_mask = tf.maximum(center_valid_mask, 0.0)
|
464
|
+
center_valid_mask = center_valid_mask * tf.ones_like(
|
465
|
+
labels['seg_ct_heatmaps'])
|
466
|
+
weights_flattened_mask = loss_ops.flatten_spatial_dimensions(
|
467
|
+
center_valid_mask)
|
468
|
+
center_loss = center_loss_fn(
|
469
|
+
target_tensor=true_flattened_ct_heatmap,
|
470
|
+
prediction_tensor=pred_flattened_ct_heatmap,
|
471
|
+
weights=weights_flattened_mask)
|
472
|
+
|
473
|
+
center_loss = tf.reduce_sum(
|
474
|
+
center_loss /
|
475
|
+
(labels['num_instances'][:, None, None] + 1.0)) / batch_size
|
476
|
+
|
477
|
+
gt_masks = labels['seg_masks']
|
478
|
+
gt_mask_weights = labels['seg_mask_weights'][:, None,
|
479
|
+
None, :] * tf.ones_like(
|
480
|
+
gt_masks)
|
481
|
+
valid_mask = labels['seg_valid_mask'][:, :, :,
|
482
|
+
None] * tf.ones_like(gt_masks)
|
483
|
+
|
484
|
+
true_flattened_masks = loss_ops.flatten_spatial_dimensions(gt_masks)
|
485
|
+
true_flattened_ct_heatmap = tf.cast(true_flattened_ct_heatmap, tf.float32)
|
486
|
+
predicted_masks = tf.cast(outputs['mask_proposal_logits'], tf.float32)
|
487
|
+
predicted_masks = tf.image.resize(
|
488
|
+
predicted_masks, tf.shape(gt_masks)[1:3], method='bilinear')
|
489
|
+
pred_flattened_masks = loss_ops.flatten_spatial_dimensions(predicted_masks)
|
490
|
+
mask_loss = tf.cast(0.0, tf.float32)
|
491
|
+
|
492
|
+
mask_loss_fn = tf_keras.losses.BinaryCrossentropy(
|
493
|
+
from_logits=True,
|
494
|
+
label_smoothing=0.0,
|
495
|
+
axis=-1,
|
496
|
+
reduction=tf_keras.losses.Reduction.NONE,
|
497
|
+
name='binary_crossentropy')
|
498
|
+
mask_weights = tf.reshape(
|
499
|
+
tf.cast(true_flattened_masks >= 0, tf.float32),
|
500
|
+
[-1, 1]) * tf.reshape(gt_mask_weights, [-1, 1]) * tf.reshape(
|
501
|
+
(valid_mask), [-1, 1])
|
502
|
+
mask_loss = mask_loss_fn(
|
503
|
+
tf.reshape(gt_masks, [-1, 1]),
|
504
|
+
tf.reshape(pred_flattened_masks, [-1, 1]),
|
505
|
+
sample_weight=mask_weights)
|
506
|
+
|
507
|
+
mask_loss = tf.reduce_sum(mask_loss) / (tf.reduce_sum(mask_weights) + 1.0)
|
508
|
+
|
509
|
+
total_loss = (center_loss + loss_params.mask_weight * mask_loss)
|
510
|
+
|
511
|
+
if aux_losses:
|
512
|
+
total_loss += tf.add_n(aux_losses)
|
513
|
+
|
514
|
+
total_loss = loss_params.loss_weight * total_loss
|
515
|
+
|
516
|
+
losses = {'total_loss': total_loss,
|
517
|
+
'mask_loss': mask_loss,
|
518
|
+
'center_loss': center_loss}
|
519
|
+
return losses
|
520
|
+
|
521
|
+
def build_metrics(self, training: bool = True) -> List[
|
522
|
+
tf_keras.metrics.Metric]:
|
523
|
+
"""Build detection metrics."""
|
524
|
+
metrics = []
|
525
|
+
if training:
|
526
|
+
metric_names = [
|
527
|
+
'total_loss',
|
528
|
+
'center_loss',
|
529
|
+
'mask_loss',
|
530
|
+
]
|
531
|
+
for name in metric_names:
|
532
|
+
metrics.append(tf_keras.metrics.Mean(name, dtype=tf.float32))
|
533
|
+
else:
|
534
|
+
self.iou_metric = segmentation_metrics.PerClassIoU(
|
535
|
+
name='per_class_iou',
|
536
|
+
num_classes=self.task_config.model.num_classes,
|
537
|
+
rescale_predictions=False,
|
538
|
+
dtype=tf.float32)
|
539
|
+
|
540
|
+
return metrics
|
541
|
+
|
542
|
+
def train_step(self,
|
543
|
+
inputs: Tuple[Any, Any],
|
544
|
+
model: tf_keras.Model,
|
545
|
+
optimizer: tf_keras.optimizers.Optimizer,
|
546
|
+
metrics: Optional[List[Any]] = None) -> Dict[str, Any]:
|
547
|
+
"""Does forward and backward.
|
548
|
+
|
549
|
+
Args:
|
550
|
+
inputs: a dictionary of input tensors.
|
551
|
+
model: the model, forward pass definition.
|
552
|
+
optimizer: the optimizer for this training step.
|
553
|
+
metrics: a nested structure of metrics objects.
|
554
|
+
|
555
|
+
Returns:
|
556
|
+
A dictionary of logs.
|
557
|
+
"""
|
558
|
+
images, labels = inputs
|
559
|
+
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
|
560
|
+
|
561
|
+
with tf.GradientTape() as tape:
|
562
|
+
outputs = model(
|
563
|
+
images,
|
564
|
+
box_indices=labels['seg_box_indices'],
|
565
|
+
classes=labels['seg_classes'],
|
566
|
+
training=True)
|
567
|
+
outputs = tf.nest.map_structure(
|
568
|
+
lambda x: tf.cast(x, tf.float32), outputs)
|
569
|
+
|
570
|
+
# Computes per-replica loss.
|
571
|
+
losses = self.build_losses(
|
572
|
+
outputs=outputs,
|
573
|
+
labels=labels,
|
574
|
+
aux_losses=model.losses,
|
575
|
+
step=optimizer.iterations)
|
576
|
+
scaled_loss = losses['total_loss'] / num_replicas
|
577
|
+
|
578
|
+
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
|
579
|
+
# scaled for numerical stability.
|
580
|
+
if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
|
581
|
+
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
|
582
|
+
|
583
|
+
tvars = model.trainable_variables
|
584
|
+
grads = tape.gradient(scaled_loss, tvars)
|
585
|
+
# Scales back gradient when LossScaleOptimizer is used.
|
586
|
+
if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
|
587
|
+
grads = optimizer.get_unscaled_gradients(grads)
|
588
|
+
optimizer.apply_gradients(list(zip(grads, tvars)))
|
589
|
+
|
590
|
+
logs = {self.loss: losses['total_loss']}
|
591
|
+
|
592
|
+
if metrics:
|
593
|
+
for m in metrics:
|
594
|
+
m.update_state(losses[m.name])
|
595
|
+
|
596
|
+
return logs
|
597
|
+
|
598
|
+
def validation_step(self,
|
599
|
+
inputs: Tuple[Any, Any],
|
600
|
+
model: tf_keras.Model,
|
601
|
+
metrics: Optional[List[Any]] = None):
|
602
|
+
"""Validatation step.
|
603
|
+
|
604
|
+
Args:
|
605
|
+
inputs: a dictionary of input tensors.
|
606
|
+
model: the keras.Model.
|
607
|
+
metrics: a nested structure of metrics objects.
|
608
|
+
|
609
|
+
Returns:
|
610
|
+
A dictionary of logs.
|
611
|
+
"""
|
612
|
+
features, labels = inputs
|
613
|
+
|
614
|
+
outputs = model(
|
615
|
+
features,
|
616
|
+
image_info=labels['image_info'],
|
617
|
+
training=False)
|
618
|
+
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
|
619
|
+
|
620
|
+
logs = {self.loss: 0}
|
621
|
+
outputs = tf.one_hot(
|
622
|
+
tf.cast(outputs['panoptic_outputs']['category_mask'], tf.int32),
|
623
|
+
self.task_config.model.num_classes)
|
624
|
+
|
625
|
+
self.iou_metric.update_state(labels, tf.cast(outputs, tf.float32))
|
626
|
+
return logs
|
627
|
+
|
628
|
+
def aggregate_logs(self, state=None, step_outputs=None):
|
629
|
+
if state is None:
|
630
|
+
self.iou_metric.reset_states()
|
631
|
+
state = self.iou_metric
|
632
|
+
return state
|
633
|
+
|
634
|
+
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
|
635
|
+
result = {}
|
636
|
+
ious = self.iou_metric.result()
|
637
|
+
for i, value in enumerate(ious.numpy()):
|
638
|
+
result.update({'iou/{}'.format(i): value})
|
639
|
+
# Computes mean IoU
|
640
|
+
result.update({'mean_iou': tf.reduce_mean(ious).numpy()})
|
641
|
+
return result
|