keras-hub-nightly 0.16.1.dev202410030339__py3-none-any.whl → 0.16.1.dev202410050339__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. keras_hub/api/layers/__init__.py +3 -0
  2. keras_hub/api/models/__init__.py +9 -0
  3. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  4. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +196 -0
  5. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  6. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  7. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  8. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -0
  9. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +109 -0
  10. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  11. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  12. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +57 -93
  13. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
  14. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +5 -3
  15. keras_hub/src/models/task.py +20 -15
  16. keras_hub/src/models/vae/__init__.py +1 -0
  17. keras_hub/src/models/vae/vae_backbone.py +172 -0
  18. keras_hub/src/models/vae/vae_layers.py +740 -0
  19. keras_hub/src/version_utils.py +1 -1
  20. {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/METADATA +1 -1
  21. {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/RECORD +23 -14
  22. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  23. {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/WHEEL +0 -0
  24. {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/top_level.txt +0 -0
@@ -8,9 +8,6 @@ from keras_hub.src.models.stable_diffusion_3.flow_match_euler_discrete_scheduler
8
8
  FlowMatchEulerDiscreteScheduler,
9
9
  )
10
10
  from keras_hub.src.models.stable_diffusion_3.mmdit import MMDiT
11
- from keras_hub.src.models.stable_diffusion_3.vae_image_decoder import (
12
- VAEImageDecoder,
13
- )
14
11
  from keras_hub.src.utils.keras_utils import standardize_data_format
15
12
 
16
13
 
@@ -159,48 +156,6 @@ class EulerStep(layers.Layer):
159
156
  return latents_shape
160
157
 
161
158
 
162
- class LatentSpaceDecoder(layers.Layer):
163
- """Decoder to transform the latent space back to the original image space.
164
-
165
- During decoding, the latents are transformed back to the original image
166
- space using the equation: `latents / scale + shift`.
167
-
168
- Args:
169
- scale: float. The scaling factor.
170
- shift: float. The shift factor.
171
- **kwargs: other keyword arguments passed to `keras.layers.Layer`,
172
- including `name`, `dtype` etc.
173
-
174
- Call arguments:
175
- latents: The latent tensor to be transformed.
176
-
177
- Reference:
178
- - [High-Resolution Image Synthesis with Latent Diffusion Models](
179
- https://arxiv.org/abs/2112.10752).
180
- """
181
-
182
- def __init__(self, scale, shift, **kwargs):
183
- super().__init__(**kwargs)
184
- self.scale = scale
185
- self.shift = shift
186
-
187
- def call(self, latents):
188
- return ops.add(ops.divide(latents, self.scale), self.shift)
189
-
190
- def get_config(self):
191
- config = super().get_config()
192
- config.update(
193
- {
194
- "scale": self.scale,
195
- "shift": self.shift,
196
- }
197
- )
198
- return config
199
-
200
- def compute_output_shape(self, latents_shape):
201
- return latents_shape
202
-
203
-
204
159
  @keras_hub_export("keras_hub.models.StableDiffusion3Backbone")
205
160
  class StableDiffusion3Backbone(Backbone):
206
161
  """Stable Diffusion 3 core network with hyperparameters.
@@ -222,16 +177,11 @@ class StableDiffusion3Backbone(Backbone):
222
177
  transformer in MMDiT.
223
178
  mmdit_position_size: int. The size of the height and width for the
224
179
  position embedding in MMDiT.
225
- vae_stackwise_num_filters: list of ints. The number of filters for each
226
- stack in VAE.
227
- vae_stackwise_num_blocks: list of ints. The number of blocks for each
228
- stack in VAE.
229
- clip_l: `keras_hub.models.CLIPTextEncoder`. The text encoder for
230
- encoding the inputs.
231
- clip_g: `keras_hub.models.CLIPTextEncoder`. The text encoder for
232
- encoding the inputs.
233
- t5: optional `keras_hub.models.T5Encoder`. The text encoder for
234
- encoding the inputs.
180
+ vae: The VAE used for transformations between pixel space and latent
181
+ space.
182
+ clip_l: The CLIP text encoder for encoding the inputs.
183
+ clip_g: The CLIP text encoder for encoding the inputs.
184
+ t5: optional The T5 text encoder for encoding the inputs.
235
185
  latent_channels: int. The number of channels in the latent. Defaults to
236
186
  `16`.
237
187
  output_channels: int. The number of channels in the output. Defaults to
@@ -239,7 +189,7 @@ class StableDiffusion3Backbone(Backbone):
239
189
  num_train_timesteps: int. The number of diffusion steps to train the
240
190
  model. Defaults to `1000`.
241
191
  shift: float. The shift value for the timestep schedule. Defaults to
242
- `1.0`.
192
+ `3.0`.
243
193
  height: optional int. The output height of the image.
244
194
  width: optional int. The output width of the image.
245
195
  data_format: `None` or str. If specified, either `"channels_last"` or
@@ -264,6 +214,7 @@ class StableDiffusion3Backbone(Backbone):
264
214
  )
265
215
 
266
216
  # Randomly initialized Stable Diffusion 3 model with custom config.
217
+ vae = keras_hub.models.VAEBackbone(...)
267
218
  clip_l = keras_hub.models.CLIPTextEncoder(...)
268
219
  clip_g = keras_hub.models.CLIPTextEncoder(...)
269
220
  model = keras_hub.models.StableDiffusion3Backbone(
@@ -272,8 +223,7 @@ class StableDiffusion3Backbone(Backbone):
272
223
  mmdit_hidden_dim=256,
273
224
  mmdit_depth=4,
274
225
  mmdit_position_size=192,
275
- vae_stackwise_num_filters=[128, 128, 64, 32],
276
- vae_stackwise_num_blocks=[1, 1, 1, 1],
226
+ vae=vae,
277
227
  clip_l=clip_l,
278
228
  clip_g=clip_g,
279
229
  )
@@ -287,15 +237,14 @@ class StableDiffusion3Backbone(Backbone):
287
237
  mmdit_num_layers,
288
238
  mmdit_num_heads,
289
239
  mmdit_position_size,
290
- vae_stackwise_num_filters,
291
- vae_stackwise_num_blocks,
240
+ vae,
292
241
  clip_l,
293
242
  clip_g,
294
243
  t5=None,
295
244
  latent_channels=16,
296
245
  output_channels=3,
297
246
  num_train_timesteps=1000,
298
- shift=1.0,
247
+ shift=3.0,
299
248
  height=None,
300
249
  width=None,
301
250
  data_format=None,
@@ -312,9 +261,11 @@ class StableDiffusion3Backbone(Backbone):
312
261
  data_format = standardize_data_format(data_format)
313
262
  if data_format != "channels_last":
314
263
  raise NotImplementedError
315
- latent_shape = (height // 8, width // 8, latent_channels)
264
+ image_shape = (height, width, int(vae.input_channels))
265
+ latent_shape = (height // 8, width // 8, int(latent_channels))
316
266
  context_shape = (None, 4096 if t5 is None else t5.hidden_dim)
317
267
  pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,)
268
+ self._latent_shape = latent_shape
318
269
 
319
270
  # === Layers ===
320
271
  self.clip_l = clip_l
@@ -341,15 +292,7 @@ class StableDiffusion3Backbone(Backbone):
341
292
  dtype=dtype,
342
293
  name="diffuser",
343
294
  )
344
- self.decoder = VAEImageDecoder(
345
- vae_stackwise_num_filters,
346
- vae_stackwise_num_blocks,
347
- output_channels,
348
- latent_shape=latent_shape,
349
- data_format=data_format,
350
- dtype=dtype,
351
- name="decoder",
352
- )
295
+ self.vae = vae
353
296
  # Set `dtype="float32"` to ensure the high precision for the noise
354
297
  # residual.
355
298
  self.scheduler = FlowMatchEulerDiscreteScheduler(
@@ -365,14 +308,18 @@ class StableDiffusion3Backbone(Backbone):
365
308
  dtype="float32", name="classifier_free_guidance"
366
309
  )
367
310
  self.euler_step = EulerStep(dtype="float32", name="euler_step")
368
- self.latent_space_decoder = LatentSpaceDecoder(
369
- scale=self.decoder.scaling_factor,
370
- shift=self.decoder.shift_factor,
311
+ self.latent_rescaling = layers.Rescaling(
312
+ scale=1.0 / self.vae.scale,
313
+ offset=self.vae.shift,
371
314
  dtype="float32",
372
- name="latent_space_decoder",
315
+ name="latent_rescaling",
373
316
  )
374
317
 
375
318
  # === Functional Model ===
319
+ image_input = keras.Input(
320
+ shape=image_shape,
321
+ name="images",
322
+ )
376
323
  latent_input = keras.Input(
377
324
  shape=latent_shape,
378
325
  name="latents",
@@ -428,17 +375,19 @@ class StableDiffusion3Backbone(Backbone):
428
375
  dtype="float32",
429
376
  name="guidance_scale",
430
377
  )
431
- embeddings = self.encode_step(token_ids, negative_token_ids)
378
+ embeddings = self.encode_text_step(token_ids, negative_token_ids)
379
+ latents = self.encode_image_step(image_input)
432
380
  # Use `steps=0` to define the functional model.
433
- latents = self.denoise_step(
381
+ denoised_latents = self.denoise_step(
434
382
  latent_input,
435
383
  embeddings,
436
384
  0,
437
385
  num_step_input[0],
438
386
  guidance_scale_input[0],
439
387
  )
440
- outputs = self.decode_step(latents)
388
+ images = self.decode_step(denoised_latents)
441
389
  inputs = {
390
+ "images": image_input,
442
391
  "latents": latent_input,
443
392
  "clip_l_token_ids": clip_l_token_id_input,
444
393
  "clip_l_negative_token_ids": clip_l_negative_token_id_input,
@@ -447,6 +396,10 @@ class StableDiffusion3Backbone(Backbone):
447
396
  "num_steps": num_step_input,
448
397
  "guidance_scale": guidance_scale_input,
449
398
  }
399
+ outputs = {
400
+ "latents": latents,
401
+ "images": images,
402
+ }
450
403
  if self.t5 is not None:
451
404
  inputs["t5_token_ids"] = t5_token_id_input
452
405
  inputs["t5_negative_token_ids"] = t5_negative_token_id_input
@@ -463,8 +416,6 @@ class StableDiffusion3Backbone(Backbone):
463
416
  self.mmdit_num_layers = mmdit_num_layers
464
417
  self.mmdit_num_heads = mmdit_num_heads
465
418
  self.mmdit_position_size = mmdit_position_size
466
- self.vae_stackwise_num_filters = vae_stackwise_num_filters
467
- self.vae_stackwise_num_blocks = vae_stackwise_num_blocks
468
419
  self.latent_channels = latent_channels
469
420
  self.output_channels = output_channels
470
421
  self.num_train_timesteps = num_train_timesteps
@@ -474,7 +425,7 @@ class StableDiffusion3Backbone(Backbone):
474
425
 
475
426
  @property
476
427
  def latent_shape(self):
477
- return (None,) + tuple(self.diffuser.latent_shape)
428
+ return (None,) + self._latent_shape
478
429
 
479
430
  @property
480
431
  def clip_hidden_dim(self):
@@ -484,7 +435,7 @@ class StableDiffusion3Backbone(Backbone):
484
435
  def t5_hidden_dim(self):
485
436
  return 4096 if self.t5 is None else self.t5.hidden_dim
486
437
 
487
- def encode_step(self, token_ids, negative_token_ids):
438
+ def encode_text_step(self, token_ids, negative_token_ids):
488
439
  clip_hidden_dim = self.clip_hidden_dim
489
440
  t5_hidden_dim = self.t5_hidden_dim
490
441
 
@@ -537,18 +488,27 @@ class StableDiffusion3Backbone(Backbone):
537
488
  negative_pooled_embeddings,
538
489
  )
539
490
 
491
+ def encode_image_step(self, images):
492
+ latents = self.vae.encode(images)
493
+ return ops.multiply(
494
+ ops.subtract(latents, self.vae.shift), self.vae.scale
495
+ )
496
+
497
+ def add_noise_step(self, latents, noises, step, num_steps):
498
+ return self.scheduler.add_noise(latents, noises, step, num_steps)
499
+
540
500
  def denoise_step(
541
501
  self,
542
502
  latents,
543
503
  embeddings,
544
- steps,
504
+ step,
545
505
  num_steps,
546
506
  guidance_scale,
547
507
  ):
548
- steps = ops.convert_to_tensor(steps)
549
- steps_next = ops.add(steps, 1)
550
- sigma, timestep = self.scheduler(steps, num_steps)
551
- sigma_next, _ = self.scheduler(steps_next, num_steps)
508
+ step = ops.convert_to_tensor(step)
509
+ next_step = ops.add(step, 1)
510
+ sigma, timestep = self.scheduler(step, num_steps)
511
+ next_sigma, _ = self.scheduler(next_step, num_steps)
552
512
 
553
513
  # Concatenation for classifier-free guidance.
554
514
  concated_latents, contexts, pooled_projs, timesteps = self.cfg_concat(
@@ -570,11 +530,11 @@ class StableDiffusion3Backbone(Backbone):
570
530
  predicted_noise = self.cfg(predicted_noise, guidance_scale)
571
531
 
572
532
  # Euler step.
573
- return self.euler_step(latents, predicted_noise, sigma, sigma_next)
533
+ return self.euler_step(latents, predicted_noise, sigma, next_sigma)
574
534
 
575
535
  def decode_step(self, latents):
576
- latents = self.latent_space_decoder(latents)
577
- return self.decoder(latents, training=False)
536
+ latents = self.latent_rescaling(latents)
537
+ return self.vae.decode(latents, training=False)
578
538
 
579
539
  def get_config(self):
580
540
  config = super().get_config()
@@ -585,8 +545,7 @@ class StableDiffusion3Backbone(Backbone):
585
545
  "mmdit_num_layers": self.mmdit_num_layers,
586
546
  "mmdit_num_heads": self.mmdit_num_heads,
587
547
  "mmdit_position_size": self.mmdit_position_size,
588
- "vae_stackwise_num_filters": self.vae_stackwise_num_filters,
589
- "vae_stackwise_num_blocks": self.vae_stackwise_num_blocks,
548
+ "vae": layers.serialize(self.vae),
590
549
  "clip_l": layers.serialize(self.clip_l),
591
550
  "clip_g": layers.serialize(self.clip_g),
592
551
  "t5": layers.serialize(self.t5),
@@ -607,6 +566,8 @@ class StableDiffusion3Backbone(Backbone):
607
566
  # Propagate `dtype` to text encoders if needed.
608
567
  if "dtype" in config and config["dtype"] is not None:
609
568
  dtype_config = config["dtype"]
569
+ if "dtype" not in config["vae"]["config"]:
570
+ config["vae"]["config"]["dtype"] = dtype_config
610
571
  if "dtype" not in config["clip_l"]["config"]:
611
572
  config["clip_l"]["config"]["dtype"] = dtype_config
612
573
  if "dtype" not in config["clip_g"]["config"]:
@@ -617,7 +578,10 @@ class StableDiffusion3Backbone(Backbone):
617
578
  ):
618
579
  config["t5"]["config"]["dtype"] = dtype_config
619
580
 
620
- # We expect `clip_l`, `clip_g` and/or `t5` to be instantiated.
581
+ # We expect `vae`, `clip_l`, `clip_g` and/or `t5` to be instantiated.
582
+ config["vae"] = layers.deserialize(
583
+ config["vae"], custom_objects=custom_objects
584
+ )
621
585
  config["clip_l"] = layers.deserialize(
622
586
  config["clip_l"], custom_objects=custom_objects
623
587
  )
@@ -5,14 +5,14 @@ backbone_presets = {
5
5
  "metadata": {
6
6
  "description": (
7
7
  "3 billion parameter, including CLIP L and CLIP G text "
8
- "encoders, MMDiT generative model, and VAE decoder. "
8
+ "encoders, MMDiT generative model, and VAE autoencoder. "
9
9
  "Developed by Stability AI."
10
10
  ),
11
- "params": 2952806723,
11
+ "params": 2987080931,
12
12
  "official_name": "StableDiffusion3",
13
13
  "path": "stablediffusion3",
14
14
  "model_card": "https://arxiv.org/abs/2110.00476",
15
15
  },
16
- "kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/1",
16
+ "kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/3",
17
17
  }
18
18
  }
@@ -38,11 +38,11 @@ class StableDiffusion3TextToImage(TextToImage):
38
38
  ["cute wallpaper art of a cat", "cute wallpaper art of a dog"]
39
39
  )
40
40
 
41
- # Generate with different `num_steps` and `classifier_free_guidance_scale`.
41
+ # Generate with different `num_steps` and `guidance_scale`.
42
42
  text_to_image.generate(
43
43
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
44
44
  num_steps=50,
45
- classifier_free_guidance_scale=5.0,
45
+ guidance_scale=5.0,
46
46
  )
47
47
  ```
48
48
  """
@@ -104,7 +104,9 @@ class StableDiffusion3TextToImage(TextToImage):
104
104
  the expense of lower image quality.
105
105
  """
106
106
  # Encode inputs.
107
- embeddings = self.backbone.encode_step(token_ids, negative_token_ids)
107
+ embeddings = self.backbone.encode_text_step(
108
+ token_ids, negative_token_ids
109
+ )
108
110
 
109
111
  # Denoise.
110
112
  def body_fun(step, latents):
@@ -4,8 +4,11 @@ from rich import markup
4
4
  from rich import table as rich_table
5
5
 
6
6
  from keras_hub.src.api_export import keras_hub_export
7
+ from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
8
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
7
9
  from keras_hub.src.models.backbone import Backbone
8
10
  from keras_hub.src.models.preprocessor import Preprocessor
11
+ from keras_hub.src.tokenizers.tokenizer import Tokenizer
9
12
  from keras_hub.src.utils.keras_utils import print_msg
10
13
  from keras_hub.src.utils.pipeline_model import PipelineModel
11
14
  from keras_hub.src.utils.preset_utils import builtin_presets
@@ -324,22 +327,24 @@ class Task(PipelineModel):
324
327
  info,
325
328
  )
326
329
 
330
+ # Since the preprocessor might be nested with multiple `Tokenizer`,
331
+ # `ImageConverter`, `AudioConverter` and even other `Preprocessor`
332
+ # instances, we should recursively iterate through them.
327
333
  preprocessor = self.preprocessor
328
- tokenizer = getattr(preprocessor, "tokenizer", None)
329
- if tokenizer:
330
- info = "Vocab size: "
331
- info += highlight_number(tokenizer.vocabulary_size())
332
- add_layer(tokenizer, info)
333
- image_converter = getattr(preprocessor, "image_converter", None)
334
- if image_converter:
335
- info = "Image size: "
336
- info += highlight_shape(image_converter.image_size)
337
- add_layer(image_converter, info)
338
- audio_converter = getattr(preprocessor, "audio_converter", None)
339
- if audio_converter:
340
- info = "Audio shape: "
341
- info += highlight_shape(audio_converter.audio_shape())
342
- add_layer(audio_converter, info)
334
+ if preprocessor and isinstance(preprocessor, keras.Layer):
335
+ for layer in preprocessor._flatten_layers(include_self=False):
336
+ if isinstance(layer, Tokenizer):
337
+ info = "Vocab size: "
338
+ info += highlight_number(layer.vocabulary_size())
339
+ add_layer(layer, info)
340
+ elif isinstance(layer, ImageConverter):
341
+ info = "Image size: "
342
+ info += highlight_shape(layer.image_size())
343
+ add_layer(layer, info)
344
+ elif isinstance(layer, AudioConverter):
345
+ info = "Audio shape: "
346
+ info += highlight_shape(layer.audio_shape())
347
+ add_layer(layer, info)
343
348
 
344
349
  # Print the to the console.
345
350
  preprocessor_name = markup.escape(preprocessor.name)
@@ -0,0 +1 @@
1
+ from keras_hub.src.models.vae.vae_backbone import VAEBackbone
@@ -0,0 +1,172 @@
1
+ import keras
2
+
3
+ from keras_hub.src.models.backbone import Backbone
4
+ from keras_hub.src.models.vae.vae_layers import (
5
+ DiagonalGaussianDistributionSampler,
6
+ )
7
+ from keras_hub.src.models.vae.vae_layers import VAEDecoder
8
+ from keras_hub.src.models.vae.vae_layers import VAEEncoder
9
+ from keras_hub.src.utils.keras_utils import standardize_data_format
10
+
11
+
12
+ class VAEBackbone(Backbone):
13
+ """VAE backbone used in latent diffusion models.
14
+
15
+ When encoding, this model generates mean and log variance of the input
16
+ images. When decoding, it reconstructs images from the latent space.
17
+
18
+ Args:
19
+ encoder_num_filters: list of ints. The number of filters for each
20
+ block in encoder.
21
+ encoder_num_blocks: list of ints. The number of blocks for each block in
22
+ encoder.
23
+ decoder_num_filters: list of ints. The number of filters for each
24
+ block in decoder.
25
+ decoder_num_blocks: list of ints. The number of blocks for each block in
26
+ decoder.
27
+ sampler_method: str. The method of the sampler for the intermediate
28
+ output. Available methods are `"sample"` and `"mode"`. `"sample"`
29
+ draws from the distribution using both the mean and log variance.
30
+ `"mode"` draws from the distribution using the mean only. Defaults
31
+ to `sample`.
32
+ input_channels: int. The number of channels in the input.
33
+ sample_channels: int. The number of channels in the sample. Typically,
34
+ this indicates the intermediate output of VAE, which is mean and
35
+ log variance.
36
+ output_channels: int. The number of channels in the output.
37
+ scale: float. The scaling factor applied to the latent space to ensure
38
+ it has unit variance during training of the diffusion model.
39
+ Defaults to `1.5305`, which is the value used in Stable Diffusion 3.
40
+ shift: float. The shift factor applied to the latent space to ensure it
41
+ has zero mean during training of the diffusion model. Defaults to
42
+ `0.0609`, which is the value used in Stable Diffusion 3.
43
+ data_format: `None` or str. If specified, either `"channels_last"` or
44
+ `"channels_first"`. The ordering of the dimensions in the
45
+ inputs. `"channels_last"` corresponds to inputs with shape
46
+ `(batch_size, height, width, channels)`
47
+ while `"channels_first"` corresponds to inputs with shape
48
+ `(batch_size, channels, height, width)`. It defaults to the
49
+ `image_data_format` value found in your Keras config file at
50
+ `~/.keras/keras.json`. If you never set it, then it will be
51
+ `"channels_last"`.
52
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
53
+ to use for the model's computations and weights.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ encoder_num_filters,
59
+ encoder_num_blocks,
60
+ decoder_num_filters,
61
+ decoder_num_blocks,
62
+ sampler_method="sample",
63
+ input_channels=3,
64
+ sample_channels=32,
65
+ output_channels=3,
66
+ scale=1.5305,
67
+ shift=0.0609,
68
+ data_format=None,
69
+ dtype=None,
70
+ **kwargs,
71
+ ):
72
+ data_format = standardize_data_format(data_format)
73
+ if data_format == "channels_last":
74
+ image_shape = (None, None, input_channels)
75
+ channel_axis = -1
76
+ else:
77
+ image_shape = (input_channels, None, None)
78
+ channel_axis = 1
79
+
80
+ # === Layers ===
81
+ self.encoder = VAEEncoder(
82
+ encoder_num_filters,
83
+ encoder_num_blocks,
84
+ output_channels=sample_channels,
85
+ data_format=data_format,
86
+ dtype=dtype,
87
+ name="encoder",
88
+ )
89
+ # Use `sample()` to define the functional model.
90
+ self.distribution_sampler = DiagonalGaussianDistributionSampler(
91
+ method=sampler_method,
92
+ axis=channel_axis,
93
+ dtype=dtype,
94
+ name="distribution_sampler",
95
+ )
96
+ self.decoder = VAEDecoder(
97
+ decoder_num_filters,
98
+ decoder_num_blocks,
99
+ output_channels=output_channels,
100
+ data_format=data_format,
101
+ dtype=dtype,
102
+ name="decoder",
103
+ )
104
+
105
+ # === Functional Model ===
106
+ image_input = keras.Input(shape=image_shape)
107
+ sample = self.encoder(image_input)
108
+ latent = self.distribution_sampler(sample)
109
+ image_output = self.decoder(latent)
110
+ super().__init__(
111
+ inputs=image_input,
112
+ outputs=image_output,
113
+ dtype=dtype,
114
+ **kwargs,
115
+ )
116
+
117
+ # === Config ===
118
+ self.encoder_num_filters = encoder_num_filters
119
+ self.encoder_num_blocks = encoder_num_blocks
120
+ self.decoder_num_filters = decoder_num_filters
121
+ self.decoder_num_blocks = decoder_num_blocks
122
+ self.sampler_method = sampler_method
123
+ self.input_channels = input_channels
124
+ self.sample_channels = sample_channels
125
+ self.output_channels = output_channels
126
+ self._scale = scale
127
+ self._shift = shift
128
+
129
+ @property
130
+ def scale(self):
131
+ """The scaling factor for the latent space.
132
+
133
+ This is used to scale the latent space to have unit variance when
134
+ training the diffusion model.
135
+ """
136
+ return self._scale
137
+
138
+ @property
139
+ def shift(self):
140
+ """The shift factor for the latent space.
141
+
142
+ This is used to shift the latent space to have zero mean when
143
+ training the diffusion model.
144
+ """
145
+ return self._shift
146
+
147
+ def encode(self, inputs, **kwargs):
148
+ """Encode the input images into latent space."""
149
+ sample = self.encoder(inputs, **kwargs)
150
+ return self.distribution_sampler(sample)
151
+
152
+ def decode(self, inputs, **kwargs):
153
+ """Decode the input latent space into images."""
154
+ return self.decoder(inputs, **kwargs)
155
+
156
+ def get_config(self):
157
+ config = super().get_config()
158
+ config.update(
159
+ {
160
+ "encoder_num_filters": self.encoder_num_filters,
161
+ "encoder_num_blocks": self.encoder_num_blocks,
162
+ "decoder_num_filters": self.decoder_num_filters,
163
+ "decoder_num_blocks": self.decoder_num_blocks,
164
+ "sampler_method": self.sampler_method,
165
+ "input_channels": self.input_channels,
166
+ "sample_channels": self.sample_channels,
167
+ "output_channels": self.output_channels,
168
+ "scale": self.scale,
169
+ "shift": self.shift,
170
+ }
171
+ )
172
+ return config