keras-hub-nightly 0.16.1.dev202410080341__py3-none-any.whl → 0.16.1.dev202410100339__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -0
- keras_hub/api/models/__init__.py +11 -0
- keras_hub/src/layers/preprocessing/image_converter.py +2 -1
- keras_hub/src/models/image_to_image.py +411 -0
- keras_hub/src/models/inpaint.py +513 -0
- keras_hub/src/models/mix_transformer/__init__.py +12 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +4 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mix_transformer/mix_transformer_image_converter.py +8 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +9 -5
- keras_hub/src/models/mix_transformer/mix_transformer_presets.py +151 -0
- keras_hub/src/models/preprocessor.py +4 -4
- keras_hub/src/models/stable_diffusion_3/mmdit.py +308 -177
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +87 -55
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +171 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +194 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +1 -1
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +13 -8
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/models/text_to_image.py +89 -36
- keras_hub/src/tests/test_case.py +3 -1
- keras_hub/src/tokenizers/tokenizer.py +7 -7
- keras_hub/src/utils/preset_utils.py +7 -7
- keras_hub/src/utils/timm/preset_loader.py +1 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/RECORD +29 -22
- {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/top_level.txt +0 -0
@@ -51,11 +51,52 @@ class CLIPProjection(layers.Layer):
|
|
51
51
|
return (inputs_shape[0], self.hidden_dim)
|
52
52
|
|
53
53
|
|
54
|
-
class
|
55
|
-
def
|
56
|
-
|
57
|
-
|
54
|
+
class CLIPConcatenate(layers.Layer):
|
55
|
+
def call(
|
56
|
+
self,
|
57
|
+
clip_l_projection,
|
58
|
+
clip_g_projection,
|
59
|
+
clip_l_intermediate_output,
|
60
|
+
clip_g_intermediate_output,
|
61
|
+
padding,
|
62
|
+
):
|
63
|
+
pooled_embeddings = ops.concatenate(
|
64
|
+
[clip_l_projection, clip_g_projection], axis=-1
|
65
|
+
)
|
66
|
+
embeddings = ops.concatenate(
|
67
|
+
[clip_l_intermediate_output, clip_g_intermediate_output], axis=-1
|
68
|
+
)
|
69
|
+
embeddings = ops.pad(embeddings, [[0, 0], [0, 0], [0, padding]])
|
70
|
+
return pooled_embeddings, embeddings
|
71
|
+
|
72
|
+
|
73
|
+
class ImageRescaling(layers.Rescaling):
|
74
|
+
"""Rescales inputs from image space to latent space.
|
75
|
+
|
76
|
+
The rescaling is performed using the formula: `(inputs - offset) * scale`.
|
77
|
+
"""
|
78
|
+
|
79
|
+
def call(self, inputs):
|
80
|
+
dtype = self.compute_dtype
|
81
|
+
scale = self.backend.cast(self.scale, dtype)
|
82
|
+
offset = self.backend.cast(self.offset, dtype)
|
83
|
+
return (self.backend.cast(inputs, dtype) - offset) * scale
|
84
|
+
|
58
85
|
|
86
|
+
class LatentRescaling(layers.Rescaling):
|
87
|
+
"""Rescales inputs from latent space to image space.
|
88
|
+
|
89
|
+
The rescaling is performed using the formula: `inputs / scale + offset`.
|
90
|
+
"""
|
91
|
+
|
92
|
+
def call(self, inputs):
|
93
|
+
dtype = self.compute_dtype
|
94
|
+
scale = self.backend.cast(self.scale, dtype)
|
95
|
+
offset = self.backend.cast(self.offset, dtype)
|
96
|
+
return (self.backend.cast(inputs, dtype) / scale) + offset
|
97
|
+
|
98
|
+
|
99
|
+
class ClassifierFreeGuidanceConcatenate(layers.Layer):
|
59
100
|
def call(
|
60
101
|
self,
|
61
102
|
latents,
|
@@ -66,20 +107,16 @@ class ClassifierFreeGuidanceConcatenate(layers.Layer):
|
|
66
107
|
timestep,
|
67
108
|
):
|
68
109
|
timestep = ops.broadcast_to(timestep, ops.shape(latents)[:1])
|
69
|
-
latents = ops.concatenate([latents, latents], axis=
|
110
|
+
latents = ops.concatenate([latents, latents], axis=0)
|
70
111
|
contexts = ops.concatenate(
|
71
|
-
[positive_contexts, negative_contexts], axis=
|
112
|
+
[positive_contexts, negative_contexts], axis=0
|
72
113
|
)
|
73
114
|
pooled_projections = ops.concatenate(
|
74
|
-
[positive_pooled_projections, negative_pooled_projections],
|
75
|
-
axis=self.axis,
|
115
|
+
[positive_pooled_projections, negative_pooled_projections], axis=0
|
76
116
|
)
|
77
|
-
timesteps = ops.concatenate([timestep, timestep], axis=
|
117
|
+
timesteps = ops.concatenate([timestep, timestep], axis=0)
|
78
118
|
return latents, contexts, pooled_projections, timesteps
|
79
119
|
|
80
|
-
def get_config(self):
|
81
|
-
return super().get_config()
|
82
|
-
|
83
120
|
|
84
121
|
class ClassifierFreeGuidance(layers.Layer):
|
85
122
|
"""Perform classifier free guidance.
|
@@ -100,9 +137,6 @@ class ClassifierFreeGuidance(layers.Layer):
|
|
100
137
|
- [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
101
138
|
"""
|
102
139
|
|
103
|
-
def __init__(self, **kwargs):
|
104
|
-
super().__init__(**kwargs)
|
105
|
-
|
106
140
|
def call(self, inputs, guidance_scale):
|
107
141
|
positive_noise, negative_noise = ops.split(inputs, 2, axis=0)
|
108
142
|
return ops.add(
|
@@ -112,9 +146,6 @@ class ClassifierFreeGuidance(layers.Layer):
|
|
112
146
|
),
|
113
147
|
)
|
114
148
|
|
115
|
-
def get_config(self):
|
116
|
-
return super().get_config()
|
117
|
-
|
118
149
|
def compute_output_shape(self, inputs_shape):
|
119
150
|
outputs_shape = list(inputs_shape)
|
120
151
|
if outputs_shape[0] is not None:
|
@@ -142,16 +173,10 @@ class EulerStep(layers.Layer):
|
|
142
173
|
https://arxiv.org/abs/2206.00364).
|
143
174
|
"""
|
144
175
|
|
145
|
-
def __init__(self, **kwargs):
|
146
|
-
super().__init__(**kwargs)
|
147
|
-
|
148
176
|
def call(self, latents, noise_residual, sigma, sigma_next):
|
149
177
|
sigma_diff = ops.subtract(sigma_next, sigma)
|
150
178
|
return ops.add(latents, ops.multiply(sigma_diff, noise_residual))
|
151
179
|
|
152
|
-
def get_config(self):
|
153
|
-
return super().get_config()
|
154
|
-
|
155
180
|
def compute_output_shape(self, latents_shape):
|
156
181
|
return latents_shape
|
157
182
|
|
@@ -272,12 +297,13 @@ class StableDiffusion3Backbone(Backbone):
|
|
272
297
|
self.clip_l_projection = CLIPProjection(
|
273
298
|
clip_l.hidden_dim, dtype=dtype, name="clip_l_projection"
|
274
299
|
)
|
275
|
-
self.clip_l_projection.build([None, clip_l.hidden_dim], None)
|
276
300
|
self.clip_g = clip_g
|
277
301
|
self.clip_g_projection = CLIPProjection(
|
278
302
|
clip_g.hidden_dim, dtype=dtype, name="clip_g_projection"
|
279
303
|
)
|
280
|
-
self.
|
304
|
+
self.clip_concatenate = CLIPConcatenate(
|
305
|
+
dtype=dtype, name="clip_concatenate"
|
306
|
+
)
|
281
307
|
self.t5 = t5
|
282
308
|
self.diffuser = MMDiT(
|
283
309
|
mmdit_patch_size,
|
@@ -293,6 +319,12 @@ class StableDiffusion3Backbone(Backbone):
|
|
293
319
|
name="diffuser",
|
294
320
|
)
|
295
321
|
self.vae = vae
|
322
|
+
self.cfg_concat = ClassifierFreeGuidanceConcatenate(
|
323
|
+
dtype=dtype, name="classifier_free_guidance_concat"
|
324
|
+
)
|
325
|
+
self.cfg = ClassifierFreeGuidance(
|
326
|
+
dtype=dtype, name="classifier_free_guidance"
|
327
|
+
)
|
296
328
|
# Set `dtype="float32"` to ensure the high precision for the noise
|
297
329
|
# residual.
|
298
330
|
self.scheduler = FlowMatchEulerDiscreteScheduler(
|
@@ -301,17 +333,17 @@ class StableDiffusion3Backbone(Backbone):
|
|
301
333
|
dtype="float32",
|
302
334
|
name="scheduler",
|
303
335
|
)
|
304
|
-
self.cfg_concat = ClassifierFreeGuidanceConcatenate(
|
305
|
-
dtype="float32", name="classifier_free_guidance_concat"
|
306
|
-
)
|
307
|
-
self.cfg = ClassifierFreeGuidance(
|
308
|
-
dtype="float32", name="classifier_free_guidance"
|
309
|
-
)
|
310
336
|
self.euler_step = EulerStep(dtype="float32", name="euler_step")
|
311
|
-
self.
|
312
|
-
scale=
|
337
|
+
self.image_rescaling = ImageRescaling(
|
338
|
+
scale=self.vae.scale,
|
313
339
|
offset=self.vae.shift,
|
314
|
-
dtype=
|
340
|
+
dtype=dtype,
|
341
|
+
name="image_rescaling",
|
342
|
+
)
|
343
|
+
self.latent_rescaling = LatentRescaling(
|
344
|
+
scale=self.vae.scale,
|
345
|
+
offset=self.vae.shift,
|
346
|
+
dtype=dtype,
|
315
347
|
name="latent_rescaling",
|
316
348
|
)
|
317
349
|
|
@@ -440,8 +472,12 @@ class StableDiffusion3Backbone(Backbone):
|
|
440
472
|
t5_hidden_dim = self.t5_hidden_dim
|
441
473
|
|
442
474
|
def encode(token_ids):
|
443
|
-
clip_l_outputs = self.clip_l(
|
444
|
-
|
475
|
+
clip_l_outputs = self.clip_l(
|
476
|
+
{"token_ids": token_ids["clip_l"]}, training=False
|
477
|
+
)
|
478
|
+
clip_g_outputs = self.clip_g(
|
479
|
+
{"token_ids": token_ids["clip_g"]}, training=False
|
480
|
+
)
|
445
481
|
clip_l_projection = self.clip_l_projection(
|
446
482
|
clip_l_outputs["sequence_output"],
|
447
483
|
token_ids["clip_l"],
|
@@ -452,23 +488,21 @@ class StableDiffusion3Backbone(Backbone):
|
|
452
488
|
token_ids["clip_g"],
|
453
489
|
training=False,
|
454
490
|
)
|
455
|
-
pooled_embeddings =
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
clip_l_outputs["intermediate_output"],
|
462
|
-
clip_g_outputs["intermediate_output"],
|
463
|
-
],
|
464
|
-
axis=-1,
|
465
|
-
)
|
466
|
-
embeddings = ops.pad(
|
467
|
-
embeddings,
|
468
|
-
[[0, 0], [0, 0], [0, t5_hidden_dim - clip_hidden_dim]],
|
491
|
+
pooled_embeddings, embeddings = self.clip_concatenate(
|
492
|
+
clip_l_projection,
|
493
|
+
clip_g_projection,
|
494
|
+
clip_l_outputs["intermediate_output"],
|
495
|
+
clip_g_outputs["intermediate_output"],
|
496
|
+
padding=t5_hidden_dim - clip_hidden_dim,
|
469
497
|
)
|
470
498
|
if self.t5 is not None:
|
471
|
-
t5_outputs = self.t5(
|
499
|
+
t5_outputs = self.t5(
|
500
|
+
{
|
501
|
+
"token_ids": token_ids["t5"],
|
502
|
+
"padding_mask": ops.ones_like(token_ids["t5"]),
|
503
|
+
},
|
504
|
+
training=False,
|
505
|
+
)
|
472
506
|
embeddings = ops.concatenate([embeddings, t5_outputs], axis=-2)
|
473
507
|
else:
|
474
508
|
padded_size = self.clip_l.max_sequence_length
|
@@ -490,9 +524,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
490
524
|
|
491
525
|
def encode_image_step(self, images):
|
492
526
|
latents = self.vae.encode(images)
|
493
|
-
return
|
494
|
-
ops.subtract(latents, self.vae.shift), self.vae.scale
|
495
|
-
)
|
527
|
+
return self.image_rescaling(latents)
|
496
528
|
|
497
529
|
def add_noise_step(self, latents, noises, step, num_steps):
|
498
530
|
return self.scheduler.add_noise(latents, noises, step, num_steps)
|
@@ -0,0 +1,171 @@
|
|
1
|
+
from keras import ops
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.image_to_image import ImageToImage
|
5
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
|
6
|
+
StableDiffusion3Backbone,
|
7
|
+
)
|
8
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
|
9
|
+
StableDiffusion3TextToImagePreprocessor,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
@keras_hub_export("keras_hub.models.StableDiffusion3ImageToImage")
|
14
|
+
class StableDiffusion3ImageToImage(ImageToImage):
|
15
|
+
"""An end-to-end Stable Diffusion 3 model for image-to-image generation.
|
16
|
+
|
17
|
+
This model has a `generate()` method, which generates images based
|
18
|
+
on a combination of a reference image and a text prompt.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
backbone: A `keras_hub.models.StableDiffusion3Backbone` instance.
|
22
|
+
preprocessor: A
|
23
|
+
`keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance.
|
24
|
+
|
25
|
+
Examples:
|
26
|
+
|
27
|
+
Use `generate()` to do image generation.
|
28
|
+
```python
|
29
|
+
image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset(
|
30
|
+
"stable_diffusion_3_medium", height=512, width=512
|
31
|
+
)
|
32
|
+
image_to_image.generate(
|
33
|
+
{
|
34
|
+
"images": np.ones((512, 512, 3), dtype="float32"),
|
35
|
+
"prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
36
|
+
}
|
37
|
+
)
|
38
|
+
|
39
|
+
# Generate with batched prompts.
|
40
|
+
image_to_image.generate(
|
41
|
+
{
|
42
|
+
"images": np.ones((2, 512, 512, 3), dtype="float32"),
|
43
|
+
"prompts": ["cute wallpaper art of a cat", "cute wallpaper art of a dog"],
|
44
|
+
}
|
45
|
+
)
|
46
|
+
|
47
|
+
# Generate with different `num_steps`, `guidance_scale` and `strength`.
|
48
|
+
image_to_image.generate(
|
49
|
+
{
|
50
|
+
"images": np.ones((512, 512, 3), dtype="float32"),
|
51
|
+
"prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
52
|
+
}
|
53
|
+
num_steps=50,
|
54
|
+
guidance_scale=5.0,
|
55
|
+
strength=0.6,
|
56
|
+
)
|
57
|
+
|
58
|
+
# Generate with `negative_prompts`.
|
59
|
+
text_to_image.generate(
|
60
|
+
{
|
61
|
+
"images": np.ones((512, 512, 3), dtype="float32"),
|
62
|
+
"prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
63
|
+
"negative_prompts": "green color",
|
64
|
+
}
|
65
|
+
)
|
66
|
+
```
|
67
|
+
"""
|
68
|
+
|
69
|
+
backbone_cls = StableDiffusion3Backbone
|
70
|
+
preprocessor_cls = StableDiffusion3TextToImagePreprocessor
|
71
|
+
|
72
|
+
def __init__(
|
73
|
+
self,
|
74
|
+
backbone,
|
75
|
+
preprocessor,
|
76
|
+
**kwargs,
|
77
|
+
):
|
78
|
+
# === Layers ===
|
79
|
+
self.backbone = backbone
|
80
|
+
self.preprocessor = preprocessor
|
81
|
+
|
82
|
+
# === Functional Model ===
|
83
|
+
inputs = backbone.input
|
84
|
+
outputs = backbone.output
|
85
|
+
super().__init__(
|
86
|
+
inputs=inputs,
|
87
|
+
outputs=outputs,
|
88
|
+
**kwargs,
|
89
|
+
)
|
90
|
+
|
91
|
+
def fit(self, *args, **kwargs):
|
92
|
+
raise NotImplementedError(
|
93
|
+
"Currently, `fit` is not supported for "
|
94
|
+
"`StableDiffusion3ImageToImage`."
|
95
|
+
)
|
96
|
+
|
97
|
+
def generate_step(
|
98
|
+
self,
|
99
|
+
images,
|
100
|
+
noises,
|
101
|
+
token_ids,
|
102
|
+
starting_step,
|
103
|
+
num_steps,
|
104
|
+
guidance_scale,
|
105
|
+
):
|
106
|
+
"""A compilable generation function for batched of inputs.
|
107
|
+
|
108
|
+
This function represents the inner, XLA-compilable, generation function
|
109
|
+
for batched inputs.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
images: A (batch_size, image_height, image_width, 3) tensor
|
113
|
+
containing the reference images.
|
114
|
+
noises: A (batch_size, latent_height, latent_width, channels) tensor
|
115
|
+
containing the noises to be added to the latents. Typically,
|
116
|
+
this tensor is sampled from the Gaussian distribution.
|
117
|
+
token_ids: A pair of (batch_size, num_tokens) tensor containing the
|
118
|
+
tokens based on the input prompts and negative prompts.
|
119
|
+
starting_step: int. The number of the starting diffusion step.
|
120
|
+
num_steps: int. The number of diffusion steps to take.
|
121
|
+
guidance_scale: float. The classifier free guidance scale defined in
|
122
|
+
[Classifier-Free Diffusion Guidance](
|
123
|
+
https://arxiv.org/abs/2207.12598). Higher scale encourages to
|
124
|
+
generate images that are closely linked to prompts, usually at
|
125
|
+
the expense of lower image quality.
|
126
|
+
"""
|
127
|
+
token_ids, negative_token_ids = token_ids
|
128
|
+
|
129
|
+
# Encode images.
|
130
|
+
latents = self.backbone.encode_image_step(images)
|
131
|
+
|
132
|
+
# Add noises to latents.
|
133
|
+
latents = self.backbone.add_noise_step(
|
134
|
+
latents, noises, starting_step, num_steps
|
135
|
+
)
|
136
|
+
|
137
|
+
# Encode inputs.
|
138
|
+
embeddings = self.backbone.encode_text_step(
|
139
|
+
token_ids, negative_token_ids
|
140
|
+
)
|
141
|
+
|
142
|
+
# Denoise.
|
143
|
+
def body_fun(step, latents):
|
144
|
+
return self.backbone.denoise_step(
|
145
|
+
latents,
|
146
|
+
embeddings,
|
147
|
+
step,
|
148
|
+
num_steps,
|
149
|
+
guidance_scale,
|
150
|
+
)
|
151
|
+
|
152
|
+
latents = ops.fori_loop(starting_step, num_steps, body_fun, latents)
|
153
|
+
|
154
|
+
# Decode.
|
155
|
+
return self.backbone.decode_step(latents)
|
156
|
+
|
157
|
+
def generate(
|
158
|
+
self,
|
159
|
+
inputs,
|
160
|
+
num_steps=50,
|
161
|
+
guidance_scale=7.0,
|
162
|
+
strength=0.8,
|
163
|
+
seed=None,
|
164
|
+
):
|
165
|
+
return super().generate(
|
166
|
+
inputs,
|
167
|
+
num_steps=num_steps,
|
168
|
+
guidance_scale=guidance_scale,
|
169
|
+
strength=strength,
|
170
|
+
seed=seed,
|
171
|
+
)
|
@@ -0,0 +1,194 @@
|
|
1
|
+
from keras import ops
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.inpaint import Inpaint
|
5
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
|
6
|
+
StableDiffusion3Backbone,
|
7
|
+
)
|
8
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
|
9
|
+
StableDiffusion3TextToImagePreprocessor,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
@keras_hub_export("keras_hub.models.StableDiffusion3Inpaint")
|
14
|
+
class StableDiffusion3Inpaint(Inpaint):
|
15
|
+
"""An end-to-end Stable Diffusion 3 model for inpaint generation.
|
16
|
+
|
17
|
+
This model has a `generate()` method, which generates images based
|
18
|
+
on a combination of a reference image, mask and a text prompt.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
backbone: A `keras_hub.models.StableDiffusion3Backbone` instance.
|
22
|
+
preprocessor: A
|
23
|
+
`keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance.
|
24
|
+
|
25
|
+
Examples:
|
26
|
+
|
27
|
+
Use `generate()` to do image generation.
|
28
|
+
```python
|
29
|
+
reference_image = np.ones((1024, 1024, 3), dtype="float32")
|
30
|
+
reference_mask = np.ones((1024, 1024), dtype="float32")
|
31
|
+
inpaint = keras_hub.models.StableDiffusion3Inpaint.from_preset(
|
32
|
+
"stable_diffusion_3_medium", height=512, width=512
|
33
|
+
)
|
34
|
+
inpaint.generate(
|
35
|
+
reference_image,
|
36
|
+
reference_mask,
|
37
|
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
38
|
+
)
|
39
|
+
|
40
|
+
# Generate with batched prompts.
|
41
|
+
reference_images = np.ones((2, 512, 512, 3), dtype="float32")
|
42
|
+
reference_mask = np.ones((2, 1024, 1024), dtype="float32")
|
43
|
+
inpaint.generate(
|
44
|
+
reference_images,
|
45
|
+
reference_mask,
|
46
|
+
["cute wallpaper art of a cat", "cute wallpaper art of a dog"]
|
47
|
+
)
|
48
|
+
|
49
|
+
# Generate with different `num_steps`, `guidance_scale` and `strength`.
|
50
|
+
inpaint.generate(
|
51
|
+
reference_image,
|
52
|
+
reference_mask,
|
53
|
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
54
|
+
num_steps=50,
|
55
|
+
guidance_scale=5.0,
|
56
|
+
strength=0.6,
|
57
|
+
)
|
58
|
+
```
|
59
|
+
"""
|
60
|
+
|
61
|
+
backbone_cls = StableDiffusion3Backbone
|
62
|
+
preprocessor_cls = StableDiffusion3TextToImagePreprocessor
|
63
|
+
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
backbone,
|
67
|
+
preprocessor,
|
68
|
+
**kwargs,
|
69
|
+
):
|
70
|
+
# === Layers ===
|
71
|
+
self.backbone = backbone
|
72
|
+
self.preprocessor = preprocessor
|
73
|
+
|
74
|
+
# === Functional Model ===
|
75
|
+
inputs = backbone.input
|
76
|
+
outputs = backbone.output
|
77
|
+
super().__init__(
|
78
|
+
inputs=inputs,
|
79
|
+
outputs=outputs,
|
80
|
+
**kwargs,
|
81
|
+
)
|
82
|
+
|
83
|
+
def fit(self, *args, **kwargs):
|
84
|
+
raise NotImplementedError(
|
85
|
+
"Currently, `fit` is not supported for "
|
86
|
+
"`StableDiffusion3Inpaint`."
|
87
|
+
)
|
88
|
+
|
89
|
+
def generate_step(
|
90
|
+
self,
|
91
|
+
images,
|
92
|
+
masks,
|
93
|
+
noises,
|
94
|
+
token_ids,
|
95
|
+
starting_step,
|
96
|
+
num_steps,
|
97
|
+
guidance_scale,
|
98
|
+
):
|
99
|
+
"""A compilable generation function for batched of inputs.
|
100
|
+
|
101
|
+
This function represents the inner, XLA-compilable, generation function
|
102
|
+
for batched inputs.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
images: A (batch_size, image_height, image_width, 3) tensor
|
106
|
+
containing the reference images.
|
107
|
+
masks: A (batch_size, image_height, image_width) tensor
|
108
|
+
containing the reference masks.
|
109
|
+
noises: A (batch_size, latent_height, latent_width, channels) tensor
|
110
|
+
containing the noises to be added to the latents. Typically,
|
111
|
+
this tensor is sampled from the Gaussian distribution.
|
112
|
+
token_ids: A pair of (batch_size, num_tokens) tensor containing the
|
113
|
+
tokens based on the input prompts and negative prompts.
|
114
|
+
starting_step: int. The number of the starting diffusion step.
|
115
|
+
num_steps: int. The number of diffusion steps to take.
|
116
|
+
guidance_scale: float. The classifier free guidance scale defined in
|
117
|
+
[Classifier-Free Diffusion Guidance](
|
118
|
+
https://arxiv.org/abs/2207.12598). Higher scale encourages to
|
119
|
+
generate images that are closely linked to prompts, usually at
|
120
|
+
the expense of lower image quality.
|
121
|
+
"""
|
122
|
+
token_ids, negative_token_ids = token_ids
|
123
|
+
|
124
|
+
# Get masked images.
|
125
|
+
masks = ops.cast(ops.expand_dims(masks, axis=-1) > 0.5, images.dtype)
|
126
|
+
masks_latent_size = ops.image.resize(
|
127
|
+
masks,
|
128
|
+
(self.backbone.latent_shape[1], self.backbone.latent_shape[2]),
|
129
|
+
interpolation="nearest",
|
130
|
+
)
|
131
|
+
|
132
|
+
# Encode images.
|
133
|
+
image_latents = self.backbone.encode_image_step(images)
|
134
|
+
|
135
|
+
# Add noises to latents.
|
136
|
+
latents = self.backbone.add_noise_step(
|
137
|
+
image_latents, noises, starting_step, num_steps
|
138
|
+
)
|
139
|
+
|
140
|
+
# Encode inputs.
|
141
|
+
embeddings = self.backbone.encode_text_step(
|
142
|
+
token_ids, negative_token_ids
|
143
|
+
)
|
144
|
+
|
145
|
+
# Denoise.
|
146
|
+
def body_fun(step, latents):
|
147
|
+
latents = self.backbone.denoise_step(
|
148
|
+
latents,
|
149
|
+
embeddings,
|
150
|
+
step,
|
151
|
+
num_steps,
|
152
|
+
guidance_scale,
|
153
|
+
)
|
154
|
+
|
155
|
+
# Compute the previous latents x_t -> x_t-1.
|
156
|
+
def true_fn():
|
157
|
+
next_step = ops.add(step, 1)
|
158
|
+
return self.backbone.add_noise_step(
|
159
|
+
image_latents, noises, next_step, num_steps
|
160
|
+
)
|
161
|
+
|
162
|
+
init_latents = ops.cond(
|
163
|
+
step < ops.subtract(num_steps, 1),
|
164
|
+
true_fn,
|
165
|
+
lambda: ops.cast(image_latents, noises.dtype),
|
166
|
+
)
|
167
|
+
latents = ops.add(
|
168
|
+
ops.multiply(
|
169
|
+
ops.subtract(1.0, masks_latent_size), init_latents
|
170
|
+
),
|
171
|
+
ops.multiply(masks_latent_size, latents),
|
172
|
+
)
|
173
|
+
return latents
|
174
|
+
|
175
|
+
latents = ops.fori_loop(starting_step, num_steps, body_fun, latents)
|
176
|
+
|
177
|
+
# Decode.
|
178
|
+
return self.backbone.decode_step(latents)
|
179
|
+
|
180
|
+
def generate(
|
181
|
+
self,
|
182
|
+
inputs,
|
183
|
+
num_steps=50,
|
184
|
+
guidance_scale=7.0,
|
185
|
+
strength=0.6,
|
186
|
+
seed=None,
|
187
|
+
):
|
188
|
+
return super().generate(
|
189
|
+
inputs,
|
190
|
+
num_steps=num_steps,
|
191
|
+
guidance_scale=guidance_scale,
|
192
|
+
strength=strength,
|
193
|
+
seed=seed,
|
194
|
+
)
|
@@ -13,6 +13,6 @@ backbone_presets = {
|
|
13
13
|
"path": "stablediffusion3",
|
14
14
|
"model_card": "https://arxiv.org/abs/2110.00476",
|
15
15
|
},
|
16
|
-
"kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/
|
16
|
+
"kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/4",
|
17
17
|
}
|
18
18
|
}
|
@@ -44,6 +44,14 @@ class StableDiffusion3TextToImage(TextToImage):
|
|
44
44
|
num_steps=50,
|
45
45
|
guidance_scale=5.0,
|
46
46
|
)
|
47
|
+
|
48
|
+
# Generate with `negative_prompts`.
|
49
|
+
text_to_image.generate(
|
50
|
+
{
|
51
|
+
"prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
52
|
+
"negative_prompts": "green color",
|
53
|
+
}
|
54
|
+
)
|
47
55
|
```
|
48
56
|
"""
|
49
57
|
|
@@ -79,7 +87,6 @@ class StableDiffusion3TextToImage(TextToImage):
|
|
79
87
|
self,
|
80
88
|
latents,
|
81
89
|
token_ids,
|
82
|
-
negative_token_ids,
|
83
90
|
num_steps,
|
84
91
|
guidance_scale,
|
85
92
|
):
|
@@ -92,10 +99,8 @@ class StableDiffusion3TextToImage(TextToImage):
|
|
92
99
|
latents: A (batch_size, height, width, channels) tensor
|
93
100
|
containing the latents to start generation from. Typically, this
|
94
101
|
tensor is sampled from the Gaussian distribution.
|
95
|
-
token_ids: A (batch_size, num_tokens) tensor containing the
|
96
|
-
tokens based on the input prompts.
|
97
|
-
negative_token_ids: A (batch_size, num_tokens) tensor
|
98
|
-
containing the negative tokens based on the input prompts.
|
102
|
+
token_ids: A pair of (batch_size, num_tokens) tensor containing the
|
103
|
+
tokens based on the input prompts and negative prompts.
|
99
104
|
num_steps: int. The number of diffusion steps to take.
|
100
105
|
guidance_scale: float. The classifier free guidance scale defined in
|
101
106
|
[Classifier-Free Diffusion Guidance](
|
@@ -103,7 +108,9 @@ class StableDiffusion3TextToImage(TextToImage):
|
|
103
108
|
generate images that are closely linked to prompts, usually at
|
104
109
|
the expense of lower image quality.
|
105
110
|
"""
|
106
|
-
|
111
|
+
token_ids, negative_token_ids = token_ids
|
112
|
+
|
113
|
+
# Encode prompts.
|
107
114
|
embeddings = self.backbone.encode_text_step(
|
108
115
|
token_ids, negative_token_ids
|
109
116
|
)
|
@@ -126,14 +133,12 @@ class StableDiffusion3TextToImage(TextToImage):
|
|
126
133
|
def generate(
|
127
134
|
self,
|
128
135
|
inputs,
|
129
|
-
negative_inputs=None,
|
130
136
|
num_steps=28,
|
131
137
|
guidance_scale=7.0,
|
132
138
|
seed=None,
|
133
139
|
):
|
134
140
|
return super().generate(
|
135
141
|
inputs,
|
136
|
-
negative_inputs=negative_inputs,
|
137
142
|
num_steps=num_steps,
|
138
143
|
guidance_scale=guidance_scale,
|
139
144
|
seed=seed,
|
keras_hub/src/models/task.py
CHANGED
@@ -339,7 +339,7 @@ class Task(PipelineModel):
|
|
339
339
|
add_layer(layer, info)
|
340
340
|
elif isinstance(layer, ImageConverter):
|
341
341
|
info = "Image size: "
|
342
|
-
info += highlight_shape(layer.image_size
|
342
|
+
info += highlight_shape(layer.image_size)
|
343
343
|
add_layer(layer, info)
|
344
344
|
elif isinstance(layer, AudioConverter):
|
345
345
|
info = "Audio shape: "
|