keras-hub-nightly 0.16.1.dev202409250340__py3-none-any.whl → 0.16.1.dev202409260340__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -0
- keras_hub/api/models/__init__.py +16 -0
- keras_hub/api/tokenizers/__init__.py +1 -0
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -2
- keras_hub/src/models/clip/clip_preprocessor.py +147 -0
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_text_encoder.py +60 -57
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +69 -30
- keras_hub/src/models/densenet/__init__.py +6 -0
- keras_hub/src/models/densenet/densenet_backbone.py +11 -8
- keras_hub/src/models/densenet/densenet_image_classifier.py +27 -4
- keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
- keras_hub/src/models/densenet/densenet_image_converter.py +23 -0
- keras_hub/src/models/densenet/densenet_presets.py +56 -0
- keras_hub/src/models/stable_diffusion_3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +93 -0
- keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -26
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +630 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +151 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +77 -0
- keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -7
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +333 -0
- keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -3
- keras_hub/src/models/text_to_image.py +295 -0
- keras_hub/src/utils/timm/convert_densenet.py +107 -0
- keras_hub/src/utils/timm/preset_loader.py +3 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/RECORD +31 -23
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
- /keras_hub/src/models/{stable_diffusion_v3 → clip}/__init__.py +0 -0
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,630 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import keras
|
15
|
+
from keras import layers
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
from keras_hub.src.models.backbone import Backbone
|
20
|
+
from keras_hub.src.models.stable_diffusion_3.flow_match_euler_discrete_scheduler import (
|
21
|
+
FlowMatchEulerDiscreteScheduler,
|
22
|
+
)
|
23
|
+
from keras_hub.src.models.stable_diffusion_3.mmdit import MMDiT
|
24
|
+
from keras_hub.src.models.stable_diffusion_3.vae_image_decoder import (
|
25
|
+
VAEImageDecoder,
|
26
|
+
)
|
27
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
28
|
+
|
29
|
+
|
30
|
+
class CLIPProjection(layers.Layer):
|
31
|
+
def __init__(self, hidden_dim, **kwargs):
|
32
|
+
super().__init__(**kwargs)
|
33
|
+
self.hidden_dim = int(hidden_dim)
|
34
|
+
|
35
|
+
self.dense = layers.Dense(
|
36
|
+
hidden_dim,
|
37
|
+
use_bias=False,
|
38
|
+
dtype=self.dtype_policy,
|
39
|
+
name="dense",
|
40
|
+
)
|
41
|
+
|
42
|
+
def build(self, inputs_shape, token_ids_shape):
|
43
|
+
inputs_shape = list(inputs_shape)
|
44
|
+
self.dense.build([None, inputs_shape[-1]])
|
45
|
+
|
46
|
+
# Assign identity matrix to the kernel as default.
|
47
|
+
self.dense._kernel.assign(ops.eye(self.hidden_dim))
|
48
|
+
|
49
|
+
def call(self, inputs, token_ids):
|
50
|
+
indices = ops.expand_dims(
|
51
|
+
ops.cast(ops.argmax(token_ids, axis=-1), "int32"), axis=-1
|
52
|
+
)
|
53
|
+
pooled_output = ops.take_along_axis(inputs, indices[:, :, None], axis=1)
|
54
|
+
pooled_output = ops.squeeze(pooled_output, axis=1)
|
55
|
+
return self.dense(pooled_output)
|
56
|
+
|
57
|
+
def get_config(self):
|
58
|
+
config = super().get_config()
|
59
|
+
config.update(
|
60
|
+
{
|
61
|
+
"hidden_dim": self.hidden_dim,
|
62
|
+
}
|
63
|
+
)
|
64
|
+
return config
|
65
|
+
|
66
|
+
def compute_output_shape(self, inputs_shape):
|
67
|
+
return (inputs_shape[0], self.hidden_dim)
|
68
|
+
|
69
|
+
|
70
|
+
class ClassifierFreeGuidanceConcatenate(layers.Layer):
|
71
|
+
def __init__(self, axis=0, **kwargs):
|
72
|
+
super().__init__(**kwargs)
|
73
|
+
self.axis = axis
|
74
|
+
|
75
|
+
def call(
|
76
|
+
self,
|
77
|
+
latents,
|
78
|
+
positive_contexts,
|
79
|
+
negative_contexts,
|
80
|
+
positive_pooled_projections,
|
81
|
+
negative_pooled_projections,
|
82
|
+
timestep,
|
83
|
+
):
|
84
|
+
timestep = ops.broadcast_to(timestep, ops.shape(latents)[:1])
|
85
|
+
latents = ops.concatenate([latents, latents], axis=self.axis)
|
86
|
+
contexts = ops.concatenate(
|
87
|
+
[positive_contexts, negative_contexts], axis=self.axis
|
88
|
+
)
|
89
|
+
pooled_projections = ops.concatenate(
|
90
|
+
[positive_pooled_projections, negative_pooled_projections],
|
91
|
+
axis=self.axis,
|
92
|
+
)
|
93
|
+
timesteps = ops.concatenate([timestep, timestep], axis=self.axis)
|
94
|
+
return latents, contexts, pooled_projections, timesteps
|
95
|
+
|
96
|
+
def get_config(self):
|
97
|
+
return super().get_config()
|
98
|
+
|
99
|
+
|
100
|
+
class ClassifierFreeGuidance(layers.Layer):
|
101
|
+
"""Perform classifier free guidance.
|
102
|
+
|
103
|
+
This layer expects the inputs to be a concatenation of positive and negative
|
104
|
+
(or empty) noise. The computation applies the classifier-free guidance
|
105
|
+
scale.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
109
|
+
including `name`, `dtype` etc.
|
110
|
+
|
111
|
+
Call arguments:
|
112
|
+
inputs: A concatenation of positive and negative (or empty) noises.
|
113
|
+
guidance_scale: The scale factor for classifier-free guidance.
|
114
|
+
|
115
|
+
Reference:
|
116
|
+
- [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
117
|
+
"""
|
118
|
+
|
119
|
+
def __init__(self, **kwargs):
|
120
|
+
super().__init__(**kwargs)
|
121
|
+
|
122
|
+
def call(self, inputs, guidance_scale):
|
123
|
+
positive_noise, negative_noise = ops.split(inputs, 2, axis=0)
|
124
|
+
return ops.add(
|
125
|
+
negative_noise,
|
126
|
+
ops.multiply(
|
127
|
+
guidance_scale, ops.subtract(positive_noise, negative_noise)
|
128
|
+
),
|
129
|
+
)
|
130
|
+
|
131
|
+
def get_config(self):
|
132
|
+
return super().get_config()
|
133
|
+
|
134
|
+
def compute_output_shape(self, inputs_shape):
|
135
|
+
outputs_shape = list(inputs_shape)
|
136
|
+
if outputs_shape[0] is not None:
|
137
|
+
outputs_shape[0] = outputs_shape[0] // 2
|
138
|
+
return outputs_shape
|
139
|
+
|
140
|
+
|
141
|
+
class EulerStep(layers.Layer):
|
142
|
+
"""A layer predicts the sample with the timestep and the predicted noise.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
146
|
+
including `name`, `dtype` etc.
|
147
|
+
|
148
|
+
Call arguments:
|
149
|
+
latents: A current sample created by the diffusion process.
|
150
|
+
noise_residual: The direct output from the diffusion model.
|
151
|
+
sigma: The amount of noise added at the current timestep.
|
152
|
+
sigma_next: The amount of noise added at the next timestep.
|
153
|
+
|
154
|
+
References:
|
155
|
+
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](
|
156
|
+
https://arxiv.org/abs/2305.08891).
|
157
|
+
- [Elucidating the Design Space of Diffusion-Based Generative Models](
|
158
|
+
https://arxiv.org/abs/2206.00364).
|
159
|
+
"""
|
160
|
+
|
161
|
+
def __init__(self, **kwargs):
|
162
|
+
super().__init__(**kwargs)
|
163
|
+
|
164
|
+
def call(self, latents, noise_residual, sigma, sigma_next):
|
165
|
+
sigma_diff = ops.subtract(sigma_next, sigma)
|
166
|
+
return ops.add(latents, ops.multiply(sigma_diff, noise_residual))
|
167
|
+
|
168
|
+
def get_config(self):
|
169
|
+
return super().get_config()
|
170
|
+
|
171
|
+
def compute_output_shape(self, latents_shape):
|
172
|
+
return latents_shape
|
173
|
+
|
174
|
+
|
175
|
+
class LatentSpaceDecoder(layers.Layer):
|
176
|
+
"""Decoder to transform the latent space back to the original image space.
|
177
|
+
|
178
|
+
During decoding, the latents are transformed back to the original image
|
179
|
+
space using the equation: `latents / scale + shift`.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
scale: float. The scaling factor.
|
183
|
+
shift: float. The shift factor.
|
184
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
185
|
+
including `name`, `dtype` etc.
|
186
|
+
|
187
|
+
Call arguments:
|
188
|
+
latents: The latent tensor to be transformed.
|
189
|
+
|
190
|
+
Reference:
|
191
|
+
- [High-Resolution Image Synthesis with Latent Diffusion Models](
|
192
|
+
https://arxiv.org/abs/2112.10752).
|
193
|
+
"""
|
194
|
+
|
195
|
+
def __init__(self, scale, shift, **kwargs):
|
196
|
+
super().__init__(**kwargs)
|
197
|
+
self.scale = scale
|
198
|
+
self.shift = shift
|
199
|
+
|
200
|
+
def call(self, latents):
|
201
|
+
return ops.add(ops.divide(latents, self.scale), self.shift)
|
202
|
+
|
203
|
+
def get_config(self):
|
204
|
+
config = super().get_config()
|
205
|
+
config.update(
|
206
|
+
{
|
207
|
+
"scale": self.scale,
|
208
|
+
"shift": self.shift,
|
209
|
+
}
|
210
|
+
)
|
211
|
+
return config
|
212
|
+
|
213
|
+
def compute_output_shape(self, latents_shape):
|
214
|
+
return latents_shape
|
215
|
+
|
216
|
+
|
217
|
+
@keras_hub_export("keras_hub.models.StableDiffusion3Backbone")
|
218
|
+
class StableDiffusion3Backbone(Backbone):
|
219
|
+
"""Stable Diffusion 3 core network with hyperparameters.
|
220
|
+
|
221
|
+
This backbone imports CLIP and T5 models as text encoders and implements the
|
222
|
+
base MMDiT and VAE networks for the Stable Diffusion 3 model.
|
223
|
+
|
224
|
+
The default constructor gives a fully customizable, randomly initialized
|
225
|
+
MMDiT and VAE models with any hyperparameters. To load preset architectures
|
226
|
+
and weights, use the `from_preset` constructor.
|
227
|
+
|
228
|
+
Args:
|
229
|
+
mmdit_patch_size: int. The size of each square patch in the input image
|
230
|
+
in MMDiT.
|
231
|
+
mmdit_hidden_dim: int. The size of the transformer hidden state at the
|
232
|
+
end of each transformer layer in MMDiT.
|
233
|
+
mmdit_num_layers: int. The number of transformer layers in MMDiT.
|
234
|
+
mmdit_num_heads: int. The number of attention heads for each
|
235
|
+
transformer in MMDiT.
|
236
|
+
mmdit_position_size: int. The size of the height and width for the
|
237
|
+
position embedding in MMDiT.
|
238
|
+
vae_stackwise_num_filters: list of ints. The number of filters for each
|
239
|
+
stack in VAE.
|
240
|
+
vae_stackwise_num_blocks: list of ints. The number of blocks for each
|
241
|
+
stack in VAE.
|
242
|
+
clip_l: `keras_hub.models.CLIPTextEncoder`. The text encoder for
|
243
|
+
encoding the inputs.
|
244
|
+
clip_g: `keras_hub.models.CLIPTextEncoder`. The text encoder for
|
245
|
+
encoding the inputs.
|
246
|
+
t5: optional `keras_hub.models.T5Encoder`. The text encoder for
|
247
|
+
encoding the inputs.
|
248
|
+
latent_channels: int. The number of channels in the latent. Defaults to
|
249
|
+
`16`.
|
250
|
+
output_channels: int. The number of channels in the output. Defaults to
|
251
|
+
`3`.
|
252
|
+
num_train_timesteps: int. The number of diffusion steps to train the
|
253
|
+
model. Defaults to `1000`.
|
254
|
+
shift: float. The shift value for the timestep schedule. Defaults to
|
255
|
+
`1.0`.
|
256
|
+
height: optional int. The output height of the image.
|
257
|
+
width: optional int. The output width of the image.
|
258
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
259
|
+
`"channels_first"`. The ordering of the dimensions in the
|
260
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
261
|
+
`(batch_size, height, width, channels)`
|
262
|
+
while `"channels_first"` corresponds to inputs with shape
|
263
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
264
|
+
`image_data_format` value found in your Keras config file at
|
265
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
266
|
+
`"channels_last"`.
|
267
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
268
|
+
for the models computations and weights. Note that some
|
269
|
+
computations, such as softmax and layer normalization will always
|
270
|
+
be done a float32 precision regardless of dtype.
|
271
|
+
|
272
|
+
Example:
|
273
|
+
```python
|
274
|
+
# Pretrained Stable Diffusion 3 model.
|
275
|
+
model = keras_hub.models.StableDiffusion3Backbone.from_preset(
|
276
|
+
"stable_diffusion_3_medium"
|
277
|
+
)
|
278
|
+
|
279
|
+
# Randomly initialized Stable Diffusion 3 model with custom config.
|
280
|
+
clip_l = keras_hub.models.CLIPTextEncoder(...)
|
281
|
+
clip_g = keras_hub.models.CLIPTextEncoder(...)
|
282
|
+
model = keras_hub.models.StableDiffusion3Backbone(
|
283
|
+
mmdit_patch_size=2,
|
284
|
+
mmdit_num_heads=4,
|
285
|
+
mmdit_hidden_dim=256,
|
286
|
+
mmdit_depth=4,
|
287
|
+
mmdit_position_size=192,
|
288
|
+
vae_stackwise_num_filters=[128, 128, 64, 32],
|
289
|
+
vae_stackwise_num_blocks=[1, 1, 1, 1],
|
290
|
+
clip_l=clip_l,
|
291
|
+
clip_g=clip_g,
|
292
|
+
)
|
293
|
+
```
|
294
|
+
"""
|
295
|
+
|
296
|
+
def __init__(
|
297
|
+
self,
|
298
|
+
mmdit_patch_size,
|
299
|
+
mmdit_hidden_dim,
|
300
|
+
mmdit_num_layers,
|
301
|
+
mmdit_num_heads,
|
302
|
+
mmdit_position_size,
|
303
|
+
vae_stackwise_num_filters,
|
304
|
+
vae_stackwise_num_blocks,
|
305
|
+
clip_l,
|
306
|
+
clip_g,
|
307
|
+
t5=None,
|
308
|
+
latent_channels=16,
|
309
|
+
output_channels=3,
|
310
|
+
num_train_timesteps=1000,
|
311
|
+
shift=1.0,
|
312
|
+
height=None,
|
313
|
+
width=None,
|
314
|
+
data_format=None,
|
315
|
+
dtype=None,
|
316
|
+
**kwargs,
|
317
|
+
):
|
318
|
+
height = int(height or 1024)
|
319
|
+
width = int(width or 1024)
|
320
|
+
if height % 8 != 0 or width % 8 != 0:
|
321
|
+
raise ValueError(
|
322
|
+
"`height` and `width` must be divisible by 8. "
|
323
|
+
f"Received: height={height}, width={width}"
|
324
|
+
)
|
325
|
+
data_format = standardize_data_format(data_format)
|
326
|
+
if data_format != "channels_last":
|
327
|
+
raise NotImplementedError
|
328
|
+
latent_shape = (height // 8, width // 8, latent_channels)
|
329
|
+
context_shape = (None, 4096 if t5 is None else t5.hidden_dim)
|
330
|
+
pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,)
|
331
|
+
|
332
|
+
# === Layers ===
|
333
|
+
self.clip_l = clip_l
|
334
|
+
self.clip_l_projection = CLIPProjection(
|
335
|
+
clip_l.hidden_dim, dtype=dtype, name="clip_l_projection"
|
336
|
+
)
|
337
|
+
self.clip_l_projection.build([None, clip_l.hidden_dim], None)
|
338
|
+
self.clip_g = clip_g
|
339
|
+
self.clip_g_projection = CLIPProjection(
|
340
|
+
clip_g.hidden_dim, dtype=dtype, name="clip_g_projection"
|
341
|
+
)
|
342
|
+
self.clip_g_projection.build([None, clip_g.hidden_dim], None)
|
343
|
+
self.t5 = t5
|
344
|
+
self.diffuser = MMDiT(
|
345
|
+
mmdit_patch_size,
|
346
|
+
mmdit_hidden_dim,
|
347
|
+
mmdit_num_layers,
|
348
|
+
mmdit_num_heads,
|
349
|
+
mmdit_position_size,
|
350
|
+
latent_shape=latent_shape,
|
351
|
+
context_shape=context_shape,
|
352
|
+
pooled_projection_shape=pooled_projection_shape,
|
353
|
+
data_format=data_format,
|
354
|
+
dtype=dtype,
|
355
|
+
name="diffuser",
|
356
|
+
)
|
357
|
+
self.decoder = VAEImageDecoder(
|
358
|
+
vae_stackwise_num_filters,
|
359
|
+
vae_stackwise_num_blocks,
|
360
|
+
output_channels,
|
361
|
+
latent_shape=latent_shape,
|
362
|
+
data_format=data_format,
|
363
|
+
dtype=dtype,
|
364
|
+
name="decoder",
|
365
|
+
)
|
366
|
+
# Set `dtype="float32"` to ensure the high precision for the noise
|
367
|
+
# residual.
|
368
|
+
self.scheduler = FlowMatchEulerDiscreteScheduler(
|
369
|
+
num_train_timesteps=num_train_timesteps,
|
370
|
+
shift=shift,
|
371
|
+
dtype="float32",
|
372
|
+
name="scheduler",
|
373
|
+
)
|
374
|
+
self.cfg_concat = ClassifierFreeGuidanceConcatenate(
|
375
|
+
dtype="float32", name="classifier_free_guidance_concat"
|
376
|
+
)
|
377
|
+
self.cfg = ClassifierFreeGuidance(
|
378
|
+
dtype="float32", name="classifier_free_guidance"
|
379
|
+
)
|
380
|
+
self.euler_step = EulerStep(dtype="float32", name="euler_step")
|
381
|
+
self.latent_space_decoder = LatentSpaceDecoder(
|
382
|
+
scale=self.decoder.scaling_factor,
|
383
|
+
shift=self.decoder.shift_factor,
|
384
|
+
dtype="float32",
|
385
|
+
name="latent_space_decoder",
|
386
|
+
)
|
387
|
+
|
388
|
+
# === Functional Model ===
|
389
|
+
latent_input = keras.Input(
|
390
|
+
shape=latent_shape,
|
391
|
+
name="latents",
|
392
|
+
)
|
393
|
+
clip_l_token_id_input = keras.Input(
|
394
|
+
shape=(None,),
|
395
|
+
dtype="int32",
|
396
|
+
name="clip_l_token_ids",
|
397
|
+
)
|
398
|
+
clip_l_negative_token_id_input = keras.Input(
|
399
|
+
shape=(None,),
|
400
|
+
dtype="int32",
|
401
|
+
name="clip_l_negative_token_ids",
|
402
|
+
)
|
403
|
+
clip_g_token_id_input = keras.Input(
|
404
|
+
shape=(None,),
|
405
|
+
dtype="int32",
|
406
|
+
name="clip_g_token_ids",
|
407
|
+
)
|
408
|
+
clip_g_negative_token_id_input = keras.Input(
|
409
|
+
shape=(None,),
|
410
|
+
dtype="int32",
|
411
|
+
name="clip_g_negative_token_ids",
|
412
|
+
)
|
413
|
+
token_ids = {
|
414
|
+
"clip_l": clip_l_token_id_input,
|
415
|
+
"clip_g": clip_g_token_id_input,
|
416
|
+
}
|
417
|
+
negative_token_ids = {
|
418
|
+
"clip_l": clip_l_negative_token_id_input,
|
419
|
+
"clip_g": clip_g_negative_token_id_input,
|
420
|
+
}
|
421
|
+
if self.t5 is not None:
|
422
|
+
t5_token_id_input = keras.Input(
|
423
|
+
shape=(None,),
|
424
|
+
dtype="int32",
|
425
|
+
name="t5_token_ids",
|
426
|
+
)
|
427
|
+
t5_negative_token_id_input = keras.Input(
|
428
|
+
shape=(None,),
|
429
|
+
dtype="int32",
|
430
|
+
name="t5_negative_token_ids",
|
431
|
+
)
|
432
|
+
token_ids["t5"] = t5_token_id_input
|
433
|
+
negative_token_ids["t5"] = t5_negative_token_id_input
|
434
|
+
num_step_input = keras.Input(
|
435
|
+
shape=(),
|
436
|
+
dtype="int32",
|
437
|
+
name="num_steps",
|
438
|
+
)
|
439
|
+
guidance_scale_input = keras.Input(
|
440
|
+
shape=(),
|
441
|
+
dtype="float32",
|
442
|
+
name="guidance_scale",
|
443
|
+
)
|
444
|
+
embeddings = self.encode_step(token_ids, negative_token_ids)
|
445
|
+
# Use `steps=0` to define the functional model.
|
446
|
+
latents = self.denoise_step(
|
447
|
+
latent_input,
|
448
|
+
embeddings,
|
449
|
+
0,
|
450
|
+
num_step_input[0],
|
451
|
+
guidance_scale_input[0],
|
452
|
+
)
|
453
|
+
outputs = self.decode_step(latents)
|
454
|
+
inputs = {
|
455
|
+
"latents": latent_input,
|
456
|
+
"clip_l_token_ids": clip_l_token_id_input,
|
457
|
+
"clip_l_negative_token_ids": clip_l_negative_token_id_input,
|
458
|
+
"clip_g_token_ids": clip_g_token_id_input,
|
459
|
+
"clip_g_negative_token_ids": clip_g_negative_token_id_input,
|
460
|
+
"num_steps": num_step_input,
|
461
|
+
"guidance_scale": guidance_scale_input,
|
462
|
+
}
|
463
|
+
if self.t5 is not None:
|
464
|
+
inputs["t5_token_ids"] = t5_token_id_input
|
465
|
+
inputs["t5_negative_token_ids"] = t5_negative_token_id_input
|
466
|
+
super().__init__(
|
467
|
+
inputs=inputs,
|
468
|
+
outputs=outputs,
|
469
|
+
dtype=dtype,
|
470
|
+
**kwargs,
|
471
|
+
)
|
472
|
+
|
473
|
+
# === Config ===
|
474
|
+
self.mmdit_patch_size = mmdit_patch_size
|
475
|
+
self.mmdit_hidden_dim = mmdit_hidden_dim
|
476
|
+
self.mmdit_num_layers = mmdit_num_layers
|
477
|
+
self.mmdit_num_heads = mmdit_num_heads
|
478
|
+
self.mmdit_position_size = mmdit_position_size
|
479
|
+
self.vae_stackwise_num_filters = vae_stackwise_num_filters
|
480
|
+
self.vae_stackwise_num_blocks = vae_stackwise_num_blocks
|
481
|
+
self.latent_channels = latent_channels
|
482
|
+
self.output_channels = output_channels
|
483
|
+
self.num_train_timesteps = num_train_timesteps
|
484
|
+
self.shift = shift
|
485
|
+
self.height = height
|
486
|
+
self.width = width
|
487
|
+
|
488
|
+
@property
|
489
|
+
def latent_shape(self):
|
490
|
+
return (None,) + tuple(self.diffuser.latent_shape)
|
491
|
+
|
492
|
+
@property
|
493
|
+
def clip_hidden_dim(self):
|
494
|
+
return self.clip_l.hidden_dim + self.clip_g.hidden_dim
|
495
|
+
|
496
|
+
@property
|
497
|
+
def t5_hidden_dim(self):
|
498
|
+
return 4096 if self.t5 is None else self.t5.hidden_dim
|
499
|
+
|
500
|
+
def encode_step(self, token_ids, negative_token_ids):
|
501
|
+
clip_hidden_dim = self.clip_hidden_dim
|
502
|
+
t5_hidden_dim = self.t5_hidden_dim
|
503
|
+
|
504
|
+
def encode(token_ids):
|
505
|
+
clip_l_outputs = self.clip_l(token_ids["clip_l"], training=False)
|
506
|
+
clip_g_outputs = self.clip_g(token_ids["clip_g"], training=False)
|
507
|
+
clip_l_projection = self.clip_l_projection(
|
508
|
+
clip_l_outputs["sequence_output"],
|
509
|
+
token_ids["clip_l"],
|
510
|
+
training=False,
|
511
|
+
)
|
512
|
+
clip_g_projection = self.clip_g_projection(
|
513
|
+
clip_g_outputs["sequence_output"],
|
514
|
+
token_ids["clip_g"],
|
515
|
+
training=False,
|
516
|
+
)
|
517
|
+
pooled_embeddings = ops.concatenate(
|
518
|
+
[clip_l_projection, clip_g_projection],
|
519
|
+
axis=-1,
|
520
|
+
)
|
521
|
+
embeddings = ops.concatenate(
|
522
|
+
[
|
523
|
+
clip_l_outputs["intermediate_output"],
|
524
|
+
clip_g_outputs["intermediate_output"],
|
525
|
+
],
|
526
|
+
axis=-1,
|
527
|
+
)
|
528
|
+
embeddings = ops.pad(
|
529
|
+
embeddings,
|
530
|
+
[[0, 0], [0, 0], [0, t5_hidden_dim - clip_hidden_dim]],
|
531
|
+
)
|
532
|
+
if self.t5 is not None:
|
533
|
+
t5_outputs = self.t5(token_ids["t5"], training=False)
|
534
|
+
embeddings = ops.concatenate([embeddings, t5_outputs], axis=-2)
|
535
|
+
else:
|
536
|
+
padded_size = self.clip_l.max_sequence_length
|
537
|
+
embeddings = ops.pad(
|
538
|
+
embeddings, [[0, 0], [0, padded_size], [0, 0]]
|
539
|
+
)
|
540
|
+
return embeddings, pooled_embeddings
|
541
|
+
|
542
|
+
positive_embeddings, positive_pooled_embeddings = encode(token_ids)
|
543
|
+
negative_embeddings, negative_pooled_embeddings = encode(
|
544
|
+
negative_token_ids
|
545
|
+
)
|
546
|
+
return (
|
547
|
+
positive_embeddings,
|
548
|
+
negative_embeddings,
|
549
|
+
positive_pooled_embeddings,
|
550
|
+
negative_pooled_embeddings,
|
551
|
+
)
|
552
|
+
|
553
|
+
def denoise_step(
|
554
|
+
self,
|
555
|
+
latents,
|
556
|
+
embeddings,
|
557
|
+
steps,
|
558
|
+
num_steps,
|
559
|
+
guidance_scale,
|
560
|
+
):
|
561
|
+
steps = ops.convert_to_tensor(steps)
|
562
|
+
steps_next = ops.add(steps, 1)
|
563
|
+
sigma, timestep = self.scheduler(steps, num_steps)
|
564
|
+
sigma_next, _ = self.scheduler(steps_next, num_steps)
|
565
|
+
|
566
|
+
# Concatenation for classifier-free guidance.
|
567
|
+
concated_latents, contexts, pooled_projs, timesteps = self.cfg_concat(
|
568
|
+
latents, *embeddings, timestep
|
569
|
+
)
|
570
|
+
|
571
|
+
# Diffusion.
|
572
|
+
predicted_noise = self.diffuser(
|
573
|
+
{
|
574
|
+
"latent": concated_latents,
|
575
|
+
"context": contexts,
|
576
|
+
"pooled_projection": pooled_projs,
|
577
|
+
"timestep": timesteps,
|
578
|
+
},
|
579
|
+
training=False,
|
580
|
+
)
|
581
|
+
|
582
|
+
# Classifier-free guidance.
|
583
|
+
predicted_noise = self.cfg(predicted_noise, guidance_scale)
|
584
|
+
|
585
|
+
# Euler step.
|
586
|
+
return self.euler_step(latents, predicted_noise, sigma, sigma_next)
|
587
|
+
|
588
|
+
def decode_step(self, latents):
|
589
|
+
latents = self.latent_space_decoder(latents)
|
590
|
+
return self.decoder(latents, training=False)
|
591
|
+
|
592
|
+
def get_config(self):
|
593
|
+
config = super().get_config()
|
594
|
+
config.update(
|
595
|
+
{
|
596
|
+
"mmdit_patch_size": self.mmdit_patch_size,
|
597
|
+
"mmdit_hidden_dim": self.mmdit_hidden_dim,
|
598
|
+
"mmdit_num_layers": self.mmdit_num_layers,
|
599
|
+
"mmdit_num_heads": self.mmdit_num_heads,
|
600
|
+
"mmdit_position_size": self.mmdit_position_size,
|
601
|
+
"vae_stackwise_num_filters": self.vae_stackwise_num_filters,
|
602
|
+
"vae_stackwise_num_blocks": self.vae_stackwise_num_blocks,
|
603
|
+
"clip_l": layers.serialize(self.clip_l),
|
604
|
+
"clip_g": layers.serialize(self.clip_g),
|
605
|
+
"t5": layers.serialize(self.t5),
|
606
|
+
"latent_channels": self.latent_channels,
|
607
|
+
"output_channels": self.output_channels,
|
608
|
+
"num_train_timesteps": self.num_train_timesteps,
|
609
|
+
"shift": self.shift,
|
610
|
+
"height": self.height,
|
611
|
+
"width": self.width,
|
612
|
+
}
|
613
|
+
)
|
614
|
+
return config
|
615
|
+
|
616
|
+
@classmethod
|
617
|
+
def from_config(cls, config, custom_objects=None):
|
618
|
+
# We expect `clip_l`, `clip_g` and/or `t5` to be instantiated.
|
619
|
+
config = config.copy()
|
620
|
+
config["clip_l"] = layers.deserialize(
|
621
|
+
config["clip_l"], custom_objects=custom_objects
|
622
|
+
)
|
623
|
+
config["clip_g"] = layers.deserialize(
|
624
|
+
config["clip_g"], custom_objects=custom_objects
|
625
|
+
)
|
626
|
+
if config["t5"] is not None:
|
627
|
+
config["t5"] = layers.deserialize(
|
628
|
+
config["t5"], custom_objects=custom_objects
|
629
|
+
)
|
630
|
+
return cls(**config)
|