tf-models-nightly 2.18.0.dev20240820__py2.py3-none-any.whl → 2.18.0.dev20240822__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -457,8 +457,6 @@ def _ensure_dir(log_dir):
457
457
 
458
458
  def main(_):
459
459
  flags_obj = flags.FLAGS
460
- if flags_obj.enable_mlir_bridge:
461
- tf.config.experimental.enable_mlir_bridge()
462
460
  task = TransformerTask(flags_obj)
463
461
 
464
462
  # Execute flag override logic for better model performance
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
@@ -0,0 +1,43 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Backbones configurations."""
16
+ import dataclasses
17
+
18
+ from typing import List, Optional
19
+ from official.modeling import hyperparams
20
+ from official.vision.configs.google import backbones
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class ResNetUNet(hyperparams.Config):
25
+ """ResNetUNet config."""
26
+ model_id: int = 50
27
+ depth_multiplier: float = 1.0
28
+ stem_type: str = 'v0'
29
+ se_ratio: float = 0.0
30
+ stochastic_depth_drop_rate: float = 0.0
31
+ scale_stem: bool = True
32
+ resnetd_shortcut: bool = False
33
+ replace_stem_max_pool: bool = False
34
+ bn_trainable: bool = True
35
+ classification_output: bool = False
36
+ upsample_kernel_sizes: Optional[List[int]] = None
37
+ upsample_repeats: Optional[List[int]] = None
38
+ upsample_filters: Optional[List[int]] = None
39
+
40
+
41
+ @dataclasses.dataclass
42
+ class Backbone(backbones.Backbone):
43
+ resnet_unet: ResNetUNet = dataclasses.field(default_factory=ResNetUNet)
@@ -0,0 +1,36 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Decoders configurations."""
16
+ import dataclasses
17
+
18
+ from official.modeling import hyperparams
19
+ from official.vision.configs import decoders
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class MaskConverFPN(hyperparams.Config):
24
+ """FPN config."""
25
+ num_filters: int = 256
26
+ fusion_type: str = 'sum'
27
+ use_separable_conv: bool = False
28
+ use_keras_layer: bool = False
29
+ use_layer_norm: bool = True
30
+ depthwise_kernel_size: int = 7
31
+
32
+
33
+ @dataclasses.dataclass
34
+ class Decoder(decoders.Decoder):
35
+ maskconver_fpn: MaskConverFPN = dataclasses.field(
36
+ default_factory=MaskConverFPN)
@@ -0,0 +1,523 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Panoptic Mask R-CNN configuration definition."""
16
+
17
+ import dataclasses
18
+ import os
19
+ from typing import List, Optional
20
+
21
+ from official.core import config_definitions as cfg
22
+ from official.core import exp_factory
23
+ from official.modeling import hyperparams
24
+ from official.modeling import optimization
25
+ from official.projects.maskconver.configs import backbones
26
+ from official.projects.maskconver.configs import decoders
27
+ from official.vision.configs import common
28
+ from official.vision.configs import maskrcnn
29
+ from official.vision.configs import semantic_segmentation
30
+
31
+
32
+ _COCO_INPUT_PATH_BASE = 'coco/tfrecords'
33
+ _COCO_TRAIN_EXAMPLES = 118287
34
+ _COCO_VAL_EXAMPLES = 5000
35
+
36
+ # PASCAL VOC 2012 Dataset
37
+ _PASCAL_TRAIN_EXAMPLES = 10582
38
+ _PASCAL_VAL_EXAMPLES = 1449
39
+ _PASCAL_INPUT_PATH_BASE = 'gs://**/pascal_voc_seg'
40
+
41
+ # Cityscapes Dataset
42
+ _CITYSCAPES_TRAIN_EXAMPLES = 2975
43
+ _CITYSCAPES_VAL_EXAMPLES = 500
44
+ _CITYSCAPES_INPUT_PATH_BASE = 'cityscapes/tfrecord'
45
+
46
+
47
+ # pytype: disable=wrong-keyword-args
48
+ # pylint: disable=unexpected-keyword-arg
49
+
50
+
51
+ @dataclasses.dataclass
52
+ class CopyPaste(hyperparams.Config):
53
+ copypaste_frequency: float = 0.5
54
+ aug_scale_min: float = 0.1
55
+ aug_scale_max: float = 1.9
56
+ copypaste_aug_scale_max: float = 1.0
57
+ copypaste_aug_scale_min: float = 0.05
58
+
59
+
60
+ @dataclasses.dataclass
61
+ class Parser(maskrcnn.Parser):
62
+ """MaskConver parser config."""
63
+ # If segmentation_resize_eval_groundtruth is set to False, original image
64
+ # sizes are used for eval. In that case,
65
+ # segmentation_groundtruth_padded_size has to be specified too to allow for
66
+ # batching the variable input sizes of images.
67
+ segmentation_resize_eval_groundtruth: bool = True
68
+ segmentation_groundtruth_padded_size: List[int] = dataclasses.field(
69
+ default_factory=list)
70
+ segmentation_ignore_label: int = 0
71
+ panoptic_ignore_label: int = 0
72
+ # Setting this to true will enable parsing category_mask and instance_mask.
73
+ include_panoptic_masks: bool = True
74
+ gaussaian_iou: float = 0.7
75
+ max_num_stuff_centers: int = 3
76
+ aug_type: common.Augmentation = dataclasses.field(
77
+ default_factory=common.Augmentation
78
+ )
79
+ copypaste: CopyPaste = dataclasses.field(default_factory=CopyPaste)
80
+
81
+
82
+ @dataclasses.dataclass
83
+ class TfExampleDecoder(common.TfExampleDecoder):
84
+ """A simple TF Example decoder config."""
85
+ # Setting this to true will enable decoding category_mask and instance_mask.
86
+ include_panoptic_masks: bool = True
87
+ panoptic_category_mask_key: str = 'image/panoptic/category_mask'
88
+ panoptic_instance_mask_key: str = 'image/panoptic/instance_mask'
89
+
90
+
91
+ @dataclasses.dataclass
92
+ class DataDecoder(common.DataDecoder):
93
+ """Data decoder config."""
94
+ simple_decoder: TfExampleDecoder = dataclasses.field(
95
+ default_factory=TfExampleDecoder
96
+ )
97
+
98
+
99
+ @dataclasses.dataclass
100
+ class DataConfig(maskrcnn.DataConfig):
101
+ """Input config for training."""
102
+ decoder: DataDecoder = dataclasses.field(default_factory=DataDecoder)
103
+ parser: Parser = dataclasses.field(default_factory=Parser)
104
+ dtype: str = 'float32'
105
+ prefetch_buffer_size: int = 8
106
+
107
+
108
+ @dataclasses.dataclass
109
+ class Anchor(hyperparams.Config):
110
+ num_scales: int = 1
111
+ aspect_ratios: List[float] = dataclasses.field(
112
+ default_factory=lambda: [0.5, 1.0, 2.0])
113
+ anchor_size: float = 8.0
114
+
115
+
116
+ @dataclasses.dataclass
117
+ class PanopticGenerator(hyperparams.Config):
118
+ """MaskConver panoptic generator."""
119
+ object_mask_threshold: float = 0.001
120
+ small_area_threshold: int = 0
121
+ overlap_threshold: float = 0.5
122
+ rescale_predictions: bool = True
123
+ use_hardware_optimization: bool = False
124
+
125
+
126
+ @dataclasses.dataclass
127
+ class SegmentationHead(semantic_segmentation.SegmentationHead):
128
+ """Segmentation head config."""
129
+ depthwise_kernel_size: int = 7
130
+ use_layer_norm: bool = False
131
+
132
+
133
+ @dataclasses.dataclass
134
+ class MaskConver(hyperparams.Config):
135
+ """MaskConver model config."""
136
+ num_classes: int = 0
137
+ num_thing_classes: int = 0
138
+ num_instances: int = 100
139
+ embedding_size: int = 512
140
+ padded_output_size: List[int] = dataclasses.field(default_factory=list)
141
+ input_size: List[int] = dataclasses.field(default_factory=list)
142
+ min_level: int = 2
143
+ max_level: int = 6
144
+ num_anchors: int = 100
145
+ panoptic_fusion_num_filters: int = 256
146
+ anchor: Anchor = dataclasses.field(default_factory=Anchor)
147
+ level: int = 3
148
+ class_head: SegmentationHead = dataclasses.field(
149
+ default_factory=SegmentationHead
150
+ )
151
+ mask_embedding_head: SegmentationHead = dataclasses.field(
152
+ default_factory=SegmentationHead
153
+ )
154
+ per_pixel_embedding_head: SegmentationHead = dataclasses.field(
155
+ default_factory=SegmentationHead
156
+ )
157
+ backbone: backbones.Backbone = dataclasses.field(
158
+ default_factory=backbones.Backbone
159
+ )
160
+ decoder: decoders.Decoder = dataclasses.field(
161
+ default_factory=lambda: decoders.Decoder(type='identity')
162
+ )
163
+ mask_decoder: Optional[decoders.Decoder] = dataclasses.field(
164
+ default_factory=lambda: decoders.Decoder(type='identity')
165
+ )
166
+ norm_activation: common.NormActivation = dataclasses.field(
167
+ default_factory=common.NormActivation
168
+ )
169
+ panoptic_generator: PanopticGenerator = dataclasses.field(
170
+ default_factory=PanopticGenerator
171
+ )
172
+
173
+
174
+ @dataclasses.dataclass
175
+ class Losses(hyperparams.Config):
176
+ """maskconver loss config."""
177
+ l2_weight_decay: float = 0.0
178
+ ignore_label: int = 0
179
+ use_groundtruth_dimension: bool = True
180
+ top_k_percent_pixels_category: float = 1.0
181
+ top_k_percent_pixels_instance: float = 1.0
182
+ loss_weight: float = 1.0
183
+ mask_weight: float = 10.0
184
+ beta: float = 4.0
185
+ alpha: float = 2.0
186
+
187
+
188
+ @dataclasses.dataclass
189
+ class PanopticQualityEvaluator(hyperparams.Config):
190
+ """Panoptic Quality Evaluator config."""
191
+ num_categories: int = 2
192
+ ignored_label: int = 0
193
+ max_instances_per_category: int = 256
194
+ offset: int = 256 * 256 * 256
195
+ is_thing: List[float] = dataclasses.field(
196
+ default_factory=list)
197
+ rescale_predictions: bool = True
198
+ report_per_class_metrics: bool = False
199
+
200
+ ###################################
201
+ ###### PANOPTIC SEGMENTATION ######
202
+ ###################################
203
+
204
+
205
+ @dataclasses.dataclass
206
+ class MaskConverTask(cfg.TaskConfig):
207
+ """MaskConverTask task config."""
208
+ model: MaskConver = dataclasses.field(default_factory=MaskConver)
209
+ train_data: DataConfig = dataclasses.field(
210
+ default_factory=lambda: DataConfig(is_training=True)
211
+ )
212
+ # pylint: disable=g-long-lambda
213
+ validation_data: DataConfig = dataclasses.field(
214
+ default_factory=lambda: DataConfig(
215
+ is_training=False, drop_remainder=False
216
+ )
217
+ )
218
+ losses: Losses = dataclasses.field(default_factory=Losses)
219
+ init_checkpoint: Optional[str] = None
220
+
221
+ init_checkpoint_modules: Optional[List[str]] = dataclasses.field(
222
+ default_factory=list)
223
+ panoptic_quality_evaluator: PanopticQualityEvaluator = dataclasses.field(
224
+ default_factory=PanopticQualityEvaluator
225
+ )
226
+ # pylint: enable=g-long-lambda
227
+
228
+
229
+ @exp_factory.register_config_factory('maskconver_coco')
230
+ def maskconver_coco() -> cfg.ExperimentConfig:
231
+ """COCO panoptic segmentation with MaskConver."""
232
+ train_batch_size = 64
233
+ eval_batch_size = 8
234
+ # steps_per_epoch = _COCO_TRAIN_EXAMPLES // train_batch_size
235
+ validation_steps = _COCO_VAL_EXAMPLES // eval_batch_size
236
+
237
+ # coco panoptic dataset has category ids ranging from [0-200] inclusive.
238
+ # 0 is not used and represents the background class
239
+ # ids 1-91 represent thing categories (91)
240
+ # ids 92-200 represent stuff categories (109)
241
+ # for the segmentation task, we continue using id=0 for the background
242
+ # and map all thing categories to id=1, the remaining 109 stuff categories
243
+ # are shifted by an offset=90 given by num_thing classes - 1. This shifting
244
+ # will make all the stuff categories begin from id=2 and end at id=110
245
+ num_panoptic_categories = 201
246
+ num_thing_categories = 91
247
+ # num_semantic_segmentation_classes = 111
248
+
249
+ is_thing = [False]
250
+ for idx in range(1, num_panoptic_categories):
251
+ is_thing.append(True if idx < num_thing_categories else False)
252
+
253
+ config = cfg.ExperimentConfig(
254
+ runtime=cfg.RuntimeConfig(
255
+ mixed_precision_dtype='float32', enable_xla=False),
256
+ task=MaskConverTask(
257
+ init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080', # pylint: disable=line-too-long
258
+ init_checkpoint_modules=['backbone'],
259
+ model=MaskConver(
260
+ num_classes=201, num_thing_classes=91, input_size=[512, 512, 3],
261
+ padded_output_size=[512, 512]),
262
+ losses=Losses(l2_weight_decay=0.0),
263
+ train_data=DataConfig(
264
+ input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'train-nocrowd*'),
265
+ is_training=True,
266
+ global_batch_size=train_batch_size,
267
+ parser=Parser(
268
+ aug_rand_hflip=True, aug_scale_min=0.5, aug_scale_max=1.5,
269
+ aug_type=common.Augmentation(
270
+ type='autoaug',
271
+ autoaug=common.AutoAugment(
272
+ augmentation_name='panoptic_deeplab_policy')))),
273
+ validation_data=DataConfig(
274
+ input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'val-nocrowd*'),
275
+ is_training=False,
276
+ global_batch_size=eval_batch_size,
277
+ parser=Parser(
278
+ segmentation_resize_eval_groundtruth=True,
279
+ segmentation_groundtruth_padded_size=[640, 640]),
280
+ drop_remainder=False),
281
+ panoptic_quality_evaluator=PanopticQualityEvaluator(
282
+ num_categories=num_panoptic_categories,
283
+ ignored_label=0,
284
+ is_thing=is_thing,
285
+ rescale_predictions=True)),
286
+ trainer=cfg.TrainerConfig(
287
+ train_steps=200000,
288
+ validation_steps=validation_steps,
289
+ validation_interval=1000,
290
+ steps_per_loop=1000,
291
+ summary_interval=1000,
292
+ checkpoint_interval=1000,
293
+ optimizer_config=optimization.OptimizationConfig({
294
+ 'optimizer': {
295
+ 'type': 'sgd',
296
+ 'sgd': {
297
+ 'momentum': 0.9
298
+ }
299
+ },
300
+ 'learning_rate': {
301
+ 'type': 'cosine',
302
+ 'cosine': {
303
+ 'initial_learning_rate': 0.08,
304
+ 'decay_steps': 200000,
305
+ }
306
+ },
307
+ 'warmup': {
308
+ 'type': 'linear',
309
+ 'linear': {
310
+ 'warmup_steps': 2000,
311
+ 'warmup_learning_rate': 0.0
312
+ }
313
+ }
314
+ })),
315
+ restrictions=[
316
+ 'task.train_data.is_training != None',
317
+ 'task.validation_data.is_training != None'
318
+ ])
319
+ return config
320
+
321
+ ###################################
322
+ ###### SEMANTIC SEGMENTATION ######
323
+ ###################################
324
+
325
+
326
+ @dataclasses.dataclass
327
+ class SegDataConfig(cfg.DataConfig):
328
+ """Input config for training."""
329
+ output_size: List[int] = dataclasses.field(default_factory=list)
330
+ # If crop_size is specified, image will be resized first to
331
+ # output_size, then crop of size crop_size will be cropped.
332
+ crop_size: List[int] = dataclasses.field(default_factory=list)
333
+ input_path: str = ''
334
+ global_batch_size: int = 0
335
+ is_training: bool = True
336
+ dtype: str = 'float32'
337
+ shuffle_buffer_size: int = 1000
338
+ prefetch_buffer_size: int = 8
339
+ cycle_length: int = 10
340
+ # If resize_eval_groundtruth is set to False, original image sizes are used
341
+ # for eval. In that case, groundtruth_padded_size has to be specified too to
342
+ # allow for batching the variable input sizes of images.
343
+ resize_eval_groundtruth: bool = True
344
+ groundtruth_padded_size: List[int] = dataclasses.field(default_factory=list)
345
+ aug_scale_min: float = 1.0
346
+ aug_scale_max: float = 1.0
347
+ aug_rand_hflip: bool = True
348
+ preserve_aspect_ratio: bool = True
349
+ aug_policy: Optional[str] = None
350
+ drop_remainder: bool = True
351
+ file_type: str = 'tfrecord'
352
+ gaussaian_iou: float = 0.7
353
+ max_num_stuff_centers: int = 3
354
+ max_num_instances: int = 100
355
+ aug_type: common.Augmentation = dataclasses.field(
356
+ default_factory=common.Augmentation)
357
+
358
+
359
+ @dataclasses.dataclass
360
+ class MaskConverSegTask(cfg.TaskConfig):
361
+ """MaskConverTask task config."""
362
+ model: MaskConver = dataclasses.field(default_factory=MaskConver)
363
+ train_data: DataConfig = dataclasses.field(
364
+ default_factory=lambda: SegDataConfig(is_training=True)
365
+ )
366
+ # pylint: disable=g-long-lambda
367
+ validation_data: DataConfig = dataclasses.field(
368
+ default_factory=lambda: SegDataConfig(
369
+ is_training=False, drop_remainder=False
370
+ )
371
+ )
372
+ # pylint: enable=g-long-lambda
373
+ losses: Losses = dataclasses.field(default_factory=Losses)
374
+ init_checkpoint: Optional[str] = None
375
+
376
+ init_checkpoint_modules: Optional[List[str]] = dataclasses.field(
377
+ default_factory=list)
378
+
379
+
380
+ @exp_factory.register_config_factory('maskconver_seg_pascal')
381
+ def maskconver_seg_pascal() -> cfg.ExperimentConfig:
382
+ """COCO panoptic segmentation with MaskConver."""
383
+ train_batch_size = 64
384
+ eval_batch_size = 8
385
+ validation_steps = _PASCAL_VAL_EXAMPLES // eval_batch_size
386
+
387
+ config = cfg.ExperimentConfig(
388
+ runtime=cfg.RuntimeConfig(
389
+ mixed_precision_dtype='float32', enable_xla=False),
390
+ task=MaskConverSegTask(
391
+ init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080', # pylint: disable=line-too-long
392
+ init_checkpoint_modules=['backbone'],
393
+ model=MaskConver(
394
+ num_classes=21, num_thing_classes=91, input_size=[512, 512, 3],
395
+ padded_output_size=[512, 512]),
396
+ losses=Losses(l2_weight_decay=0.00004),
397
+ train_data=SegDataConfig(
398
+ input_path=os.path.join(_PASCAL_INPUT_PATH_BASE, 'train_aug*'),
399
+ output_size=[512, 512],
400
+ is_training=True,
401
+ global_batch_size=train_batch_size,
402
+ aug_scale_min=0.5,
403
+ aug_scale_max=2.0,
404
+ aug_type=common.Augmentation(
405
+ type='autoaug',
406
+ autoaug=common.AutoAugment(
407
+ augmentation_name='panoptic_deeplab_policy'))),
408
+ validation_data=SegDataConfig(
409
+ input_path=os.path.join(_PASCAL_INPUT_PATH_BASE, 'val*'),
410
+ output_size=[512, 512],
411
+ is_training=False,
412
+ global_batch_size=eval_batch_size,
413
+ resize_eval_groundtruth=False,
414
+ groundtruth_padded_size=[512, 512],
415
+ drop_remainder=False)),
416
+ trainer=cfg.TrainerConfig(
417
+ train_steps=200000,
418
+ validation_steps=validation_steps,
419
+ validation_interval=1000,
420
+ steps_per_loop=1000,
421
+ summary_interval=1000,
422
+ checkpoint_interval=1000,
423
+ optimizer_config=optimization.OptimizationConfig({
424
+ 'optimizer': {
425
+ 'type': 'sgd',
426
+ 'sgd': {
427
+ 'momentum': 0.9
428
+ }
429
+ },
430
+ 'learning_rate': {
431
+ 'type': 'cosine',
432
+ 'cosine': {
433
+ 'initial_learning_rate': 0.08,
434
+ 'decay_steps': 200000,
435
+ }
436
+ },
437
+ 'warmup': {
438
+ 'type': 'linear',
439
+ 'linear': {
440
+ 'warmup_steps': 2000,
441
+ 'warmup_learning_rate': 0.0
442
+ }
443
+ }
444
+ })),
445
+ restrictions=[
446
+ 'task.train_data.is_training != None',
447
+ 'task.validation_data.is_training != None'
448
+ ])
449
+ return config
450
+
451
+
452
+ @exp_factory.register_config_factory('maskconver_seg_cityscapes')
453
+ def maskconver_seg_cityscapes() -> cfg.ExperimentConfig:
454
+ """Cityscapes semantic segmentation with MaskConver."""
455
+ train_batch_size = 32
456
+ eval_batch_size = 8
457
+ validation_steps = _CITYSCAPES_VAL_EXAMPLES // eval_batch_size
458
+
459
+ config = cfg.ExperimentConfig(
460
+ runtime=cfg.RuntimeConfig(
461
+ mixed_precision_dtype='float32', enable_xla=False),
462
+ task=MaskConverSegTask(
463
+ init_checkpoint='maskconver_seg_mnv3p5rf_coco_200k/43437096', # pylint: disable=line-too-long
464
+ init_checkpoint_modules=['backbone'],
465
+ model=MaskConver(
466
+ num_classes=19, input_size=[None, None, 3],
467
+ padded_output_size=[1024, 2048]),
468
+ losses=Losses(l2_weight_decay=0.00004),
469
+ train_data=SegDataConfig(
470
+ input_path=os.path.join(_CITYSCAPES_INPUT_PATH_BASE,
471
+ 'train_fine*'),
472
+ output_size=[1024, 2048],
473
+ crop_size=[512, 1024],
474
+ is_training=True,
475
+ global_batch_size=train_batch_size,
476
+ aug_scale_min=0.5,
477
+ aug_scale_max=2.0,
478
+ aug_type=common.Augmentation(
479
+ type='autoaug',
480
+ autoaug=common.AutoAugment(
481
+ augmentation_name='panoptic_deeplab_policy'))),
482
+ validation_data=SegDataConfig(
483
+ input_path=os.path.join(_CITYSCAPES_INPUT_PATH_BASE, 'val_fine*'),
484
+ output_size=[1024, 2048],
485
+ is_training=False,
486
+ global_batch_size=eval_batch_size,
487
+ resize_eval_groundtruth=False,
488
+ groundtruth_padded_size=[1024, 2048],
489
+ drop_remainder=False)),
490
+ trainer=cfg.TrainerConfig(
491
+ train_steps=100000,
492
+ validation_steps=validation_steps,
493
+ validation_interval=185,
494
+ steps_per_loop=185,
495
+ summary_interval=185,
496
+ checkpoint_interval=185,
497
+ optimizer_config=optimization.OptimizationConfig({
498
+ 'optimizer': {
499
+ 'type': 'sgd',
500
+ 'sgd': {
501
+ 'momentum': 0.9
502
+ }
503
+ },
504
+ 'learning_rate': {
505
+ 'type': 'polynomial',
506
+ 'polynomial': {
507
+ 'initial_learning_rate': 0.01,
508
+ 'decay_steps': 100000,
509
+ }
510
+ },
511
+ 'warmup': {
512
+ 'type': 'linear',
513
+ 'linear': {
514
+ 'warmup_steps': 925,
515
+ 'warmup_learning_rate': 0.0
516
+ }
517
+ }
518
+ })),
519
+ restrictions=[
520
+ 'task.train_data.is_training != None',
521
+ 'task.validation_data.is_training != None'
522
+ ])
523
+ return config