tf-models-nightly 2.18.0.dev20240819__py2.py3-none-any.whl → 2.18.0.dev20240821__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/legacy/transformer/transformer_main.py +0 -2
- official/projects/maskconver/__init__.py +14 -0
- official/projects/maskconver/configs/__init__.py +14 -0
- official/projects/maskconver/configs/backbones.py +43 -0
- official/projects/maskconver/configs/decoders.py +36 -0
- official/projects/maskconver/configs/maskconver.py +523 -0
- official/projects/maskconver/configs/multiscale_maskconver.py +215 -0
- official/projects/maskconver/tasks/__init__.py +14 -0
- official/projects/maskconver/tasks/maskconver.py +641 -0
- official/projects/maskconver/tasks/multiscale_maskconver.py +278 -0
- official/projects/maskconver/train.py +30 -0
- {tf_models_nightly-2.18.0.dev20240819.dist-info → tf_models_nightly-2.18.0.dev20240821.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.18.0.dev20240819.dist-info → tf_models_nightly-2.18.0.dev20240821.dist-info}/RECORD +17 -7
- {tf_models_nightly-2.18.0.dev20240819.dist-info → tf_models_nightly-2.18.0.dev20240821.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.18.0.dev20240819.dist-info → tf_models_nightly-2.18.0.dev20240821.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.18.0.dev20240819.dist-info → tf_models_nightly-2.18.0.dev20240821.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.18.0.dev20240819.dist-info → tf_models_nightly-2.18.0.dev20240821.dist-info}/top_level.txt +0 -0
@@ -457,8 +457,6 @@ def _ensure_dir(log_dir):
|
|
457
457
|
|
458
458
|
def main(_):
|
459
459
|
flags_obj = flags.FLAGS
|
460
|
-
if flags_obj.enable_mlir_bridge:
|
461
|
-
tf.config.experimental.enable_mlir_bridge()
|
462
460
|
task = TransformerTask(flags_obj)
|
463
461
|
|
464
462
|
# Execute flag override logic for better model performance
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
@@ -0,0 +1,43 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Backbones configurations."""
|
16
|
+
import dataclasses
|
17
|
+
|
18
|
+
from typing import List, Optional
|
19
|
+
from official.modeling import hyperparams
|
20
|
+
from official.vision.configs.google import backbones
|
21
|
+
|
22
|
+
|
23
|
+
@dataclasses.dataclass
|
24
|
+
class ResNetUNet(hyperparams.Config):
|
25
|
+
"""ResNetUNet config."""
|
26
|
+
model_id: int = 50
|
27
|
+
depth_multiplier: float = 1.0
|
28
|
+
stem_type: str = 'v0'
|
29
|
+
se_ratio: float = 0.0
|
30
|
+
stochastic_depth_drop_rate: float = 0.0
|
31
|
+
scale_stem: bool = True
|
32
|
+
resnetd_shortcut: bool = False
|
33
|
+
replace_stem_max_pool: bool = False
|
34
|
+
bn_trainable: bool = True
|
35
|
+
classification_output: bool = False
|
36
|
+
upsample_kernel_sizes: Optional[List[int]] = None
|
37
|
+
upsample_repeats: Optional[List[int]] = None
|
38
|
+
upsample_filters: Optional[List[int]] = None
|
39
|
+
|
40
|
+
|
41
|
+
@dataclasses.dataclass
|
42
|
+
class Backbone(backbones.Backbone):
|
43
|
+
resnet_unet: ResNetUNet = dataclasses.field(default_factory=ResNetUNet)
|
@@ -0,0 +1,36 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Decoders configurations."""
|
16
|
+
import dataclasses
|
17
|
+
|
18
|
+
from official.modeling import hyperparams
|
19
|
+
from official.vision.configs import decoders
|
20
|
+
|
21
|
+
|
22
|
+
@dataclasses.dataclass
|
23
|
+
class MaskConverFPN(hyperparams.Config):
|
24
|
+
"""FPN config."""
|
25
|
+
num_filters: int = 256
|
26
|
+
fusion_type: str = 'sum'
|
27
|
+
use_separable_conv: bool = False
|
28
|
+
use_keras_layer: bool = False
|
29
|
+
use_layer_norm: bool = True
|
30
|
+
depthwise_kernel_size: int = 7
|
31
|
+
|
32
|
+
|
33
|
+
@dataclasses.dataclass
|
34
|
+
class Decoder(decoders.Decoder):
|
35
|
+
maskconver_fpn: MaskConverFPN = dataclasses.field(
|
36
|
+
default_factory=MaskConverFPN)
|
@@ -0,0 +1,523 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Panoptic Mask R-CNN configuration definition."""
|
16
|
+
|
17
|
+
import dataclasses
|
18
|
+
import os
|
19
|
+
from typing import List, Optional
|
20
|
+
|
21
|
+
from official.core import config_definitions as cfg
|
22
|
+
from official.core import exp_factory
|
23
|
+
from official.modeling import hyperparams
|
24
|
+
from official.modeling import optimization
|
25
|
+
from official.projects.maskconver.configs import backbones
|
26
|
+
from official.projects.maskconver.configs import decoders
|
27
|
+
from official.vision.configs import common
|
28
|
+
from official.vision.configs import maskrcnn
|
29
|
+
from official.vision.configs import semantic_segmentation
|
30
|
+
|
31
|
+
|
32
|
+
_COCO_INPUT_PATH_BASE = 'coco/tfrecords'
|
33
|
+
_COCO_TRAIN_EXAMPLES = 118287
|
34
|
+
_COCO_VAL_EXAMPLES = 5000
|
35
|
+
|
36
|
+
# PASCAL VOC 2012 Dataset
|
37
|
+
_PASCAL_TRAIN_EXAMPLES = 10582
|
38
|
+
_PASCAL_VAL_EXAMPLES = 1449
|
39
|
+
_PASCAL_INPUT_PATH_BASE = 'gs://**/pascal_voc_seg'
|
40
|
+
|
41
|
+
# Cityscapes Dataset
|
42
|
+
_CITYSCAPES_TRAIN_EXAMPLES = 2975
|
43
|
+
_CITYSCAPES_VAL_EXAMPLES = 500
|
44
|
+
_CITYSCAPES_INPUT_PATH_BASE = 'cityscapes/tfrecord'
|
45
|
+
|
46
|
+
|
47
|
+
# pytype: disable=wrong-keyword-args
|
48
|
+
# pylint: disable=unexpected-keyword-arg
|
49
|
+
|
50
|
+
|
51
|
+
@dataclasses.dataclass
|
52
|
+
class CopyPaste(hyperparams.Config):
|
53
|
+
copypaste_frequency: float = 0.5
|
54
|
+
aug_scale_min: float = 0.1
|
55
|
+
aug_scale_max: float = 1.9
|
56
|
+
copypaste_aug_scale_max: float = 1.0
|
57
|
+
copypaste_aug_scale_min: float = 0.05
|
58
|
+
|
59
|
+
|
60
|
+
@dataclasses.dataclass
|
61
|
+
class Parser(maskrcnn.Parser):
|
62
|
+
"""MaskConver parser config."""
|
63
|
+
# If segmentation_resize_eval_groundtruth is set to False, original image
|
64
|
+
# sizes are used for eval. In that case,
|
65
|
+
# segmentation_groundtruth_padded_size has to be specified too to allow for
|
66
|
+
# batching the variable input sizes of images.
|
67
|
+
segmentation_resize_eval_groundtruth: bool = True
|
68
|
+
segmentation_groundtruth_padded_size: List[int] = dataclasses.field(
|
69
|
+
default_factory=list)
|
70
|
+
segmentation_ignore_label: int = 0
|
71
|
+
panoptic_ignore_label: int = 0
|
72
|
+
# Setting this to true will enable parsing category_mask and instance_mask.
|
73
|
+
include_panoptic_masks: bool = True
|
74
|
+
gaussaian_iou: float = 0.7
|
75
|
+
max_num_stuff_centers: int = 3
|
76
|
+
aug_type: common.Augmentation = dataclasses.field(
|
77
|
+
default_factory=common.Augmentation
|
78
|
+
)
|
79
|
+
copypaste: CopyPaste = dataclasses.field(default_factory=CopyPaste)
|
80
|
+
|
81
|
+
|
82
|
+
@dataclasses.dataclass
|
83
|
+
class TfExampleDecoder(common.TfExampleDecoder):
|
84
|
+
"""A simple TF Example decoder config."""
|
85
|
+
# Setting this to true will enable decoding category_mask and instance_mask.
|
86
|
+
include_panoptic_masks: bool = True
|
87
|
+
panoptic_category_mask_key: str = 'image/panoptic/category_mask'
|
88
|
+
panoptic_instance_mask_key: str = 'image/panoptic/instance_mask'
|
89
|
+
|
90
|
+
|
91
|
+
@dataclasses.dataclass
|
92
|
+
class DataDecoder(common.DataDecoder):
|
93
|
+
"""Data decoder config."""
|
94
|
+
simple_decoder: TfExampleDecoder = dataclasses.field(
|
95
|
+
default_factory=TfExampleDecoder
|
96
|
+
)
|
97
|
+
|
98
|
+
|
99
|
+
@dataclasses.dataclass
|
100
|
+
class DataConfig(maskrcnn.DataConfig):
|
101
|
+
"""Input config for training."""
|
102
|
+
decoder: DataDecoder = dataclasses.field(default_factory=DataDecoder)
|
103
|
+
parser: Parser = dataclasses.field(default_factory=Parser)
|
104
|
+
dtype: str = 'float32'
|
105
|
+
prefetch_buffer_size: int = 8
|
106
|
+
|
107
|
+
|
108
|
+
@dataclasses.dataclass
|
109
|
+
class Anchor(hyperparams.Config):
|
110
|
+
num_scales: int = 1
|
111
|
+
aspect_ratios: List[float] = dataclasses.field(
|
112
|
+
default_factory=lambda: [0.5, 1.0, 2.0])
|
113
|
+
anchor_size: float = 8.0
|
114
|
+
|
115
|
+
|
116
|
+
@dataclasses.dataclass
|
117
|
+
class PanopticGenerator(hyperparams.Config):
|
118
|
+
"""MaskConver panoptic generator."""
|
119
|
+
object_mask_threshold: float = 0.001
|
120
|
+
small_area_threshold: int = 0
|
121
|
+
overlap_threshold: float = 0.5
|
122
|
+
rescale_predictions: bool = True
|
123
|
+
use_hardware_optimization: bool = False
|
124
|
+
|
125
|
+
|
126
|
+
@dataclasses.dataclass
|
127
|
+
class SegmentationHead(semantic_segmentation.SegmentationHead):
|
128
|
+
"""Segmentation head config."""
|
129
|
+
depthwise_kernel_size: int = 7
|
130
|
+
use_layer_norm: bool = False
|
131
|
+
|
132
|
+
|
133
|
+
@dataclasses.dataclass
|
134
|
+
class MaskConver(hyperparams.Config):
|
135
|
+
"""MaskConver model config."""
|
136
|
+
num_classes: int = 0
|
137
|
+
num_thing_classes: int = 0
|
138
|
+
num_instances: int = 100
|
139
|
+
embedding_size: int = 512
|
140
|
+
padded_output_size: List[int] = dataclasses.field(default_factory=list)
|
141
|
+
input_size: List[int] = dataclasses.field(default_factory=list)
|
142
|
+
min_level: int = 2
|
143
|
+
max_level: int = 6
|
144
|
+
num_anchors: int = 100
|
145
|
+
panoptic_fusion_num_filters: int = 256
|
146
|
+
anchor: Anchor = dataclasses.field(default_factory=Anchor)
|
147
|
+
level: int = 3
|
148
|
+
class_head: SegmentationHead = dataclasses.field(
|
149
|
+
default_factory=SegmentationHead
|
150
|
+
)
|
151
|
+
mask_embedding_head: SegmentationHead = dataclasses.field(
|
152
|
+
default_factory=SegmentationHead
|
153
|
+
)
|
154
|
+
per_pixel_embedding_head: SegmentationHead = dataclasses.field(
|
155
|
+
default_factory=SegmentationHead
|
156
|
+
)
|
157
|
+
backbone: backbones.Backbone = dataclasses.field(
|
158
|
+
default_factory=backbones.Backbone
|
159
|
+
)
|
160
|
+
decoder: decoders.Decoder = dataclasses.field(
|
161
|
+
default_factory=lambda: decoders.Decoder(type='identity')
|
162
|
+
)
|
163
|
+
mask_decoder: Optional[decoders.Decoder] = dataclasses.field(
|
164
|
+
default_factory=lambda: decoders.Decoder(type='identity')
|
165
|
+
)
|
166
|
+
norm_activation: common.NormActivation = dataclasses.field(
|
167
|
+
default_factory=common.NormActivation
|
168
|
+
)
|
169
|
+
panoptic_generator: PanopticGenerator = dataclasses.field(
|
170
|
+
default_factory=PanopticGenerator
|
171
|
+
)
|
172
|
+
|
173
|
+
|
174
|
+
@dataclasses.dataclass
|
175
|
+
class Losses(hyperparams.Config):
|
176
|
+
"""maskconver loss config."""
|
177
|
+
l2_weight_decay: float = 0.0
|
178
|
+
ignore_label: int = 0
|
179
|
+
use_groundtruth_dimension: bool = True
|
180
|
+
top_k_percent_pixels_category: float = 1.0
|
181
|
+
top_k_percent_pixels_instance: float = 1.0
|
182
|
+
loss_weight: float = 1.0
|
183
|
+
mask_weight: float = 10.0
|
184
|
+
beta: float = 4.0
|
185
|
+
alpha: float = 2.0
|
186
|
+
|
187
|
+
|
188
|
+
@dataclasses.dataclass
|
189
|
+
class PanopticQualityEvaluator(hyperparams.Config):
|
190
|
+
"""Panoptic Quality Evaluator config."""
|
191
|
+
num_categories: int = 2
|
192
|
+
ignored_label: int = 0
|
193
|
+
max_instances_per_category: int = 256
|
194
|
+
offset: int = 256 * 256 * 256
|
195
|
+
is_thing: List[float] = dataclasses.field(
|
196
|
+
default_factory=list)
|
197
|
+
rescale_predictions: bool = True
|
198
|
+
report_per_class_metrics: bool = False
|
199
|
+
|
200
|
+
###################################
|
201
|
+
###### PANOPTIC SEGMENTATION ######
|
202
|
+
###################################
|
203
|
+
|
204
|
+
|
205
|
+
@dataclasses.dataclass
|
206
|
+
class MaskConverTask(cfg.TaskConfig):
|
207
|
+
"""MaskConverTask task config."""
|
208
|
+
model: MaskConver = dataclasses.field(default_factory=MaskConver)
|
209
|
+
train_data: DataConfig = dataclasses.field(
|
210
|
+
default_factory=lambda: DataConfig(is_training=True)
|
211
|
+
)
|
212
|
+
# pylint: disable=g-long-lambda
|
213
|
+
validation_data: DataConfig = dataclasses.field(
|
214
|
+
default_factory=lambda: DataConfig(
|
215
|
+
is_training=False, drop_remainder=False
|
216
|
+
)
|
217
|
+
)
|
218
|
+
losses: Losses = dataclasses.field(default_factory=Losses)
|
219
|
+
init_checkpoint: Optional[str] = None
|
220
|
+
|
221
|
+
init_checkpoint_modules: Optional[List[str]] = dataclasses.field(
|
222
|
+
default_factory=list)
|
223
|
+
panoptic_quality_evaluator: PanopticQualityEvaluator = dataclasses.field(
|
224
|
+
default_factory=PanopticQualityEvaluator
|
225
|
+
)
|
226
|
+
# pylint: enable=g-long-lambda
|
227
|
+
|
228
|
+
|
229
|
+
@exp_factory.register_config_factory('maskconver_coco')
|
230
|
+
def maskconver_coco() -> cfg.ExperimentConfig:
|
231
|
+
"""COCO panoptic segmentation with MaskConver."""
|
232
|
+
train_batch_size = 64
|
233
|
+
eval_batch_size = 8
|
234
|
+
# steps_per_epoch = _COCO_TRAIN_EXAMPLES // train_batch_size
|
235
|
+
validation_steps = _COCO_VAL_EXAMPLES // eval_batch_size
|
236
|
+
|
237
|
+
# coco panoptic dataset has category ids ranging from [0-200] inclusive.
|
238
|
+
# 0 is not used and represents the background class
|
239
|
+
# ids 1-91 represent thing categories (91)
|
240
|
+
# ids 92-200 represent stuff categories (109)
|
241
|
+
# for the segmentation task, we continue using id=0 for the background
|
242
|
+
# and map all thing categories to id=1, the remaining 109 stuff categories
|
243
|
+
# are shifted by an offset=90 given by num_thing classes - 1. This shifting
|
244
|
+
# will make all the stuff categories begin from id=2 and end at id=110
|
245
|
+
num_panoptic_categories = 201
|
246
|
+
num_thing_categories = 91
|
247
|
+
# num_semantic_segmentation_classes = 111
|
248
|
+
|
249
|
+
is_thing = [False]
|
250
|
+
for idx in range(1, num_panoptic_categories):
|
251
|
+
is_thing.append(True if idx < num_thing_categories else False)
|
252
|
+
|
253
|
+
config = cfg.ExperimentConfig(
|
254
|
+
runtime=cfg.RuntimeConfig(
|
255
|
+
mixed_precision_dtype='float32', enable_xla=False),
|
256
|
+
task=MaskConverTask(
|
257
|
+
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080', # pylint: disable=line-too-long
|
258
|
+
init_checkpoint_modules=['backbone'],
|
259
|
+
model=MaskConver(
|
260
|
+
num_classes=201, num_thing_classes=91, input_size=[512, 512, 3],
|
261
|
+
padded_output_size=[512, 512]),
|
262
|
+
losses=Losses(l2_weight_decay=0.0),
|
263
|
+
train_data=DataConfig(
|
264
|
+
input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'train-nocrowd*'),
|
265
|
+
is_training=True,
|
266
|
+
global_batch_size=train_batch_size,
|
267
|
+
parser=Parser(
|
268
|
+
aug_rand_hflip=True, aug_scale_min=0.5, aug_scale_max=1.5,
|
269
|
+
aug_type=common.Augmentation(
|
270
|
+
type='autoaug',
|
271
|
+
autoaug=common.AutoAugment(
|
272
|
+
augmentation_name='panoptic_deeplab_policy')))),
|
273
|
+
validation_data=DataConfig(
|
274
|
+
input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'val-nocrowd*'),
|
275
|
+
is_training=False,
|
276
|
+
global_batch_size=eval_batch_size,
|
277
|
+
parser=Parser(
|
278
|
+
segmentation_resize_eval_groundtruth=True,
|
279
|
+
segmentation_groundtruth_padded_size=[640, 640]),
|
280
|
+
drop_remainder=False),
|
281
|
+
panoptic_quality_evaluator=PanopticQualityEvaluator(
|
282
|
+
num_categories=num_panoptic_categories,
|
283
|
+
ignored_label=0,
|
284
|
+
is_thing=is_thing,
|
285
|
+
rescale_predictions=True)),
|
286
|
+
trainer=cfg.TrainerConfig(
|
287
|
+
train_steps=200000,
|
288
|
+
validation_steps=validation_steps,
|
289
|
+
validation_interval=1000,
|
290
|
+
steps_per_loop=1000,
|
291
|
+
summary_interval=1000,
|
292
|
+
checkpoint_interval=1000,
|
293
|
+
optimizer_config=optimization.OptimizationConfig({
|
294
|
+
'optimizer': {
|
295
|
+
'type': 'sgd',
|
296
|
+
'sgd': {
|
297
|
+
'momentum': 0.9
|
298
|
+
}
|
299
|
+
},
|
300
|
+
'learning_rate': {
|
301
|
+
'type': 'cosine',
|
302
|
+
'cosine': {
|
303
|
+
'initial_learning_rate': 0.08,
|
304
|
+
'decay_steps': 200000,
|
305
|
+
}
|
306
|
+
},
|
307
|
+
'warmup': {
|
308
|
+
'type': 'linear',
|
309
|
+
'linear': {
|
310
|
+
'warmup_steps': 2000,
|
311
|
+
'warmup_learning_rate': 0.0
|
312
|
+
}
|
313
|
+
}
|
314
|
+
})),
|
315
|
+
restrictions=[
|
316
|
+
'task.train_data.is_training != None',
|
317
|
+
'task.validation_data.is_training != None'
|
318
|
+
])
|
319
|
+
return config
|
320
|
+
|
321
|
+
###################################
|
322
|
+
###### SEMANTIC SEGMENTATION ######
|
323
|
+
###################################
|
324
|
+
|
325
|
+
|
326
|
+
@dataclasses.dataclass
|
327
|
+
class SegDataConfig(cfg.DataConfig):
|
328
|
+
"""Input config for training."""
|
329
|
+
output_size: List[int] = dataclasses.field(default_factory=list)
|
330
|
+
# If crop_size is specified, image will be resized first to
|
331
|
+
# output_size, then crop of size crop_size will be cropped.
|
332
|
+
crop_size: List[int] = dataclasses.field(default_factory=list)
|
333
|
+
input_path: str = ''
|
334
|
+
global_batch_size: int = 0
|
335
|
+
is_training: bool = True
|
336
|
+
dtype: str = 'float32'
|
337
|
+
shuffle_buffer_size: int = 1000
|
338
|
+
prefetch_buffer_size: int = 8
|
339
|
+
cycle_length: int = 10
|
340
|
+
# If resize_eval_groundtruth is set to False, original image sizes are used
|
341
|
+
# for eval. In that case, groundtruth_padded_size has to be specified too to
|
342
|
+
# allow for batching the variable input sizes of images.
|
343
|
+
resize_eval_groundtruth: bool = True
|
344
|
+
groundtruth_padded_size: List[int] = dataclasses.field(default_factory=list)
|
345
|
+
aug_scale_min: float = 1.0
|
346
|
+
aug_scale_max: float = 1.0
|
347
|
+
aug_rand_hflip: bool = True
|
348
|
+
preserve_aspect_ratio: bool = True
|
349
|
+
aug_policy: Optional[str] = None
|
350
|
+
drop_remainder: bool = True
|
351
|
+
file_type: str = 'tfrecord'
|
352
|
+
gaussaian_iou: float = 0.7
|
353
|
+
max_num_stuff_centers: int = 3
|
354
|
+
max_num_instances: int = 100
|
355
|
+
aug_type: common.Augmentation = dataclasses.field(
|
356
|
+
default_factory=common.Augmentation)
|
357
|
+
|
358
|
+
|
359
|
+
@dataclasses.dataclass
|
360
|
+
class MaskConverSegTask(cfg.TaskConfig):
|
361
|
+
"""MaskConverTask task config."""
|
362
|
+
model: MaskConver = dataclasses.field(default_factory=MaskConver)
|
363
|
+
train_data: DataConfig = dataclasses.field(
|
364
|
+
default_factory=lambda: SegDataConfig(is_training=True)
|
365
|
+
)
|
366
|
+
# pylint: disable=g-long-lambda
|
367
|
+
validation_data: DataConfig = dataclasses.field(
|
368
|
+
default_factory=lambda: SegDataConfig(
|
369
|
+
is_training=False, drop_remainder=False
|
370
|
+
)
|
371
|
+
)
|
372
|
+
# pylint: enable=g-long-lambda
|
373
|
+
losses: Losses = dataclasses.field(default_factory=Losses)
|
374
|
+
init_checkpoint: Optional[str] = None
|
375
|
+
|
376
|
+
init_checkpoint_modules: Optional[List[str]] = dataclasses.field(
|
377
|
+
default_factory=list)
|
378
|
+
|
379
|
+
|
380
|
+
@exp_factory.register_config_factory('maskconver_seg_pascal')
|
381
|
+
def maskconver_seg_pascal() -> cfg.ExperimentConfig:
|
382
|
+
"""COCO panoptic segmentation with MaskConver."""
|
383
|
+
train_batch_size = 64
|
384
|
+
eval_batch_size = 8
|
385
|
+
validation_steps = _PASCAL_VAL_EXAMPLES // eval_batch_size
|
386
|
+
|
387
|
+
config = cfg.ExperimentConfig(
|
388
|
+
runtime=cfg.RuntimeConfig(
|
389
|
+
mixed_precision_dtype='float32', enable_xla=False),
|
390
|
+
task=MaskConverSegTask(
|
391
|
+
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080', # pylint: disable=line-too-long
|
392
|
+
init_checkpoint_modules=['backbone'],
|
393
|
+
model=MaskConver(
|
394
|
+
num_classes=21, num_thing_classes=91, input_size=[512, 512, 3],
|
395
|
+
padded_output_size=[512, 512]),
|
396
|
+
losses=Losses(l2_weight_decay=0.00004),
|
397
|
+
train_data=SegDataConfig(
|
398
|
+
input_path=os.path.join(_PASCAL_INPUT_PATH_BASE, 'train_aug*'),
|
399
|
+
output_size=[512, 512],
|
400
|
+
is_training=True,
|
401
|
+
global_batch_size=train_batch_size,
|
402
|
+
aug_scale_min=0.5,
|
403
|
+
aug_scale_max=2.0,
|
404
|
+
aug_type=common.Augmentation(
|
405
|
+
type='autoaug',
|
406
|
+
autoaug=common.AutoAugment(
|
407
|
+
augmentation_name='panoptic_deeplab_policy'))),
|
408
|
+
validation_data=SegDataConfig(
|
409
|
+
input_path=os.path.join(_PASCAL_INPUT_PATH_BASE, 'val*'),
|
410
|
+
output_size=[512, 512],
|
411
|
+
is_training=False,
|
412
|
+
global_batch_size=eval_batch_size,
|
413
|
+
resize_eval_groundtruth=False,
|
414
|
+
groundtruth_padded_size=[512, 512],
|
415
|
+
drop_remainder=False)),
|
416
|
+
trainer=cfg.TrainerConfig(
|
417
|
+
train_steps=200000,
|
418
|
+
validation_steps=validation_steps,
|
419
|
+
validation_interval=1000,
|
420
|
+
steps_per_loop=1000,
|
421
|
+
summary_interval=1000,
|
422
|
+
checkpoint_interval=1000,
|
423
|
+
optimizer_config=optimization.OptimizationConfig({
|
424
|
+
'optimizer': {
|
425
|
+
'type': 'sgd',
|
426
|
+
'sgd': {
|
427
|
+
'momentum': 0.9
|
428
|
+
}
|
429
|
+
},
|
430
|
+
'learning_rate': {
|
431
|
+
'type': 'cosine',
|
432
|
+
'cosine': {
|
433
|
+
'initial_learning_rate': 0.08,
|
434
|
+
'decay_steps': 200000,
|
435
|
+
}
|
436
|
+
},
|
437
|
+
'warmup': {
|
438
|
+
'type': 'linear',
|
439
|
+
'linear': {
|
440
|
+
'warmup_steps': 2000,
|
441
|
+
'warmup_learning_rate': 0.0
|
442
|
+
}
|
443
|
+
}
|
444
|
+
})),
|
445
|
+
restrictions=[
|
446
|
+
'task.train_data.is_training != None',
|
447
|
+
'task.validation_data.is_training != None'
|
448
|
+
])
|
449
|
+
return config
|
450
|
+
|
451
|
+
|
452
|
+
@exp_factory.register_config_factory('maskconver_seg_cityscapes')
|
453
|
+
def maskconver_seg_cityscapes() -> cfg.ExperimentConfig:
|
454
|
+
"""Cityscapes semantic segmentation with MaskConver."""
|
455
|
+
train_batch_size = 32
|
456
|
+
eval_batch_size = 8
|
457
|
+
validation_steps = _CITYSCAPES_VAL_EXAMPLES // eval_batch_size
|
458
|
+
|
459
|
+
config = cfg.ExperimentConfig(
|
460
|
+
runtime=cfg.RuntimeConfig(
|
461
|
+
mixed_precision_dtype='float32', enable_xla=False),
|
462
|
+
task=MaskConverSegTask(
|
463
|
+
init_checkpoint='maskconver_seg_mnv3p5rf_coco_200k/43437096', # pylint: disable=line-too-long
|
464
|
+
init_checkpoint_modules=['backbone'],
|
465
|
+
model=MaskConver(
|
466
|
+
num_classes=19, input_size=[None, None, 3],
|
467
|
+
padded_output_size=[1024, 2048]),
|
468
|
+
losses=Losses(l2_weight_decay=0.00004),
|
469
|
+
train_data=SegDataConfig(
|
470
|
+
input_path=os.path.join(_CITYSCAPES_INPUT_PATH_BASE,
|
471
|
+
'train_fine*'),
|
472
|
+
output_size=[1024, 2048],
|
473
|
+
crop_size=[512, 1024],
|
474
|
+
is_training=True,
|
475
|
+
global_batch_size=train_batch_size,
|
476
|
+
aug_scale_min=0.5,
|
477
|
+
aug_scale_max=2.0,
|
478
|
+
aug_type=common.Augmentation(
|
479
|
+
type='autoaug',
|
480
|
+
autoaug=common.AutoAugment(
|
481
|
+
augmentation_name='panoptic_deeplab_policy'))),
|
482
|
+
validation_data=SegDataConfig(
|
483
|
+
input_path=os.path.join(_CITYSCAPES_INPUT_PATH_BASE, 'val_fine*'),
|
484
|
+
output_size=[1024, 2048],
|
485
|
+
is_training=False,
|
486
|
+
global_batch_size=eval_batch_size,
|
487
|
+
resize_eval_groundtruth=False,
|
488
|
+
groundtruth_padded_size=[1024, 2048],
|
489
|
+
drop_remainder=False)),
|
490
|
+
trainer=cfg.TrainerConfig(
|
491
|
+
train_steps=100000,
|
492
|
+
validation_steps=validation_steps,
|
493
|
+
validation_interval=185,
|
494
|
+
steps_per_loop=185,
|
495
|
+
summary_interval=185,
|
496
|
+
checkpoint_interval=185,
|
497
|
+
optimizer_config=optimization.OptimizationConfig({
|
498
|
+
'optimizer': {
|
499
|
+
'type': 'sgd',
|
500
|
+
'sgd': {
|
501
|
+
'momentum': 0.9
|
502
|
+
}
|
503
|
+
},
|
504
|
+
'learning_rate': {
|
505
|
+
'type': 'polynomial',
|
506
|
+
'polynomial': {
|
507
|
+
'initial_learning_rate': 0.01,
|
508
|
+
'decay_steps': 100000,
|
509
|
+
}
|
510
|
+
},
|
511
|
+
'warmup': {
|
512
|
+
'type': 'linear',
|
513
|
+
'linear': {
|
514
|
+
'warmup_steps': 925,
|
515
|
+
'warmup_learning_rate': 0.0
|
516
|
+
}
|
517
|
+
}
|
518
|
+
})),
|
519
|
+
restrictions=[
|
520
|
+
'task.train_data.is_training != None',
|
521
|
+
'task.validation_data.is_training != None'
|
522
|
+
])
|
523
|
+
return config
|