keras-hub-nightly 0.16.1.dev202410030339__py3-none-any.whl → 0.16.1.dev202410050339__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -0
- keras_hub/api/models/__init__.py +9 -0
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +196 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +109 -0
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +57 -93
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +5 -3
- keras_hub/src/models/task.py +20 -15
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +172 -0
- keras_hub/src/models/vae/vae_layers.py +740 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/RECORD +23 -14
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/top_level.txt +0 -0
@@ -8,9 +8,6 @@ from keras_hub.src.models.stable_diffusion_3.flow_match_euler_discrete_scheduler
|
|
8
8
|
FlowMatchEulerDiscreteScheduler,
|
9
9
|
)
|
10
10
|
from keras_hub.src.models.stable_diffusion_3.mmdit import MMDiT
|
11
|
-
from keras_hub.src.models.stable_diffusion_3.vae_image_decoder import (
|
12
|
-
VAEImageDecoder,
|
13
|
-
)
|
14
11
|
from keras_hub.src.utils.keras_utils import standardize_data_format
|
15
12
|
|
16
13
|
|
@@ -159,48 +156,6 @@ class EulerStep(layers.Layer):
|
|
159
156
|
return latents_shape
|
160
157
|
|
161
158
|
|
162
|
-
class LatentSpaceDecoder(layers.Layer):
|
163
|
-
"""Decoder to transform the latent space back to the original image space.
|
164
|
-
|
165
|
-
During decoding, the latents are transformed back to the original image
|
166
|
-
space using the equation: `latents / scale + shift`.
|
167
|
-
|
168
|
-
Args:
|
169
|
-
scale: float. The scaling factor.
|
170
|
-
shift: float. The shift factor.
|
171
|
-
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
172
|
-
including `name`, `dtype` etc.
|
173
|
-
|
174
|
-
Call arguments:
|
175
|
-
latents: The latent tensor to be transformed.
|
176
|
-
|
177
|
-
Reference:
|
178
|
-
- [High-Resolution Image Synthesis with Latent Diffusion Models](
|
179
|
-
https://arxiv.org/abs/2112.10752).
|
180
|
-
"""
|
181
|
-
|
182
|
-
def __init__(self, scale, shift, **kwargs):
|
183
|
-
super().__init__(**kwargs)
|
184
|
-
self.scale = scale
|
185
|
-
self.shift = shift
|
186
|
-
|
187
|
-
def call(self, latents):
|
188
|
-
return ops.add(ops.divide(latents, self.scale), self.shift)
|
189
|
-
|
190
|
-
def get_config(self):
|
191
|
-
config = super().get_config()
|
192
|
-
config.update(
|
193
|
-
{
|
194
|
-
"scale": self.scale,
|
195
|
-
"shift": self.shift,
|
196
|
-
}
|
197
|
-
)
|
198
|
-
return config
|
199
|
-
|
200
|
-
def compute_output_shape(self, latents_shape):
|
201
|
-
return latents_shape
|
202
|
-
|
203
|
-
|
204
159
|
@keras_hub_export("keras_hub.models.StableDiffusion3Backbone")
|
205
160
|
class StableDiffusion3Backbone(Backbone):
|
206
161
|
"""Stable Diffusion 3 core network with hyperparameters.
|
@@ -222,16 +177,11 @@ class StableDiffusion3Backbone(Backbone):
|
|
222
177
|
transformer in MMDiT.
|
223
178
|
mmdit_position_size: int. The size of the height and width for the
|
224
179
|
position embedding in MMDiT.
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
encoding the inputs.
|
231
|
-
clip_g: `keras_hub.models.CLIPTextEncoder`. The text encoder for
|
232
|
-
encoding the inputs.
|
233
|
-
t5: optional `keras_hub.models.T5Encoder`. The text encoder for
|
234
|
-
encoding the inputs.
|
180
|
+
vae: The VAE used for transformations between pixel space and latent
|
181
|
+
space.
|
182
|
+
clip_l: The CLIP text encoder for encoding the inputs.
|
183
|
+
clip_g: The CLIP text encoder for encoding the inputs.
|
184
|
+
t5: optional The T5 text encoder for encoding the inputs.
|
235
185
|
latent_channels: int. The number of channels in the latent. Defaults to
|
236
186
|
`16`.
|
237
187
|
output_channels: int. The number of channels in the output. Defaults to
|
@@ -239,7 +189,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
239
189
|
num_train_timesteps: int. The number of diffusion steps to train the
|
240
190
|
model. Defaults to `1000`.
|
241
191
|
shift: float. The shift value for the timestep schedule. Defaults to
|
242
|
-
`
|
192
|
+
`3.0`.
|
243
193
|
height: optional int. The output height of the image.
|
244
194
|
width: optional int. The output width of the image.
|
245
195
|
data_format: `None` or str. If specified, either `"channels_last"` or
|
@@ -264,6 +214,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
264
214
|
)
|
265
215
|
|
266
216
|
# Randomly initialized Stable Diffusion 3 model with custom config.
|
217
|
+
vae = keras_hub.models.VAEBackbone(...)
|
267
218
|
clip_l = keras_hub.models.CLIPTextEncoder(...)
|
268
219
|
clip_g = keras_hub.models.CLIPTextEncoder(...)
|
269
220
|
model = keras_hub.models.StableDiffusion3Backbone(
|
@@ -272,8 +223,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
272
223
|
mmdit_hidden_dim=256,
|
273
224
|
mmdit_depth=4,
|
274
225
|
mmdit_position_size=192,
|
275
|
-
|
276
|
-
vae_stackwise_num_blocks=[1, 1, 1, 1],
|
226
|
+
vae=vae,
|
277
227
|
clip_l=clip_l,
|
278
228
|
clip_g=clip_g,
|
279
229
|
)
|
@@ -287,15 +237,14 @@ class StableDiffusion3Backbone(Backbone):
|
|
287
237
|
mmdit_num_layers,
|
288
238
|
mmdit_num_heads,
|
289
239
|
mmdit_position_size,
|
290
|
-
|
291
|
-
vae_stackwise_num_blocks,
|
240
|
+
vae,
|
292
241
|
clip_l,
|
293
242
|
clip_g,
|
294
243
|
t5=None,
|
295
244
|
latent_channels=16,
|
296
245
|
output_channels=3,
|
297
246
|
num_train_timesteps=1000,
|
298
|
-
shift=
|
247
|
+
shift=3.0,
|
299
248
|
height=None,
|
300
249
|
width=None,
|
301
250
|
data_format=None,
|
@@ -312,9 +261,11 @@ class StableDiffusion3Backbone(Backbone):
|
|
312
261
|
data_format = standardize_data_format(data_format)
|
313
262
|
if data_format != "channels_last":
|
314
263
|
raise NotImplementedError
|
315
|
-
|
264
|
+
image_shape = (height, width, int(vae.input_channels))
|
265
|
+
latent_shape = (height // 8, width // 8, int(latent_channels))
|
316
266
|
context_shape = (None, 4096 if t5 is None else t5.hidden_dim)
|
317
267
|
pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,)
|
268
|
+
self._latent_shape = latent_shape
|
318
269
|
|
319
270
|
# === Layers ===
|
320
271
|
self.clip_l = clip_l
|
@@ -341,15 +292,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
341
292
|
dtype=dtype,
|
342
293
|
name="diffuser",
|
343
294
|
)
|
344
|
-
self.
|
345
|
-
vae_stackwise_num_filters,
|
346
|
-
vae_stackwise_num_blocks,
|
347
|
-
output_channels,
|
348
|
-
latent_shape=latent_shape,
|
349
|
-
data_format=data_format,
|
350
|
-
dtype=dtype,
|
351
|
-
name="decoder",
|
352
|
-
)
|
295
|
+
self.vae = vae
|
353
296
|
# Set `dtype="float32"` to ensure the high precision for the noise
|
354
297
|
# residual.
|
355
298
|
self.scheduler = FlowMatchEulerDiscreteScheduler(
|
@@ -365,14 +308,18 @@ class StableDiffusion3Backbone(Backbone):
|
|
365
308
|
dtype="float32", name="classifier_free_guidance"
|
366
309
|
)
|
367
310
|
self.euler_step = EulerStep(dtype="float32", name="euler_step")
|
368
|
-
self.
|
369
|
-
scale=self.
|
370
|
-
|
311
|
+
self.latent_rescaling = layers.Rescaling(
|
312
|
+
scale=1.0 / self.vae.scale,
|
313
|
+
offset=self.vae.shift,
|
371
314
|
dtype="float32",
|
372
|
-
name="
|
315
|
+
name="latent_rescaling",
|
373
316
|
)
|
374
317
|
|
375
318
|
# === Functional Model ===
|
319
|
+
image_input = keras.Input(
|
320
|
+
shape=image_shape,
|
321
|
+
name="images",
|
322
|
+
)
|
376
323
|
latent_input = keras.Input(
|
377
324
|
shape=latent_shape,
|
378
325
|
name="latents",
|
@@ -428,17 +375,19 @@ class StableDiffusion3Backbone(Backbone):
|
|
428
375
|
dtype="float32",
|
429
376
|
name="guidance_scale",
|
430
377
|
)
|
431
|
-
embeddings = self.
|
378
|
+
embeddings = self.encode_text_step(token_ids, negative_token_ids)
|
379
|
+
latents = self.encode_image_step(image_input)
|
432
380
|
# Use `steps=0` to define the functional model.
|
433
|
-
|
381
|
+
denoised_latents = self.denoise_step(
|
434
382
|
latent_input,
|
435
383
|
embeddings,
|
436
384
|
0,
|
437
385
|
num_step_input[0],
|
438
386
|
guidance_scale_input[0],
|
439
387
|
)
|
440
|
-
|
388
|
+
images = self.decode_step(denoised_latents)
|
441
389
|
inputs = {
|
390
|
+
"images": image_input,
|
442
391
|
"latents": latent_input,
|
443
392
|
"clip_l_token_ids": clip_l_token_id_input,
|
444
393
|
"clip_l_negative_token_ids": clip_l_negative_token_id_input,
|
@@ -447,6 +396,10 @@ class StableDiffusion3Backbone(Backbone):
|
|
447
396
|
"num_steps": num_step_input,
|
448
397
|
"guidance_scale": guidance_scale_input,
|
449
398
|
}
|
399
|
+
outputs = {
|
400
|
+
"latents": latents,
|
401
|
+
"images": images,
|
402
|
+
}
|
450
403
|
if self.t5 is not None:
|
451
404
|
inputs["t5_token_ids"] = t5_token_id_input
|
452
405
|
inputs["t5_negative_token_ids"] = t5_negative_token_id_input
|
@@ -463,8 +416,6 @@ class StableDiffusion3Backbone(Backbone):
|
|
463
416
|
self.mmdit_num_layers = mmdit_num_layers
|
464
417
|
self.mmdit_num_heads = mmdit_num_heads
|
465
418
|
self.mmdit_position_size = mmdit_position_size
|
466
|
-
self.vae_stackwise_num_filters = vae_stackwise_num_filters
|
467
|
-
self.vae_stackwise_num_blocks = vae_stackwise_num_blocks
|
468
419
|
self.latent_channels = latent_channels
|
469
420
|
self.output_channels = output_channels
|
470
421
|
self.num_train_timesteps = num_train_timesteps
|
@@ -474,7 +425,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
474
425
|
|
475
426
|
@property
|
476
427
|
def latent_shape(self):
|
477
|
-
return (None,) +
|
428
|
+
return (None,) + self._latent_shape
|
478
429
|
|
479
430
|
@property
|
480
431
|
def clip_hidden_dim(self):
|
@@ -484,7 +435,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
484
435
|
def t5_hidden_dim(self):
|
485
436
|
return 4096 if self.t5 is None else self.t5.hidden_dim
|
486
437
|
|
487
|
-
def
|
438
|
+
def encode_text_step(self, token_ids, negative_token_ids):
|
488
439
|
clip_hidden_dim = self.clip_hidden_dim
|
489
440
|
t5_hidden_dim = self.t5_hidden_dim
|
490
441
|
|
@@ -537,18 +488,27 @@ class StableDiffusion3Backbone(Backbone):
|
|
537
488
|
negative_pooled_embeddings,
|
538
489
|
)
|
539
490
|
|
491
|
+
def encode_image_step(self, images):
|
492
|
+
latents = self.vae.encode(images)
|
493
|
+
return ops.multiply(
|
494
|
+
ops.subtract(latents, self.vae.shift), self.vae.scale
|
495
|
+
)
|
496
|
+
|
497
|
+
def add_noise_step(self, latents, noises, step, num_steps):
|
498
|
+
return self.scheduler.add_noise(latents, noises, step, num_steps)
|
499
|
+
|
540
500
|
def denoise_step(
|
541
501
|
self,
|
542
502
|
latents,
|
543
503
|
embeddings,
|
544
|
-
|
504
|
+
step,
|
545
505
|
num_steps,
|
546
506
|
guidance_scale,
|
547
507
|
):
|
548
|
-
|
549
|
-
|
550
|
-
sigma, timestep = self.scheduler(
|
551
|
-
|
508
|
+
step = ops.convert_to_tensor(step)
|
509
|
+
next_step = ops.add(step, 1)
|
510
|
+
sigma, timestep = self.scheduler(step, num_steps)
|
511
|
+
next_sigma, _ = self.scheduler(next_step, num_steps)
|
552
512
|
|
553
513
|
# Concatenation for classifier-free guidance.
|
554
514
|
concated_latents, contexts, pooled_projs, timesteps = self.cfg_concat(
|
@@ -570,11 +530,11 @@ class StableDiffusion3Backbone(Backbone):
|
|
570
530
|
predicted_noise = self.cfg(predicted_noise, guidance_scale)
|
571
531
|
|
572
532
|
# Euler step.
|
573
|
-
return self.euler_step(latents, predicted_noise, sigma,
|
533
|
+
return self.euler_step(latents, predicted_noise, sigma, next_sigma)
|
574
534
|
|
575
535
|
def decode_step(self, latents):
|
576
|
-
latents = self.
|
577
|
-
return self.
|
536
|
+
latents = self.latent_rescaling(latents)
|
537
|
+
return self.vae.decode(latents, training=False)
|
578
538
|
|
579
539
|
def get_config(self):
|
580
540
|
config = super().get_config()
|
@@ -585,8 +545,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
585
545
|
"mmdit_num_layers": self.mmdit_num_layers,
|
586
546
|
"mmdit_num_heads": self.mmdit_num_heads,
|
587
547
|
"mmdit_position_size": self.mmdit_position_size,
|
588
|
-
"
|
589
|
-
"vae_stackwise_num_blocks": self.vae_stackwise_num_blocks,
|
548
|
+
"vae": layers.serialize(self.vae),
|
590
549
|
"clip_l": layers.serialize(self.clip_l),
|
591
550
|
"clip_g": layers.serialize(self.clip_g),
|
592
551
|
"t5": layers.serialize(self.t5),
|
@@ -607,6 +566,8 @@ class StableDiffusion3Backbone(Backbone):
|
|
607
566
|
# Propagate `dtype` to text encoders if needed.
|
608
567
|
if "dtype" in config and config["dtype"] is not None:
|
609
568
|
dtype_config = config["dtype"]
|
569
|
+
if "dtype" not in config["vae"]["config"]:
|
570
|
+
config["vae"]["config"]["dtype"] = dtype_config
|
610
571
|
if "dtype" not in config["clip_l"]["config"]:
|
611
572
|
config["clip_l"]["config"]["dtype"] = dtype_config
|
612
573
|
if "dtype" not in config["clip_g"]["config"]:
|
@@ -617,7 +578,10 @@ class StableDiffusion3Backbone(Backbone):
|
|
617
578
|
):
|
618
579
|
config["t5"]["config"]["dtype"] = dtype_config
|
619
580
|
|
620
|
-
# We expect `clip_l`, `clip_g` and/or `t5` to be instantiated.
|
581
|
+
# We expect `vae`, `clip_l`, `clip_g` and/or `t5` to be instantiated.
|
582
|
+
config["vae"] = layers.deserialize(
|
583
|
+
config["vae"], custom_objects=custom_objects
|
584
|
+
)
|
621
585
|
config["clip_l"] = layers.deserialize(
|
622
586
|
config["clip_l"], custom_objects=custom_objects
|
623
587
|
)
|
@@ -5,14 +5,14 @@ backbone_presets = {
|
|
5
5
|
"metadata": {
|
6
6
|
"description": (
|
7
7
|
"3 billion parameter, including CLIP L and CLIP G text "
|
8
|
-
"encoders, MMDiT generative model, and VAE
|
8
|
+
"encoders, MMDiT generative model, and VAE autoencoder. "
|
9
9
|
"Developed by Stability AI."
|
10
10
|
),
|
11
|
-
"params":
|
11
|
+
"params": 2987080931,
|
12
12
|
"official_name": "StableDiffusion3",
|
13
13
|
"path": "stablediffusion3",
|
14
14
|
"model_card": "https://arxiv.org/abs/2110.00476",
|
15
15
|
},
|
16
|
-
"kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/
|
16
|
+
"kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/3",
|
17
17
|
}
|
18
18
|
}
|
@@ -38,11 +38,11 @@ class StableDiffusion3TextToImage(TextToImage):
|
|
38
38
|
["cute wallpaper art of a cat", "cute wallpaper art of a dog"]
|
39
39
|
)
|
40
40
|
|
41
|
-
# Generate with different `num_steps` and `
|
41
|
+
# Generate with different `num_steps` and `guidance_scale`.
|
42
42
|
text_to_image.generate(
|
43
43
|
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
44
44
|
num_steps=50,
|
45
|
-
|
45
|
+
guidance_scale=5.0,
|
46
46
|
)
|
47
47
|
```
|
48
48
|
"""
|
@@ -104,7 +104,9 @@ class StableDiffusion3TextToImage(TextToImage):
|
|
104
104
|
the expense of lower image quality.
|
105
105
|
"""
|
106
106
|
# Encode inputs.
|
107
|
-
embeddings = self.backbone.
|
107
|
+
embeddings = self.backbone.encode_text_step(
|
108
|
+
token_ids, negative_token_ids
|
109
|
+
)
|
108
110
|
|
109
111
|
# Denoise.
|
110
112
|
def body_fun(step, latents):
|
keras_hub/src/models/task.py
CHANGED
@@ -4,8 +4,11 @@ from rich import markup
|
|
4
4
|
from rich import table as rich_table
|
5
5
|
|
6
6
|
from keras_hub.src.api_export import keras_hub_export
|
7
|
+
from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
|
8
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
7
9
|
from keras_hub.src.models.backbone import Backbone
|
8
10
|
from keras_hub.src.models.preprocessor import Preprocessor
|
11
|
+
from keras_hub.src.tokenizers.tokenizer import Tokenizer
|
9
12
|
from keras_hub.src.utils.keras_utils import print_msg
|
10
13
|
from keras_hub.src.utils.pipeline_model import PipelineModel
|
11
14
|
from keras_hub.src.utils.preset_utils import builtin_presets
|
@@ -324,22 +327,24 @@ class Task(PipelineModel):
|
|
324
327
|
info,
|
325
328
|
)
|
326
329
|
|
330
|
+
# Since the preprocessor might be nested with multiple `Tokenizer`,
|
331
|
+
# `ImageConverter`, `AudioConverter` and even other `Preprocessor`
|
332
|
+
# instances, we should recursively iterate through them.
|
327
333
|
preprocessor = self.preprocessor
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
add_layer(audio_converter, info)
|
334
|
+
if preprocessor and isinstance(preprocessor, keras.Layer):
|
335
|
+
for layer in preprocessor._flatten_layers(include_self=False):
|
336
|
+
if isinstance(layer, Tokenizer):
|
337
|
+
info = "Vocab size: "
|
338
|
+
info += highlight_number(layer.vocabulary_size())
|
339
|
+
add_layer(layer, info)
|
340
|
+
elif isinstance(layer, ImageConverter):
|
341
|
+
info = "Image size: "
|
342
|
+
info += highlight_shape(layer.image_size())
|
343
|
+
add_layer(layer, info)
|
344
|
+
elif isinstance(layer, AudioConverter):
|
345
|
+
info = "Audio shape: "
|
346
|
+
info += highlight_shape(layer.audio_shape())
|
347
|
+
add_layer(layer, info)
|
343
348
|
|
344
349
|
# Print the to the console.
|
345
350
|
preprocessor_name = markup.escape(preprocessor.name)
|
@@ -0,0 +1 @@
|
|
1
|
+
from keras_hub.src.models.vae.vae_backbone import VAEBackbone
|
@@ -0,0 +1,172 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.models.backbone import Backbone
|
4
|
+
from keras_hub.src.models.vae.vae_layers import (
|
5
|
+
DiagonalGaussianDistributionSampler,
|
6
|
+
)
|
7
|
+
from keras_hub.src.models.vae.vae_layers import VAEDecoder
|
8
|
+
from keras_hub.src.models.vae.vae_layers import VAEEncoder
|
9
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
10
|
+
|
11
|
+
|
12
|
+
class VAEBackbone(Backbone):
|
13
|
+
"""VAE backbone used in latent diffusion models.
|
14
|
+
|
15
|
+
When encoding, this model generates mean and log variance of the input
|
16
|
+
images. When decoding, it reconstructs images from the latent space.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
encoder_num_filters: list of ints. The number of filters for each
|
20
|
+
block in encoder.
|
21
|
+
encoder_num_blocks: list of ints. The number of blocks for each block in
|
22
|
+
encoder.
|
23
|
+
decoder_num_filters: list of ints. The number of filters for each
|
24
|
+
block in decoder.
|
25
|
+
decoder_num_blocks: list of ints. The number of blocks for each block in
|
26
|
+
decoder.
|
27
|
+
sampler_method: str. The method of the sampler for the intermediate
|
28
|
+
output. Available methods are `"sample"` and `"mode"`. `"sample"`
|
29
|
+
draws from the distribution using both the mean and log variance.
|
30
|
+
`"mode"` draws from the distribution using the mean only. Defaults
|
31
|
+
to `sample`.
|
32
|
+
input_channels: int. The number of channels in the input.
|
33
|
+
sample_channels: int. The number of channels in the sample. Typically,
|
34
|
+
this indicates the intermediate output of VAE, which is mean and
|
35
|
+
log variance.
|
36
|
+
output_channels: int. The number of channels in the output.
|
37
|
+
scale: float. The scaling factor applied to the latent space to ensure
|
38
|
+
it has unit variance during training of the diffusion model.
|
39
|
+
Defaults to `1.5305`, which is the value used in Stable Diffusion 3.
|
40
|
+
shift: float. The shift factor applied to the latent space to ensure it
|
41
|
+
has zero mean during training of the diffusion model. Defaults to
|
42
|
+
`0.0609`, which is the value used in Stable Diffusion 3.
|
43
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
44
|
+
`"channels_first"`. The ordering of the dimensions in the
|
45
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
46
|
+
`(batch_size, height, width, channels)`
|
47
|
+
while `"channels_first"` corresponds to inputs with shape
|
48
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
49
|
+
`image_data_format` value found in your Keras config file at
|
50
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
51
|
+
`"channels_last"`.
|
52
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
53
|
+
to use for the model's computations and weights.
|
54
|
+
"""
|
55
|
+
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
encoder_num_filters,
|
59
|
+
encoder_num_blocks,
|
60
|
+
decoder_num_filters,
|
61
|
+
decoder_num_blocks,
|
62
|
+
sampler_method="sample",
|
63
|
+
input_channels=3,
|
64
|
+
sample_channels=32,
|
65
|
+
output_channels=3,
|
66
|
+
scale=1.5305,
|
67
|
+
shift=0.0609,
|
68
|
+
data_format=None,
|
69
|
+
dtype=None,
|
70
|
+
**kwargs,
|
71
|
+
):
|
72
|
+
data_format = standardize_data_format(data_format)
|
73
|
+
if data_format == "channels_last":
|
74
|
+
image_shape = (None, None, input_channels)
|
75
|
+
channel_axis = -1
|
76
|
+
else:
|
77
|
+
image_shape = (input_channels, None, None)
|
78
|
+
channel_axis = 1
|
79
|
+
|
80
|
+
# === Layers ===
|
81
|
+
self.encoder = VAEEncoder(
|
82
|
+
encoder_num_filters,
|
83
|
+
encoder_num_blocks,
|
84
|
+
output_channels=sample_channels,
|
85
|
+
data_format=data_format,
|
86
|
+
dtype=dtype,
|
87
|
+
name="encoder",
|
88
|
+
)
|
89
|
+
# Use `sample()` to define the functional model.
|
90
|
+
self.distribution_sampler = DiagonalGaussianDistributionSampler(
|
91
|
+
method=sampler_method,
|
92
|
+
axis=channel_axis,
|
93
|
+
dtype=dtype,
|
94
|
+
name="distribution_sampler",
|
95
|
+
)
|
96
|
+
self.decoder = VAEDecoder(
|
97
|
+
decoder_num_filters,
|
98
|
+
decoder_num_blocks,
|
99
|
+
output_channels=output_channels,
|
100
|
+
data_format=data_format,
|
101
|
+
dtype=dtype,
|
102
|
+
name="decoder",
|
103
|
+
)
|
104
|
+
|
105
|
+
# === Functional Model ===
|
106
|
+
image_input = keras.Input(shape=image_shape)
|
107
|
+
sample = self.encoder(image_input)
|
108
|
+
latent = self.distribution_sampler(sample)
|
109
|
+
image_output = self.decoder(latent)
|
110
|
+
super().__init__(
|
111
|
+
inputs=image_input,
|
112
|
+
outputs=image_output,
|
113
|
+
dtype=dtype,
|
114
|
+
**kwargs,
|
115
|
+
)
|
116
|
+
|
117
|
+
# === Config ===
|
118
|
+
self.encoder_num_filters = encoder_num_filters
|
119
|
+
self.encoder_num_blocks = encoder_num_blocks
|
120
|
+
self.decoder_num_filters = decoder_num_filters
|
121
|
+
self.decoder_num_blocks = decoder_num_blocks
|
122
|
+
self.sampler_method = sampler_method
|
123
|
+
self.input_channels = input_channels
|
124
|
+
self.sample_channels = sample_channels
|
125
|
+
self.output_channels = output_channels
|
126
|
+
self._scale = scale
|
127
|
+
self._shift = shift
|
128
|
+
|
129
|
+
@property
|
130
|
+
def scale(self):
|
131
|
+
"""The scaling factor for the latent space.
|
132
|
+
|
133
|
+
This is used to scale the latent space to have unit variance when
|
134
|
+
training the diffusion model.
|
135
|
+
"""
|
136
|
+
return self._scale
|
137
|
+
|
138
|
+
@property
|
139
|
+
def shift(self):
|
140
|
+
"""The shift factor for the latent space.
|
141
|
+
|
142
|
+
This is used to shift the latent space to have zero mean when
|
143
|
+
training the diffusion model.
|
144
|
+
"""
|
145
|
+
return self._shift
|
146
|
+
|
147
|
+
def encode(self, inputs, **kwargs):
|
148
|
+
"""Encode the input images into latent space."""
|
149
|
+
sample = self.encoder(inputs, **kwargs)
|
150
|
+
return self.distribution_sampler(sample)
|
151
|
+
|
152
|
+
def decode(self, inputs, **kwargs):
|
153
|
+
"""Decode the input latent space into images."""
|
154
|
+
return self.decoder(inputs, **kwargs)
|
155
|
+
|
156
|
+
def get_config(self):
|
157
|
+
config = super().get_config()
|
158
|
+
config.update(
|
159
|
+
{
|
160
|
+
"encoder_num_filters": self.encoder_num_filters,
|
161
|
+
"encoder_num_blocks": self.encoder_num_blocks,
|
162
|
+
"decoder_num_filters": self.decoder_num_filters,
|
163
|
+
"decoder_num_blocks": self.decoder_num_blocks,
|
164
|
+
"sampler_method": self.sampler_method,
|
165
|
+
"input_channels": self.input_channels,
|
166
|
+
"sample_channels": self.sample_channels,
|
167
|
+
"output_channels": self.output_channels,
|
168
|
+
"scale": self.scale,
|
169
|
+
"shift": self.shift,
|
170
|
+
}
|
171
|
+
)
|
172
|
+
return config
|