keras-hub-nightly 0.16.1.dev202409250340__py3-none-any.whl → 0.16.1.dev202409260340__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. keras_hub/api/layers/__init__.py +3 -0
  2. keras_hub/api/models/__init__.py +16 -0
  3. keras_hub/api/tokenizers/__init__.py +1 -0
  4. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -2
  5. keras_hub/src/models/clip/clip_preprocessor.py +147 -0
  6. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_text_encoder.py +60 -57
  7. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +69 -30
  8. keras_hub/src/models/densenet/__init__.py +6 -0
  9. keras_hub/src/models/densenet/densenet_backbone.py +11 -8
  10. keras_hub/src/models/densenet/densenet_image_classifier.py +27 -4
  11. keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
  12. keras_hub/src/models/densenet/densenet_image_converter.py +23 -0
  13. keras_hub/src/models/densenet/densenet_presets.py +56 -0
  14. keras_hub/src/models/stable_diffusion_3/__init__.py +13 -0
  15. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +93 -0
  16. keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -26
  17. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +630 -0
  18. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +151 -0
  19. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +77 -0
  20. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -7
  21. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +333 -0
  22. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -3
  23. keras_hub/src/models/text_to_image.py +295 -0
  24. keras_hub/src/utils/timm/convert_densenet.py +107 -0
  25. keras_hub/src/utils/timm/preset_loader.py +3 -0
  26. keras_hub/src/version_utils.py +1 -1
  27. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/METADATA +1 -1
  28. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/RECORD +31 -23
  29. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
  30. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
  31. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
  32. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
  33. /keras_hub/src/models/{stable_diffusion_v3 → clip}/__init__.py +0 -0
  34. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/WHEEL +0 -0
  35. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,151 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from keras import ops
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
18
+ StableDiffusion3Backbone,
19
+ )
20
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
21
+ StableDiffusion3TextToImagePreprocessor,
22
+ )
23
+ from keras_hub.src.models.text_to_image import TextToImage
24
+
25
+
26
+ @keras_hub_export("keras_hub.models.StableDiffusion3TextToImage")
27
+ class StableDiffusion3TextToImage(TextToImage):
28
+ """An end-to-end Stable Diffusion 3 model for text-to-image generation.
29
+
30
+ This model has a `generate()` method, which generates image based on a
31
+ prompt.
32
+
33
+ Args:
34
+ backbone: A `keras_hub.models.StableDiffusion3Backbone` instance.
35
+ preprocessor: A
36
+ `keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance.
37
+
38
+ Examples:
39
+
40
+ Use `generate()` to do image generation.
41
+ ```python
42
+ text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
43
+ "stable_diffusion_3_medium", height=512, width=512
44
+ )
45
+ text_to_image.generate(
46
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
47
+ )
48
+
49
+ # Generate with batched prompts.
50
+ text_to_image.generate(
51
+ ["cute wallpaper art of a cat", "cute wallpaper art of a dog"]
52
+ )
53
+
54
+ # Generate with different `num_steps` and `classifier_free_guidance_scale`.
55
+ text_to_image.generate(
56
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
57
+ num_steps=50,
58
+ classifier_free_guidance_scale=5.0,
59
+ )
60
+ ```
61
+ """
62
+
63
+ backbone_cls = StableDiffusion3Backbone
64
+ preprocessor_cls = StableDiffusion3TextToImagePreprocessor
65
+
66
+ def __init__(
67
+ self,
68
+ backbone,
69
+ preprocessor,
70
+ **kwargs,
71
+ ):
72
+ # === Layers ===
73
+ self.backbone = backbone
74
+ self.preprocessor = preprocessor
75
+
76
+ # === Functional Model ===
77
+ inputs = backbone.input
78
+ outputs = backbone.output
79
+ super().__init__(
80
+ inputs=inputs,
81
+ outputs=outputs,
82
+ **kwargs,
83
+ )
84
+
85
+ def fit(self, *args, **kwargs):
86
+ raise NotImplementedError(
87
+ "Currently, `fit` is not supported for "
88
+ "`StableDiffusion3TextToImage`."
89
+ )
90
+
91
+ def generate_step(
92
+ self,
93
+ latents,
94
+ token_ids,
95
+ negative_token_ids,
96
+ num_steps,
97
+ guidance_scale,
98
+ ):
99
+ """A compilable generation function for batched of inputs.
100
+
101
+ This function represents the inner, XLA-compilable, generation function
102
+ for batched inputs.
103
+
104
+ Args:
105
+ latents: A (batch_size, height, width, channels) tensor
106
+ containing the latents to start generation from. Typically, this
107
+ tensor is sampled from the Gaussian distribution.
108
+ token_ids: A (batch_size, num_tokens) tensor containing the
109
+ tokens based on the input prompts.
110
+ negative_token_ids: A (batch_size, num_tokens) tensor
111
+ containing the negative tokens based on the input prompts.
112
+ num_steps: int. The number of diffusion steps to take.
113
+ guidance_scale: float. The classifier free guidance scale defined in
114
+ [Classifier-Free Diffusion Guidance](
115
+ https://arxiv.org/abs/2207.12598). Higher scale encourages to
116
+ generate images that are closely linked to prompts, usually at
117
+ the expense of lower image quality.
118
+ """
119
+ # Encode inputs.
120
+ embeddings = self.backbone.encode_step(token_ids, negative_token_ids)
121
+
122
+ # Denoise.
123
+ def body_fun(step, latents):
124
+ return self.backbone.denoise_step(
125
+ latents,
126
+ embeddings,
127
+ step,
128
+ num_steps,
129
+ guidance_scale,
130
+ )
131
+
132
+ latents = ops.fori_loop(0, num_steps, body_fun, latents)
133
+
134
+ # Decode.
135
+ return self.backbone.decode_step(latents)
136
+
137
+ def generate(
138
+ self,
139
+ inputs,
140
+ negative_inputs=None,
141
+ num_steps=28,
142
+ guidance_scale=7.0,
143
+ seed=None,
144
+ ):
145
+ return super().generate(
146
+ inputs,
147
+ negative_inputs=negative_inputs,
148
+ num_steps=num_steps,
149
+ guidance_scale=guidance_scale,
150
+ seed=seed,
151
+ )
@@ -0,0 +1,77 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from keras import layers
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.preprocessor import Preprocessor
18
+
19
+
20
+ @keras_hub_export("keras_hub.models.StableDiffusion3TextToImagePreprocessor")
21
+ class StableDiffusion3TextToImagePreprocessor(Preprocessor):
22
+ """Stable Diffusion 3 text-to-image model preprocessor.
23
+
24
+ This preprocessing layer is meant for use with
25
+ `keras_hub.models.StableDiffusion3TextToImage`.
26
+
27
+ For use with generation, the layer exposes one methods
28
+ `generate_preprocess()`.
29
+
30
+ Args:
31
+ clip_l_preprocessor: A `keras_hub.models.CLIPPreprocessor` instance.
32
+ clip_g_preprocessor: A `keras_hub.models.CLIPPreprocessor` instance.
33
+ t5_preprocessor: A optional `keras_hub.models.T5Preprocessor` instance.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ clip_l_preprocessor,
39
+ clip_g_preprocessor,
40
+ t5_preprocessor=None,
41
+ **kwargs,
42
+ ):
43
+ super().__init__(**kwargs)
44
+ self.clip_l_preprocessor = clip_l_preprocessor
45
+ self.clip_g_preprocessor = clip_g_preprocessor
46
+ self.t5_preprocessor = t5_preprocessor
47
+
48
+ def build(self, input_shape):
49
+ self.built = True
50
+
51
+ def generate_preprocess(self, x):
52
+ token_ids = {}
53
+ token_ids["clip_l"] = self.clip_l_preprocessor(x)["token_ids"]
54
+ token_ids["clip_g"] = self.clip_g_preprocessor(x)["token_ids"]
55
+ if self.t5_preprocessor is not None:
56
+ token_ids["t5"] = self.t5_preprocessor(x)["token_ids"]
57
+ return token_ids
58
+
59
+ def get_config(self):
60
+ config = super().get_config()
61
+ config.update(
62
+ {
63
+ "clip_l_preprocessor": layers.serialize(
64
+ self.clip_l_preprocessor
65
+ ),
66
+ "clip_g_preprocessor": layers.serialize(
67
+ self.clip_g_preprocessor
68
+ ),
69
+ "t5_preprocessor": layers.serialize(self.t5_preprocessor),
70
+ }
71
+ )
72
+ return config
73
+
74
+ @property
75
+ def sequence_length(self):
76
+ """The padded length of model input sequences."""
77
+ return self.clip_l_preprocessor.sequence_length
@@ -20,7 +20,7 @@ from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm
20
20
  from keras_hub.src.models.t5.t5_transformer_layer import T5TransformerLayer
21
21
 
22
22
 
23
- class T5XXLTextEncoder(keras.Model):
23
+ class T5Encoder(keras.Model):
24
24
  def __init__(
25
25
  self,
26
26
  vocabulary_size,
@@ -81,10 +81,10 @@ class T5XXLTextEncoder(keras.Model):
81
81
 
82
82
  # === Functional Model ===
83
83
  encoder_token_id_input = keras.Input(
84
- shape=(None,), dtype="int32", name="encoder_token_ids"
84
+ shape=(None,), dtype="int32", name="token_ids"
85
85
  )
86
86
  encoder_padding_mask_input = keras.Input(
87
- shape=(None,), dtype="int32", name="encoder_padding_mask"
87
+ shape=(None,), dtype="int32", name="padding_mask"
88
88
  )
89
89
  # Encoder.
90
90
  x = self.token_embedding(encoder_token_id_input)
@@ -102,14 +102,14 @@ class T5XXLTextEncoder(keras.Model):
102
102
  x, position_bias = output
103
103
  x = self.encoder_layer_norm(x)
104
104
  x = self.encoder_dropout(x)
105
- encoder_output = x
105
+ sequence_output = x
106
106
 
107
107
  super().__init__(
108
108
  {
109
- "encoder_token_ids": encoder_token_id_input,
110
- "encoder_padding_mask": encoder_padding_mask_input,
109
+ "token_ids": encoder_token_id_input,
110
+ "padding_mask": encoder_padding_mask_input,
111
111
  },
112
- outputs=encoder_output,
112
+ outputs=sequence_output,
113
113
  **kwargs,
114
114
  )
115
115
 
@@ -0,0 +1,333 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ from keras import layers
17
+ from keras import ops
18
+
19
+ from keras_hub.src.models.backbone import Backbone
20
+ from keras_hub.src.utils.keras_utils import standardize_data_format
21
+
22
+
23
+ class VAEAttention(layers.Layer):
24
+ def __init__(self, filters, groups=32, data_format=None, **kwargs):
25
+ super().__init__(**kwargs)
26
+ self.filters = filters
27
+ self.data_format = standardize_data_format(data_format)
28
+ gn_axis = -1 if self.data_format == "channels_last" else 1
29
+
30
+ self.group_norm = layers.GroupNormalization(
31
+ groups=groups,
32
+ axis=gn_axis,
33
+ epsilon=1e-6,
34
+ dtype="float32",
35
+ name="group_norm",
36
+ )
37
+ self.query_conv2d = layers.Conv2D(
38
+ filters,
39
+ 1,
40
+ 1,
41
+ data_format=self.data_format,
42
+ dtype=self.dtype_policy,
43
+ name="query_conv2d",
44
+ )
45
+ self.key_conv2d = layers.Conv2D(
46
+ filters,
47
+ 1,
48
+ 1,
49
+ data_format=self.data_format,
50
+ dtype=self.dtype_policy,
51
+ name="key_conv2d",
52
+ )
53
+ self.value_conv2d = layers.Conv2D(
54
+ filters,
55
+ 1,
56
+ 1,
57
+ data_format=self.data_format,
58
+ dtype=self.dtype_policy,
59
+ name="value_conv2d",
60
+ )
61
+ self.softmax = layers.Softmax(dtype="float32")
62
+ self.output_conv2d = layers.Conv2D(
63
+ filters,
64
+ 1,
65
+ 1,
66
+ data_format=self.data_format,
67
+ dtype=self.dtype_policy,
68
+ name="output_conv2d",
69
+ )
70
+
71
+ self.groups = groups
72
+ self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
73
+
74
+ def build(self, input_shape):
75
+ self.group_norm.build(input_shape)
76
+ self.query_conv2d.build(input_shape)
77
+ self.key_conv2d.build(input_shape)
78
+ self.value_conv2d.build(input_shape)
79
+ self.output_conv2d.build(input_shape)
80
+
81
+ def call(self, inputs, training=None):
82
+ x = self.group_norm(inputs)
83
+ query = self.query_conv2d(x)
84
+ key = self.key_conv2d(x)
85
+ value = self.value_conv2d(x)
86
+
87
+ if self.data_format == "channels_first":
88
+ query = ops.transpose(query, (0, 2, 3, 1))
89
+ key = ops.transpose(key, (0, 2, 3, 1))
90
+ value = ops.transpose(value, (0, 2, 3, 1))
91
+ shape = ops.shape(inputs)
92
+ b = shape[0]
93
+ query = ops.reshape(query, (b, -1, self.filters))
94
+ key = ops.reshape(key, (b, -1, self.filters))
95
+ value = ops.reshape(value, (b, -1, self.filters))
96
+
97
+ # Compute attention.
98
+ query = ops.multiply(
99
+ query, ops.cast(self._inverse_sqrt_filters, query.dtype)
100
+ )
101
+ # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
102
+ attention_scores = ops.einsum("abc,adc->abd", query, key)
103
+ attention_scores = ops.cast(
104
+ self.softmax(attention_scores), self.compute_dtype
105
+ )
106
+ # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
107
+ attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
108
+ x = ops.reshape(attention_output, shape)
109
+
110
+ x = self.output_conv2d(x)
111
+ if self.data_format == "channels_first":
112
+ x = ops.transpose(x, (0, 3, 1, 2))
113
+ x = ops.add(x, inputs)
114
+ return x
115
+
116
+ def get_config(self):
117
+ config = super().get_config()
118
+ config.update(
119
+ {
120
+ "filters": self.filters,
121
+ "groups": self.groups,
122
+ }
123
+ )
124
+ return config
125
+
126
+ def compute_output_shape(self, input_shape):
127
+ return input_shape
128
+
129
+
130
+ def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
131
+ data_format = standardize_data_format(data_format)
132
+ gn_axis = -1 if data_format == "channels_last" else 1
133
+ input_filters = x.shape[gn_axis]
134
+
135
+ residual = x
136
+ x = layers.GroupNormalization(
137
+ groups=32,
138
+ axis=gn_axis,
139
+ epsilon=1e-6,
140
+ dtype="float32",
141
+ name=f"{name}_norm1",
142
+ )(x)
143
+ x = layers.Activation("swish", dtype=dtype)(x)
144
+ x = layers.Conv2D(
145
+ filters,
146
+ 3,
147
+ 1,
148
+ padding="same",
149
+ data_format=data_format,
150
+ dtype=dtype,
151
+ name=f"{name}_conv1",
152
+ )(x)
153
+ x = layers.GroupNormalization(
154
+ groups=32,
155
+ axis=gn_axis,
156
+ epsilon=1e-6,
157
+ dtype="float32",
158
+ name=f"{name}_norm2",
159
+ )(x)
160
+ x = layers.Activation("swish", dtype=dtype)(x)
161
+ x = layers.Conv2D(
162
+ filters,
163
+ 3,
164
+ 1,
165
+ padding="same",
166
+ data_format=data_format,
167
+ dtype=dtype,
168
+ name=f"{name}_conv2",
169
+ )(x)
170
+ if input_filters != filters:
171
+ residual = layers.Conv2D(
172
+ filters,
173
+ 1,
174
+ 1,
175
+ data_format=data_format,
176
+ dtype=dtype,
177
+ name=f"{name}_residual_projection",
178
+ )(residual)
179
+ x = layers.Add(dtype=dtype)([residual, x])
180
+ return x
181
+
182
+
183
+ class VAEImageDecoder(Backbone):
184
+ """Decoder for the VAE model used in Stable Diffusion 3.
185
+
186
+ Args:
187
+ stackwise_num_filters: list of ints. The number of filters for each
188
+ stack.
189
+ stackwise_num_blocks: list of ints. The number of blocks for each stack.
190
+ output_channels: int. The number of channels in the output.
191
+ latent_shape: tuple. The shape of the latent image.
192
+ data_format: `None` or str. If specified, either `"channels_last"` or
193
+ `"channels_first"`. The ordering of the dimensions in the
194
+ inputs. `"channels_last"` corresponds to inputs with shape
195
+ `(batch_size, height, width, channels)`
196
+ while `"channels_first"` corresponds to inputs with shape
197
+ `(batch_size, channels, height, width)`. It defaults to the
198
+ `image_data_format` value found in your Keras config file at
199
+ `~/.keras/keras.json`. If you never set it, then it will be
200
+ `"channels_last"`.
201
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
202
+ to use for the model's computations and weights.
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ stackwise_num_filters,
208
+ stackwise_num_blocks,
209
+ output_channels=3,
210
+ latent_shape=(None, None, 16),
211
+ data_format=None,
212
+ dtype=None,
213
+ **kwargs,
214
+ ):
215
+ data_format = standardize_data_format(data_format)
216
+ gn_axis = -1 if data_format == "channels_last" else 1
217
+
218
+ # === Functional Model ===
219
+ latent_inputs = layers.Input(shape=latent_shape)
220
+
221
+ x = layers.Conv2D(
222
+ stackwise_num_filters[0],
223
+ 3,
224
+ 1,
225
+ padding="same",
226
+ data_format=data_format,
227
+ dtype=dtype,
228
+ name="input_projection",
229
+ )(latent_inputs)
230
+ x = apply_resnet_block(
231
+ x,
232
+ stackwise_num_filters[0],
233
+ data_format=data_format,
234
+ dtype=dtype,
235
+ name="input_block0",
236
+ )
237
+ x = VAEAttention(
238
+ stackwise_num_filters[0],
239
+ data_format=data_format,
240
+ dtype=dtype,
241
+ name="input_attention",
242
+ )(x)
243
+ x = apply_resnet_block(
244
+ x,
245
+ stackwise_num_filters[0],
246
+ data_format=data_format,
247
+ dtype=dtype,
248
+ name="input_block1",
249
+ )
250
+
251
+ # Stacks.
252
+ for i, filters in enumerate(stackwise_num_filters):
253
+ for j in range(stackwise_num_blocks[i]):
254
+ x = apply_resnet_block(
255
+ x,
256
+ filters,
257
+ data_format=data_format,
258
+ dtype=dtype,
259
+ name=f"block{i}_{j}",
260
+ )
261
+ if i != len(stackwise_num_filters) - 1:
262
+ # No upsamling in the last blcok.
263
+ x = layers.UpSampling2D(
264
+ 2,
265
+ data_format=data_format,
266
+ dtype=dtype,
267
+ name=f"upsample_{i}",
268
+ )(x)
269
+ x = layers.Conv2D(
270
+ filters,
271
+ 3,
272
+ 1,
273
+ padding="same",
274
+ data_format=data_format,
275
+ dtype=dtype,
276
+ name=f"upsample_{i}_conv",
277
+ )(x)
278
+
279
+ # Ouput block.
280
+ x = layers.GroupNormalization(
281
+ groups=32,
282
+ axis=gn_axis,
283
+ epsilon=1e-6,
284
+ dtype="float32",
285
+ name="output_norm",
286
+ )(x)
287
+ x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
288
+ image_outputs = layers.Conv2D(
289
+ output_channels,
290
+ 3,
291
+ 1,
292
+ padding="same",
293
+ data_format=data_format,
294
+ dtype=dtype,
295
+ name="output_projection",
296
+ )(x)
297
+ super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)
298
+
299
+ # === Config ===
300
+ self.stackwise_num_filters = stackwise_num_filters
301
+ self.stackwise_num_blocks = stackwise_num_blocks
302
+ self.output_channels = output_channels
303
+ self.latent_shape = latent_shape
304
+
305
+ @property
306
+ def scaling_factor(self):
307
+ """The scaling factor for the latent space.
308
+
309
+ This is used to scale the latent space to have unit variance when
310
+ training the diffusion model.
311
+ """
312
+ return 1.5305
313
+
314
+ @property
315
+ def shift_factor(self):
316
+ """The shift factor for the latent space.
317
+
318
+ This is used to shift the latent space to have zero mean when
319
+ training the diffusion model.
320
+ """
321
+ return 0.0609
322
+
323
+ def get_config(self):
324
+ config = super().get_config()
325
+ config.update(
326
+ {
327
+ "stackwise_num_filters": self.stackwise_num_filters,
328
+ "stackwise_num_blocks": self.stackwise_num_blocks,
329
+ "output_channels": self.output_channels,
330
+ "image_shape": self.latent_shape,
331
+ }
332
+ )
333
+ return config
@@ -13,13 +13,15 @@
13
13
  # limitations under the License.
14
14
  import keras
15
15
 
16
+ from keras_hub.src.api_export import keras_hub_export
16
17
  from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
17
18
  from keras_hub.src.models.preprocessor import Preprocessor
18
19
  from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer
19
20
  from keras_hub.src.utils.tensor_utils import preprocessing_function
20
21
 
21
22
 
22
- class T5XXLPreprocessor(Preprocessor):
23
+ @keras_hub_export("keras_hub.models.T5Preprocessor")
24
+ class T5Preprocessor(Preprocessor):
23
25
  tokenizer_cls = T5Tokenizer
24
26
 
25
27
  def __init__(
@@ -49,10 +51,17 @@ class T5XXLPreprocessor(Preprocessor):
49
51
  self.built = True
50
52
 
51
53
  @preprocessing_function
52
- def call(self, x, y=None, sample_weight=None, sequence_length=None):
54
+ def call(
55
+ self,
56
+ x,
57
+ y=None,
58
+ sample_weight=None,
59
+ sequence_length=None,
60
+ ):
61
+ sequence_length = sequence_length or self.sequence_length
53
62
  token_ids, padding_mask = self.packer(
54
63
  self.tokenizer(x),
55
- sequence_length=sequence_length or self.sequence_length,
64
+ sequence_length=sequence_length,
56
65
  add_start_value=self.add_start_token,
57
66
  add_end_value=self.add_end_token,
58
67
  )