keras-hub-nightly 0.16.1.dev202409250340__py3-none-any.whl → 0.16.1.dev202409260340__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. keras_hub/api/layers/__init__.py +3 -0
  2. keras_hub/api/models/__init__.py +16 -0
  3. keras_hub/api/tokenizers/__init__.py +1 -0
  4. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -2
  5. keras_hub/src/models/clip/clip_preprocessor.py +147 -0
  6. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_text_encoder.py +60 -57
  7. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +69 -30
  8. keras_hub/src/models/densenet/__init__.py +6 -0
  9. keras_hub/src/models/densenet/densenet_backbone.py +11 -8
  10. keras_hub/src/models/densenet/densenet_image_classifier.py +27 -4
  11. keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
  12. keras_hub/src/models/densenet/densenet_image_converter.py +23 -0
  13. keras_hub/src/models/densenet/densenet_presets.py +56 -0
  14. keras_hub/src/models/stable_diffusion_3/__init__.py +13 -0
  15. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +93 -0
  16. keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -26
  17. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +630 -0
  18. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +151 -0
  19. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +77 -0
  20. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -7
  21. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +333 -0
  22. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -3
  23. keras_hub/src/models/text_to_image.py +295 -0
  24. keras_hub/src/utils/timm/convert_densenet.py +107 -0
  25. keras_hub/src/utils/timm/preset_loader.py +3 -0
  26. keras_hub/src/version_utils.py +1 -1
  27. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/METADATA +1 -1
  28. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/RECORD +31 -23
  29. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
  30. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
  31. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
  32. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
  33. /keras_hub/src/models/{stable_diffusion_v3 → clip}/__init__.py +0 -0
  34. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/WHEEL +0 -0
  35. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/top_level.txt +0 -0
@@ -1,317 +0,0 @@
1
- # Copyright 2024 The KerasHub Authors
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import math
15
-
16
- from keras import layers
17
- from keras import models
18
- from keras import ops
19
-
20
- from keras_hub.src.utils.keras_utils import gelu_approximate
21
-
22
-
23
- class DismantledBlock(layers.Layer):
24
- def __init__(
25
- self,
26
- num_heads,
27
- hidden_dim,
28
- mlp_ratio=4.0,
29
- use_projection=True,
30
- **kwargs,
31
- ):
32
- super().__init__(**kwargs)
33
- self.num_heads = num_heads
34
- self.hidden_dim = hidden_dim
35
- self.mlp_ratio = mlp_ratio
36
- self.use_projection = use_projection
37
-
38
- head_dim = hidden_dim // num_heads
39
- self.head_dim = head_dim
40
- mlp_hidden_dim = int(hidden_dim * mlp_ratio)
41
- self.mlp_hidden_dim = mlp_hidden_dim
42
- num_modulations = 6 if use_projection else 2
43
- self.num_modulations = num_modulations
44
-
45
- self.adaptive_norm_modulation = models.Sequential(
46
- [
47
- layers.Activation("silu", dtype=self.dtype_policy),
48
- layers.Dense(
49
- num_modulations * hidden_dim, dtype=self.dtype_policy
50
- ),
51
- ],
52
- name="adaptive_norm_modulation",
53
- )
54
- self.norm1 = layers.LayerNormalization(
55
- epsilon=1e-6,
56
- center=False,
57
- scale=False,
58
- dtype=self.dtype_policy,
59
- name="norm1",
60
- )
61
- self.attention_qkv = layers.Dense(
62
- hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
63
- )
64
- if use_projection:
65
- self.attention_proj = layers.Dense(
66
- hidden_dim, dtype=self.dtype_policy, name="attention_proj"
67
- )
68
- self.norm2 = layers.LayerNormalization(
69
- epsilon=1e-6,
70
- center=False,
71
- scale=False,
72
- dtype=self.dtype_policy,
73
- name="norm2",
74
- )
75
- self.mlp = models.Sequential(
76
- [
77
- layers.Dense(
78
- mlp_hidden_dim,
79
- activation=gelu_approximate,
80
- dtype=self.dtype_policy,
81
- ),
82
- layers.Dense(
83
- hidden_dim,
84
- dtype=self.dtype_policy,
85
- ),
86
- ],
87
- name="mlp",
88
- )
89
-
90
- def build(self, inputs_shape, timestep_embedding):
91
- self.adaptive_norm_modulation.build(timestep_embedding)
92
- self.attention_qkv.build(inputs_shape)
93
- self.norm1.build(inputs_shape)
94
- if self.use_projection:
95
- self.attention_proj.build(inputs_shape)
96
- self.norm2.build(inputs_shape)
97
- self.mlp.build(inputs_shape)
98
-
99
- def _modulate(self, inputs, shift, scale):
100
- shift = ops.expand_dims(shift, axis=1)
101
- scale = ops.expand_dims(scale, axis=1)
102
- return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
103
-
104
- def _compute_pre_attention(self, inputs, timestep_embedding, training=None):
105
- batch_size = ops.shape(inputs)[0]
106
- if self.use_projection:
107
- modulation = self.adaptive_norm_modulation(
108
- timestep_embedding, training=training
109
- )
110
- modulation = ops.reshape(
111
- modulation, (batch_size, 6, self.hidden_dim)
112
- )
113
- (
114
- shift_msa,
115
- scale_msa,
116
- gate_msa,
117
- shift_mlp,
118
- scale_mlp,
119
- gate_mlp,
120
- ) = ops.unstack(modulation, 6, axis=1)
121
- qkv = self.attention_qkv(
122
- self._modulate(self.norm1(inputs), shift_msa, scale_msa),
123
- training=training,
124
- )
125
- qkv = ops.reshape(
126
- qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
127
- )
128
- q, k, v = ops.unstack(qkv, 3, axis=2)
129
- return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
130
- else:
131
- modulation = self.adaptive_norm_modulation(
132
- timestep_embedding, training=training
133
- )
134
- modulation = ops.reshape(
135
- modulation, (batch_size, 2, self.hidden_dim)
136
- )
137
- shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1)
138
- qkv = self.attention_qkv(
139
- self._modulate(self.norm1(inputs), shift_msa, scale_msa),
140
- training=training,
141
- )
142
- qkv = ops.reshape(
143
- qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
144
- )
145
- q, k, v = ops.unstack(qkv, 3, axis=2)
146
- return (q, k, v)
147
-
148
- def _compute_post_attention(
149
- self, inputs, inputs_intermediates, training=None
150
- ):
151
- x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates
152
- attn = self.attention_proj(inputs, training=training)
153
- x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn))
154
- x = ops.add(
155
- x,
156
- ops.multiply(
157
- ops.expand_dims(gate_mlp, axis=1),
158
- self.mlp(
159
- self._modulate(self.norm2(x), shift_mlp, scale_mlp),
160
- training=training,
161
- ),
162
- ),
163
- )
164
- return x
165
-
166
- def call(
167
- self,
168
- inputs,
169
- timestep_embedding=None,
170
- inputs_intermediates=None,
171
- pre_attention=True,
172
- training=None,
173
- ):
174
- if pre_attention:
175
- return self._compute_pre_attention(
176
- inputs, timestep_embedding, training=training
177
- )
178
- else:
179
- return self._compute_post_attention(
180
- inputs, inputs_intermediates, training=training
181
- )
182
-
183
- def get_config(self):
184
- config = super().get_config()
185
- config.update(
186
- {
187
- "num_heads": self.num_heads,
188
- "hidden_dim": self.hidden_dim,
189
- "mlp_ratio": self.mlp_ratio,
190
- "use_projection": self.use_projection,
191
- }
192
- )
193
- return config
194
-
195
-
196
- class MMDiTBlock(layers.Layer):
197
- def __init__(
198
- self,
199
- num_heads,
200
- hidden_dim,
201
- mlp_ratio=4.0,
202
- use_context_projection=True,
203
- **kwargs,
204
- ):
205
- super().__init__(**kwargs)
206
- self.num_heads = num_heads
207
- self.hidden_dim = hidden_dim
208
- self.mlp_ratio = mlp_ratio
209
- self.use_context_projection = use_context_projection
210
-
211
- head_dim = hidden_dim // num_heads
212
- self.head_dim = head_dim
213
- self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim)
214
- self._dot_product_equation = "aecd,abcd->acbe"
215
- self._combine_equation = "acbe,aecd->abcd"
216
-
217
- self.x_block = DismantledBlock(
218
- num_heads=num_heads,
219
- hidden_dim=hidden_dim,
220
- mlp_ratio=mlp_ratio,
221
- use_projection=True,
222
- dtype=self.dtype_policy,
223
- name="x_block",
224
- )
225
- self.context_block = DismantledBlock(
226
- num_heads=num_heads,
227
- hidden_dim=hidden_dim,
228
- mlp_ratio=mlp_ratio,
229
- use_projection=use_context_projection,
230
- dtype=self.dtype_policy,
231
- name="context_block",
232
- )
233
-
234
- def build(self, inputs_shape, context_shape, timestep_embedding_shape):
235
- self.x_block.build(inputs_shape, timestep_embedding_shape)
236
- self.context_block.build(context_shape, timestep_embedding_shape)
237
-
238
- def _compute_attention(self, query, key, value):
239
- query = ops.multiply(
240
- query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
241
- )
242
- attention_scores = ops.einsum(self._dot_product_equation, key, query)
243
- attention_scores = ops.nn.softmax(attention_scores, axis=-1)
244
- attention_output = ops.einsum(
245
- self._combine_equation, attention_scores, value
246
- )
247
- batch_size = ops.shape(attention_output)[0]
248
- attention_output = ops.reshape(
249
- attention_output, (batch_size, -1, self.num_heads * self.head_dim)
250
- )
251
- return attention_output
252
-
253
- def call(self, inputs, context, timestep_embedding, training=None):
254
- # Compute pre-attention.
255
- x = inputs
256
- if self.use_context_projection:
257
- context_qkv, context_intermediates = self.context_block(
258
- context,
259
- timestep_embedding=timestep_embedding,
260
- training=training,
261
- )
262
- else:
263
- context_qkv = self.context_block(
264
- context,
265
- timestep_embedding=timestep_embedding,
266
- training=training,
267
- )
268
- context_len = ops.shape(context_qkv[0])[1]
269
- x_qkv, x_intermediates = self.x_block(
270
- x, timestep_embedding=timestep_embedding, training=training
271
- )
272
- q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1)
273
- k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1)
274
- v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1)
275
-
276
- # Compute attention.
277
- attention = self._compute_attention(q, k, v)
278
- context_attention = attention[:, :context_len]
279
- x_attention = attention[:, context_len:]
280
-
281
- # Compute post-attention.
282
- x = self.x_block(
283
- x_attention,
284
- inputs_intermediates=x_intermediates,
285
- pre_attention=False,
286
- training=training,
287
- )
288
- if self.use_context_projection:
289
- context = self.context_block(
290
- context_attention,
291
- inputs_intermediates=context_intermediates,
292
- pre_attention=False,
293
- training=training,
294
- )
295
- return x, context
296
- else:
297
- return x
298
-
299
- def get_config(self):
300
- config = super().get_config()
301
- config.update(
302
- {
303
- "num_heads": self.num_heads,
304
- "hidden_dim": self.hidden_dim,
305
- "mlp_ratio": self.mlp_ratio,
306
- "use_context_projection": self.use_context_projection,
307
- }
308
- )
309
- return config
310
-
311
- def compute_output_shape(
312
- self, inputs_shape, context_shape, timestep_embedding_shape
313
- ):
314
- if self.use_context_projection:
315
- return inputs_shape, context_shape
316
- else:
317
- return inputs_shape
@@ -1,126 +0,0 @@
1
- # Copyright 2024 The KerasHub Authors
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import math
15
-
16
- from keras import layers
17
- from keras import ops
18
-
19
- from keras_hub.src.utils.keras_utils import standardize_data_format
20
-
21
-
22
- class VAEAttention(layers.Layer):
23
- def __init__(self, filters, groups=32, data_format=None, **kwargs):
24
- super().__init__(**kwargs)
25
- self.filters = filters
26
- self.data_format = standardize_data_format(data_format)
27
- gn_axis = -1 if self.data_format == "channels_last" else 1
28
-
29
- self.group_norm = layers.GroupNormalization(
30
- groups=groups,
31
- axis=gn_axis,
32
- epsilon=1e-6,
33
- dtype=self.dtype_policy,
34
- name="group_norm",
35
- )
36
- self.query_conv2d = layers.Conv2D(
37
- filters,
38
- 1,
39
- 1,
40
- data_format=self.data_format,
41
- dtype=self.dtype_policy,
42
- name="query_conv2d",
43
- )
44
- self.key_conv2d = layers.Conv2D(
45
- filters,
46
- 1,
47
- 1,
48
- data_format=self.data_format,
49
- dtype=self.dtype_policy,
50
- name="key_conv2d",
51
- )
52
- self.value_conv2d = layers.Conv2D(
53
- filters,
54
- 1,
55
- 1,
56
- data_format=self.data_format,
57
- dtype=self.dtype_policy,
58
- name="value_conv2d",
59
- )
60
- self.softmax = layers.Softmax(dtype="float32")
61
- self.output_conv2d = layers.Conv2D(
62
- filters,
63
- 1,
64
- 1,
65
- data_format=self.data_format,
66
- dtype=self.dtype_policy,
67
- name="output_conv2d",
68
- )
69
-
70
- self.groups = groups
71
- self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
72
-
73
- def build(self, input_shape):
74
- self.group_norm.build(input_shape)
75
- self.query_conv2d.build(input_shape)
76
- self.key_conv2d.build(input_shape)
77
- self.value_conv2d.build(input_shape)
78
- self.output_conv2d.build(input_shape)
79
-
80
- def call(self, inputs, training=None):
81
- x = self.group_norm(inputs)
82
- query = self.query_conv2d(x)
83
- key = self.key_conv2d(x)
84
- value = self.value_conv2d(x)
85
-
86
- if self.data_format == "channels_first":
87
- query = ops.transpose(query, (0, 2, 3, 1))
88
- key = ops.transpose(key, (0, 2, 3, 1))
89
- value = ops.transpose(value, (0, 2, 3, 1))
90
- shape = ops.shape(inputs)
91
- b = shape[0]
92
- query = ops.reshape(query, (b, -1, self.filters))
93
- key = ops.reshape(key, (b, -1, self.filters))
94
- value = ops.reshape(value, (b, -1, self.filters))
95
-
96
- # Compute attention.
97
- query = ops.multiply(
98
- query, ops.cast(self._inverse_sqrt_filters, query.dtype)
99
- )
100
- # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
101
- attention_scores = ops.einsum("abc,adc->abd", query, key)
102
- attention_scores = ops.cast(
103
- self.softmax(attention_scores), self.compute_dtype
104
- )
105
- # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
106
- attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
107
- x = ops.reshape(attention_output, shape)
108
-
109
- x = self.output_conv2d(x)
110
- if self.data_format == "channels_first":
111
- x = ops.transpose(x, (0, 3, 1, 2))
112
- x = ops.add(x, inputs)
113
- return x
114
-
115
- def get_config(self):
116
- config = super().get_config()
117
- config.update(
118
- {
119
- "filters": self.filters,
120
- "groups": self.groups,
121
- }
122
- )
123
- return config
124
-
125
- def compute_output_shape(self, input_shape):
126
- return input_shape
@@ -1,186 +0,0 @@
1
- # Copyright 2024 The KerasHub Authors
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import keras
15
- from keras import layers
16
-
17
- from keras_hub.src.models.stable_diffusion_v3.vae_attention import VAEAttention
18
- from keras_hub.src.utils.keras_utils import standardize_data_format
19
-
20
-
21
- class VAEImageDecoder(keras.Model):
22
- def __init__(
23
- self,
24
- stackwise_num_filters,
25
- stackwise_num_blocks,
26
- output_channels=3,
27
- latent_shape=(None, None, 16),
28
- data_format=None,
29
- dtype=None,
30
- **kwargs,
31
- ):
32
- data_format = standardize_data_format(data_format)
33
- gn_axis = -1 if data_format == "channels_last" else 1
34
-
35
- # === Functional Model ===
36
- latent_inputs = layers.Input(shape=latent_shape)
37
-
38
- x = layers.Conv2D(
39
- stackwise_num_filters[0],
40
- 3,
41
- 1,
42
- padding="same",
43
- data_format=data_format,
44
- dtype=dtype,
45
- name="input_projection",
46
- )(latent_inputs)
47
- x = apply_resnet_block(
48
- x,
49
- stackwise_num_filters[0],
50
- data_format=data_format,
51
- dtype=dtype,
52
- name="input_block0",
53
- )
54
- x = VAEAttention(
55
- stackwise_num_filters[0],
56
- data_format=data_format,
57
- dtype=dtype,
58
- name="input_attention",
59
- )(x)
60
- x = apply_resnet_block(
61
- x,
62
- stackwise_num_filters[0],
63
- data_format=data_format,
64
- dtype=dtype,
65
- name="input_block1",
66
- )
67
-
68
- # Stacks.
69
- for i, filters in enumerate(stackwise_num_filters):
70
- for j in range(stackwise_num_blocks[i]):
71
- x = apply_resnet_block(
72
- x,
73
- filters,
74
- data_format=data_format,
75
- dtype=dtype,
76
- name=f"block{i}_{j}",
77
- )
78
- if i != len(stackwise_num_filters) - 1:
79
- # No upsamling in the last blcok.
80
- x = layers.UpSampling2D(
81
- 2,
82
- data_format=data_format,
83
- dtype=dtype,
84
- name=f"upsample_{i}",
85
- )(x)
86
- x = layers.Conv2D(
87
- filters,
88
- 3,
89
- 1,
90
- padding="same",
91
- data_format=data_format,
92
- dtype=dtype,
93
- name=f"upsample_{i}_conv",
94
- )(x)
95
-
96
- # Ouput block.
97
- x = layers.GroupNormalization(
98
- groups=32,
99
- axis=gn_axis,
100
- epsilon=1e-6,
101
- dtype=dtype,
102
- name="output_norm",
103
- )(x)
104
- x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
105
- image_outputs = layers.Conv2D(
106
- output_channels,
107
- 3,
108
- 1,
109
- padding="same",
110
- data_format=data_format,
111
- dtype=dtype,
112
- name="output_projection",
113
- )(x)
114
- super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)
115
-
116
- # === Config ===
117
- self.stackwise_num_filters = stackwise_num_filters
118
- self.stackwise_num_blocks = stackwise_num_blocks
119
- self.output_channels = output_channels
120
- self.latent_shape = latent_shape
121
-
122
- if dtype is not None:
123
- try:
124
- self.dtype_policy = keras.dtype_policies.get(dtype)
125
- # Before Keras 3.2, there is no `keras.dtype_policies.get`.
126
- except AttributeError:
127
- if isinstance(dtype, keras.DTypePolicy):
128
- dtype = dtype.name
129
- self.dtype_policy = keras.DTypePolicy(dtype)
130
-
131
- def get_config(self):
132
- config = super().get_config()
133
- config.update(
134
- {
135
- "stackwise_num_filters": self.stackwise_num_filters,
136
- "stackwise_num_blocks": self.stackwise_num_blocks,
137
- "output_channels": self.output_channels,
138
- "image_shape": self.latent_shape,
139
- }
140
- )
141
- return config
142
-
143
-
144
- def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
145
- data_format = standardize_data_format(data_format)
146
- gn_axis = -1 if data_format == "channels_last" else 1
147
- input_filters = x.shape[gn_axis]
148
-
149
- residual = x
150
- x = layers.GroupNormalization(
151
- groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm1"
152
- )(x)
153
- x = layers.Activation("swish", dtype=dtype)(x)
154
- x = layers.Conv2D(
155
- filters,
156
- 3,
157
- 1,
158
- padding="same",
159
- data_format=data_format,
160
- dtype=dtype,
161
- name=f"{name}_conv1",
162
- )(x)
163
- x = layers.GroupNormalization(
164
- groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm2"
165
- )(x)
166
- x = layers.Activation("swish")(x)
167
- x = layers.Conv2D(
168
- filters,
169
- 3,
170
- 1,
171
- padding="same",
172
- data_format=data_format,
173
- dtype=dtype,
174
- name=f"{name}_conv2",
175
- )(x)
176
- if input_filters != filters:
177
- residual = layers.Conv2D(
178
- filters,
179
- 1,
180
- 1,
181
- data_format=data_format,
182
- dtype=dtype,
183
- name=f"{name}_residual_projection",
184
- )(residual)
185
- x = layers.Add(dtype=dtype)([residual, x])
186
- return x