keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/api/layers/__init__.py +12 -0
  2. keras_hub/api/models/__init__.py +32 -0
  3. keras_hub/src/bounding_box/__init__.py +2 -0
  4. keras_hub/src/bounding_box/converters.py +102 -12
  5. keras_hub/src/layers/modeling/rms_normalization.py +34 -0
  6. keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
  7. keras_hub/src/layers/preprocessing/image_converter.py +5 -0
  8. keras_hub/src/models/albert/albert_presets.py +0 -8
  9. keras_hub/src/models/bart/bart_presets.py +0 -6
  10. keras_hub/src/models/bert/bert_presets.py +0 -20
  11. keras_hub/src/models/bloom/bloom_presets.py +0 -16
  12. keras_hub/src/models/clip/__init__.py +5 -0
  13. keras_hub/src/models/clip/clip_backbone.py +286 -0
  14. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  15. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  16. keras_hub/src/models/clip/clip_presets.py +93 -0
  17. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  18. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  19. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  20. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  21. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
  22. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
  23. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
  24. keras_hub/src/models/densenet/densenet_backbone.py +1 -1
  25. keras_hub/src/models/densenet/densenet_presets.py +0 -6
  26. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
  27. keras_hub/src/models/efficientnet/__init__.py +9 -0
  28. keras_hub/src/models/efficientnet/cba.py +141 -0
  29. keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
  30. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  31. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  32. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  33. keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
  34. keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
  35. keras_hub/src/models/efficientnet/mbconv.py +52 -21
  36. keras_hub/src/models/electra/electra_presets.py +0 -12
  37. keras_hub/src/models/f_net/f_net_presets.py +0 -4
  38. keras_hub/src/models/falcon/falcon_presets.py +0 -2
  39. keras_hub/src/models/flux/__init__.py +5 -0
  40. keras_hub/src/models/flux/flux_layers.py +494 -0
  41. keras_hub/src/models/flux/flux_maths.py +218 -0
  42. keras_hub/src/models/flux/flux_model.py +231 -0
  43. keras_hub/src/models/flux/flux_presets.py +14 -0
  44. keras_hub/src/models/flux/flux_text_to_image.py +142 -0
  45. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  46. keras_hub/src/models/gemma/gemma_presets.py +0 -40
  47. keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
  48. keras_hub/src/models/image_object_detector.py +87 -0
  49. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  50. keras_hub/src/models/image_to_image.py +16 -10
  51. keras_hub/src/models/inpaint.py +20 -13
  52. keras_hub/src/models/llama/llama_backbone.py +1 -1
  53. keras_hub/src/models/llama/llama_presets.py +5 -15
  54. keras_hub/src/models/llama3/llama3_presets.py +0 -8
  55. keras_hub/src/models/mistral/mistral_presets.py +0 -6
  56. keras_hub/src/models/mit/mit_backbone.py +41 -27
  57. keras_hub/src/models/mit/mit_layers.py +9 -7
  58. keras_hub/src/models/mit/mit_presets.py +12 -24
  59. keras_hub/src/models/opt/opt_presets.py +0 -8
  60. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
  61. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  62. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
  63. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
  64. keras_hub/src/models/phi3/phi3_presets.py +0 -4
  65. keras_hub/src/models/resnet/resnet_presets.py +10 -42
  66. keras_hub/src/models/retinanet/__init__.py +5 -0
  67. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  68. keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
  69. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  70. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  71. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  72. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  73. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  74. keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
  75. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  76. keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
  77. keras_hub/src/models/roberta/roberta_presets.py +0 -4
  78. keras_hub/src/models/sam/sam_backbone.py +0 -1
  79. keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
  80. keras_hub/src/models/sam/sam_presets.py +0 -6
  81. keras_hub/src/models/segformer/__init__.py +8 -0
  82. keras_hub/src/models/segformer/segformer_backbone.py +163 -0
  83. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  84. keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
  85. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  86. keras_hub/src/models/segformer/segformer_presets.py +124 -0
  87. keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
  88. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
  89. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
  90. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
  92. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
  93. keras_hub/src/models/t5/t5_backbone.py +5 -4
  94. keras_hub/src/models/t5/t5_presets.py +41 -13
  95. keras_hub/src/models/text_to_image.py +13 -5
  96. keras_hub/src/models/vgg/vgg_backbone.py +1 -1
  97. keras_hub/src/models/vgg/vgg_presets.py +0 -8
  98. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
  99. keras_hub/src/models/whisper/whisper_presets.py +0 -20
  100. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
  101. keras_hub/src/tests/test_case.py +25 -0
  102. keras_hub/src/utils/preset_utils.py +17 -4
  103. keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
  104. keras_hub/src/utils/timm/preset_loader.py +3 -0
  105. keras_hub/src/version_utils.py +1 -1
  106. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
  107. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
  108. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
  109. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -8,9 +8,7 @@ backbone_presets = {
8
8
  "lowercased. Trained on English Wikipedia + BooksCorpus."
9
9
  ),
10
10
  "params": 13548800,
11
- "official_name": "ELECTRA",
12
11
  "path": "electra",
13
- "model_card": "https://github.com/google-research/electra",
14
12
  },
15
13
  "kaggle_handle": "kaggle://keras/electra/keras/electra_small_discriminator_uncased_en/1",
16
14
  },
@@ -21,9 +19,7 @@ backbone_presets = {
21
19
  "lowercased. Trained on English Wikipedia + BooksCorpus."
22
20
  ),
23
21
  "params": 13548800,
24
- "official_name": "ELECTRA",
25
22
  "path": "electra",
26
- "model_card": "https://github.com/google-research/electra",
27
23
  },
28
24
  "kaggle_handle": "kaggle://keras/electra/keras/electra_small_generator_uncased_en/1",
29
25
  },
@@ -34,9 +30,7 @@ backbone_presets = {
34
30
  "lowercased. Trained on English Wikipedia + BooksCorpus."
35
31
  ),
36
32
  "params": 109482240,
37
- "official_name": "ELECTRA",
38
33
  "path": "electra",
39
- "model_card": "https://github.com/google-research/electra",
40
34
  },
41
35
  "kaggle_handle": "kaggle://keras/electra/keras/electra_base_discriminator_uncased_en/1",
42
36
  },
@@ -47,9 +41,7 @@ backbone_presets = {
47
41
  "lowercased. Trained on English Wikipedia + BooksCorpus."
48
42
  ),
49
43
  "params": 33576960,
50
- "official_name": "ELECTRA",
51
44
  "path": "electra",
52
- "model_card": "https://github.com/google-research/electra",
53
45
  },
54
46
  "kaggle_handle": "kaggle://keras/electra/keras/electra_base_generator_uncased_en/1",
55
47
  },
@@ -60,9 +52,7 @@ backbone_presets = {
60
52
  "lowercased. Trained on English Wikipedia + BooksCorpus."
61
53
  ),
62
54
  "params": 335141888,
63
- "official_name": "ELECTRA",
64
55
  "path": "electra",
65
- "model_card": "https://github.com/google-research/electra",
66
56
  },
67
57
  "kaggle_handle": "kaggle://keras/electra/keras/electra_large_discriminator_uncased_en/1",
68
58
  },
@@ -73,9 +63,7 @@ backbone_presets = {
73
63
  "lowercased. Trained on English Wikipedia + BooksCorpus."
74
64
  ),
75
65
  "params": 51065344,
76
- "official_name": "ELECTRA",
77
66
  "path": "electra",
78
- "model_card": "https://github.com/google-research/electra",
79
67
  },
80
68
  "kaggle_handle": "kaggle://keras/electra/keras/electra_large_generator_uncased_en/1",
81
69
  },
@@ -8,9 +8,7 @@ backbone_presets = {
8
8
  "Trained on the C4 dataset."
9
9
  ),
10
10
  "params": 82861056,
11
- "official_name": "FNet",
12
11
  "path": "f_net",
13
- "model_card": "https://github.com/google-research/google-research/blob/master/f_net/README.md",
14
12
  },
15
13
  "kaggle_handle": "kaggle://keras/f_net/keras/f_net_base_en/2",
16
14
  },
@@ -21,9 +19,7 @@ backbone_presets = {
21
19
  "Trained on the C4 dataset."
22
20
  ),
23
21
  "params": 236945408,
24
- "official_name": "FNet",
25
22
  "path": "f_net",
26
- "model_card": "https://github.com/google-research/google-research/blob/master/f_net/README.md",
27
23
  },
28
24
  "kaggle_handle": "kaggle://keras/f_net/keras/f_net_large_en/2",
29
25
  },
@@ -8,9 +8,7 @@ backbone_presets = {
8
8
  "350B tokens of RefinedWeb dataset."
9
9
  ),
10
10
  "params": 1311625216,
11
- "official_name": "Falcon",
12
11
  "path": "falcon",
13
- "model_card": "https://huggingface.co/tiiuae/falcon-rw-1b",
14
12
  },
15
13
  "kaggle_handle": "kaggle://keras/falcon/keras/falcon_refinedweb_1b_en/1",
16
14
  },
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.flux.flux_model import FluxBackbone
2
+ from keras_hub.src.models.flux.flux_presets import presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(presets, FluxBackbone)
@@ -0,0 +1,494 @@
1
+ import keras
2
+ from keras import layers
3
+ from keras import ops
4
+
5
+ from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization
6
+ from keras_hub.src.models.flux.flux_maths import FluxRoPEAttention
7
+ from keras_hub.src.models.flux.flux_maths import RotaryPositionalEmbedding
8
+ from keras_hub.src.models.flux.flux_maths import rearrange_symbolic_tensors
9
+
10
+
11
+ class EmbedND(keras.Model):
12
+ """
13
+ Embedding layer for N-dimensional inputs using Rotary Positional Embedding (RoPE).
14
+
15
+ This layer applies RoPE embeddings across multiple axes of the input tensor and
16
+ concatenates the embeddings along a specified axis.
17
+
18
+ Args:
19
+ theta. Rotational angle parameter for RoPE.
20
+ axes_dim. Dimensionality for each axis of the input tensor.
21
+ """
22
+
23
+ def __init__(self, theta, axes_dim):
24
+ super().__init__()
25
+ self.theta = theta
26
+ self.axes_dim = axes_dim
27
+ self.rope = RotaryPositionalEmbedding()
28
+
29
+ def build(self, input_shape):
30
+ n_axes = input_shape[-1]
31
+ for i in range(n_axes):
32
+ self.rope.build((input_shape[:-1] + (self.axes_dim[i],)))
33
+
34
+ def call(self, ids):
35
+ """
36
+ Computes the positional embeddings for each axis and concatenates them.
37
+
38
+ Args:
39
+ ids: KerasTensor. Input tensor of shape (..., num_axes).
40
+
41
+ Returns:
42
+ KerasTensor: Positional embeddings of shape (..., concatenated_dim, 1, ...).
43
+ """
44
+ n_axes = ids.shape[-1]
45
+ emb = ops.concatenate(
46
+ [
47
+ self.rope(ids[..., i], dim=self.axes_dim[i], theta=self.theta)
48
+ for i in range(n_axes)
49
+ ],
50
+ axis=-3,
51
+ )
52
+
53
+ return ops.expand_dims(emb, axis=1)
54
+
55
+
56
+ class MLPEmbedder(keras.Model):
57
+ """
58
+ A simple multi-layer perceptron (MLP) embedder model.
59
+
60
+ This model applies a linear transformation followed by the SiLU activation
61
+ function and another linear transformation to the input tensor.
62
+
63
+ Args:
64
+ hidden_dim. The dimensionality of the hidden layer.
65
+ """
66
+
67
+ def __init__(self, hidden_dim):
68
+ super().__init__()
69
+ self.hidden_dim = hidden_dim
70
+ self.input_layer = layers.Dense(hidden_dim, use_bias=True)
71
+ self.silu = layers.Activation("silu")
72
+ self.output_layer = layers.Dense(hidden_dim, use_bias=True)
73
+
74
+ def build(self, input_shape):
75
+ self.input_layer.build(input_shape)
76
+ self.output_layer.build((input_shape[0], self.input_layer.units))
77
+
78
+ def call(self, x):
79
+ """
80
+ Applies the MLP embedding to the input tensor.
81
+
82
+ Args:
83
+ x: KerasTensor. Input tensor of shape (batch_size, in_dim).
84
+
85
+ Returns:
86
+ KerasTensor: Output tensor of shape (batch_size, hidden_dim) after applying
87
+ the MLP transformations.
88
+ """
89
+ x = self.input_layer(x)
90
+ x = self.silu(x)
91
+ return self.output_layer(x)
92
+
93
+
94
+ class QKNorm(keras.layers.Layer):
95
+ """
96
+ A layer that applies RMS normalization to query and key tensors.
97
+
98
+ This layer normalizes the input query and key tensors using separate RMSNormalization
99
+ layers for each.
100
+
101
+ Args:
102
+ input_dim. The dimensionality of the input query and key tensors.
103
+ """
104
+
105
+ def __init__(self, input_dim):
106
+ super().__init__()
107
+ self.query_norm = RMSNormalization(input_dim)
108
+ self.key_norm = RMSNormalization(input_dim)
109
+
110
+ def build(self, input_shape):
111
+ self.query_norm.build(input_shape)
112
+ self.key_norm.build(input_shape)
113
+
114
+ def call(self, q, k):
115
+ """
116
+ Applies RMS normalization to the query and key tensors.
117
+
118
+ Args:
119
+ q: KerasTensor. The query tensor of shape (batch_size, input_dim).
120
+ k: KerasTensor. The key tensor of shape (batch_size, input_dim).
121
+
122
+ Returns:
123
+ tuple[KerasTensor, KerasTensor]: A tuple containing the normalized query and key tensors.
124
+ """
125
+ q = self.query_norm(q)
126
+ k = self.key_norm(k)
127
+ return q, k
128
+
129
+
130
+ class SelfAttention(keras.Model):
131
+ """
132
+ Multi-head self-attention layer with RoPE embeddings and RMS normalization.
133
+
134
+ This layer performs self-attention over the input sequence and applies RMS
135
+ normalization to the query and key tensors before computing the attention scores.
136
+
137
+ Args:
138
+ dim: int. Dimensionality of the input tensor.
139
+ num_heads: int. Number of attention heads. Default is 8.
140
+ use_bias: bool. Whether to use bias in the query, key, value projection layers.
141
+ Default is False.
142
+ """
143
+
144
+ def __init__(self, dim, num_heads=8, use_bias=False):
145
+ super().__init__()
146
+ self.num_heads = num_heads
147
+ head_dim = dim // num_heads
148
+ self.dim = dim
149
+
150
+ self.qkv = layers.Dense(dim * 3, use_bias=use_bias)
151
+ self.norm = QKNorm(head_dim)
152
+ self.proj = layers.Dense(dim)
153
+ self.attention = FluxRoPEAttention()
154
+
155
+ def build(self, input_shape):
156
+ self.qkv.build(input_shape)
157
+ head_dim = input_shape[-1] // self.num_heads
158
+ self.norm.build((None, input_shape[1], head_dim))
159
+ self.proj.build((None, input_shape[1], input_shape[-1]))
160
+
161
+ def call(self, x, positional_encoding):
162
+ """
163
+ Applies self-attention with RoPE embeddings.
164
+
165
+ Args:
166
+ x: KerasTensor. Input tensor of shape (batch_size, seq_len, dim).
167
+ positional_encoding: KerasTensor. Positional encoding tensor for RoPE.
168
+
169
+ Returns:
170
+ KerasTensor: Output tensor after self-attention and projection.
171
+ """
172
+ qkv = self.qkv(x)
173
+ q, k, v = rearrange_symbolic_tensors(qkv, K=3, H=self.num_heads)
174
+ q, k = self.norm(q, k)
175
+ x = self.attention(
176
+ q=q, k=k, v=v, positional_encoding=positional_encoding
177
+ )
178
+ x = self.proj(x)
179
+ return x
180
+
181
+
182
+ class Modulation(keras.Model):
183
+ """
184
+ Modulation layer that produces shift, scale, and gate tensors.
185
+
186
+ This layer applies a SiLU activation to the input tensor followed by a linear
187
+ transformation to generate modulation parameters. It can optionally generate two
188
+ sets of modulation parameters.
189
+
190
+ Args:
191
+ dim: int. Dimensionality of the modulation output.
192
+ double: bool. Whether to generate two sets of modulation parameters.
193
+ """
194
+
195
+ def __init__(self, dim, double):
196
+ super().__init__()
197
+ self.dim = dim
198
+ self.is_double = double
199
+ self.multiplier = 6 if double else 3
200
+ self.linear_projection = keras.layers.Dense(
201
+ self.multiplier * dim, use_bias=True
202
+ )
203
+
204
+ def build(self, input_shape):
205
+ self.linear_projection.build(input_shape)
206
+
207
+ def call(self, x):
208
+ """
209
+ Generates modulation parameters from the input tensor.
210
+
211
+ Args:
212
+ x: KerasTensor. Input tensor.
213
+
214
+ Returns:
215
+ tuple[ModulationOut, ModulationOut | None]: A tuple containing the shift,
216
+ scale, and gate tensors. If `double` is True, returns two sets of modulation parameters.
217
+ """
218
+ x = keras.layers.Activation("silu")(x)
219
+ out = self.linear_projection(x)
220
+ out = ops.split(
221
+ out[:, None, :], indices_or_sections=self.multiplier, axis=-1
222
+ )
223
+
224
+ first_output = {"shift": out[0], "scale": out[1], "gate": out[2]}
225
+ second_output = (
226
+ {"shift": out[3], "scale": out[4], "gate": out[5]}
227
+ if self.is_double
228
+ else None
229
+ )
230
+
231
+ return first_output, second_output
232
+
233
+
234
+ class DoubleStreamBlock(keras.Model):
235
+ """
236
+ A block that processes image and text inputs in parallel using
237
+ self-attention and MLP layers, with modulation.
238
+
239
+ Args:
240
+ hidden_size: int. The hidden dimension size for the model.
241
+ num_heads: int. The number of attention heads.
242
+ mlp_ratio: float. The ratio of the MLP hidden dimension to the hidden size.
243
+ use_bias: bool, optional. Whether to include bias in QKV projection. Default is False.
244
+ """
245
+
246
+ def __init__(
247
+ self,
248
+ hidden_size,
249
+ num_heads,
250
+ mlp_ratio,
251
+ use_bias=False,
252
+ ):
253
+ super().__init__()
254
+
255
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
256
+ self.num_heads = num_heads
257
+ self.hidden_size = hidden_size
258
+
259
+ self.image_mod = Modulation(hidden_size, double=True)
260
+ self.image_norm1 = keras.layers.LayerNormalization(epsilon=1e-6)
261
+ self.image_attn = SelfAttention(
262
+ dim=hidden_size, num_heads=num_heads, use_bias=use_bias
263
+ )
264
+
265
+ self.image_norm2 = keras.layers.LayerNormalization(epsilon=1e-6)
266
+ self.image_mlp = keras.Sequential(
267
+ [
268
+ keras.layers.Dense(mlp_hidden_dim, use_bias=True),
269
+ keras.layers.Activation("gelu"),
270
+ keras.layers.Dense(hidden_size, use_bias=True),
271
+ ]
272
+ )
273
+
274
+ self.text_mod = Modulation(hidden_size, double=True)
275
+ self.text_norm1 = keras.layers.LayerNormalization(epsilon=1e-6)
276
+ self.text_attn = SelfAttention(
277
+ dim=hidden_size, num_heads=num_heads, use_bias=use_bias
278
+ )
279
+
280
+ self.text_norm2 = keras.layers.LayerNormalization(epsilon=1e-6)
281
+ self.text_mlp = keras.Sequential(
282
+ [
283
+ keras.layers.Dense(mlp_hidden_dim, use_bias=True),
284
+ keras.layers.Activation("gelu"),
285
+ keras.layers.Dense(hidden_size, use_bias=True),
286
+ ]
287
+ )
288
+ self.attention = FluxRoPEAttention()
289
+
290
+ def call(self, image, text, modulation_encoding, positional_encoding):
291
+ """
292
+ Forward pass for the DoubleStreamBlock.
293
+
294
+ Args:
295
+ image: KerasTensor. Input image tensor.
296
+ text: KerasTensor. Input text tensor.
297
+ modulation_encoding: KerasTensor. Modulation vector.
298
+ positional_encoding: KerasTensor. Positional encoding tensor.
299
+
300
+ Returns:
301
+ Tuple[KerasTensor, KerasTensor]: The modified image and text tensors.
302
+ """
303
+ image_mod1, image_mod2 = self.image_mod(modulation_encoding)
304
+ text_mod1, text_mod2 = self.text_mod(modulation_encoding)
305
+
306
+ # prepare image for attention
307
+ image_modulated = self.image_norm1(image)
308
+ image_modulated = (
309
+ 1 + image_mod1["scale"]
310
+ ) * image_modulated + image_mod1["shift"]
311
+ image_qkv = self.image_attn.qkv(image_modulated)
312
+
313
+ image_q, image_k, image_v = rearrange_symbolic_tensors(
314
+ image_qkv, K=3, H=self.num_heads
315
+ )
316
+ image_q, image_k = self.image_attn.norm(image_q, image_k)
317
+
318
+ # prepare text for attention
319
+ text_modulated = self.text_norm1(text)
320
+ text_modulated = (1 + text_mod1["scale"]) * text_modulated + text_mod1[
321
+ "shift"
322
+ ]
323
+ text_qkv = self.text_attn.qkv(text_modulated)
324
+
325
+ text_q, text_k, text_v = rearrange_symbolic_tensors(
326
+ text_qkv, K=3, H=self.num_heads
327
+ )
328
+
329
+ text_q, text_k = self.text_attn.norm(text_q, text_k)
330
+
331
+ # run actual attention
332
+ q = ops.concatenate((text_q, image_q), axis=2)
333
+ k = ops.concatenate((text_k, image_k), axis=2)
334
+ v = ops.concatenate((text_v, image_v), axis=2)
335
+
336
+ attn = self.attention(
337
+ q=q, k=k, v=v, positional_encoding=positional_encoding
338
+ )
339
+ text_attn, image_attn = (
340
+ attn[:, : text.shape[1]],
341
+ attn[:, text.shape[1] :],
342
+ )
343
+
344
+ # calculate the image blocks
345
+ image = image + image_mod1["gate"] * self.image_attn.proj(image_attn)
346
+ image = image + image_mod2["gate"] * self.image_mlp(
347
+ (1 + image_mod2["scale"]) * self.image_norm2(image)
348
+ + image_mod2["shift"]
349
+ )
350
+
351
+ # calculate the text blocks
352
+ text = text + text_mod1["gate"] * self.text_attn.proj(text_attn)
353
+ text = text + text_mod2["gate"] * self.text_mlp(
354
+ (1 + text_mod2["scale"]) * self.text_norm2(text)
355
+ + text_mod2["shift"]
356
+ )
357
+ return image, text
358
+
359
+
360
+ class SingleStreamBlock(keras.Model):
361
+ """
362
+ A DiT block with parallel linear layers.
363
+
364
+ As described in https://arxiv.org/abs/2302.05442 and
365
+ adapted for the modulation interface.
366
+
367
+ Args:
368
+ hidden_size: int. The hidden dimension size for the model.
369
+ num_heads: int. The number of attention heads.
370
+ mlp_ratio: float, optional. The ratio of the MLP hidden dimension to the hidden size. Default is 4.0.
371
+ qk_scale: float, optional. Scaling factor for the query-key product. Default is None.
372
+ """
373
+
374
+ def __init__(
375
+ self,
376
+ hidden_size,
377
+ num_heads,
378
+ mlp_ratio=4.0,
379
+ qk_scale=None,
380
+ ):
381
+ super().__init__()
382
+ self.hidden_dim = hidden_size
383
+ self.num_heads = num_heads
384
+ head_dim = hidden_size // num_heads
385
+ self.scale = qk_scale or head_dim**-0.5
386
+
387
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
388
+ # qkv and mlp_in
389
+ self.linear1 = keras.layers.Dense(hidden_size * 3 + self.mlp_hidden_dim)
390
+ # proj and mlp_out
391
+ self.linear2 = keras.layers.Dense(hidden_size)
392
+
393
+ self.norm = QKNorm(head_dim)
394
+
395
+ self.hidden_size = hidden_size
396
+ self.pre_norm = keras.layers.LayerNormalization(epsilon=1e-6)
397
+ self.modulation = Modulation(hidden_size, double=False)
398
+ self.attention = FluxRoPEAttention()
399
+
400
+ def build(
401
+ self, x_shape, modulation_encoding_shape, positional_encoding_shape
402
+ ):
403
+ self.linear1.build(x_shape)
404
+ self.linear2.build(
405
+ (x_shape[0], x_shape[1], self.hidden_size + self.mlp_hidden_dim)
406
+ )
407
+
408
+ self.modulation.build(
409
+ modulation_encoding_shape
410
+ ) # Build the modulation layer
411
+
412
+ self.norm.build(
413
+ (
414
+ x_shape[0],
415
+ self.num_heads,
416
+ x_shape[1],
417
+ x_shape[-1] // self.num_heads,
418
+ )
419
+ )
420
+
421
+ def call(self, x, modulation_encoding, positional_encoding):
422
+ """
423
+ Forward pass for the SingleStreamBlock.
424
+
425
+ Args:
426
+ x: KerasTensor. Input tensor.
427
+ modulation_encoding: KerasTensor. Modulation vector.
428
+ positional_encoding: KerasTensor. Positional encoding tensor.
429
+
430
+ Returns:
431
+ KerasTensor: The modified input tensor after processing.
432
+ """
433
+ mod, _ = self.modulation(modulation_encoding)
434
+ x_mod = (1 + mod["scale"]) * self.pre_norm(x) + mod["shift"]
435
+ qkv, mlp = ops.split(
436
+ self.linear1(x_mod), [3 * self.hidden_size], axis=-1
437
+ )
438
+
439
+ q, k, v = rearrange_symbolic_tensors(qkv, K=3, H=self.num_heads)
440
+ q, k = self.norm(q, k)
441
+
442
+ # compute attention
443
+ attn = self.attention(
444
+ q, k=k, v=v, positional_encoding=positional_encoding
445
+ )
446
+ # compute activation in mlp stream, cat again and run second linear layer
447
+ output = self.linear2(
448
+ ops.concatenate(
449
+ (attn, keras.activations.gelu(mlp, approximate=True)), 2
450
+ )
451
+ )
452
+ return x + mod["gate"] * output
453
+
454
+
455
+ class LastLayer(keras.Model):
456
+ """
457
+ Final layer for processing output tensors with adaptive normalization.
458
+
459
+ Args:
460
+ hidden_size: int. The hidden dimension size for the model.
461
+ patch_size: int. The size of each patch.
462
+ output_channels: int. The number of output channels.
463
+ """
464
+
465
+ def __init__(self, hidden_size, patch_size, output_channels):
466
+ super().__init__()
467
+ self.norm_final = keras.layers.LayerNormalization(epsilon=1e-6)
468
+ self.linear = keras.layers.Dense(
469
+ patch_size * patch_size * output_channels, use_bias=True
470
+ )
471
+ self.adaLN_modulation = keras.Sequential(
472
+ [
473
+ keras.layers.Activation("silu"),
474
+ keras.layers.Dense(2 * hidden_size, use_bias=True),
475
+ ]
476
+ )
477
+
478
+ def call(self, x, modulation_encoding):
479
+ """
480
+ Forward pass for the LastLayer.
481
+
482
+ Args:
483
+ x: KerasTensor. Input tensor.
484
+ modulation_encoding: KerasTensor. Modulation vector.
485
+
486
+ Returns:
487
+ KerasTensor: The output tensor after final processing.
488
+ """
489
+ shift, scale = ops.split(
490
+ self.adaLN_modulation(modulation_encoding), 2, axis=1
491
+ )
492
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
493
+ x = self.linear(x)
494
+ return x