keras-hub-nightly 0.16.1.dev202409250340__py3-none-any.whl → 0.16.1.dev202409260340__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. keras_hub/api/layers/__init__.py +3 -0
  2. keras_hub/api/models/__init__.py +16 -0
  3. keras_hub/api/tokenizers/__init__.py +1 -0
  4. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -2
  5. keras_hub/src/models/clip/clip_preprocessor.py +147 -0
  6. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_text_encoder.py +60 -57
  7. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +69 -30
  8. keras_hub/src/models/densenet/__init__.py +6 -0
  9. keras_hub/src/models/densenet/densenet_backbone.py +11 -8
  10. keras_hub/src/models/densenet/densenet_image_classifier.py +27 -4
  11. keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
  12. keras_hub/src/models/densenet/densenet_image_converter.py +23 -0
  13. keras_hub/src/models/densenet/densenet_presets.py +56 -0
  14. keras_hub/src/models/stable_diffusion_3/__init__.py +13 -0
  15. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +93 -0
  16. keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -26
  17. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +630 -0
  18. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +151 -0
  19. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +77 -0
  20. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -7
  21. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +333 -0
  22. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -3
  23. keras_hub/src/models/text_to_image.py +295 -0
  24. keras_hub/src/utils/timm/convert_densenet.py +107 -0
  25. keras_hub/src/utils/timm/preset_loader.py +3 -0
  26. keras_hub/src/version_utils.py +1 -1
  27. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/METADATA +1 -1
  28. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/RECORD +31 -23
  29. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
  30. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
  31. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
  32. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
  33. /keras_hub/src/models/{stable_diffusion_v3 → clip}/__init__.py +0 -0
  34. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/WHEEL +0 -0
  35. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,630 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+ from keras import layers
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.models.backbone import Backbone
20
+ from keras_hub.src.models.stable_diffusion_3.flow_match_euler_discrete_scheduler import (
21
+ FlowMatchEulerDiscreteScheduler,
22
+ )
23
+ from keras_hub.src.models.stable_diffusion_3.mmdit import MMDiT
24
+ from keras_hub.src.models.stable_diffusion_3.vae_image_decoder import (
25
+ VAEImageDecoder,
26
+ )
27
+ from keras_hub.src.utils.keras_utils import standardize_data_format
28
+
29
+
30
+ class CLIPProjection(layers.Layer):
31
+ def __init__(self, hidden_dim, **kwargs):
32
+ super().__init__(**kwargs)
33
+ self.hidden_dim = int(hidden_dim)
34
+
35
+ self.dense = layers.Dense(
36
+ hidden_dim,
37
+ use_bias=False,
38
+ dtype=self.dtype_policy,
39
+ name="dense",
40
+ )
41
+
42
+ def build(self, inputs_shape, token_ids_shape):
43
+ inputs_shape = list(inputs_shape)
44
+ self.dense.build([None, inputs_shape[-1]])
45
+
46
+ # Assign identity matrix to the kernel as default.
47
+ self.dense._kernel.assign(ops.eye(self.hidden_dim))
48
+
49
+ def call(self, inputs, token_ids):
50
+ indices = ops.expand_dims(
51
+ ops.cast(ops.argmax(token_ids, axis=-1), "int32"), axis=-1
52
+ )
53
+ pooled_output = ops.take_along_axis(inputs, indices[:, :, None], axis=1)
54
+ pooled_output = ops.squeeze(pooled_output, axis=1)
55
+ return self.dense(pooled_output)
56
+
57
+ def get_config(self):
58
+ config = super().get_config()
59
+ config.update(
60
+ {
61
+ "hidden_dim": self.hidden_dim,
62
+ }
63
+ )
64
+ return config
65
+
66
+ def compute_output_shape(self, inputs_shape):
67
+ return (inputs_shape[0], self.hidden_dim)
68
+
69
+
70
+ class ClassifierFreeGuidanceConcatenate(layers.Layer):
71
+ def __init__(self, axis=0, **kwargs):
72
+ super().__init__(**kwargs)
73
+ self.axis = axis
74
+
75
+ def call(
76
+ self,
77
+ latents,
78
+ positive_contexts,
79
+ negative_contexts,
80
+ positive_pooled_projections,
81
+ negative_pooled_projections,
82
+ timestep,
83
+ ):
84
+ timestep = ops.broadcast_to(timestep, ops.shape(latents)[:1])
85
+ latents = ops.concatenate([latents, latents], axis=self.axis)
86
+ contexts = ops.concatenate(
87
+ [positive_contexts, negative_contexts], axis=self.axis
88
+ )
89
+ pooled_projections = ops.concatenate(
90
+ [positive_pooled_projections, negative_pooled_projections],
91
+ axis=self.axis,
92
+ )
93
+ timesteps = ops.concatenate([timestep, timestep], axis=self.axis)
94
+ return latents, contexts, pooled_projections, timesteps
95
+
96
+ def get_config(self):
97
+ return super().get_config()
98
+
99
+
100
+ class ClassifierFreeGuidance(layers.Layer):
101
+ """Perform classifier free guidance.
102
+
103
+ This layer expects the inputs to be a concatenation of positive and negative
104
+ (or empty) noise. The computation applies the classifier-free guidance
105
+ scale.
106
+
107
+ Args:
108
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
109
+ including `name`, `dtype` etc.
110
+
111
+ Call arguments:
112
+ inputs: A concatenation of positive and negative (or empty) noises.
113
+ guidance_scale: The scale factor for classifier-free guidance.
114
+
115
+ Reference:
116
+ - [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
117
+ """
118
+
119
+ def __init__(self, **kwargs):
120
+ super().__init__(**kwargs)
121
+
122
+ def call(self, inputs, guidance_scale):
123
+ positive_noise, negative_noise = ops.split(inputs, 2, axis=0)
124
+ return ops.add(
125
+ negative_noise,
126
+ ops.multiply(
127
+ guidance_scale, ops.subtract(positive_noise, negative_noise)
128
+ ),
129
+ )
130
+
131
+ def get_config(self):
132
+ return super().get_config()
133
+
134
+ def compute_output_shape(self, inputs_shape):
135
+ outputs_shape = list(inputs_shape)
136
+ if outputs_shape[0] is not None:
137
+ outputs_shape[0] = outputs_shape[0] // 2
138
+ return outputs_shape
139
+
140
+
141
+ class EulerStep(layers.Layer):
142
+ """A layer predicts the sample with the timestep and the predicted noise.
143
+
144
+ Args:
145
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
146
+ including `name`, `dtype` etc.
147
+
148
+ Call arguments:
149
+ latents: A current sample created by the diffusion process.
150
+ noise_residual: The direct output from the diffusion model.
151
+ sigma: The amount of noise added at the current timestep.
152
+ sigma_next: The amount of noise added at the next timestep.
153
+
154
+ References:
155
+ - [Common Diffusion Noise Schedules and Sample Steps are Flawed](
156
+ https://arxiv.org/abs/2305.08891).
157
+ - [Elucidating the Design Space of Diffusion-Based Generative Models](
158
+ https://arxiv.org/abs/2206.00364).
159
+ """
160
+
161
+ def __init__(self, **kwargs):
162
+ super().__init__(**kwargs)
163
+
164
+ def call(self, latents, noise_residual, sigma, sigma_next):
165
+ sigma_diff = ops.subtract(sigma_next, sigma)
166
+ return ops.add(latents, ops.multiply(sigma_diff, noise_residual))
167
+
168
+ def get_config(self):
169
+ return super().get_config()
170
+
171
+ def compute_output_shape(self, latents_shape):
172
+ return latents_shape
173
+
174
+
175
+ class LatentSpaceDecoder(layers.Layer):
176
+ """Decoder to transform the latent space back to the original image space.
177
+
178
+ During decoding, the latents are transformed back to the original image
179
+ space using the equation: `latents / scale + shift`.
180
+
181
+ Args:
182
+ scale: float. The scaling factor.
183
+ shift: float. The shift factor.
184
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
185
+ including `name`, `dtype` etc.
186
+
187
+ Call arguments:
188
+ latents: The latent tensor to be transformed.
189
+
190
+ Reference:
191
+ - [High-Resolution Image Synthesis with Latent Diffusion Models](
192
+ https://arxiv.org/abs/2112.10752).
193
+ """
194
+
195
+ def __init__(self, scale, shift, **kwargs):
196
+ super().__init__(**kwargs)
197
+ self.scale = scale
198
+ self.shift = shift
199
+
200
+ def call(self, latents):
201
+ return ops.add(ops.divide(latents, self.scale), self.shift)
202
+
203
+ def get_config(self):
204
+ config = super().get_config()
205
+ config.update(
206
+ {
207
+ "scale": self.scale,
208
+ "shift": self.shift,
209
+ }
210
+ )
211
+ return config
212
+
213
+ def compute_output_shape(self, latents_shape):
214
+ return latents_shape
215
+
216
+
217
+ @keras_hub_export("keras_hub.models.StableDiffusion3Backbone")
218
+ class StableDiffusion3Backbone(Backbone):
219
+ """Stable Diffusion 3 core network with hyperparameters.
220
+
221
+ This backbone imports CLIP and T5 models as text encoders and implements the
222
+ base MMDiT and VAE networks for the Stable Diffusion 3 model.
223
+
224
+ The default constructor gives a fully customizable, randomly initialized
225
+ MMDiT and VAE models with any hyperparameters. To load preset architectures
226
+ and weights, use the `from_preset` constructor.
227
+
228
+ Args:
229
+ mmdit_patch_size: int. The size of each square patch in the input image
230
+ in MMDiT.
231
+ mmdit_hidden_dim: int. The size of the transformer hidden state at the
232
+ end of each transformer layer in MMDiT.
233
+ mmdit_num_layers: int. The number of transformer layers in MMDiT.
234
+ mmdit_num_heads: int. The number of attention heads for each
235
+ transformer in MMDiT.
236
+ mmdit_position_size: int. The size of the height and width for the
237
+ position embedding in MMDiT.
238
+ vae_stackwise_num_filters: list of ints. The number of filters for each
239
+ stack in VAE.
240
+ vae_stackwise_num_blocks: list of ints. The number of blocks for each
241
+ stack in VAE.
242
+ clip_l: `keras_hub.models.CLIPTextEncoder`. The text encoder for
243
+ encoding the inputs.
244
+ clip_g: `keras_hub.models.CLIPTextEncoder`. The text encoder for
245
+ encoding the inputs.
246
+ t5: optional `keras_hub.models.T5Encoder`. The text encoder for
247
+ encoding the inputs.
248
+ latent_channels: int. The number of channels in the latent. Defaults to
249
+ `16`.
250
+ output_channels: int. The number of channels in the output. Defaults to
251
+ `3`.
252
+ num_train_timesteps: int. The number of diffusion steps to train the
253
+ model. Defaults to `1000`.
254
+ shift: float. The shift value for the timestep schedule. Defaults to
255
+ `1.0`.
256
+ height: optional int. The output height of the image.
257
+ width: optional int. The output width of the image.
258
+ data_format: `None` or str. If specified, either `"channels_last"` or
259
+ `"channels_first"`. The ordering of the dimensions in the
260
+ inputs. `"channels_last"` corresponds to inputs with shape
261
+ `(batch_size, height, width, channels)`
262
+ while `"channels_first"` corresponds to inputs with shape
263
+ `(batch_size, channels, height, width)`. It defaults to the
264
+ `image_data_format` value found in your Keras config file at
265
+ `~/.keras/keras.json`. If you never set it, then it will be
266
+ `"channels_last"`.
267
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
268
+ for the models computations and weights. Note that some
269
+ computations, such as softmax and layer normalization will always
270
+ be done a float32 precision regardless of dtype.
271
+
272
+ Example:
273
+ ```python
274
+ # Pretrained Stable Diffusion 3 model.
275
+ model = keras_hub.models.StableDiffusion3Backbone.from_preset(
276
+ "stable_diffusion_3_medium"
277
+ )
278
+
279
+ # Randomly initialized Stable Diffusion 3 model with custom config.
280
+ clip_l = keras_hub.models.CLIPTextEncoder(...)
281
+ clip_g = keras_hub.models.CLIPTextEncoder(...)
282
+ model = keras_hub.models.StableDiffusion3Backbone(
283
+ mmdit_patch_size=2,
284
+ mmdit_num_heads=4,
285
+ mmdit_hidden_dim=256,
286
+ mmdit_depth=4,
287
+ mmdit_position_size=192,
288
+ vae_stackwise_num_filters=[128, 128, 64, 32],
289
+ vae_stackwise_num_blocks=[1, 1, 1, 1],
290
+ clip_l=clip_l,
291
+ clip_g=clip_g,
292
+ )
293
+ ```
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ mmdit_patch_size,
299
+ mmdit_hidden_dim,
300
+ mmdit_num_layers,
301
+ mmdit_num_heads,
302
+ mmdit_position_size,
303
+ vae_stackwise_num_filters,
304
+ vae_stackwise_num_blocks,
305
+ clip_l,
306
+ clip_g,
307
+ t5=None,
308
+ latent_channels=16,
309
+ output_channels=3,
310
+ num_train_timesteps=1000,
311
+ shift=1.0,
312
+ height=None,
313
+ width=None,
314
+ data_format=None,
315
+ dtype=None,
316
+ **kwargs,
317
+ ):
318
+ height = int(height or 1024)
319
+ width = int(width or 1024)
320
+ if height % 8 != 0 or width % 8 != 0:
321
+ raise ValueError(
322
+ "`height` and `width` must be divisible by 8. "
323
+ f"Received: height={height}, width={width}"
324
+ )
325
+ data_format = standardize_data_format(data_format)
326
+ if data_format != "channels_last":
327
+ raise NotImplementedError
328
+ latent_shape = (height // 8, width // 8, latent_channels)
329
+ context_shape = (None, 4096 if t5 is None else t5.hidden_dim)
330
+ pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,)
331
+
332
+ # === Layers ===
333
+ self.clip_l = clip_l
334
+ self.clip_l_projection = CLIPProjection(
335
+ clip_l.hidden_dim, dtype=dtype, name="clip_l_projection"
336
+ )
337
+ self.clip_l_projection.build([None, clip_l.hidden_dim], None)
338
+ self.clip_g = clip_g
339
+ self.clip_g_projection = CLIPProjection(
340
+ clip_g.hidden_dim, dtype=dtype, name="clip_g_projection"
341
+ )
342
+ self.clip_g_projection.build([None, clip_g.hidden_dim], None)
343
+ self.t5 = t5
344
+ self.diffuser = MMDiT(
345
+ mmdit_patch_size,
346
+ mmdit_hidden_dim,
347
+ mmdit_num_layers,
348
+ mmdit_num_heads,
349
+ mmdit_position_size,
350
+ latent_shape=latent_shape,
351
+ context_shape=context_shape,
352
+ pooled_projection_shape=pooled_projection_shape,
353
+ data_format=data_format,
354
+ dtype=dtype,
355
+ name="diffuser",
356
+ )
357
+ self.decoder = VAEImageDecoder(
358
+ vae_stackwise_num_filters,
359
+ vae_stackwise_num_blocks,
360
+ output_channels,
361
+ latent_shape=latent_shape,
362
+ data_format=data_format,
363
+ dtype=dtype,
364
+ name="decoder",
365
+ )
366
+ # Set `dtype="float32"` to ensure the high precision for the noise
367
+ # residual.
368
+ self.scheduler = FlowMatchEulerDiscreteScheduler(
369
+ num_train_timesteps=num_train_timesteps,
370
+ shift=shift,
371
+ dtype="float32",
372
+ name="scheduler",
373
+ )
374
+ self.cfg_concat = ClassifierFreeGuidanceConcatenate(
375
+ dtype="float32", name="classifier_free_guidance_concat"
376
+ )
377
+ self.cfg = ClassifierFreeGuidance(
378
+ dtype="float32", name="classifier_free_guidance"
379
+ )
380
+ self.euler_step = EulerStep(dtype="float32", name="euler_step")
381
+ self.latent_space_decoder = LatentSpaceDecoder(
382
+ scale=self.decoder.scaling_factor,
383
+ shift=self.decoder.shift_factor,
384
+ dtype="float32",
385
+ name="latent_space_decoder",
386
+ )
387
+
388
+ # === Functional Model ===
389
+ latent_input = keras.Input(
390
+ shape=latent_shape,
391
+ name="latents",
392
+ )
393
+ clip_l_token_id_input = keras.Input(
394
+ shape=(None,),
395
+ dtype="int32",
396
+ name="clip_l_token_ids",
397
+ )
398
+ clip_l_negative_token_id_input = keras.Input(
399
+ shape=(None,),
400
+ dtype="int32",
401
+ name="clip_l_negative_token_ids",
402
+ )
403
+ clip_g_token_id_input = keras.Input(
404
+ shape=(None,),
405
+ dtype="int32",
406
+ name="clip_g_token_ids",
407
+ )
408
+ clip_g_negative_token_id_input = keras.Input(
409
+ shape=(None,),
410
+ dtype="int32",
411
+ name="clip_g_negative_token_ids",
412
+ )
413
+ token_ids = {
414
+ "clip_l": clip_l_token_id_input,
415
+ "clip_g": clip_g_token_id_input,
416
+ }
417
+ negative_token_ids = {
418
+ "clip_l": clip_l_negative_token_id_input,
419
+ "clip_g": clip_g_negative_token_id_input,
420
+ }
421
+ if self.t5 is not None:
422
+ t5_token_id_input = keras.Input(
423
+ shape=(None,),
424
+ dtype="int32",
425
+ name="t5_token_ids",
426
+ )
427
+ t5_negative_token_id_input = keras.Input(
428
+ shape=(None,),
429
+ dtype="int32",
430
+ name="t5_negative_token_ids",
431
+ )
432
+ token_ids["t5"] = t5_token_id_input
433
+ negative_token_ids["t5"] = t5_negative_token_id_input
434
+ num_step_input = keras.Input(
435
+ shape=(),
436
+ dtype="int32",
437
+ name="num_steps",
438
+ )
439
+ guidance_scale_input = keras.Input(
440
+ shape=(),
441
+ dtype="float32",
442
+ name="guidance_scale",
443
+ )
444
+ embeddings = self.encode_step(token_ids, negative_token_ids)
445
+ # Use `steps=0` to define the functional model.
446
+ latents = self.denoise_step(
447
+ latent_input,
448
+ embeddings,
449
+ 0,
450
+ num_step_input[0],
451
+ guidance_scale_input[0],
452
+ )
453
+ outputs = self.decode_step(latents)
454
+ inputs = {
455
+ "latents": latent_input,
456
+ "clip_l_token_ids": clip_l_token_id_input,
457
+ "clip_l_negative_token_ids": clip_l_negative_token_id_input,
458
+ "clip_g_token_ids": clip_g_token_id_input,
459
+ "clip_g_negative_token_ids": clip_g_negative_token_id_input,
460
+ "num_steps": num_step_input,
461
+ "guidance_scale": guidance_scale_input,
462
+ }
463
+ if self.t5 is not None:
464
+ inputs["t5_token_ids"] = t5_token_id_input
465
+ inputs["t5_negative_token_ids"] = t5_negative_token_id_input
466
+ super().__init__(
467
+ inputs=inputs,
468
+ outputs=outputs,
469
+ dtype=dtype,
470
+ **kwargs,
471
+ )
472
+
473
+ # === Config ===
474
+ self.mmdit_patch_size = mmdit_patch_size
475
+ self.mmdit_hidden_dim = mmdit_hidden_dim
476
+ self.mmdit_num_layers = mmdit_num_layers
477
+ self.mmdit_num_heads = mmdit_num_heads
478
+ self.mmdit_position_size = mmdit_position_size
479
+ self.vae_stackwise_num_filters = vae_stackwise_num_filters
480
+ self.vae_stackwise_num_blocks = vae_stackwise_num_blocks
481
+ self.latent_channels = latent_channels
482
+ self.output_channels = output_channels
483
+ self.num_train_timesteps = num_train_timesteps
484
+ self.shift = shift
485
+ self.height = height
486
+ self.width = width
487
+
488
+ @property
489
+ def latent_shape(self):
490
+ return (None,) + tuple(self.diffuser.latent_shape)
491
+
492
+ @property
493
+ def clip_hidden_dim(self):
494
+ return self.clip_l.hidden_dim + self.clip_g.hidden_dim
495
+
496
+ @property
497
+ def t5_hidden_dim(self):
498
+ return 4096 if self.t5 is None else self.t5.hidden_dim
499
+
500
+ def encode_step(self, token_ids, negative_token_ids):
501
+ clip_hidden_dim = self.clip_hidden_dim
502
+ t5_hidden_dim = self.t5_hidden_dim
503
+
504
+ def encode(token_ids):
505
+ clip_l_outputs = self.clip_l(token_ids["clip_l"], training=False)
506
+ clip_g_outputs = self.clip_g(token_ids["clip_g"], training=False)
507
+ clip_l_projection = self.clip_l_projection(
508
+ clip_l_outputs["sequence_output"],
509
+ token_ids["clip_l"],
510
+ training=False,
511
+ )
512
+ clip_g_projection = self.clip_g_projection(
513
+ clip_g_outputs["sequence_output"],
514
+ token_ids["clip_g"],
515
+ training=False,
516
+ )
517
+ pooled_embeddings = ops.concatenate(
518
+ [clip_l_projection, clip_g_projection],
519
+ axis=-1,
520
+ )
521
+ embeddings = ops.concatenate(
522
+ [
523
+ clip_l_outputs["intermediate_output"],
524
+ clip_g_outputs["intermediate_output"],
525
+ ],
526
+ axis=-1,
527
+ )
528
+ embeddings = ops.pad(
529
+ embeddings,
530
+ [[0, 0], [0, 0], [0, t5_hidden_dim - clip_hidden_dim]],
531
+ )
532
+ if self.t5 is not None:
533
+ t5_outputs = self.t5(token_ids["t5"], training=False)
534
+ embeddings = ops.concatenate([embeddings, t5_outputs], axis=-2)
535
+ else:
536
+ padded_size = self.clip_l.max_sequence_length
537
+ embeddings = ops.pad(
538
+ embeddings, [[0, 0], [0, padded_size], [0, 0]]
539
+ )
540
+ return embeddings, pooled_embeddings
541
+
542
+ positive_embeddings, positive_pooled_embeddings = encode(token_ids)
543
+ negative_embeddings, negative_pooled_embeddings = encode(
544
+ negative_token_ids
545
+ )
546
+ return (
547
+ positive_embeddings,
548
+ negative_embeddings,
549
+ positive_pooled_embeddings,
550
+ negative_pooled_embeddings,
551
+ )
552
+
553
+ def denoise_step(
554
+ self,
555
+ latents,
556
+ embeddings,
557
+ steps,
558
+ num_steps,
559
+ guidance_scale,
560
+ ):
561
+ steps = ops.convert_to_tensor(steps)
562
+ steps_next = ops.add(steps, 1)
563
+ sigma, timestep = self.scheduler(steps, num_steps)
564
+ sigma_next, _ = self.scheduler(steps_next, num_steps)
565
+
566
+ # Concatenation for classifier-free guidance.
567
+ concated_latents, contexts, pooled_projs, timesteps = self.cfg_concat(
568
+ latents, *embeddings, timestep
569
+ )
570
+
571
+ # Diffusion.
572
+ predicted_noise = self.diffuser(
573
+ {
574
+ "latent": concated_latents,
575
+ "context": contexts,
576
+ "pooled_projection": pooled_projs,
577
+ "timestep": timesteps,
578
+ },
579
+ training=False,
580
+ )
581
+
582
+ # Classifier-free guidance.
583
+ predicted_noise = self.cfg(predicted_noise, guidance_scale)
584
+
585
+ # Euler step.
586
+ return self.euler_step(latents, predicted_noise, sigma, sigma_next)
587
+
588
+ def decode_step(self, latents):
589
+ latents = self.latent_space_decoder(latents)
590
+ return self.decoder(latents, training=False)
591
+
592
+ def get_config(self):
593
+ config = super().get_config()
594
+ config.update(
595
+ {
596
+ "mmdit_patch_size": self.mmdit_patch_size,
597
+ "mmdit_hidden_dim": self.mmdit_hidden_dim,
598
+ "mmdit_num_layers": self.mmdit_num_layers,
599
+ "mmdit_num_heads": self.mmdit_num_heads,
600
+ "mmdit_position_size": self.mmdit_position_size,
601
+ "vae_stackwise_num_filters": self.vae_stackwise_num_filters,
602
+ "vae_stackwise_num_blocks": self.vae_stackwise_num_blocks,
603
+ "clip_l": layers.serialize(self.clip_l),
604
+ "clip_g": layers.serialize(self.clip_g),
605
+ "t5": layers.serialize(self.t5),
606
+ "latent_channels": self.latent_channels,
607
+ "output_channels": self.output_channels,
608
+ "num_train_timesteps": self.num_train_timesteps,
609
+ "shift": self.shift,
610
+ "height": self.height,
611
+ "width": self.width,
612
+ }
613
+ )
614
+ return config
615
+
616
+ @classmethod
617
+ def from_config(cls, config, custom_objects=None):
618
+ # We expect `clip_l`, `clip_g` and/or `t5` to be instantiated.
619
+ config = config.copy()
620
+ config["clip_l"] = layers.deserialize(
621
+ config["clip_l"], custom_objects=custom_objects
622
+ )
623
+ config["clip_g"] = layers.deserialize(
624
+ config["clip_g"], custom_objects=custom_objects
625
+ )
626
+ if config["t5"] is not None:
627
+ config["t5"] = layers.deserialize(
628
+ config["t5"], custom_objects=custom_objects
629
+ )
630
+ return cls(**config)