keras-hub-nightly 0.16.1.dev202409250340__py3-none-any.whl → 0.16.1.dev202409260340__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -0
- keras_hub/api/models/__init__.py +16 -0
- keras_hub/api/tokenizers/__init__.py +1 -0
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -2
- keras_hub/src/models/clip/clip_preprocessor.py +147 -0
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_text_encoder.py +60 -57
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +69 -30
- keras_hub/src/models/densenet/__init__.py +6 -0
- keras_hub/src/models/densenet/densenet_backbone.py +11 -8
- keras_hub/src/models/densenet/densenet_image_classifier.py +27 -4
- keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
- keras_hub/src/models/densenet/densenet_image_converter.py +23 -0
- keras_hub/src/models/densenet/densenet_presets.py +56 -0
- keras_hub/src/models/stable_diffusion_3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +93 -0
- keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -26
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +630 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +151 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +77 -0
- keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -7
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +333 -0
- keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -3
- keras_hub/src/models/text_to_image.py +295 -0
- keras_hub/src/utils/timm/convert_densenet.py +107 -0
- keras_hub/src/utils/timm/preset_loader.py +3 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/RECORD +31 -23
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
- /keras_hub/src/models/{stable_diffusion_v3 → clip}/__init__.py +0 -0
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,151 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from keras import ops
|
15
|
+
|
16
|
+
from keras_hub.src.api_export import keras_hub_export
|
17
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
|
18
|
+
StableDiffusion3Backbone,
|
19
|
+
)
|
20
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
|
21
|
+
StableDiffusion3TextToImagePreprocessor,
|
22
|
+
)
|
23
|
+
from keras_hub.src.models.text_to_image import TextToImage
|
24
|
+
|
25
|
+
|
26
|
+
@keras_hub_export("keras_hub.models.StableDiffusion3TextToImage")
|
27
|
+
class StableDiffusion3TextToImage(TextToImage):
|
28
|
+
"""An end-to-end Stable Diffusion 3 model for text-to-image generation.
|
29
|
+
|
30
|
+
This model has a `generate()` method, which generates image based on a
|
31
|
+
prompt.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
backbone: A `keras_hub.models.StableDiffusion3Backbone` instance.
|
35
|
+
preprocessor: A
|
36
|
+
`keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance.
|
37
|
+
|
38
|
+
Examples:
|
39
|
+
|
40
|
+
Use `generate()` to do image generation.
|
41
|
+
```python
|
42
|
+
text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
|
43
|
+
"stable_diffusion_3_medium", height=512, width=512
|
44
|
+
)
|
45
|
+
text_to_image.generate(
|
46
|
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
47
|
+
)
|
48
|
+
|
49
|
+
# Generate with batched prompts.
|
50
|
+
text_to_image.generate(
|
51
|
+
["cute wallpaper art of a cat", "cute wallpaper art of a dog"]
|
52
|
+
)
|
53
|
+
|
54
|
+
# Generate with different `num_steps` and `classifier_free_guidance_scale`.
|
55
|
+
text_to_image.generate(
|
56
|
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
57
|
+
num_steps=50,
|
58
|
+
classifier_free_guidance_scale=5.0,
|
59
|
+
)
|
60
|
+
```
|
61
|
+
"""
|
62
|
+
|
63
|
+
backbone_cls = StableDiffusion3Backbone
|
64
|
+
preprocessor_cls = StableDiffusion3TextToImagePreprocessor
|
65
|
+
|
66
|
+
def __init__(
|
67
|
+
self,
|
68
|
+
backbone,
|
69
|
+
preprocessor,
|
70
|
+
**kwargs,
|
71
|
+
):
|
72
|
+
# === Layers ===
|
73
|
+
self.backbone = backbone
|
74
|
+
self.preprocessor = preprocessor
|
75
|
+
|
76
|
+
# === Functional Model ===
|
77
|
+
inputs = backbone.input
|
78
|
+
outputs = backbone.output
|
79
|
+
super().__init__(
|
80
|
+
inputs=inputs,
|
81
|
+
outputs=outputs,
|
82
|
+
**kwargs,
|
83
|
+
)
|
84
|
+
|
85
|
+
def fit(self, *args, **kwargs):
|
86
|
+
raise NotImplementedError(
|
87
|
+
"Currently, `fit` is not supported for "
|
88
|
+
"`StableDiffusion3TextToImage`."
|
89
|
+
)
|
90
|
+
|
91
|
+
def generate_step(
|
92
|
+
self,
|
93
|
+
latents,
|
94
|
+
token_ids,
|
95
|
+
negative_token_ids,
|
96
|
+
num_steps,
|
97
|
+
guidance_scale,
|
98
|
+
):
|
99
|
+
"""A compilable generation function for batched of inputs.
|
100
|
+
|
101
|
+
This function represents the inner, XLA-compilable, generation function
|
102
|
+
for batched inputs.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
latents: A (batch_size, height, width, channels) tensor
|
106
|
+
containing the latents to start generation from. Typically, this
|
107
|
+
tensor is sampled from the Gaussian distribution.
|
108
|
+
token_ids: A (batch_size, num_tokens) tensor containing the
|
109
|
+
tokens based on the input prompts.
|
110
|
+
negative_token_ids: A (batch_size, num_tokens) tensor
|
111
|
+
containing the negative tokens based on the input prompts.
|
112
|
+
num_steps: int. The number of diffusion steps to take.
|
113
|
+
guidance_scale: float. The classifier free guidance scale defined in
|
114
|
+
[Classifier-Free Diffusion Guidance](
|
115
|
+
https://arxiv.org/abs/2207.12598). Higher scale encourages to
|
116
|
+
generate images that are closely linked to prompts, usually at
|
117
|
+
the expense of lower image quality.
|
118
|
+
"""
|
119
|
+
# Encode inputs.
|
120
|
+
embeddings = self.backbone.encode_step(token_ids, negative_token_ids)
|
121
|
+
|
122
|
+
# Denoise.
|
123
|
+
def body_fun(step, latents):
|
124
|
+
return self.backbone.denoise_step(
|
125
|
+
latents,
|
126
|
+
embeddings,
|
127
|
+
step,
|
128
|
+
num_steps,
|
129
|
+
guidance_scale,
|
130
|
+
)
|
131
|
+
|
132
|
+
latents = ops.fori_loop(0, num_steps, body_fun, latents)
|
133
|
+
|
134
|
+
# Decode.
|
135
|
+
return self.backbone.decode_step(latents)
|
136
|
+
|
137
|
+
def generate(
|
138
|
+
self,
|
139
|
+
inputs,
|
140
|
+
negative_inputs=None,
|
141
|
+
num_steps=28,
|
142
|
+
guidance_scale=7.0,
|
143
|
+
seed=None,
|
144
|
+
):
|
145
|
+
return super().generate(
|
146
|
+
inputs,
|
147
|
+
negative_inputs=negative_inputs,
|
148
|
+
num_steps=num_steps,
|
149
|
+
guidance_scale=guidance_scale,
|
150
|
+
seed=seed,
|
151
|
+
)
|
@@ -0,0 +1,77 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from keras import layers
|
15
|
+
|
16
|
+
from keras_hub.src.api_export import keras_hub_export
|
17
|
+
from keras_hub.src.models.preprocessor import Preprocessor
|
18
|
+
|
19
|
+
|
20
|
+
@keras_hub_export("keras_hub.models.StableDiffusion3TextToImagePreprocessor")
|
21
|
+
class StableDiffusion3TextToImagePreprocessor(Preprocessor):
|
22
|
+
"""Stable Diffusion 3 text-to-image model preprocessor.
|
23
|
+
|
24
|
+
This preprocessing layer is meant for use with
|
25
|
+
`keras_hub.models.StableDiffusion3TextToImage`.
|
26
|
+
|
27
|
+
For use with generation, the layer exposes one methods
|
28
|
+
`generate_preprocess()`.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
clip_l_preprocessor: A `keras_hub.models.CLIPPreprocessor` instance.
|
32
|
+
clip_g_preprocessor: A `keras_hub.models.CLIPPreprocessor` instance.
|
33
|
+
t5_preprocessor: A optional `keras_hub.models.T5Preprocessor` instance.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
clip_l_preprocessor,
|
39
|
+
clip_g_preprocessor,
|
40
|
+
t5_preprocessor=None,
|
41
|
+
**kwargs,
|
42
|
+
):
|
43
|
+
super().__init__(**kwargs)
|
44
|
+
self.clip_l_preprocessor = clip_l_preprocessor
|
45
|
+
self.clip_g_preprocessor = clip_g_preprocessor
|
46
|
+
self.t5_preprocessor = t5_preprocessor
|
47
|
+
|
48
|
+
def build(self, input_shape):
|
49
|
+
self.built = True
|
50
|
+
|
51
|
+
def generate_preprocess(self, x):
|
52
|
+
token_ids = {}
|
53
|
+
token_ids["clip_l"] = self.clip_l_preprocessor(x)["token_ids"]
|
54
|
+
token_ids["clip_g"] = self.clip_g_preprocessor(x)["token_ids"]
|
55
|
+
if self.t5_preprocessor is not None:
|
56
|
+
token_ids["t5"] = self.t5_preprocessor(x)["token_ids"]
|
57
|
+
return token_ids
|
58
|
+
|
59
|
+
def get_config(self):
|
60
|
+
config = super().get_config()
|
61
|
+
config.update(
|
62
|
+
{
|
63
|
+
"clip_l_preprocessor": layers.serialize(
|
64
|
+
self.clip_l_preprocessor
|
65
|
+
),
|
66
|
+
"clip_g_preprocessor": layers.serialize(
|
67
|
+
self.clip_g_preprocessor
|
68
|
+
),
|
69
|
+
"t5_preprocessor": layers.serialize(self.t5_preprocessor),
|
70
|
+
}
|
71
|
+
)
|
72
|
+
return config
|
73
|
+
|
74
|
+
@property
|
75
|
+
def sequence_length(self):
|
76
|
+
"""The padded length of model input sequences."""
|
77
|
+
return self.clip_l_preprocessor.sequence_length
|
keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py}
RENAMED
@@ -20,7 +20,7 @@ from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm
|
|
20
20
|
from keras_hub.src.models.t5.t5_transformer_layer import T5TransformerLayer
|
21
21
|
|
22
22
|
|
23
|
-
class
|
23
|
+
class T5Encoder(keras.Model):
|
24
24
|
def __init__(
|
25
25
|
self,
|
26
26
|
vocabulary_size,
|
@@ -81,10 +81,10 @@ class T5XXLTextEncoder(keras.Model):
|
|
81
81
|
|
82
82
|
# === Functional Model ===
|
83
83
|
encoder_token_id_input = keras.Input(
|
84
|
-
shape=(None,), dtype="int32", name="
|
84
|
+
shape=(None,), dtype="int32", name="token_ids"
|
85
85
|
)
|
86
86
|
encoder_padding_mask_input = keras.Input(
|
87
|
-
shape=(None,), dtype="int32", name="
|
87
|
+
shape=(None,), dtype="int32", name="padding_mask"
|
88
88
|
)
|
89
89
|
# Encoder.
|
90
90
|
x = self.token_embedding(encoder_token_id_input)
|
@@ -102,14 +102,14 @@ class T5XXLTextEncoder(keras.Model):
|
|
102
102
|
x, position_bias = output
|
103
103
|
x = self.encoder_layer_norm(x)
|
104
104
|
x = self.encoder_dropout(x)
|
105
|
-
|
105
|
+
sequence_output = x
|
106
106
|
|
107
107
|
super().__init__(
|
108
108
|
{
|
109
|
-
"
|
110
|
-
"
|
109
|
+
"token_ids": encoder_token_id_input,
|
110
|
+
"padding_mask": encoder_padding_mask_input,
|
111
111
|
},
|
112
|
-
outputs=
|
112
|
+
outputs=sequence_output,
|
113
113
|
**kwargs,
|
114
114
|
)
|
115
115
|
|
@@ -0,0 +1,333 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import math
|
15
|
+
|
16
|
+
from keras import layers
|
17
|
+
from keras import ops
|
18
|
+
|
19
|
+
from keras_hub.src.models.backbone import Backbone
|
20
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
21
|
+
|
22
|
+
|
23
|
+
class VAEAttention(layers.Layer):
|
24
|
+
def __init__(self, filters, groups=32, data_format=None, **kwargs):
|
25
|
+
super().__init__(**kwargs)
|
26
|
+
self.filters = filters
|
27
|
+
self.data_format = standardize_data_format(data_format)
|
28
|
+
gn_axis = -1 if self.data_format == "channels_last" else 1
|
29
|
+
|
30
|
+
self.group_norm = layers.GroupNormalization(
|
31
|
+
groups=groups,
|
32
|
+
axis=gn_axis,
|
33
|
+
epsilon=1e-6,
|
34
|
+
dtype="float32",
|
35
|
+
name="group_norm",
|
36
|
+
)
|
37
|
+
self.query_conv2d = layers.Conv2D(
|
38
|
+
filters,
|
39
|
+
1,
|
40
|
+
1,
|
41
|
+
data_format=self.data_format,
|
42
|
+
dtype=self.dtype_policy,
|
43
|
+
name="query_conv2d",
|
44
|
+
)
|
45
|
+
self.key_conv2d = layers.Conv2D(
|
46
|
+
filters,
|
47
|
+
1,
|
48
|
+
1,
|
49
|
+
data_format=self.data_format,
|
50
|
+
dtype=self.dtype_policy,
|
51
|
+
name="key_conv2d",
|
52
|
+
)
|
53
|
+
self.value_conv2d = layers.Conv2D(
|
54
|
+
filters,
|
55
|
+
1,
|
56
|
+
1,
|
57
|
+
data_format=self.data_format,
|
58
|
+
dtype=self.dtype_policy,
|
59
|
+
name="value_conv2d",
|
60
|
+
)
|
61
|
+
self.softmax = layers.Softmax(dtype="float32")
|
62
|
+
self.output_conv2d = layers.Conv2D(
|
63
|
+
filters,
|
64
|
+
1,
|
65
|
+
1,
|
66
|
+
data_format=self.data_format,
|
67
|
+
dtype=self.dtype_policy,
|
68
|
+
name="output_conv2d",
|
69
|
+
)
|
70
|
+
|
71
|
+
self.groups = groups
|
72
|
+
self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
|
73
|
+
|
74
|
+
def build(self, input_shape):
|
75
|
+
self.group_norm.build(input_shape)
|
76
|
+
self.query_conv2d.build(input_shape)
|
77
|
+
self.key_conv2d.build(input_shape)
|
78
|
+
self.value_conv2d.build(input_shape)
|
79
|
+
self.output_conv2d.build(input_shape)
|
80
|
+
|
81
|
+
def call(self, inputs, training=None):
|
82
|
+
x = self.group_norm(inputs)
|
83
|
+
query = self.query_conv2d(x)
|
84
|
+
key = self.key_conv2d(x)
|
85
|
+
value = self.value_conv2d(x)
|
86
|
+
|
87
|
+
if self.data_format == "channels_first":
|
88
|
+
query = ops.transpose(query, (0, 2, 3, 1))
|
89
|
+
key = ops.transpose(key, (0, 2, 3, 1))
|
90
|
+
value = ops.transpose(value, (0, 2, 3, 1))
|
91
|
+
shape = ops.shape(inputs)
|
92
|
+
b = shape[0]
|
93
|
+
query = ops.reshape(query, (b, -1, self.filters))
|
94
|
+
key = ops.reshape(key, (b, -1, self.filters))
|
95
|
+
value = ops.reshape(value, (b, -1, self.filters))
|
96
|
+
|
97
|
+
# Compute attention.
|
98
|
+
query = ops.multiply(
|
99
|
+
query, ops.cast(self._inverse_sqrt_filters, query.dtype)
|
100
|
+
)
|
101
|
+
# [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
|
102
|
+
attention_scores = ops.einsum("abc,adc->abd", query, key)
|
103
|
+
attention_scores = ops.cast(
|
104
|
+
self.softmax(attention_scores), self.compute_dtype
|
105
|
+
)
|
106
|
+
# [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
|
107
|
+
attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
|
108
|
+
x = ops.reshape(attention_output, shape)
|
109
|
+
|
110
|
+
x = self.output_conv2d(x)
|
111
|
+
if self.data_format == "channels_first":
|
112
|
+
x = ops.transpose(x, (0, 3, 1, 2))
|
113
|
+
x = ops.add(x, inputs)
|
114
|
+
return x
|
115
|
+
|
116
|
+
def get_config(self):
|
117
|
+
config = super().get_config()
|
118
|
+
config.update(
|
119
|
+
{
|
120
|
+
"filters": self.filters,
|
121
|
+
"groups": self.groups,
|
122
|
+
}
|
123
|
+
)
|
124
|
+
return config
|
125
|
+
|
126
|
+
def compute_output_shape(self, input_shape):
|
127
|
+
return input_shape
|
128
|
+
|
129
|
+
|
130
|
+
def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
|
131
|
+
data_format = standardize_data_format(data_format)
|
132
|
+
gn_axis = -1 if data_format == "channels_last" else 1
|
133
|
+
input_filters = x.shape[gn_axis]
|
134
|
+
|
135
|
+
residual = x
|
136
|
+
x = layers.GroupNormalization(
|
137
|
+
groups=32,
|
138
|
+
axis=gn_axis,
|
139
|
+
epsilon=1e-6,
|
140
|
+
dtype="float32",
|
141
|
+
name=f"{name}_norm1",
|
142
|
+
)(x)
|
143
|
+
x = layers.Activation("swish", dtype=dtype)(x)
|
144
|
+
x = layers.Conv2D(
|
145
|
+
filters,
|
146
|
+
3,
|
147
|
+
1,
|
148
|
+
padding="same",
|
149
|
+
data_format=data_format,
|
150
|
+
dtype=dtype,
|
151
|
+
name=f"{name}_conv1",
|
152
|
+
)(x)
|
153
|
+
x = layers.GroupNormalization(
|
154
|
+
groups=32,
|
155
|
+
axis=gn_axis,
|
156
|
+
epsilon=1e-6,
|
157
|
+
dtype="float32",
|
158
|
+
name=f"{name}_norm2",
|
159
|
+
)(x)
|
160
|
+
x = layers.Activation("swish", dtype=dtype)(x)
|
161
|
+
x = layers.Conv2D(
|
162
|
+
filters,
|
163
|
+
3,
|
164
|
+
1,
|
165
|
+
padding="same",
|
166
|
+
data_format=data_format,
|
167
|
+
dtype=dtype,
|
168
|
+
name=f"{name}_conv2",
|
169
|
+
)(x)
|
170
|
+
if input_filters != filters:
|
171
|
+
residual = layers.Conv2D(
|
172
|
+
filters,
|
173
|
+
1,
|
174
|
+
1,
|
175
|
+
data_format=data_format,
|
176
|
+
dtype=dtype,
|
177
|
+
name=f"{name}_residual_projection",
|
178
|
+
)(residual)
|
179
|
+
x = layers.Add(dtype=dtype)([residual, x])
|
180
|
+
return x
|
181
|
+
|
182
|
+
|
183
|
+
class VAEImageDecoder(Backbone):
|
184
|
+
"""Decoder for the VAE model used in Stable Diffusion 3.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
stackwise_num_filters: list of ints. The number of filters for each
|
188
|
+
stack.
|
189
|
+
stackwise_num_blocks: list of ints. The number of blocks for each stack.
|
190
|
+
output_channels: int. The number of channels in the output.
|
191
|
+
latent_shape: tuple. The shape of the latent image.
|
192
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
193
|
+
`"channels_first"`. The ordering of the dimensions in the
|
194
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
195
|
+
`(batch_size, height, width, channels)`
|
196
|
+
while `"channels_first"` corresponds to inputs with shape
|
197
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
198
|
+
`image_data_format` value found in your Keras config file at
|
199
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
200
|
+
`"channels_last"`.
|
201
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
202
|
+
to use for the model's computations and weights.
|
203
|
+
"""
|
204
|
+
|
205
|
+
def __init__(
|
206
|
+
self,
|
207
|
+
stackwise_num_filters,
|
208
|
+
stackwise_num_blocks,
|
209
|
+
output_channels=3,
|
210
|
+
latent_shape=(None, None, 16),
|
211
|
+
data_format=None,
|
212
|
+
dtype=None,
|
213
|
+
**kwargs,
|
214
|
+
):
|
215
|
+
data_format = standardize_data_format(data_format)
|
216
|
+
gn_axis = -1 if data_format == "channels_last" else 1
|
217
|
+
|
218
|
+
# === Functional Model ===
|
219
|
+
latent_inputs = layers.Input(shape=latent_shape)
|
220
|
+
|
221
|
+
x = layers.Conv2D(
|
222
|
+
stackwise_num_filters[0],
|
223
|
+
3,
|
224
|
+
1,
|
225
|
+
padding="same",
|
226
|
+
data_format=data_format,
|
227
|
+
dtype=dtype,
|
228
|
+
name="input_projection",
|
229
|
+
)(latent_inputs)
|
230
|
+
x = apply_resnet_block(
|
231
|
+
x,
|
232
|
+
stackwise_num_filters[0],
|
233
|
+
data_format=data_format,
|
234
|
+
dtype=dtype,
|
235
|
+
name="input_block0",
|
236
|
+
)
|
237
|
+
x = VAEAttention(
|
238
|
+
stackwise_num_filters[0],
|
239
|
+
data_format=data_format,
|
240
|
+
dtype=dtype,
|
241
|
+
name="input_attention",
|
242
|
+
)(x)
|
243
|
+
x = apply_resnet_block(
|
244
|
+
x,
|
245
|
+
stackwise_num_filters[0],
|
246
|
+
data_format=data_format,
|
247
|
+
dtype=dtype,
|
248
|
+
name="input_block1",
|
249
|
+
)
|
250
|
+
|
251
|
+
# Stacks.
|
252
|
+
for i, filters in enumerate(stackwise_num_filters):
|
253
|
+
for j in range(stackwise_num_blocks[i]):
|
254
|
+
x = apply_resnet_block(
|
255
|
+
x,
|
256
|
+
filters,
|
257
|
+
data_format=data_format,
|
258
|
+
dtype=dtype,
|
259
|
+
name=f"block{i}_{j}",
|
260
|
+
)
|
261
|
+
if i != len(stackwise_num_filters) - 1:
|
262
|
+
# No upsamling in the last blcok.
|
263
|
+
x = layers.UpSampling2D(
|
264
|
+
2,
|
265
|
+
data_format=data_format,
|
266
|
+
dtype=dtype,
|
267
|
+
name=f"upsample_{i}",
|
268
|
+
)(x)
|
269
|
+
x = layers.Conv2D(
|
270
|
+
filters,
|
271
|
+
3,
|
272
|
+
1,
|
273
|
+
padding="same",
|
274
|
+
data_format=data_format,
|
275
|
+
dtype=dtype,
|
276
|
+
name=f"upsample_{i}_conv",
|
277
|
+
)(x)
|
278
|
+
|
279
|
+
# Ouput block.
|
280
|
+
x = layers.GroupNormalization(
|
281
|
+
groups=32,
|
282
|
+
axis=gn_axis,
|
283
|
+
epsilon=1e-6,
|
284
|
+
dtype="float32",
|
285
|
+
name="output_norm",
|
286
|
+
)(x)
|
287
|
+
x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
|
288
|
+
image_outputs = layers.Conv2D(
|
289
|
+
output_channels,
|
290
|
+
3,
|
291
|
+
1,
|
292
|
+
padding="same",
|
293
|
+
data_format=data_format,
|
294
|
+
dtype=dtype,
|
295
|
+
name="output_projection",
|
296
|
+
)(x)
|
297
|
+
super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)
|
298
|
+
|
299
|
+
# === Config ===
|
300
|
+
self.stackwise_num_filters = stackwise_num_filters
|
301
|
+
self.stackwise_num_blocks = stackwise_num_blocks
|
302
|
+
self.output_channels = output_channels
|
303
|
+
self.latent_shape = latent_shape
|
304
|
+
|
305
|
+
@property
|
306
|
+
def scaling_factor(self):
|
307
|
+
"""The scaling factor for the latent space.
|
308
|
+
|
309
|
+
This is used to scale the latent space to have unit variance when
|
310
|
+
training the diffusion model.
|
311
|
+
"""
|
312
|
+
return 1.5305
|
313
|
+
|
314
|
+
@property
|
315
|
+
def shift_factor(self):
|
316
|
+
"""The shift factor for the latent space.
|
317
|
+
|
318
|
+
This is used to shift the latent space to have zero mean when
|
319
|
+
training the diffusion model.
|
320
|
+
"""
|
321
|
+
return 0.0609
|
322
|
+
|
323
|
+
def get_config(self):
|
324
|
+
config = super().get_config()
|
325
|
+
config.update(
|
326
|
+
{
|
327
|
+
"stackwise_num_filters": self.stackwise_num_filters,
|
328
|
+
"stackwise_num_blocks": self.stackwise_num_blocks,
|
329
|
+
"output_channels": self.output_channels,
|
330
|
+
"image_shape": self.latent_shape,
|
331
|
+
}
|
332
|
+
)
|
333
|
+
return config
|
@@ -13,13 +13,15 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
import keras
|
15
15
|
|
16
|
+
from keras_hub.src.api_export import keras_hub_export
|
16
17
|
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
|
17
18
|
from keras_hub.src.models.preprocessor import Preprocessor
|
18
19
|
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer
|
19
20
|
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
20
21
|
|
21
22
|
|
22
|
-
|
23
|
+
@keras_hub_export("keras_hub.models.T5Preprocessor")
|
24
|
+
class T5Preprocessor(Preprocessor):
|
23
25
|
tokenizer_cls = T5Tokenizer
|
24
26
|
|
25
27
|
def __init__(
|
@@ -49,10 +51,17 @@ class T5XXLPreprocessor(Preprocessor):
|
|
49
51
|
self.built = True
|
50
52
|
|
51
53
|
@preprocessing_function
|
52
|
-
def call(
|
54
|
+
def call(
|
55
|
+
self,
|
56
|
+
x,
|
57
|
+
y=None,
|
58
|
+
sample_weight=None,
|
59
|
+
sequence_length=None,
|
60
|
+
):
|
61
|
+
sequence_length = sequence_length or self.sequence_length
|
53
62
|
token_ids, padding_mask = self.packer(
|
54
63
|
self.tokenizer(x),
|
55
|
-
sequence_length=sequence_length
|
64
|
+
sequence_length=sequence_length,
|
56
65
|
add_start_value=self.add_start_token,
|
57
66
|
add_end_value=self.add_end_token,
|
58
67
|
)
|