keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +12 -0
- keras_hub/api/models/__init__.py +32 -0
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/rms_normalization.py +34 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
- keras_hub/src/layers/preprocessing/image_converter.py +5 -0
- keras_hub/src/models/albert/albert_presets.py +0 -8
- keras_hub/src/models/bart/bart_presets.py +0 -6
- keras_hub/src/models/bert/bert_presets.py +0 -20
- keras_hub/src/models/bloom/bloom_presets.py +0 -16
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
- keras_hub/src/models/densenet/densenet_backbone.py +1 -1
- keras_hub/src/models/densenet/densenet_presets.py +0 -6
- keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
- keras_hub/src/models/efficientnet/mbconv.py +52 -21
- keras_hub/src/models/electra/electra_presets.py +0 -12
- keras_hub/src/models/f_net/f_net_presets.py +0 -4
- keras_hub/src/models/falcon/falcon_presets.py +0 -2
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +494 -0
- keras_hub/src/models/flux/flux_maths.py +218 -0
- keras_hub/src/models/flux/flux_model.py +231 -0
- keras_hub/src/models/flux/flux_presets.py +14 -0
- keras_hub/src/models/flux/flux_text_to_image.py +142 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_presets.py +0 -40
- keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_to_image.py +16 -10
- keras_hub/src/models/inpaint.py +20 -13
- keras_hub/src/models/llama/llama_backbone.py +1 -1
- keras_hub/src/models/llama/llama_presets.py +5 -15
- keras_hub/src/models/llama3/llama3_presets.py +0 -8
- keras_hub/src/models/mistral/mistral_presets.py +0 -6
- keras_hub/src/models/mit/mit_backbone.py +41 -27
- keras_hub/src/models/mit/mit_layers.py +9 -7
- keras_hub/src/models/mit/mit_presets.py +12 -24
- keras_hub/src/models/opt/opt_presets.py +0 -8
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
- keras_hub/src/models/phi3/phi3_presets.py +0 -4
- keras_hub/src/models/resnet/resnet_presets.py +10 -42
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
- keras_hub/src/models/roberta/roberta_presets.py +0 -4
- keras_hub/src/models/sam/sam_backbone.py +0 -1
- keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
- keras_hub/src/models/sam/sam_presets.py +0 -6
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +163 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +124 -0
- keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +41 -13
- keras_hub/src/models/text_to_image.py +13 -5
- keras_hub/src/models/vgg/vgg_backbone.py +1 -1
- keras_hub/src/models/vgg/vgg_presets.py +0 -8
- keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
- keras_hub/src/models/whisper/whisper_presets.py +0 -20
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
- keras_hub/src/tests/test_case.py +25 -0
- keras_hub/src/utils/preset_utils.py +17 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
- keras_hub/src/utils/timm/preset_loader.py +3 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -8,9 +8,7 @@ backbone_presets = {
|
|
8
8
|
"lowercased. Trained on English Wikipedia + BooksCorpus."
|
9
9
|
),
|
10
10
|
"params": 13548800,
|
11
|
-
"official_name": "ELECTRA",
|
12
11
|
"path": "electra",
|
13
|
-
"model_card": "https://github.com/google-research/electra",
|
14
12
|
},
|
15
13
|
"kaggle_handle": "kaggle://keras/electra/keras/electra_small_discriminator_uncased_en/1",
|
16
14
|
},
|
@@ -21,9 +19,7 @@ backbone_presets = {
|
|
21
19
|
"lowercased. Trained on English Wikipedia + BooksCorpus."
|
22
20
|
),
|
23
21
|
"params": 13548800,
|
24
|
-
"official_name": "ELECTRA",
|
25
22
|
"path": "electra",
|
26
|
-
"model_card": "https://github.com/google-research/electra",
|
27
23
|
},
|
28
24
|
"kaggle_handle": "kaggle://keras/electra/keras/electra_small_generator_uncased_en/1",
|
29
25
|
},
|
@@ -34,9 +30,7 @@ backbone_presets = {
|
|
34
30
|
"lowercased. Trained on English Wikipedia + BooksCorpus."
|
35
31
|
),
|
36
32
|
"params": 109482240,
|
37
|
-
"official_name": "ELECTRA",
|
38
33
|
"path": "electra",
|
39
|
-
"model_card": "https://github.com/google-research/electra",
|
40
34
|
},
|
41
35
|
"kaggle_handle": "kaggle://keras/electra/keras/electra_base_discriminator_uncased_en/1",
|
42
36
|
},
|
@@ -47,9 +41,7 @@ backbone_presets = {
|
|
47
41
|
"lowercased. Trained on English Wikipedia + BooksCorpus."
|
48
42
|
),
|
49
43
|
"params": 33576960,
|
50
|
-
"official_name": "ELECTRA",
|
51
44
|
"path": "electra",
|
52
|
-
"model_card": "https://github.com/google-research/electra",
|
53
45
|
},
|
54
46
|
"kaggle_handle": "kaggle://keras/electra/keras/electra_base_generator_uncased_en/1",
|
55
47
|
},
|
@@ -60,9 +52,7 @@ backbone_presets = {
|
|
60
52
|
"lowercased. Trained on English Wikipedia + BooksCorpus."
|
61
53
|
),
|
62
54
|
"params": 335141888,
|
63
|
-
"official_name": "ELECTRA",
|
64
55
|
"path": "electra",
|
65
|
-
"model_card": "https://github.com/google-research/electra",
|
66
56
|
},
|
67
57
|
"kaggle_handle": "kaggle://keras/electra/keras/electra_large_discriminator_uncased_en/1",
|
68
58
|
},
|
@@ -73,9 +63,7 @@ backbone_presets = {
|
|
73
63
|
"lowercased. Trained on English Wikipedia + BooksCorpus."
|
74
64
|
),
|
75
65
|
"params": 51065344,
|
76
|
-
"official_name": "ELECTRA",
|
77
66
|
"path": "electra",
|
78
|
-
"model_card": "https://github.com/google-research/electra",
|
79
67
|
},
|
80
68
|
"kaggle_handle": "kaggle://keras/electra/keras/electra_large_generator_uncased_en/1",
|
81
69
|
},
|
@@ -8,9 +8,7 @@ backbone_presets = {
|
|
8
8
|
"Trained on the C4 dataset."
|
9
9
|
),
|
10
10
|
"params": 82861056,
|
11
|
-
"official_name": "FNet",
|
12
11
|
"path": "f_net",
|
13
|
-
"model_card": "https://github.com/google-research/google-research/blob/master/f_net/README.md",
|
14
12
|
},
|
15
13
|
"kaggle_handle": "kaggle://keras/f_net/keras/f_net_base_en/2",
|
16
14
|
},
|
@@ -21,9 +19,7 @@ backbone_presets = {
|
|
21
19
|
"Trained on the C4 dataset."
|
22
20
|
),
|
23
21
|
"params": 236945408,
|
24
|
-
"official_name": "FNet",
|
25
22
|
"path": "f_net",
|
26
|
-
"model_card": "https://github.com/google-research/google-research/blob/master/f_net/README.md",
|
27
23
|
},
|
28
24
|
"kaggle_handle": "kaggle://keras/f_net/keras/f_net_large_en/2",
|
29
25
|
},
|
@@ -8,9 +8,7 @@ backbone_presets = {
|
|
8
8
|
"350B tokens of RefinedWeb dataset."
|
9
9
|
),
|
10
10
|
"params": 1311625216,
|
11
|
-
"official_name": "Falcon",
|
12
11
|
"path": "falcon",
|
13
|
-
"model_card": "https://huggingface.co/tiiuae/falcon-rw-1b",
|
14
12
|
},
|
15
13
|
"kaggle_handle": "kaggle://keras/falcon/keras/falcon_refinedweb_1b_en/1",
|
16
14
|
},
|
@@ -0,0 +1,494 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import layers
|
3
|
+
from keras import ops
|
4
|
+
|
5
|
+
from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization
|
6
|
+
from keras_hub.src.models.flux.flux_maths import FluxRoPEAttention
|
7
|
+
from keras_hub.src.models.flux.flux_maths import RotaryPositionalEmbedding
|
8
|
+
from keras_hub.src.models.flux.flux_maths import rearrange_symbolic_tensors
|
9
|
+
|
10
|
+
|
11
|
+
class EmbedND(keras.Model):
|
12
|
+
"""
|
13
|
+
Embedding layer for N-dimensional inputs using Rotary Positional Embedding (RoPE).
|
14
|
+
|
15
|
+
This layer applies RoPE embeddings across multiple axes of the input tensor and
|
16
|
+
concatenates the embeddings along a specified axis.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
theta. Rotational angle parameter for RoPE.
|
20
|
+
axes_dim. Dimensionality for each axis of the input tensor.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(self, theta, axes_dim):
|
24
|
+
super().__init__()
|
25
|
+
self.theta = theta
|
26
|
+
self.axes_dim = axes_dim
|
27
|
+
self.rope = RotaryPositionalEmbedding()
|
28
|
+
|
29
|
+
def build(self, input_shape):
|
30
|
+
n_axes = input_shape[-1]
|
31
|
+
for i in range(n_axes):
|
32
|
+
self.rope.build((input_shape[:-1] + (self.axes_dim[i],)))
|
33
|
+
|
34
|
+
def call(self, ids):
|
35
|
+
"""
|
36
|
+
Computes the positional embeddings for each axis and concatenates them.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
ids: KerasTensor. Input tensor of shape (..., num_axes).
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
KerasTensor: Positional embeddings of shape (..., concatenated_dim, 1, ...).
|
43
|
+
"""
|
44
|
+
n_axes = ids.shape[-1]
|
45
|
+
emb = ops.concatenate(
|
46
|
+
[
|
47
|
+
self.rope(ids[..., i], dim=self.axes_dim[i], theta=self.theta)
|
48
|
+
for i in range(n_axes)
|
49
|
+
],
|
50
|
+
axis=-3,
|
51
|
+
)
|
52
|
+
|
53
|
+
return ops.expand_dims(emb, axis=1)
|
54
|
+
|
55
|
+
|
56
|
+
class MLPEmbedder(keras.Model):
|
57
|
+
"""
|
58
|
+
A simple multi-layer perceptron (MLP) embedder model.
|
59
|
+
|
60
|
+
This model applies a linear transformation followed by the SiLU activation
|
61
|
+
function and another linear transformation to the input tensor.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
hidden_dim. The dimensionality of the hidden layer.
|
65
|
+
"""
|
66
|
+
|
67
|
+
def __init__(self, hidden_dim):
|
68
|
+
super().__init__()
|
69
|
+
self.hidden_dim = hidden_dim
|
70
|
+
self.input_layer = layers.Dense(hidden_dim, use_bias=True)
|
71
|
+
self.silu = layers.Activation("silu")
|
72
|
+
self.output_layer = layers.Dense(hidden_dim, use_bias=True)
|
73
|
+
|
74
|
+
def build(self, input_shape):
|
75
|
+
self.input_layer.build(input_shape)
|
76
|
+
self.output_layer.build((input_shape[0], self.input_layer.units))
|
77
|
+
|
78
|
+
def call(self, x):
|
79
|
+
"""
|
80
|
+
Applies the MLP embedding to the input tensor.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
x: KerasTensor. Input tensor of shape (batch_size, in_dim).
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
KerasTensor: Output tensor of shape (batch_size, hidden_dim) after applying
|
87
|
+
the MLP transformations.
|
88
|
+
"""
|
89
|
+
x = self.input_layer(x)
|
90
|
+
x = self.silu(x)
|
91
|
+
return self.output_layer(x)
|
92
|
+
|
93
|
+
|
94
|
+
class QKNorm(keras.layers.Layer):
|
95
|
+
"""
|
96
|
+
A layer that applies RMS normalization to query and key tensors.
|
97
|
+
|
98
|
+
This layer normalizes the input query and key tensors using separate RMSNormalization
|
99
|
+
layers for each.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
input_dim. The dimensionality of the input query and key tensors.
|
103
|
+
"""
|
104
|
+
|
105
|
+
def __init__(self, input_dim):
|
106
|
+
super().__init__()
|
107
|
+
self.query_norm = RMSNormalization(input_dim)
|
108
|
+
self.key_norm = RMSNormalization(input_dim)
|
109
|
+
|
110
|
+
def build(self, input_shape):
|
111
|
+
self.query_norm.build(input_shape)
|
112
|
+
self.key_norm.build(input_shape)
|
113
|
+
|
114
|
+
def call(self, q, k):
|
115
|
+
"""
|
116
|
+
Applies RMS normalization to the query and key tensors.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
q: KerasTensor. The query tensor of shape (batch_size, input_dim).
|
120
|
+
k: KerasTensor. The key tensor of shape (batch_size, input_dim).
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
tuple[KerasTensor, KerasTensor]: A tuple containing the normalized query and key tensors.
|
124
|
+
"""
|
125
|
+
q = self.query_norm(q)
|
126
|
+
k = self.key_norm(k)
|
127
|
+
return q, k
|
128
|
+
|
129
|
+
|
130
|
+
class SelfAttention(keras.Model):
|
131
|
+
"""
|
132
|
+
Multi-head self-attention layer with RoPE embeddings and RMS normalization.
|
133
|
+
|
134
|
+
This layer performs self-attention over the input sequence and applies RMS
|
135
|
+
normalization to the query and key tensors before computing the attention scores.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
dim: int. Dimensionality of the input tensor.
|
139
|
+
num_heads: int. Number of attention heads. Default is 8.
|
140
|
+
use_bias: bool. Whether to use bias in the query, key, value projection layers.
|
141
|
+
Default is False.
|
142
|
+
"""
|
143
|
+
|
144
|
+
def __init__(self, dim, num_heads=8, use_bias=False):
|
145
|
+
super().__init__()
|
146
|
+
self.num_heads = num_heads
|
147
|
+
head_dim = dim // num_heads
|
148
|
+
self.dim = dim
|
149
|
+
|
150
|
+
self.qkv = layers.Dense(dim * 3, use_bias=use_bias)
|
151
|
+
self.norm = QKNorm(head_dim)
|
152
|
+
self.proj = layers.Dense(dim)
|
153
|
+
self.attention = FluxRoPEAttention()
|
154
|
+
|
155
|
+
def build(self, input_shape):
|
156
|
+
self.qkv.build(input_shape)
|
157
|
+
head_dim = input_shape[-1] // self.num_heads
|
158
|
+
self.norm.build((None, input_shape[1], head_dim))
|
159
|
+
self.proj.build((None, input_shape[1], input_shape[-1]))
|
160
|
+
|
161
|
+
def call(self, x, positional_encoding):
|
162
|
+
"""
|
163
|
+
Applies self-attention with RoPE embeddings.
|
164
|
+
|
165
|
+
Args:
|
166
|
+
x: KerasTensor. Input tensor of shape (batch_size, seq_len, dim).
|
167
|
+
positional_encoding: KerasTensor. Positional encoding tensor for RoPE.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
KerasTensor: Output tensor after self-attention and projection.
|
171
|
+
"""
|
172
|
+
qkv = self.qkv(x)
|
173
|
+
q, k, v = rearrange_symbolic_tensors(qkv, K=3, H=self.num_heads)
|
174
|
+
q, k = self.norm(q, k)
|
175
|
+
x = self.attention(
|
176
|
+
q=q, k=k, v=v, positional_encoding=positional_encoding
|
177
|
+
)
|
178
|
+
x = self.proj(x)
|
179
|
+
return x
|
180
|
+
|
181
|
+
|
182
|
+
class Modulation(keras.Model):
|
183
|
+
"""
|
184
|
+
Modulation layer that produces shift, scale, and gate tensors.
|
185
|
+
|
186
|
+
This layer applies a SiLU activation to the input tensor followed by a linear
|
187
|
+
transformation to generate modulation parameters. It can optionally generate two
|
188
|
+
sets of modulation parameters.
|
189
|
+
|
190
|
+
Args:
|
191
|
+
dim: int. Dimensionality of the modulation output.
|
192
|
+
double: bool. Whether to generate two sets of modulation parameters.
|
193
|
+
"""
|
194
|
+
|
195
|
+
def __init__(self, dim, double):
|
196
|
+
super().__init__()
|
197
|
+
self.dim = dim
|
198
|
+
self.is_double = double
|
199
|
+
self.multiplier = 6 if double else 3
|
200
|
+
self.linear_projection = keras.layers.Dense(
|
201
|
+
self.multiplier * dim, use_bias=True
|
202
|
+
)
|
203
|
+
|
204
|
+
def build(self, input_shape):
|
205
|
+
self.linear_projection.build(input_shape)
|
206
|
+
|
207
|
+
def call(self, x):
|
208
|
+
"""
|
209
|
+
Generates modulation parameters from the input tensor.
|
210
|
+
|
211
|
+
Args:
|
212
|
+
x: KerasTensor. Input tensor.
|
213
|
+
|
214
|
+
Returns:
|
215
|
+
tuple[ModulationOut, ModulationOut | None]: A tuple containing the shift,
|
216
|
+
scale, and gate tensors. If `double` is True, returns two sets of modulation parameters.
|
217
|
+
"""
|
218
|
+
x = keras.layers.Activation("silu")(x)
|
219
|
+
out = self.linear_projection(x)
|
220
|
+
out = ops.split(
|
221
|
+
out[:, None, :], indices_or_sections=self.multiplier, axis=-1
|
222
|
+
)
|
223
|
+
|
224
|
+
first_output = {"shift": out[0], "scale": out[1], "gate": out[2]}
|
225
|
+
second_output = (
|
226
|
+
{"shift": out[3], "scale": out[4], "gate": out[5]}
|
227
|
+
if self.is_double
|
228
|
+
else None
|
229
|
+
)
|
230
|
+
|
231
|
+
return first_output, second_output
|
232
|
+
|
233
|
+
|
234
|
+
class DoubleStreamBlock(keras.Model):
|
235
|
+
"""
|
236
|
+
A block that processes image and text inputs in parallel using
|
237
|
+
self-attention and MLP layers, with modulation.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
hidden_size: int. The hidden dimension size for the model.
|
241
|
+
num_heads: int. The number of attention heads.
|
242
|
+
mlp_ratio: float. The ratio of the MLP hidden dimension to the hidden size.
|
243
|
+
use_bias: bool, optional. Whether to include bias in QKV projection. Default is False.
|
244
|
+
"""
|
245
|
+
|
246
|
+
def __init__(
|
247
|
+
self,
|
248
|
+
hidden_size,
|
249
|
+
num_heads,
|
250
|
+
mlp_ratio,
|
251
|
+
use_bias=False,
|
252
|
+
):
|
253
|
+
super().__init__()
|
254
|
+
|
255
|
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
256
|
+
self.num_heads = num_heads
|
257
|
+
self.hidden_size = hidden_size
|
258
|
+
|
259
|
+
self.image_mod = Modulation(hidden_size, double=True)
|
260
|
+
self.image_norm1 = keras.layers.LayerNormalization(epsilon=1e-6)
|
261
|
+
self.image_attn = SelfAttention(
|
262
|
+
dim=hidden_size, num_heads=num_heads, use_bias=use_bias
|
263
|
+
)
|
264
|
+
|
265
|
+
self.image_norm2 = keras.layers.LayerNormalization(epsilon=1e-6)
|
266
|
+
self.image_mlp = keras.Sequential(
|
267
|
+
[
|
268
|
+
keras.layers.Dense(mlp_hidden_dim, use_bias=True),
|
269
|
+
keras.layers.Activation("gelu"),
|
270
|
+
keras.layers.Dense(hidden_size, use_bias=True),
|
271
|
+
]
|
272
|
+
)
|
273
|
+
|
274
|
+
self.text_mod = Modulation(hidden_size, double=True)
|
275
|
+
self.text_norm1 = keras.layers.LayerNormalization(epsilon=1e-6)
|
276
|
+
self.text_attn = SelfAttention(
|
277
|
+
dim=hidden_size, num_heads=num_heads, use_bias=use_bias
|
278
|
+
)
|
279
|
+
|
280
|
+
self.text_norm2 = keras.layers.LayerNormalization(epsilon=1e-6)
|
281
|
+
self.text_mlp = keras.Sequential(
|
282
|
+
[
|
283
|
+
keras.layers.Dense(mlp_hidden_dim, use_bias=True),
|
284
|
+
keras.layers.Activation("gelu"),
|
285
|
+
keras.layers.Dense(hidden_size, use_bias=True),
|
286
|
+
]
|
287
|
+
)
|
288
|
+
self.attention = FluxRoPEAttention()
|
289
|
+
|
290
|
+
def call(self, image, text, modulation_encoding, positional_encoding):
|
291
|
+
"""
|
292
|
+
Forward pass for the DoubleStreamBlock.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
image: KerasTensor. Input image tensor.
|
296
|
+
text: KerasTensor. Input text tensor.
|
297
|
+
modulation_encoding: KerasTensor. Modulation vector.
|
298
|
+
positional_encoding: KerasTensor. Positional encoding tensor.
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
Tuple[KerasTensor, KerasTensor]: The modified image and text tensors.
|
302
|
+
"""
|
303
|
+
image_mod1, image_mod2 = self.image_mod(modulation_encoding)
|
304
|
+
text_mod1, text_mod2 = self.text_mod(modulation_encoding)
|
305
|
+
|
306
|
+
# prepare image for attention
|
307
|
+
image_modulated = self.image_norm1(image)
|
308
|
+
image_modulated = (
|
309
|
+
1 + image_mod1["scale"]
|
310
|
+
) * image_modulated + image_mod1["shift"]
|
311
|
+
image_qkv = self.image_attn.qkv(image_modulated)
|
312
|
+
|
313
|
+
image_q, image_k, image_v = rearrange_symbolic_tensors(
|
314
|
+
image_qkv, K=3, H=self.num_heads
|
315
|
+
)
|
316
|
+
image_q, image_k = self.image_attn.norm(image_q, image_k)
|
317
|
+
|
318
|
+
# prepare text for attention
|
319
|
+
text_modulated = self.text_norm1(text)
|
320
|
+
text_modulated = (1 + text_mod1["scale"]) * text_modulated + text_mod1[
|
321
|
+
"shift"
|
322
|
+
]
|
323
|
+
text_qkv = self.text_attn.qkv(text_modulated)
|
324
|
+
|
325
|
+
text_q, text_k, text_v = rearrange_symbolic_tensors(
|
326
|
+
text_qkv, K=3, H=self.num_heads
|
327
|
+
)
|
328
|
+
|
329
|
+
text_q, text_k = self.text_attn.norm(text_q, text_k)
|
330
|
+
|
331
|
+
# run actual attention
|
332
|
+
q = ops.concatenate((text_q, image_q), axis=2)
|
333
|
+
k = ops.concatenate((text_k, image_k), axis=2)
|
334
|
+
v = ops.concatenate((text_v, image_v), axis=2)
|
335
|
+
|
336
|
+
attn = self.attention(
|
337
|
+
q=q, k=k, v=v, positional_encoding=positional_encoding
|
338
|
+
)
|
339
|
+
text_attn, image_attn = (
|
340
|
+
attn[:, : text.shape[1]],
|
341
|
+
attn[:, text.shape[1] :],
|
342
|
+
)
|
343
|
+
|
344
|
+
# calculate the image blocks
|
345
|
+
image = image + image_mod1["gate"] * self.image_attn.proj(image_attn)
|
346
|
+
image = image + image_mod2["gate"] * self.image_mlp(
|
347
|
+
(1 + image_mod2["scale"]) * self.image_norm2(image)
|
348
|
+
+ image_mod2["shift"]
|
349
|
+
)
|
350
|
+
|
351
|
+
# calculate the text blocks
|
352
|
+
text = text + text_mod1["gate"] * self.text_attn.proj(text_attn)
|
353
|
+
text = text + text_mod2["gate"] * self.text_mlp(
|
354
|
+
(1 + text_mod2["scale"]) * self.text_norm2(text)
|
355
|
+
+ text_mod2["shift"]
|
356
|
+
)
|
357
|
+
return image, text
|
358
|
+
|
359
|
+
|
360
|
+
class SingleStreamBlock(keras.Model):
|
361
|
+
"""
|
362
|
+
A DiT block with parallel linear layers.
|
363
|
+
|
364
|
+
As described in https://arxiv.org/abs/2302.05442 and
|
365
|
+
adapted for the modulation interface.
|
366
|
+
|
367
|
+
Args:
|
368
|
+
hidden_size: int. The hidden dimension size for the model.
|
369
|
+
num_heads: int. The number of attention heads.
|
370
|
+
mlp_ratio: float, optional. The ratio of the MLP hidden dimension to the hidden size. Default is 4.0.
|
371
|
+
qk_scale: float, optional. Scaling factor for the query-key product. Default is None.
|
372
|
+
"""
|
373
|
+
|
374
|
+
def __init__(
|
375
|
+
self,
|
376
|
+
hidden_size,
|
377
|
+
num_heads,
|
378
|
+
mlp_ratio=4.0,
|
379
|
+
qk_scale=None,
|
380
|
+
):
|
381
|
+
super().__init__()
|
382
|
+
self.hidden_dim = hidden_size
|
383
|
+
self.num_heads = num_heads
|
384
|
+
head_dim = hidden_size // num_heads
|
385
|
+
self.scale = qk_scale or head_dim**-0.5
|
386
|
+
|
387
|
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
388
|
+
# qkv and mlp_in
|
389
|
+
self.linear1 = keras.layers.Dense(hidden_size * 3 + self.mlp_hidden_dim)
|
390
|
+
# proj and mlp_out
|
391
|
+
self.linear2 = keras.layers.Dense(hidden_size)
|
392
|
+
|
393
|
+
self.norm = QKNorm(head_dim)
|
394
|
+
|
395
|
+
self.hidden_size = hidden_size
|
396
|
+
self.pre_norm = keras.layers.LayerNormalization(epsilon=1e-6)
|
397
|
+
self.modulation = Modulation(hidden_size, double=False)
|
398
|
+
self.attention = FluxRoPEAttention()
|
399
|
+
|
400
|
+
def build(
|
401
|
+
self, x_shape, modulation_encoding_shape, positional_encoding_shape
|
402
|
+
):
|
403
|
+
self.linear1.build(x_shape)
|
404
|
+
self.linear2.build(
|
405
|
+
(x_shape[0], x_shape[1], self.hidden_size + self.mlp_hidden_dim)
|
406
|
+
)
|
407
|
+
|
408
|
+
self.modulation.build(
|
409
|
+
modulation_encoding_shape
|
410
|
+
) # Build the modulation layer
|
411
|
+
|
412
|
+
self.norm.build(
|
413
|
+
(
|
414
|
+
x_shape[0],
|
415
|
+
self.num_heads,
|
416
|
+
x_shape[1],
|
417
|
+
x_shape[-1] // self.num_heads,
|
418
|
+
)
|
419
|
+
)
|
420
|
+
|
421
|
+
def call(self, x, modulation_encoding, positional_encoding):
|
422
|
+
"""
|
423
|
+
Forward pass for the SingleStreamBlock.
|
424
|
+
|
425
|
+
Args:
|
426
|
+
x: KerasTensor. Input tensor.
|
427
|
+
modulation_encoding: KerasTensor. Modulation vector.
|
428
|
+
positional_encoding: KerasTensor. Positional encoding tensor.
|
429
|
+
|
430
|
+
Returns:
|
431
|
+
KerasTensor: The modified input tensor after processing.
|
432
|
+
"""
|
433
|
+
mod, _ = self.modulation(modulation_encoding)
|
434
|
+
x_mod = (1 + mod["scale"]) * self.pre_norm(x) + mod["shift"]
|
435
|
+
qkv, mlp = ops.split(
|
436
|
+
self.linear1(x_mod), [3 * self.hidden_size], axis=-1
|
437
|
+
)
|
438
|
+
|
439
|
+
q, k, v = rearrange_symbolic_tensors(qkv, K=3, H=self.num_heads)
|
440
|
+
q, k = self.norm(q, k)
|
441
|
+
|
442
|
+
# compute attention
|
443
|
+
attn = self.attention(
|
444
|
+
q, k=k, v=v, positional_encoding=positional_encoding
|
445
|
+
)
|
446
|
+
# compute activation in mlp stream, cat again and run second linear layer
|
447
|
+
output = self.linear2(
|
448
|
+
ops.concatenate(
|
449
|
+
(attn, keras.activations.gelu(mlp, approximate=True)), 2
|
450
|
+
)
|
451
|
+
)
|
452
|
+
return x + mod["gate"] * output
|
453
|
+
|
454
|
+
|
455
|
+
class LastLayer(keras.Model):
|
456
|
+
"""
|
457
|
+
Final layer for processing output tensors with adaptive normalization.
|
458
|
+
|
459
|
+
Args:
|
460
|
+
hidden_size: int. The hidden dimension size for the model.
|
461
|
+
patch_size: int. The size of each patch.
|
462
|
+
output_channels: int. The number of output channels.
|
463
|
+
"""
|
464
|
+
|
465
|
+
def __init__(self, hidden_size, patch_size, output_channels):
|
466
|
+
super().__init__()
|
467
|
+
self.norm_final = keras.layers.LayerNormalization(epsilon=1e-6)
|
468
|
+
self.linear = keras.layers.Dense(
|
469
|
+
patch_size * patch_size * output_channels, use_bias=True
|
470
|
+
)
|
471
|
+
self.adaLN_modulation = keras.Sequential(
|
472
|
+
[
|
473
|
+
keras.layers.Activation("silu"),
|
474
|
+
keras.layers.Dense(2 * hidden_size, use_bias=True),
|
475
|
+
]
|
476
|
+
)
|
477
|
+
|
478
|
+
def call(self, x, modulation_encoding):
|
479
|
+
"""
|
480
|
+
Forward pass for the LastLayer.
|
481
|
+
|
482
|
+
Args:
|
483
|
+
x: KerasTensor. Input tensor.
|
484
|
+
modulation_encoding: KerasTensor. Modulation vector.
|
485
|
+
|
486
|
+
Returns:
|
487
|
+
KerasTensor: The output tensor after final processing.
|
488
|
+
"""
|
489
|
+
shift, scale = ops.split(
|
490
|
+
self.adaLN_modulation(modulation_encoding), 2, axis=1
|
491
|
+
)
|
492
|
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
493
|
+
x = self.linear(x)
|
494
|
+
return x
|