keras-hub-nightly 0.16.1.dev202410080341__py3-none-any.whl → 0.16.1.dev202410100339__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. keras_hub/api/layers/__init__.py +3 -0
  2. keras_hub/api/models/__init__.py +11 -0
  3. keras_hub/src/layers/preprocessing/image_converter.py +2 -1
  4. keras_hub/src/models/image_to_image.py +411 -0
  5. keras_hub/src/models/inpaint.py +513 -0
  6. keras_hub/src/models/mix_transformer/__init__.py +12 -0
  7. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +4 -0
  8. keras_hub/src/models/mix_transformer/mix_transformer_classifier_preprocessor.py +16 -0
  9. keras_hub/src/models/mix_transformer/mix_transformer_image_converter.py +8 -0
  10. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +9 -5
  11. keras_hub/src/models/mix_transformer/mix_transformer_presets.py +151 -0
  12. keras_hub/src/models/preprocessor.py +4 -4
  13. keras_hub/src/models/stable_diffusion_3/mmdit.py +308 -177
  14. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +87 -55
  15. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +171 -0
  16. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +194 -0
  17. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +1 -1
  18. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +13 -8
  19. keras_hub/src/models/task.py +1 -1
  20. keras_hub/src/models/text_to_image.py +89 -36
  21. keras_hub/src/tests/test_case.py +3 -1
  22. keras_hub/src/tokenizers/tokenizer.py +7 -7
  23. keras_hub/src/utils/preset_utils.py +7 -7
  24. keras_hub/src/utils/timm/preset_loader.py +1 -3
  25. keras_hub/src/version_utils.py +1 -1
  26. {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/METADATA +1 -1
  27. {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/RECORD +29 -22
  28. {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/WHEEL +0 -0
  29. {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/top_level.txt +0 -0
@@ -51,11 +51,52 @@ class CLIPProjection(layers.Layer):
51
51
  return (inputs_shape[0], self.hidden_dim)
52
52
 
53
53
 
54
- class ClassifierFreeGuidanceConcatenate(layers.Layer):
55
- def __init__(self, axis=0, **kwargs):
56
- super().__init__(**kwargs)
57
- self.axis = axis
54
+ class CLIPConcatenate(layers.Layer):
55
+ def call(
56
+ self,
57
+ clip_l_projection,
58
+ clip_g_projection,
59
+ clip_l_intermediate_output,
60
+ clip_g_intermediate_output,
61
+ padding,
62
+ ):
63
+ pooled_embeddings = ops.concatenate(
64
+ [clip_l_projection, clip_g_projection], axis=-1
65
+ )
66
+ embeddings = ops.concatenate(
67
+ [clip_l_intermediate_output, clip_g_intermediate_output], axis=-1
68
+ )
69
+ embeddings = ops.pad(embeddings, [[0, 0], [0, 0], [0, padding]])
70
+ return pooled_embeddings, embeddings
71
+
72
+
73
+ class ImageRescaling(layers.Rescaling):
74
+ """Rescales inputs from image space to latent space.
75
+
76
+ The rescaling is performed using the formula: `(inputs - offset) * scale`.
77
+ """
78
+
79
+ def call(self, inputs):
80
+ dtype = self.compute_dtype
81
+ scale = self.backend.cast(self.scale, dtype)
82
+ offset = self.backend.cast(self.offset, dtype)
83
+ return (self.backend.cast(inputs, dtype) - offset) * scale
84
+
58
85
 
86
+ class LatentRescaling(layers.Rescaling):
87
+ """Rescales inputs from latent space to image space.
88
+
89
+ The rescaling is performed using the formula: `inputs / scale + offset`.
90
+ """
91
+
92
+ def call(self, inputs):
93
+ dtype = self.compute_dtype
94
+ scale = self.backend.cast(self.scale, dtype)
95
+ offset = self.backend.cast(self.offset, dtype)
96
+ return (self.backend.cast(inputs, dtype) / scale) + offset
97
+
98
+
99
+ class ClassifierFreeGuidanceConcatenate(layers.Layer):
59
100
  def call(
60
101
  self,
61
102
  latents,
@@ -66,20 +107,16 @@ class ClassifierFreeGuidanceConcatenate(layers.Layer):
66
107
  timestep,
67
108
  ):
68
109
  timestep = ops.broadcast_to(timestep, ops.shape(latents)[:1])
69
- latents = ops.concatenate([latents, latents], axis=self.axis)
110
+ latents = ops.concatenate([latents, latents], axis=0)
70
111
  contexts = ops.concatenate(
71
- [positive_contexts, negative_contexts], axis=self.axis
112
+ [positive_contexts, negative_contexts], axis=0
72
113
  )
73
114
  pooled_projections = ops.concatenate(
74
- [positive_pooled_projections, negative_pooled_projections],
75
- axis=self.axis,
115
+ [positive_pooled_projections, negative_pooled_projections], axis=0
76
116
  )
77
- timesteps = ops.concatenate([timestep, timestep], axis=self.axis)
117
+ timesteps = ops.concatenate([timestep, timestep], axis=0)
78
118
  return latents, contexts, pooled_projections, timesteps
79
119
 
80
- def get_config(self):
81
- return super().get_config()
82
-
83
120
 
84
121
  class ClassifierFreeGuidance(layers.Layer):
85
122
  """Perform classifier free guidance.
@@ -100,9 +137,6 @@ class ClassifierFreeGuidance(layers.Layer):
100
137
  - [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
101
138
  """
102
139
 
103
- def __init__(self, **kwargs):
104
- super().__init__(**kwargs)
105
-
106
140
  def call(self, inputs, guidance_scale):
107
141
  positive_noise, negative_noise = ops.split(inputs, 2, axis=0)
108
142
  return ops.add(
@@ -112,9 +146,6 @@ class ClassifierFreeGuidance(layers.Layer):
112
146
  ),
113
147
  )
114
148
 
115
- def get_config(self):
116
- return super().get_config()
117
-
118
149
  def compute_output_shape(self, inputs_shape):
119
150
  outputs_shape = list(inputs_shape)
120
151
  if outputs_shape[0] is not None:
@@ -142,16 +173,10 @@ class EulerStep(layers.Layer):
142
173
  https://arxiv.org/abs/2206.00364).
143
174
  """
144
175
 
145
- def __init__(self, **kwargs):
146
- super().__init__(**kwargs)
147
-
148
176
  def call(self, latents, noise_residual, sigma, sigma_next):
149
177
  sigma_diff = ops.subtract(sigma_next, sigma)
150
178
  return ops.add(latents, ops.multiply(sigma_diff, noise_residual))
151
179
 
152
- def get_config(self):
153
- return super().get_config()
154
-
155
180
  def compute_output_shape(self, latents_shape):
156
181
  return latents_shape
157
182
 
@@ -272,12 +297,13 @@ class StableDiffusion3Backbone(Backbone):
272
297
  self.clip_l_projection = CLIPProjection(
273
298
  clip_l.hidden_dim, dtype=dtype, name="clip_l_projection"
274
299
  )
275
- self.clip_l_projection.build([None, clip_l.hidden_dim], None)
276
300
  self.clip_g = clip_g
277
301
  self.clip_g_projection = CLIPProjection(
278
302
  clip_g.hidden_dim, dtype=dtype, name="clip_g_projection"
279
303
  )
280
- self.clip_g_projection.build([None, clip_g.hidden_dim], None)
304
+ self.clip_concatenate = CLIPConcatenate(
305
+ dtype=dtype, name="clip_concatenate"
306
+ )
281
307
  self.t5 = t5
282
308
  self.diffuser = MMDiT(
283
309
  mmdit_patch_size,
@@ -293,6 +319,12 @@ class StableDiffusion3Backbone(Backbone):
293
319
  name="diffuser",
294
320
  )
295
321
  self.vae = vae
322
+ self.cfg_concat = ClassifierFreeGuidanceConcatenate(
323
+ dtype=dtype, name="classifier_free_guidance_concat"
324
+ )
325
+ self.cfg = ClassifierFreeGuidance(
326
+ dtype=dtype, name="classifier_free_guidance"
327
+ )
296
328
  # Set `dtype="float32"` to ensure the high precision for the noise
297
329
  # residual.
298
330
  self.scheduler = FlowMatchEulerDiscreteScheduler(
@@ -301,17 +333,17 @@ class StableDiffusion3Backbone(Backbone):
301
333
  dtype="float32",
302
334
  name="scheduler",
303
335
  )
304
- self.cfg_concat = ClassifierFreeGuidanceConcatenate(
305
- dtype="float32", name="classifier_free_guidance_concat"
306
- )
307
- self.cfg = ClassifierFreeGuidance(
308
- dtype="float32", name="classifier_free_guidance"
309
- )
310
336
  self.euler_step = EulerStep(dtype="float32", name="euler_step")
311
- self.latent_rescaling = layers.Rescaling(
312
- scale=1.0 / self.vae.scale,
337
+ self.image_rescaling = ImageRescaling(
338
+ scale=self.vae.scale,
313
339
  offset=self.vae.shift,
314
- dtype="float32",
340
+ dtype=dtype,
341
+ name="image_rescaling",
342
+ )
343
+ self.latent_rescaling = LatentRescaling(
344
+ scale=self.vae.scale,
345
+ offset=self.vae.shift,
346
+ dtype=dtype,
315
347
  name="latent_rescaling",
316
348
  )
317
349
 
@@ -440,8 +472,12 @@ class StableDiffusion3Backbone(Backbone):
440
472
  t5_hidden_dim = self.t5_hidden_dim
441
473
 
442
474
  def encode(token_ids):
443
- clip_l_outputs = self.clip_l(token_ids["clip_l"], training=False)
444
- clip_g_outputs = self.clip_g(token_ids["clip_g"], training=False)
475
+ clip_l_outputs = self.clip_l(
476
+ {"token_ids": token_ids["clip_l"]}, training=False
477
+ )
478
+ clip_g_outputs = self.clip_g(
479
+ {"token_ids": token_ids["clip_g"]}, training=False
480
+ )
445
481
  clip_l_projection = self.clip_l_projection(
446
482
  clip_l_outputs["sequence_output"],
447
483
  token_ids["clip_l"],
@@ -452,23 +488,21 @@ class StableDiffusion3Backbone(Backbone):
452
488
  token_ids["clip_g"],
453
489
  training=False,
454
490
  )
455
- pooled_embeddings = ops.concatenate(
456
- [clip_l_projection, clip_g_projection],
457
- axis=-1,
458
- )
459
- embeddings = ops.concatenate(
460
- [
461
- clip_l_outputs["intermediate_output"],
462
- clip_g_outputs["intermediate_output"],
463
- ],
464
- axis=-1,
465
- )
466
- embeddings = ops.pad(
467
- embeddings,
468
- [[0, 0], [0, 0], [0, t5_hidden_dim - clip_hidden_dim]],
491
+ pooled_embeddings, embeddings = self.clip_concatenate(
492
+ clip_l_projection,
493
+ clip_g_projection,
494
+ clip_l_outputs["intermediate_output"],
495
+ clip_g_outputs["intermediate_output"],
496
+ padding=t5_hidden_dim - clip_hidden_dim,
469
497
  )
470
498
  if self.t5 is not None:
471
- t5_outputs = self.t5(token_ids["t5"], training=False)
499
+ t5_outputs = self.t5(
500
+ {
501
+ "token_ids": token_ids["t5"],
502
+ "padding_mask": ops.ones_like(token_ids["t5"]),
503
+ },
504
+ training=False,
505
+ )
472
506
  embeddings = ops.concatenate([embeddings, t5_outputs], axis=-2)
473
507
  else:
474
508
  padded_size = self.clip_l.max_sequence_length
@@ -490,9 +524,7 @@ class StableDiffusion3Backbone(Backbone):
490
524
 
491
525
  def encode_image_step(self, images):
492
526
  latents = self.vae.encode(images)
493
- return ops.multiply(
494
- ops.subtract(latents, self.vae.shift), self.vae.scale
495
- )
527
+ return self.image_rescaling(latents)
496
528
 
497
529
  def add_noise_step(self, latents, noises, step, num_steps):
498
530
  return self.scheduler.add_noise(latents, noises, step, num_steps)
@@ -0,0 +1,171 @@
1
+ from keras import ops
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.image_to_image import ImageToImage
5
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
6
+ StableDiffusion3Backbone,
7
+ )
8
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
9
+ StableDiffusion3TextToImagePreprocessor,
10
+ )
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.StableDiffusion3ImageToImage")
14
+ class StableDiffusion3ImageToImage(ImageToImage):
15
+ """An end-to-end Stable Diffusion 3 model for image-to-image generation.
16
+
17
+ This model has a `generate()` method, which generates images based
18
+ on a combination of a reference image and a text prompt.
19
+
20
+ Args:
21
+ backbone: A `keras_hub.models.StableDiffusion3Backbone` instance.
22
+ preprocessor: A
23
+ `keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance.
24
+
25
+ Examples:
26
+
27
+ Use `generate()` to do image generation.
28
+ ```python
29
+ image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset(
30
+ "stable_diffusion_3_medium", height=512, width=512
31
+ )
32
+ image_to_image.generate(
33
+ {
34
+ "images": np.ones((512, 512, 3), dtype="float32"),
35
+ "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
36
+ }
37
+ )
38
+
39
+ # Generate with batched prompts.
40
+ image_to_image.generate(
41
+ {
42
+ "images": np.ones((2, 512, 512, 3), dtype="float32"),
43
+ "prompts": ["cute wallpaper art of a cat", "cute wallpaper art of a dog"],
44
+ }
45
+ )
46
+
47
+ # Generate with different `num_steps`, `guidance_scale` and `strength`.
48
+ image_to_image.generate(
49
+ {
50
+ "images": np.ones((512, 512, 3), dtype="float32"),
51
+ "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
52
+ }
53
+ num_steps=50,
54
+ guidance_scale=5.0,
55
+ strength=0.6,
56
+ )
57
+
58
+ # Generate with `negative_prompts`.
59
+ text_to_image.generate(
60
+ {
61
+ "images": np.ones((512, 512, 3), dtype="float32"),
62
+ "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
63
+ "negative_prompts": "green color",
64
+ }
65
+ )
66
+ ```
67
+ """
68
+
69
+ backbone_cls = StableDiffusion3Backbone
70
+ preprocessor_cls = StableDiffusion3TextToImagePreprocessor
71
+
72
+ def __init__(
73
+ self,
74
+ backbone,
75
+ preprocessor,
76
+ **kwargs,
77
+ ):
78
+ # === Layers ===
79
+ self.backbone = backbone
80
+ self.preprocessor = preprocessor
81
+
82
+ # === Functional Model ===
83
+ inputs = backbone.input
84
+ outputs = backbone.output
85
+ super().__init__(
86
+ inputs=inputs,
87
+ outputs=outputs,
88
+ **kwargs,
89
+ )
90
+
91
+ def fit(self, *args, **kwargs):
92
+ raise NotImplementedError(
93
+ "Currently, `fit` is not supported for "
94
+ "`StableDiffusion3ImageToImage`."
95
+ )
96
+
97
+ def generate_step(
98
+ self,
99
+ images,
100
+ noises,
101
+ token_ids,
102
+ starting_step,
103
+ num_steps,
104
+ guidance_scale,
105
+ ):
106
+ """A compilable generation function for batched of inputs.
107
+
108
+ This function represents the inner, XLA-compilable, generation function
109
+ for batched inputs.
110
+
111
+ Args:
112
+ images: A (batch_size, image_height, image_width, 3) tensor
113
+ containing the reference images.
114
+ noises: A (batch_size, latent_height, latent_width, channels) tensor
115
+ containing the noises to be added to the latents. Typically,
116
+ this tensor is sampled from the Gaussian distribution.
117
+ token_ids: A pair of (batch_size, num_tokens) tensor containing the
118
+ tokens based on the input prompts and negative prompts.
119
+ starting_step: int. The number of the starting diffusion step.
120
+ num_steps: int. The number of diffusion steps to take.
121
+ guidance_scale: float. The classifier free guidance scale defined in
122
+ [Classifier-Free Diffusion Guidance](
123
+ https://arxiv.org/abs/2207.12598). Higher scale encourages to
124
+ generate images that are closely linked to prompts, usually at
125
+ the expense of lower image quality.
126
+ """
127
+ token_ids, negative_token_ids = token_ids
128
+
129
+ # Encode images.
130
+ latents = self.backbone.encode_image_step(images)
131
+
132
+ # Add noises to latents.
133
+ latents = self.backbone.add_noise_step(
134
+ latents, noises, starting_step, num_steps
135
+ )
136
+
137
+ # Encode inputs.
138
+ embeddings = self.backbone.encode_text_step(
139
+ token_ids, negative_token_ids
140
+ )
141
+
142
+ # Denoise.
143
+ def body_fun(step, latents):
144
+ return self.backbone.denoise_step(
145
+ latents,
146
+ embeddings,
147
+ step,
148
+ num_steps,
149
+ guidance_scale,
150
+ )
151
+
152
+ latents = ops.fori_loop(starting_step, num_steps, body_fun, latents)
153
+
154
+ # Decode.
155
+ return self.backbone.decode_step(latents)
156
+
157
+ def generate(
158
+ self,
159
+ inputs,
160
+ num_steps=50,
161
+ guidance_scale=7.0,
162
+ strength=0.8,
163
+ seed=None,
164
+ ):
165
+ return super().generate(
166
+ inputs,
167
+ num_steps=num_steps,
168
+ guidance_scale=guidance_scale,
169
+ strength=strength,
170
+ seed=seed,
171
+ )
@@ -0,0 +1,194 @@
1
+ from keras import ops
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.inpaint import Inpaint
5
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
6
+ StableDiffusion3Backbone,
7
+ )
8
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
9
+ StableDiffusion3TextToImagePreprocessor,
10
+ )
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.StableDiffusion3Inpaint")
14
+ class StableDiffusion3Inpaint(Inpaint):
15
+ """An end-to-end Stable Diffusion 3 model for inpaint generation.
16
+
17
+ This model has a `generate()` method, which generates images based
18
+ on a combination of a reference image, mask and a text prompt.
19
+
20
+ Args:
21
+ backbone: A `keras_hub.models.StableDiffusion3Backbone` instance.
22
+ preprocessor: A
23
+ `keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance.
24
+
25
+ Examples:
26
+
27
+ Use `generate()` to do image generation.
28
+ ```python
29
+ reference_image = np.ones((1024, 1024, 3), dtype="float32")
30
+ reference_mask = np.ones((1024, 1024), dtype="float32")
31
+ inpaint = keras_hub.models.StableDiffusion3Inpaint.from_preset(
32
+ "stable_diffusion_3_medium", height=512, width=512
33
+ )
34
+ inpaint.generate(
35
+ reference_image,
36
+ reference_mask,
37
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
38
+ )
39
+
40
+ # Generate with batched prompts.
41
+ reference_images = np.ones((2, 512, 512, 3), dtype="float32")
42
+ reference_mask = np.ones((2, 1024, 1024), dtype="float32")
43
+ inpaint.generate(
44
+ reference_images,
45
+ reference_mask,
46
+ ["cute wallpaper art of a cat", "cute wallpaper art of a dog"]
47
+ )
48
+
49
+ # Generate with different `num_steps`, `guidance_scale` and `strength`.
50
+ inpaint.generate(
51
+ reference_image,
52
+ reference_mask,
53
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
54
+ num_steps=50,
55
+ guidance_scale=5.0,
56
+ strength=0.6,
57
+ )
58
+ ```
59
+ """
60
+
61
+ backbone_cls = StableDiffusion3Backbone
62
+ preprocessor_cls = StableDiffusion3TextToImagePreprocessor
63
+
64
+ def __init__(
65
+ self,
66
+ backbone,
67
+ preprocessor,
68
+ **kwargs,
69
+ ):
70
+ # === Layers ===
71
+ self.backbone = backbone
72
+ self.preprocessor = preprocessor
73
+
74
+ # === Functional Model ===
75
+ inputs = backbone.input
76
+ outputs = backbone.output
77
+ super().__init__(
78
+ inputs=inputs,
79
+ outputs=outputs,
80
+ **kwargs,
81
+ )
82
+
83
+ def fit(self, *args, **kwargs):
84
+ raise NotImplementedError(
85
+ "Currently, `fit` is not supported for "
86
+ "`StableDiffusion3Inpaint`."
87
+ )
88
+
89
+ def generate_step(
90
+ self,
91
+ images,
92
+ masks,
93
+ noises,
94
+ token_ids,
95
+ starting_step,
96
+ num_steps,
97
+ guidance_scale,
98
+ ):
99
+ """A compilable generation function for batched of inputs.
100
+
101
+ This function represents the inner, XLA-compilable, generation function
102
+ for batched inputs.
103
+
104
+ Args:
105
+ images: A (batch_size, image_height, image_width, 3) tensor
106
+ containing the reference images.
107
+ masks: A (batch_size, image_height, image_width) tensor
108
+ containing the reference masks.
109
+ noises: A (batch_size, latent_height, latent_width, channels) tensor
110
+ containing the noises to be added to the latents. Typically,
111
+ this tensor is sampled from the Gaussian distribution.
112
+ token_ids: A pair of (batch_size, num_tokens) tensor containing the
113
+ tokens based on the input prompts and negative prompts.
114
+ starting_step: int. The number of the starting diffusion step.
115
+ num_steps: int. The number of diffusion steps to take.
116
+ guidance_scale: float. The classifier free guidance scale defined in
117
+ [Classifier-Free Diffusion Guidance](
118
+ https://arxiv.org/abs/2207.12598). Higher scale encourages to
119
+ generate images that are closely linked to prompts, usually at
120
+ the expense of lower image quality.
121
+ """
122
+ token_ids, negative_token_ids = token_ids
123
+
124
+ # Get masked images.
125
+ masks = ops.cast(ops.expand_dims(masks, axis=-1) > 0.5, images.dtype)
126
+ masks_latent_size = ops.image.resize(
127
+ masks,
128
+ (self.backbone.latent_shape[1], self.backbone.latent_shape[2]),
129
+ interpolation="nearest",
130
+ )
131
+
132
+ # Encode images.
133
+ image_latents = self.backbone.encode_image_step(images)
134
+
135
+ # Add noises to latents.
136
+ latents = self.backbone.add_noise_step(
137
+ image_latents, noises, starting_step, num_steps
138
+ )
139
+
140
+ # Encode inputs.
141
+ embeddings = self.backbone.encode_text_step(
142
+ token_ids, negative_token_ids
143
+ )
144
+
145
+ # Denoise.
146
+ def body_fun(step, latents):
147
+ latents = self.backbone.denoise_step(
148
+ latents,
149
+ embeddings,
150
+ step,
151
+ num_steps,
152
+ guidance_scale,
153
+ )
154
+
155
+ # Compute the previous latents x_t -> x_t-1.
156
+ def true_fn():
157
+ next_step = ops.add(step, 1)
158
+ return self.backbone.add_noise_step(
159
+ image_latents, noises, next_step, num_steps
160
+ )
161
+
162
+ init_latents = ops.cond(
163
+ step < ops.subtract(num_steps, 1),
164
+ true_fn,
165
+ lambda: ops.cast(image_latents, noises.dtype),
166
+ )
167
+ latents = ops.add(
168
+ ops.multiply(
169
+ ops.subtract(1.0, masks_latent_size), init_latents
170
+ ),
171
+ ops.multiply(masks_latent_size, latents),
172
+ )
173
+ return latents
174
+
175
+ latents = ops.fori_loop(starting_step, num_steps, body_fun, latents)
176
+
177
+ # Decode.
178
+ return self.backbone.decode_step(latents)
179
+
180
+ def generate(
181
+ self,
182
+ inputs,
183
+ num_steps=50,
184
+ guidance_scale=7.0,
185
+ strength=0.6,
186
+ seed=None,
187
+ ):
188
+ return super().generate(
189
+ inputs,
190
+ num_steps=num_steps,
191
+ guidance_scale=guidance_scale,
192
+ strength=strength,
193
+ seed=seed,
194
+ )
@@ -13,6 +13,6 @@ backbone_presets = {
13
13
  "path": "stablediffusion3",
14
14
  "model_card": "https://arxiv.org/abs/2110.00476",
15
15
  },
16
- "kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/3",
16
+ "kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/4",
17
17
  }
18
18
  }
@@ -44,6 +44,14 @@ class StableDiffusion3TextToImage(TextToImage):
44
44
  num_steps=50,
45
45
  guidance_scale=5.0,
46
46
  )
47
+
48
+ # Generate with `negative_prompts`.
49
+ text_to_image.generate(
50
+ {
51
+ "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
52
+ "negative_prompts": "green color",
53
+ }
54
+ )
47
55
  ```
48
56
  """
49
57
 
@@ -79,7 +87,6 @@ class StableDiffusion3TextToImage(TextToImage):
79
87
  self,
80
88
  latents,
81
89
  token_ids,
82
- negative_token_ids,
83
90
  num_steps,
84
91
  guidance_scale,
85
92
  ):
@@ -92,10 +99,8 @@ class StableDiffusion3TextToImage(TextToImage):
92
99
  latents: A (batch_size, height, width, channels) tensor
93
100
  containing the latents to start generation from. Typically, this
94
101
  tensor is sampled from the Gaussian distribution.
95
- token_ids: A (batch_size, num_tokens) tensor containing the
96
- tokens based on the input prompts.
97
- negative_token_ids: A (batch_size, num_tokens) tensor
98
- containing the negative tokens based on the input prompts.
102
+ token_ids: A pair of (batch_size, num_tokens) tensor containing the
103
+ tokens based on the input prompts and negative prompts.
99
104
  num_steps: int. The number of diffusion steps to take.
100
105
  guidance_scale: float. The classifier free guidance scale defined in
101
106
  [Classifier-Free Diffusion Guidance](
@@ -103,7 +108,9 @@ class StableDiffusion3TextToImage(TextToImage):
103
108
  generate images that are closely linked to prompts, usually at
104
109
  the expense of lower image quality.
105
110
  """
106
- # Encode inputs.
111
+ token_ids, negative_token_ids = token_ids
112
+
113
+ # Encode prompts.
107
114
  embeddings = self.backbone.encode_text_step(
108
115
  token_ids, negative_token_ids
109
116
  )
@@ -126,14 +133,12 @@ class StableDiffusion3TextToImage(TextToImage):
126
133
  def generate(
127
134
  self,
128
135
  inputs,
129
- negative_inputs=None,
130
136
  num_steps=28,
131
137
  guidance_scale=7.0,
132
138
  seed=None,
133
139
  ):
134
140
  return super().generate(
135
141
  inputs,
136
- negative_inputs=negative_inputs,
137
142
  num_steps=num_steps,
138
143
  guidance_scale=guidance_scale,
139
144
  seed=seed,
@@ -339,7 +339,7 @@ class Task(PipelineModel):
339
339
  add_layer(layer, info)
340
340
  elif isinstance(layer, ImageConverter):
341
341
  info = "Image size: "
342
- info += highlight_shape(layer.image_size())
342
+ info += highlight_shape(layer.image_size)
343
343
  add_layer(layer, info)
344
344
  elif isinstance(layer, AudioConverter):
345
345
  info = "Audio shape: "