tf-models-nightly 2.17.0.dev20240523__py2.py3-none-any.whl → 2.17.0.dev20240525__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,13 +14,13 @@
14
14
 
15
15
  """Backbones package definition."""
16
16
 
17
- from official.projects.maskconver.modeling.resnet_unet import ResNetUNet
18
17
  from official.vision.modeling.backbones.efficientnet import EfficientNet
19
18
  from official.vision.modeling.backbones.mobiledet import MobileDet
20
19
  from official.vision.modeling.backbones.mobilenet import MobileNet
21
20
  from official.vision.modeling.backbones.resnet import ResNet
22
21
  from official.vision.modeling.backbones.resnet_3d import ResNet3D
23
22
  from official.vision.modeling.backbones.resnet_deeplab import DilatedResNet
23
+ from official.vision.modeling.backbones.resnet_unet import ResNetUNet
24
24
  from official.vision.modeling.backbones.revnet import RevNet
25
25
  from official.vision.modeling.backbones.spinenet import SpineNet
26
26
  from official.vision.modeling.backbones.spinenet_mobile import SpineNetMobile
@@ -0,0 +1,588 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Contains definitions of ResNet UNet style."""
16
+
17
+ from typing import Callable, Optional, List
18
+
19
+ # Import libraries
20
+ import tensorflow as tf, tf_keras
21
+
22
+ from official.modeling import hyperparams
23
+ from official.modeling import tf_utils
24
+ from official.vision.modeling.backbones import factory
25
+ from official.vision.modeling.layers import nn_blocks
26
+ from official.vision.modeling.layers import nn_layers
27
+ from official.vision.ops import spatial_transform_ops
28
+
29
+ layers = tf_keras.layers
30
+
31
+
32
+ # Specifications for different ResNet variants.
33
+ # Each entry specifies block configurations of the particular ResNet variant.
34
+ # Each element in the block configuration is in the following format:
35
+ # (block_fn, num_filters, block_repeats)
36
+ RESNET_SPECS = {
37
+ 10: [
38
+ ('residual', 64, 1),
39
+ ('residual', 128, 1),
40
+ ('residual', 256, 1),
41
+ ('residual', 512, 1),
42
+ ],
43
+ 18: [
44
+ ('residual', 64, 2),
45
+ ('residual', 128, 2),
46
+ ('residual', 256, 2),
47
+ ('residual', 512, 2),
48
+ ],
49
+ 34: [
50
+ ('residual', 64, 3),
51
+ ('residual', 128, 4),
52
+ ('residual', 256, 6),
53
+ ('residual', 512, 3),
54
+ ],
55
+ 50: [
56
+ ('bottleneck', 64, 3),
57
+ ('bottleneck', 128, 4),
58
+ ('bottleneck', 256, 6),
59
+ ('bottleneck', 512, 3),
60
+ ],
61
+ 101: [
62
+ ('bottleneck', 64, 3),
63
+ ('bottleneck', 128, 4),
64
+ ('bottleneck', 256, 23),
65
+ ('bottleneck', 512, 3),
66
+ ],
67
+ 152: [
68
+ ('bottleneck', 64, 3),
69
+ ('bottleneck', 128, 8),
70
+ ('bottleneck', 256, 36),
71
+ ('bottleneck', 512, 3),
72
+ ],
73
+ 200: [
74
+ ('bottleneck', 64, 3),
75
+ ('bottleneck', 128, 24),
76
+ ('bottleneck', 256, 36),
77
+ ('bottleneck', 512, 3),
78
+ ],
79
+ 270: [
80
+ ('bottleneck', 64, 4),
81
+ ('bottleneck', 128, 29),
82
+ ('bottleneck', 256, 53),
83
+ ('bottleneck', 512, 4),
84
+ ],
85
+ 350: [
86
+ ('bottleneck', 64, 4),
87
+ ('bottleneck', 128, 36),
88
+ ('bottleneck', 256, 72),
89
+ ('bottleneck', 512, 4),
90
+ ],
91
+ 420: [
92
+ ('bottleneck', 64, 4),
93
+ ('bottleneck', 128, 44),
94
+ ('bottleneck', 256, 87),
95
+ ('bottleneck', 512, 4),
96
+ ],
97
+ }
98
+
99
+
100
+ def conv_2d(*args, **kwargs):
101
+ return tf_keras.layers.Conv2D(
102
+ kernel_initializer=tf_keras.initializers.truncated_normal(stddev=0.02),
103
+ bias_initializer='zeros',
104
+ *args,
105
+ **kwargs,
106
+ )
107
+
108
+
109
+ def dense(*args, **kwargs):
110
+ return tf_keras.layers.Dense(
111
+ kernel_initializer=tf_keras.initializers.truncated_normal(stddev=0.02),
112
+ bias_initializer='zeros',
113
+ *args,
114
+ **kwargs,
115
+ )
116
+
117
+
118
+ class ConvNeXtBlock(tf_keras.Model):
119
+ """ConvNeXt block."""
120
+
121
+ def __init__(
122
+ self,
123
+ dim,
124
+ drop_rate=0.0,
125
+ layer_scale_init_value=1e-6,
126
+ norm_fn=None,
127
+ kernel_size=7,
128
+ se_ratio=0.0625,
129
+ **kwargs,
130
+ ):
131
+ super().__init__(**kwargs)
132
+ self.depthwise_conv = tf_keras.layers.DepthwiseConv2D(
133
+ kernel_size=kernel_size,
134
+ padding='same',
135
+ depthwise_initializer=tf_keras.initializers.truncated_normal(
136
+ stddev=0.02
137
+ ),
138
+ bias_initializer='zeros',
139
+ )
140
+ if norm_fn:
141
+ self.norm = norm_fn()
142
+ else:
143
+ self.norm = tf_keras.layers.LayerNormalization(epsilon=1e-6)
144
+ self.pointwise_conv1 = dense(4 * dim)
145
+ self.act = tf_keras.layers.Activation('gelu')
146
+ self.pointwise_conv2 = dense(dim)
147
+ if layer_scale_init_value > 0:
148
+ self.gamma = self.add_weight(
149
+ name='layer_scale',
150
+ shape=(1, 1, 1, dim),
151
+ initializer=tf_keras.initializers.Constant(layer_scale_init_value))
152
+ else:
153
+ self.gamma = None
154
+
155
+ self.drop_path = nn_layers.StochasticDepth(
156
+ drop_rate
157
+ ) if drop_rate > 0 else tf_keras.layers.Activation('linear')
158
+ if se_ratio and se_ratio > 0 and se_ratio <= 1:
159
+ self._squeeze_excitation = nn_layers.SqueezeExcitation(
160
+ activation='gelu',
161
+ in_filters=4 * dim,
162
+ out_filters=4 * dim,
163
+ se_ratio=se_ratio,)
164
+ else:
165
+ self._squeeze_excitation = None
166
+
167
+ def call(self, x, training=None):
168
+ inputs = x
169
+
170
+ x = self.depthwise_conv(x)
171
+ x = self.norm(x)
172
+ x = self.pointwise_conv1(x)
173
+ x = self.act(x)
174
+ if self._squeeze_excitation:
175
+ x = self._squeeze_excitation(x)
176
+ x = self.pointwise_conv2(x)
177
+
178
+ if self.gamma is not None:
179
+ x = self.gamma * x
180
+
181
+ x = inputs + self.drop_path(x, training=training)
182
+ return x
183
+
184
+
185
+ @tf_keras.utils.register_keras_serializable(package='Vision')
186
+ class ResNetUNet(tf_keras.Model):
187
+ """Creates ResNet and ResNet-RS family models.
188
+
189
+ This implements the Deep Residual Network from:
190
+ Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
191
+ Deep Residual Learning for Image Recognition.
192
+ (https://arxiv.org/pdf/1512.03385) and
193
+ Irwan Bello, William Fedus, Xianzhi Du, Ekin D. Cubuk, Aravind Srinivas,
194
+ Tsung-Yi Lin, Jonathon Shlens, Barret Zoph.
195
+ Revisiting ResNets: Improved Training and Scaling Strategies.
196
+ (https://arxiv.org/abs/2103.07579).
197
+ """
198
+
199
+ def __init__(
200
+ self,
201
+ model_id: int,
202
+ input_specs: tf_keras.layers.InputSpec = layers.InputSpec(
203
+ shape=[None, None, None, 3]),
204
+ depth_multiplier: float = 1.0,
205
+ stem_type: str = 'v0',
206
+ resnetd_shortcut: bool = False,
207
+ replace_stem_max_pool: bool = False,
208
+ se_ratio: Optional[float] = None,
209
+ init_stochastic_depth_rate: float = 0.0,
210
+ upsample_repeats: Optional[List[int]] = None,
211
+ upsample_filters: Optional[List[int]] = None,
212
+ upsample_kernel_sizes: Optional[List[int]] = None,
213
+ scale_stem: bool = True,
214
+ activation: str = 'relu',
215
+ use_sync_bn: bool = False,
216
+ norm_momentum: float = 0.99,
217
+ norm_epsilon: float = 0.001,
218
+ kernel_initializer: str = 'VarianceScaling',
219
+ kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
220
+ bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
221
+ bn_trainable: bool = True,
222
+ classification_output: bool = False,
223
+ **kwargs):
224
+ """Initializes a ResNet model.
225
+
226
+ Args:
227
+ model_id: An `int` of the depth of ResNet backbone model.
228
+ input_specs: A `tf_keras.layers.InputSpec` of the input tensor.
229
+ depth_multiplier: A `float` of the depth multiplier to uniformaly scale up
230
+ all layers in channel size. This argument is also referred to as
231
+ `width_multiplier` in (https://arxiv.org/abs/2103.07579).
232
+ stem_type: A `str` of stem type of ResNet. Default to `v0`. If set to
233
+ `v1`, use ResNet-D type stem (https://arxiv.org/abs/1812.01187).
234
+ resnetd_shortcut: A `bool` of whether to use ResNet-D shortcut in
235
+ downsampling blocks.
236
+ replace_stem_max_pool: A `bool` of whether to replace the max pool in stem
237
+ with a stride-2 conv,
238
+ se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
239
+ init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
240
+ upsample_repeats: A `list` for upsample repeats of the ConvNext blocks for
241
+ each level starting from L5, then L4, and so on.
242
+ upsample_filters: A `list` for the upsample filter sizes for the ConvNext
243
+ blocks for each level.
244
+ upsample_kernel_sizes: A `list` for upsample kernel sizes for the ConvNext
245
+ blocks for each level.
246
+ scale_stem: A `bool` of whether to scale stem layers.
247
+ activation: A `str` name of the activation function.
248
+ use_sync_bn: If True, use synchronized batch normalization.
249
+ norm_momentum: A `float` of normalization momentum for the moving average.
250
+ norm_epsilon: A small `float` added to variance to avoid dividing by zero.
251
+ kernel_initializer: A str for kernel initializer of convolutional layers.
252
+ kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
253
+ Conv2D. Default to None.
254
+ bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
255
+ Default to None.
256
+ bn_trainable: A `bool` that indicates whether batch norm layers should be
257
+ trainable. Default to True.
258
+ classification_output: A `bool` to output the correct level needed for
259
+ classification (L3), only set to True for pretraining.
260
+ **kwargs: Additional keyword arguments to be passed.
261
+ """
262
+ self._model_id = model_id
263
+ self._input_specs = input_specs
264
+ self._depth_multiplier = depth_multiplier
265
+ self._stem_type = stem_type
266
+ self._resnetd_shortcut = resnetd_shortcut
267
+ self._replace_stem_max_pool = replace_stem_max_pool
268
+ self._se_ratio = se_ratio
269
+ self._init_stochastic_depth_rate = init_stochastic_depth_rate
270
+ self._upsample_repeats = upsample_repeats
271
+ self._upsample_filters = upsample_filters
272
+ self._upsample_kernel_sizes = upsample_kernel_sizes
273
+ self._scale_stem = scale_stem
274
+ self._use_sync_bn = use_sync_bn
275
+ self._activation = activation
276
+ self._norm_momentum = norm_momentum
277
+ self._norm_epsilon = norm_epsilon
278
+ if use_sync_bn:
279
+ self._norm = layers.experimental.SyncBatchNormalization
280
+ else:
281
+ self._norm = layers.BatchNormalization
282
+ self._kernel_initializer = kernel_initializer
283
+ self._kernel_regularizer = kernel_regularizer
284
+ self._bias_regularizer = bias_regularizer
285
+ self._bn_trainable = bn_trainable
286
+ self._classification_output = classification_output
287
+
288
+ if tf_keras.backend.image_data_format() == 'channels_last':
289
+ bn_axis = -1
290
+ else:
291
+ bn_axis = 1
292
+
293
+ # Build ResNet.
294
+ inputs = tf_keras.Input(shape=input_specs.shape[1:])
295
+
296
+ stem_depth_multiplier = self._depth_multiplier if scale_stem else 1.0
297
+ if stem_type == 'v0':
298
+ x = layers.Conv2D(
299
+ filters=int(64 * stem_depth_multiplier),
300
+ kernel_size=7,
301
+ strides=2,
302
+ use_bias=False,
303
+ padding='same',
304
+ kernel_initializer=self._kernel_initializer,
305
+ kernel_regularizer=self._kernel_regularizer,
306
+ bias_regularizer=self._bias_regularizer)(
307
+ inputs)
308
+ x = self._norm(
309
+ axis=bn_axis,
310
+ momentum=norm_momentum,
311
+ epsilon=norm_epsilon,
312
+ trainable=bn_trainable)(
313
+ x)
314
+ x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
315
+ elif stem_type == 'v1':
316
+ x = layers.Conv2D(
317
+ filters=int(32 * stem_depth_multiplier),
318
+ kernel_size=3,
319
+ strides=2,
320
+ use_bias=False,
321
+ padding='same',
322
+ kernel_initializer=self._kernel_initializer,
323
+ kernel_regularizer=self._kernel_regularizer,
324
+ bias_regularizer=self._bias_regularizer)(
325
+ inputs)
326
+ x = self._norm(
327
+ axis=bn_axis,
328
+ momentum=norm_momentum,
329
+ epsilon=norm_epsilon,
330
+ trainable=bn_trainable)(
331
+ x)
332
+ x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
333
+ x = layers.Conv2D(
334
+ filters=int(32 * stem_depth_multiplier),
335
+ kernel_size=3,
336
+ strides=1,
337
+ use_bias=False,
338
+ padding='same',
339
+ kernel_initializer=self._kernel_initializer,
340
+ kernel_regularizer=self._kernel_regularizer,
341
+ bias_regularizer=self._bias_regularizer)(
342
+ x)
343
+ x = self._norm(
344
+ axis=bn_axis,
345
+ momentum=norm_momentum,
346
+ epsilon=norm_epsilon,
347
+ trainable=bn_trainable)(
348
+ x)
349
+ x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
350
+ x = layers.Conv2D(
351
+ filters=int(64 * stem_depth_multiplier),
352
+ kernel_size=3,
353
+ strides=1,
354
+ use_bias=False,
355
+ padding='same',
356
+ kernel_initializer=self._kernel_initializer,
357
+ kernel_regularizer=self._kernel_regularizer,
358
+ bias_regularizer=self._bias_regularizer)(
359
+ x)
360
+ x = self._norm(
361
+ axis=bn_axis,
362
+ momentum=norm_momentum,
363
+ epsilon=norm_epsilon,
364
+ trainable=bn_trainable)(
365
+ x)
366
+ x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
367
+ else:
368
+ raise ValueError('Stem type {} not supported.'.format(stem_type))
369
+
370
+ if replace_stem_max_pool:
371
+ x = layers.Conv2D(
372
+ filters=int(64 * self._depth_multiplier),
373
+ kernel_size=3,
374
+ strides=2,
375
+ use_bias=False,
376
+ padding='same',
377
+ kernel_initializer=self._kernel_initializer,
378
+ kernel_regularizer=self._kernel_regularizer,
379
+ bias_regularizer=self._bias_regularizer)(
380
+ x)
381
+ x = self._norm(
382
+ axis=bn_axis,
383
+ momentum=norm_momentum,
384
+ epsilon=norm_epsilon,
385
+ trainable=bn_trainable)(
386
+ x)
387
+ x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
388
+ else:
389
+ x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
390
+
391
+ endpoints = {}
392
+ for i, spec in enumerate(RESNET_SPECS[model_id]):
393
+ if spec[0] == 'residual':
394
+ block_fn = nn_blocks.ResidualBlock
395
+ elif spec[0] == 'bottleneck':
396
+ block_fn = nn_blocks.BottleneckBlock
397
+ else:
398
+ raise ValueError('Block fn `{}` is not supported.'.format(spec[0]))
399
+ x = self._block_group(
400
+ inputs=x,
401
+ filters=int(spec[1] * self._depth_multiplier),
402
+ strides=(1 if i == 0 else 2),
403
+ block_fn=block_fn,
404
+ block_repeats=spec[2],
405
+ stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
406
+ self._init_stochastic_depth_rate, i + 2, 8),
407
+ name='block_group_l{}'.format(i + 2))
408
+ endpoints[str(i + 2)] = x
409
+
410
+ norm_layer = lambda: tf_keras.layers.LayerNormalization(epsilon=1e-6)
411
+ for i in range(len(upsample_filters)):
412
+ backbone_feature = layers.Conv2D(
413
+ filters=int(upsample_filters[i] * stem_depth_multiplier),
414
+ kernel_size=1,
415
+ strides=1,
416
+ use_bias=False,
417
+ padding='same',
418
+ kernel_initializer=self._kernel_initializer,
419
+ kernel_regularizer=self._kernel_regularizer,
420
+ bias_regularizer=self._bias_regularizer)(
421
+ endpoints['{}'.format(5 - i)])
422
+ backbone_feature = norm_layer()(backbone_feature)
423
+
424
+ if i == 0:
425
+ x = backbone_feature
426
+ else:
427
+ x = layers.Conv2D(
428
+ filters=int(upsample_filters[i] * stem_depth_multiplier),
429
+ kernel_size=1,
430
+ strides=1,
431
+ use_bias=False,
432
+ padding='same',
433
+ kernel_initializer=self._kernel_initializer,
434
+ kernel_regularizer=self._kernel_regularizer,
435
+ bias_regularizer=self._bias_regularizer)(
436
+ x)
437
+ x = norm_layer()(x)
438
+
439
+ x = spatial_transform_ops.nearest_upsampling(
440
+ x, scale=2,
441
+ use_keras_layer=True) + backbone_feature
442
+
443
+ for _ in range(upsample_repeats[i]):
444
+ x = ConvNeXtBlock(
445
+ int(upsample_filters[i] * self._depth_multiplier),
446
+ drop_rate=nn_layers.get_stochastic_depth_rate(
447
+ self._init_stochastic_depth_rate, i + 6, 8),
448
+ kernel_size=upsample_kernel_sizes[i])(x)
449
+ x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
450
+ endpoints[str(5 - i)] = x
451
+
452
+ if classification_output:
453
+ endpoints['6'] = endpoints[str(5 - len(upsample_repeats) + 1)]
454
+
455
+ self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
456
+
457
+ super().__init__(inputs=inputs, outputs=endpoints, **kwargs)
458
+
459
+ def _block_group(self,
460
+ inputs: tf.Tensor,
461
+ filters: int,
462
+ strides: int,
463
+ block_fn: Callable[..., tf_keras.layers.Layer],
464
+ block_repeats: int = 1,
465
+ stochastic_depth_drop_rate: float = 0.0,
466
+ name: str = 'block_group'):
467
+ """Creates one group of blocks for the ResNet model.
468
+
469
+ Args:
470
+ inputs: A `tf.Tensor` of size `[batch, channels, height, width]`.
471
+ filters: An `int` number of filters for the first convolution of the
472
+ layer.
473
+ strides: An `int` stride to use for the first convolution of the layer.
474
+ If greater than 1, this layer will downsample the input.
475
+ block_fn: The type of block group. Either `nn_blocks.ResidualBlock` or
476
+ `nn_blocks.BottleneckBlock`.
477
+ block_repeats: An `int` number of blocks contained in the layer.
478
+ stochastic_depth_drop_rate: A `float` of drop rate of the current block
479
+ group.
480
+ name: A `str` name for the block.
481
+
482
+ Returns:
483
+ The output `tf.Tensor` of the block layer.
484
+ """
485
+ x = block_fn(
486
+ filters=filters,
487
+ strides=strides,
488
+ use_projection=True,
489
+ stochastic_depth_drop_rate=stochastic_depth_drop_rate,
490
+ se_ratio=self._se_ratio,
491
+ resnetd_shortcut=self._resnetd_shortcut,
492
+ kernel_initializer=self._kernel_initializer,
493
+ kernel_regularizer=self._kernel_regularizer,
494
+ bias_regularizer=self._bias_regularizer,
495
+ activation=self._activation,
496
+ use_sync_bn=self._use_sync_bn,
497
+ norm_momentum=self._norm_momentum,
498
+ norm_epsilon=self._norm_epsilon,
499
+ bn_trainable=self._bn_trainable)(
500
+ inputs)
501
+
502
+ for _ in range(1, block_repeats):
503
+ x = block_fn(
504
+ filters=filters,
505
+ strides=1,
506
+ use_projection=False,
507
+ stochastic_depth_drop_rate=stochastic_depth_drop_rate,
508
+ se_ratio=self._se_ratio,
509
+ resnetd_shortcut=self._resnetd_shortcut,
510
+ kernel_initializer=self._kernel_initializer,
511
+ kernel_regularizer=self._kernel_regularizer,
512
+ bias_regularizer=self._bias_regularizer,
513
+ activation=self._activation,
514
+ use_sync_bn=self._use_sync_bn,
515
+ norm_momentum=self._norm_momentum,
516
+ norm_epsilon=self._norm_epsilon,
517
+ bn_trainable=self._bn_trainable)(
518
+ x)
519
+
520
+ return tf_keras.layers.Activation('linear', name=name)(x)
521
+
522
+ def get_config(self):
523
+ config_dict = {
524
+ 'model_id': self._model_id,
525
+ 'depth_multiplier': self._depth_multiplier,
526
+ 'stem_type': self._stem_type,
527
+ 'resnetd_shortcut': self._resnetd_shortcut,
528
+ 'replace_stem_max_pool': self._replace_stem_max_pool,
529
+ 'activation': self._activation,
530
+ 'se_ratio': self._se_ratio,
531
+ 'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
532
+ 'upsample_repeats': self._upsample_repeats,
533
+ 'upsample_filters': self._upsample_filters,
534
+ 'upsample_kernel_sizes': self._upsample_kernel_sizes,
535
+ 'scale_stem': self._scale_stem,
536
+ 'use_sync_bn': self._use_sync_bn,
537
+ 'norm_momentum': self._norm_momentum,
538
+ 'norm_epsilon': self._norm_epsilon,
539
+ 'kernel_initializer': self._kernel_initializer,
540
+ 'kernel_regularizer': self._kernel_regularizer,
541
+ 'bias_regularizer': self._bias_regularizer,
542
+ 'bn_trainable': self._bn_trainable,
543
+ 'classification_output': self._classification_output,
544
+ }
545
+ return config_dict
546
+
547
+ @classmethod
548
+ def from_config(cls, config, custom_objects=None):
549
+ return cls(**config)
550
+
551
+ @property
552
+ def output_specs(self):
553
+ """A dict of {level: TensorShape} pairs for the model output."""
554
+ return self._output_specs
555
+
556
+
557
+ @factory.register_backbone_builder('resnet_unet')
558
+ def build_resnet(
559
+ input_specs: tf_keras.layers.InputSpec,
560
+ backbone_config: hyperparams.Config,
561
+ norm_activation_config: hyperparams.Config,
562
+ l2_regularizer: tf_keras.regularizers.Regularizer = None) -> tf_keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
563
+ """Builds ResNet ConvNext Unet backbone from a config."""
564
+ backbone_type = backbone_config.type
565
+ backbone_cfg = backbone_config.get()
566
+ assert backbone_type == 'resnet_unet', (
567
+ f'Inconsistent backbone type {backbone_type}')
568
+
569
+ return ResNetUNet(
570
+ model_id=backbone_cfg.model_id,
571
+ input_specs=input_specs,
572
+ depth_multiplier=backbone_cfg.depth_multiplier,
573
+ upsample_repeats=backbone_cfg.upsample_repeats,
574
+ upsample_filters=backbone_cfg.upsample_filters,
575
+ upsample_kernel_sizes=backbone_cfg.upsample_kernel_sizes,
576
+ stem_type=backbone_cfg.stem_type,
577
+ resnetd_shortcut=backbone_cfg.resnetd_shortcut,
578
+ replace_stem_max_pool=backbone_cfg.replace_stem_max_pool,
579
+ se_ratio=backbone_cfg.se_ratio,
580
+ init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
581
+ scale_stem=backbone_cfg.scale_stem,
582
+ activation=norm_activation_config.activation,
583
+ use_sync_bn=norm_activation_config.use_sync_bn,
584
+ norm_momentum=norm_activation_config.norm_momentum,
585
+ norm_epsilon=norm_activation_config.norm_epsilon,
586
+ kernel_regularizer=l2_regularizer,
587
+ bn_trainable=backbone_cfg.bn_trainable,
588
+ classification_output=backbone_cfg.classification_output)
@@ -0,0 +1,80 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for resnet."""
16
+
17
+ # Import libraries
18
+
19
+ from absl.testing import parameterized
20
+ import tensorflow as tf, tf_keras
21
+
22
+ from official.vision.modeling.backbones import resnet_unet
23
+
24
+
25
+ class ResNetUNetTest(parameterized.TestCase, tf.test.TestCase):
26
+
27
+ @parameterized.parameters(
28
+ (128, 50, 4),
29
+ )
30
+ def test_network_creation(self, input_size, model_id, endpoint_filter_scale):
31
+ """Test creation of ResNet family models."""
32
+ resnet_unet_params = {
33
+ 50: 55_205_440,
34
+ }
35
+ tf_keras.backend.set_image_data_format('channels_last')
36
+
37
+ network = resnet_unet.ResNetUNet(
38
+ model_id=model_id,
39
+ upsample_repeats=[18, 1, 1],
40
+ upsample_filters=[384, 384, 384],
41
+ upsample_kernel_sizes=[7, 7, 7],
42
+ )
43
+ self.assertEqual(network.count_params(), resnet_unet_params[model_id])
44
+
45
+ inputs = tf_keras.Input(shape=(input_size, input_size, 3), batch_size=1)
46
+ endpoints = network(inputs)
47
+ print(endpoints)
48
+
49
+ self.assertAllEqual(
50
+ [1, input_size / 2**2, input_size / 2**2, 64 * endpoint_filter_scale],
51
+ endpoints['2'].shape.as_list(),
52
+ )
53
+ for i in range(3, 6):
54
+ self.assertAllEqual(
55
+ [1, input_size / 2**i, input_size / 2**i, 384],
56
+ endpoints[f'{i}'].shape.as_list(),
57
+ )
58
+
59
+ def test_serialize_deserialize(self):
60
+ # Create a network object that sets all of its config options.
61
+ kwargs = dict(
62
+ model_id=50,
63
+ upsample_repeats=[18, 1, 1],
64
+ upsample_filters=[384, 384, 384],
65
+ upsample_kernel_sizes=[7, 7, 7],
66
+ )
67
+ network = resnet_unet.ResNetUNet(**kwargs)
68
+
69
+ # Create another network object from the first object's config.
70
+ new_network = resnet_unet.ResNetUNet.from_config(network.get_config())
71
+
72
+ # Validate that the config can be forced to JSON.
73
+ _ = new_network.to_json()
74
+
75
+ # If the serialization was successful, the new config should match the old.
76
+ self.assertAllEqual(network.get_config(), new_network.get_config())
77
+
78
+
79
+ if __name__ == '__main__':
80
+ tf.test.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.17.0.dev20240523
3
+ Version: 2.17.0.dev20240525
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -1042,7 +1042,7 @@ official/vision/modeling/segmentation_model.py,sha256=BrH3w0yZ60uwyEGP08iYt0rAWV
1042
1042
  official/vision/modeling/segmentation_model_test.py,sha256=fimmpt0pFMcX214Vuf0C4NcRdWLLB9F2IJeiuTFUO54,2817
1043
1043
  official/vision/modeling/video_classification_model.py,sha256=RKrT22D5yW435S9AyYlgOFCIi0knbmZFvgwh_5JubPI,4713
1044
1044
  official/vision/modeling/video_classification_model_test.py,sha256=HL0hVwOFJCxSjDNFVK42MwzuoDx5r4rvraEbBXVkX8c,3271
1045
- official/vision/modeling/backbones/__init__.py,sha256=21iuq1HPa3KcN-6ljjRlUyklRV8Ynyc3Cz47XivHWMw,1402
1045
+ official/vision/modeling/backbones/__init__.py,sha256=sq2NgKWiD2aqGmM-xGPPFIXtEM26gT9mSVrR6qDAFA0,1399
1046
1046
  official/vision/modeling/backbones/efficientnet.py,sha256=j716OGSpkzgpDA4jV-Hk73mAGPULrPwlwHuS_9j40bE,12438
1047
1047
  official/vision/modeling/backbones/efficientnet_test.py,sha256=TYsUieiLrEU5913s3Yxhv-9eaolQK_kfzJm77lyCF0M,3762
1048
1048
  official/vision/modeling/backbones/factory.py,sha256=coJKJpPMhgM9gAc2Q7I5_CuzAaHZNJwPcvGbaUYp8gU,3504
@@ -1057,6 +1057,8 @@ official/vision/modeling/backbones/resnet_3d_test.py,sha256=hhCkW28UXc2peKHGgFl0
1057
1057
  official/vision/modeling/backbones/resnet_deeplab.py,sha256=RCwLTEGwe4XGPGM5-ELfqdHVJRW35-whoFk6U_PK8Sc,15840
1058
1058
  official/vision/modeling/backbones/resnet_deeplab_test.py,sha256=JXklR7mTi7wWVAPsu48wW5IXzJBLw862uzlKPGdMdps,5520
1059
1059
  official/vision/modeling/backbones/resnet_test.py,sha256=rjIFkLsbsUqobectT96jqMygOyCWWih0l9sZbR-Wi9I,5555
1060
+ official/vision/modeling/backbones/resnet_unet.py,sha256=2B4VIX2jPqT3aYpH69VSip6zqSeq4WcMFp0_cvEDFxM,21303
1061
+ official/vision/modeling/backbones/resnet_unet_test.py,sha256=SPwHtodwfnLnGeEcgK-PZYt_xHDoVe98ULfiKeuI13Q,2615
1060
1062
  official/vision/modeling/backbones/revnet.py,sha256=JscmrEGjUbAv_koxei-hJdhvS6kl8L8knBXYofbEJ5A,8797
1061
1063
  official/vision/modeling/backbones/revnet_test.py,sha256=GPR4CCr9CiOekfzGUCxCR02TQibY8nRGLxJbMvT2o6w,3225
1062
1064
  official/vision/modeling/backbones/spinenet.py,sha256=FOCafyw_ZVIY76gzpiY8Al4mXrlanqknBok5PNR7Wfg,21154
@@ -1206,9 +1208,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
1206
1208
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1207
1209
  tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
1208
1210
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1209
- tf_models_nightly-2.17.0.dev20240523.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1210
- tf_models_nightly-2.17.0.dev20240523.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1211
- tf_models_nightly-2.17.0.dev20240523.dist-info/METADATA,sha256=PQ2wFwIVDMEHhGtOiFXie9SuSYTju3AjLVgptVUMUUA,1432
1212
- tf_models_nightly-2.17.0.dev20240523.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1213
- tf_models_nightly-2.17.0.dev20240523.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1214
- tf_models_nightly-2.17.0.dev20240523.dist-info/RECORD,,
1211
+ tf_models_nightly-2.17.0.dev20240525.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1212
+ tf_models_nightly-2.17.0.dev20240525.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1213
+ tf_models_nightly-2.17.0.dev20240525.dist-info/METADATA,sha256=jGd6GcfQ1OLHGxh9QiHoAbOsjmdT8tqkQgkC_r2BV1c,1432
1214
+ tf_models_nightly-2.17.0.dev20240525.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1215
+ tf_models_nightly-2.17.0.dev20240525.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1216
+ tf_models_nightly-2.17.0.dev20240525.dist-info/RECORD,,