keras-hub-nightly 0.19.0.dev202412120352__py3-none-any.whl → 0.19.0.dev202412140350__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +1 -0
- keras_hub/api/models/__init__.py +11 -6
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/converters.py +2 -2
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/rms_normalization.py +8 -6
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +3 -1
- keras_hub/src/metrics/bleu.py +1 -1
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/bert/bert_presets.py +4 -2
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/causal_lm.py +19 -15
- keras_hub/src/models/clip/clip_vision_embedding.py +1 -1
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +17 -13
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -3
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +1 -1
- keras_hub/src/models/densenet/densenet_backbone.py +3 -1
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -1
- keras_hub/src/models/densenet/densenet_presets.py +6 -6
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +2 -1
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/cba.py +1 -1
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +20 -8
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +1 -1
- keras_hub/src/models/efficientnet/efficientnet_presets.py +12 -11
- keras_hub/src/models/efficientnet/fusedmbconv.py +3 -5
- keras_hub/src/models/efficientnet/mbconv.py +1 -1
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/flux/flux_layers.py +46 -44
- keras_hub/src/models/flux/flux_maths.py +24 -17
- keras_hub/src/models/flux/flux_model.py +24 -19
- keras_hub/src/models/flux/flux_presets.py +2 -1
- keras_hub/src/models/flux/flux_text_to_image.py +7 -3
- keras_hub/src/models/gemma/gemma_backbone.py +27 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +9 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier_preprocessor.py +4 -1
- keras_hub/src/models/image_object_detector.py +2 -2
- keras_hub/src/models/image_object_detector_preprocessor.py +4 -4
- keras_hub/src/models/image_segmenter_preprocessor.py +2 -2
- keras_hub/src/models/llama/llama_backbone.py +34 -26
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +3 -3
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/mit_backbone.py +4 -3
- keras_hub/src/models/mit/mit_layers.py +2 -1
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +7 -7
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +5 -3
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +2 -2
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +2 -2
- keras_hub/src/models/retinanet/feature_pyramid.py +3 -2
- keras_hub/src/models/retinanet/prediction_head.py +2 -2
- keras_hub/src/models/retinanet/retinanet_backbone.py +2 -2
- keras_hub/src/models/retinanet/retinanet_image_converter.py +1 -1
- keras_hub/src/models/retinanet/retinanet_object_detector.py +5 -6
- keras_hub/src/models/retinanet/retinanet_presets.py +2 -1
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +4 -2
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/sam_backbone.py +2 -2
- keras_hub/src/models/sam/sam_image_segmenter.py +6 -5
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/segformer_backbone.py +18 -14
- keras_hub/src/models/segformer/segformer_image_segmenter.py +51 -38
- keras_hub/src/models/segformer/segformer_presets.py +24 -12
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +20 -1
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +1 -1
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +13 -6
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +2 -2
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +7 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/task.py +4 -2
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +5 -1
- keras_hub/src/models/vae/vae_layers.py +0 -1
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +49 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +4 -2
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +1 -3
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +2 -2
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +2 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +2 -8
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +7 -12
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +8 -5
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +7 -3
- keras_hub/src/utils/preset_utils.py +25 -18
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +2 -4
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/RECORD +148 -140
- {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/top_level.txt +0 -0
@@ -9,11 +9,10 @@ from keras_hub.src.models.flux.flux_maths import rearrange_symbolic_tensors
|
|
9
9
|
|
10
10
|
|
11
11
|
class EmbedND(keras.Model):
|
12
|
-
"""
|
13
|
-
Embedding layer for N-dimensional inputs using Rotary Positional Embedding (RoPE).
|
12
|
+
"""Embedding layer for N-dimensional inputs using RoPE.
|
14
13
|
|
15
|
-
This layer applies RoPE embeddings across multiple axes of the input tensor
|
16
|
-
concatenates the embeddings along a specified axis.
|
14
|
+
This layer applies RoPE embeddings across multiple axes of the input tensor
|
15
|
+
and concatenates the embeddings along a specified axis.
|
17
16
|
|
18
17
|
Args:
|
19
18
|
theta. Rotational angle parameter for RoPE.
|
@@ -32,14 +31,14 @@ class EmbedND(keras.Model):
|
|
32
31
|
self.rope.build((input_shape[:-1] + (self.axes_dim[i],)))
|
33
32
|
|
34
33
|
def call(self, ids):
|
35
|
-
"""
|
36
|
-
Computes the positional embeddings for each axis and concatenates them.
|
34
|
+
"""Computes the positional embeddings for each axis and concatenates.
|
37
35
|
|
38
36
|
Args:
|
39
37
|
ids: KerasTensor. Input tensor of shape (..., num_axes).
|
40
38
|
|
41
39
|
Returns:
|
42
|
-
KerasTensor: Positional embeddings of shape
|
40
|
+
KerasTensor: Positional embeddings of shape
|
41
|
+
(..., concatenated_dim, 1, ...).
|
43
42
|
"""
|
44
43
|
n_axes = ids.shape[-1]
|
45
44
|
emb = ops.concatenate(
|
@@ -54,8 +53,7 @@ class EmbedND(keras.Model):
|
|
54
53
|
|
55
54
|
|
56
55
|
class MLPEmbedder(keras.Model):
|
57
|
-
"""
|
58
|
-
A simple multi-layer perceptron (MLP) embedder model.
|
56
|
+
"""A simple multi-layer perceptron (MLP) embedder model.
|
59
57
|
|
60
58
|
This model applies a linear transformation followed by the SiLU activation
|
61
59
|
function and another linear transformation to the input tensor.
|
@@ -76,15 +74,14 @@ class MLPEmbedder(keras.Model):
|
|
76
74
|
self.output_layer.build((input_shape[0], self.input_layer.units))
|
77
75
|
|
78
76
|
def call(self, x):
|
79
|
-
"""
|
80
|
-
Applies the MLP embedding to the input tensor.
|
77
|
+
"""Applies the MLP embedding to the input tensor.
|
81
78
|
|
82
79
|
Args:
|
83
|
-
x:
|
80
|
+
x: Input tensor of shape (batch_size, in_dim).
|
84
81
|
|
85
82
|
Returns:
|
86
|
-
|
87
|
-
|
83
|
+
Output tensor of shape (batch_size, hidden_dim) after applying the
|
84
|
+
MLP transformations.
|
88
85
|
"""
|
89
86
|
x = self.input_layer(x)
|
90
87
|
x = self.silu(x)
|
@@ -92,11 +89,10 @@ class MLPEmbedder(keras.Model):
|
|
92
89
|
|
93
90
|
|
94
91
|
class QKNorm(keras.layers.Layer):
|
95
|
-
"""
|
96
|
-
A layer that applies RMS normalization to query and key tensors.
|
92
|
+
"""A layer that applies RMS normalization to query and key tensors.
|
97
93
|
|
98
|
-
This layer normalizes the input query and key tensors using separate
|
99
|
-
layers for each.
|
94
|
+
This layer normalizes the input query and key tensors using separate
|
95
|
+
RMSNormalization layers for each.
|
100
96
|
|
101
97
|
Args:
|
102
98
|
input_dim. The dimensionality of the input query and key tensors.
|
@@ -120,7 +116,8 @@ class QKNorm(keras.layers.Layer):
|
|
120
116
|
k: KerasTensor. The key tensor of shape (batch_size, input_dim).
|
121
117
|
|
122
118
|
Returns:
|
123
|
-
tuple[KerasTensor, KerasTensor]: A tuple containing the normalized
|
119
|
+
tuple[KerasTensor, KerasTensor]: A tuple containing the normalized
|
120
|
+
query and key tensors.
|
124
121
|
"""
|
125
122
|
q = self.query_norm(q)
|
126
123
|
k = self.key_norm(k)
|
@@ -128,17 +125,17 @@ class QKNorm(keras.layers.Layer):
|
|
128
125
|
|
129
126
|
|
130
127
|
class SelfAttention(keras.Model):
|
131
|
-
"""
|
132
|
-
Multi-head self-attention layer with RoPE embeddings and RMS normalization.
|
128
|
+
"""Multi-head self-attention layer with RoPE and RMS normalization.
|
133
129
|
|
134
130
|
This layer performs self-attention over the input sequence and applies RMS
|
135
|
-
normalization to the query and key tensors before computing the attention
|
131
|
+
normalization to the query and key tensors before computing the attention
|
132
|
+
scores.
|
136
133
|
|
137
134
|
Args:
|
138
135
|
dim: int. Dimensionality of the input tensor.
|
139
136
|
num_heads: int. Number of attention heads. Default is 8.
|
140
|
-
use_bias: bool. Whether to use bias in the query, key, value projection
|
141
|
-
Default is False.
|
137
|
+
use_bias: bool. Whether to use bias in the query, key, value projection
|
138
|
+
layers. Default is False.
|
142
139
|
"""
|
143
140
|
|
144
141
|
def __init__(self, dim, num_heads=8, use_bias=False):
|
@@ -159,12 +156,12 @@ class SelfAttention(keras.Model):
|
|
159
156
|
self.proj.build((None, input_shape[1], input_shape[-1]))
|
160
157
|
|
161
158
|
def call(self, x, positional_encoding):
|
162
|
-
"""
|
163
|
-
Applies self-attention with RoPE embeddings.
|
159
|
+
"""Applies self-attention with RoPE embeddings.
|
164
160
|
|
165
161
|
Args:
|
166
162
|
x: KerasTensor. Input tensor of shape (batch_size, seq_len, dim).
|
167
|
-
positional_encoding: KerasTensor. Positional encoding tensor for
|
163
|
+
positional_encoding: KerasTensor. Positional encoding tensor for
|
164
|
+
RoPE.
|
168
165
|
|
169
166
|
Returns:
|
170
167
|
KerasTensor: Output tensor after self-attention and projection.
|
@@ -180,12 +177,11 @@ class SelfAttention(keras.Model):
|
|
180
177
|
|
181
178
|
|
182
179
|
class Modulation(keras.Model):
|
183
|
-
"""
|
184
|
-
Modulation layer that produces shift, scale, and gate tensors.
|
180
|
+
"""Modulation layer that produces shift, scale, and gate tensors.
|
185
181
|
|
186
|
-
This layer applies a SiLU activation to the input tensor followed by a
|
187
|
-
transformation to generate modulation parameters. It can optionally
|
188
|
-
sets of modulation parameters.
|
182
|
+
This layer applies a SiLU activation to the input tensor followed by a
|
183
|
+
linear transformation to generate modulation parameters. It can optionally
|
184
|
+
generate two sets of modulation parameters.
|
189
185
|
|
190
186
|
Args:
|
191
187
|
dim: int. Dimensionality of the modulation output.
|
@@ -212,8 +208,9 @@ class Modulation(keras.Model):
|
|
212
208
|
x: KerasTensor. Input tensor.
|
213
209
|
|
214
210
|
Returns:
|
215
|
-
tuple[ModulationOut, ModulationOut | None]: A tuple containing
|
216
|
-
scale, and gate tensors. If `double` is True, returns two
|
211
|
+
tuple[ModulationOut, ModulationOut | None]: A tuple containing th
|
212
|
+
shift, scale, and gate tensors. If `double` is True, returns two
|
213
|
+
sets of modulation parameters.
|
217
214
|
"""
|
218
215
|
x = keras.layers.Activation("silu")(x)
|
219
216
|
out = self.linear_projection(x)
|
@@ -239,8 +236,10 @@ class DoubleStreamBlock(keras.Model):
|
|
239
236
|
Args:
|
240
237
|
hidden_size: int. The hidden dimension size for the model.
|
241
238
|
num_heads: int. The number of attention heads.
|
242
|
-
mlp_ratio: float. The ratio of the MLP hidden dimension to the
|
243
|
-
|
239
|
+
mlp_ratio: float. The ratio of the MLP hidden dimension to the hidde
|
240
|
+
size.
|
241
|
+
use_bias: bool, optional. Whether to include bias in QKV projection.
|
242
|
+
Default is False.
|
244
243
|
"""
|
245
244
|
|
246
245
|
def __init__(
|
@@ -292,13 +291,13 @@ class DoubleStreamBlock(keras.Model):
|
|
292
291
|
Forward pass for the DoubleStreamBlock.
|
293
292
|
|
294
293
|
Args:
|
295
|
-
image:
|
296
|
-
text:
|
297
|
-
modulation_encoding:
|
298
|
-
positional_encoding:
|
294
|
+
image: Input image tensor.
|
295
|
+
text: Input text tensor.
|
296
|
+
modulation_encoding: Modulation vector.
|
297
|
+
positional_encoding: Positional encoding tensor.
|
299
298
|
|
300
299
|
Returns:
|
301
|
-
|
300
|
+
A `(image, text)` tuple of modified image and text tensors.
|
302
301
|
"""
|
303
302
|
image_mod1, image_mod2 = self.image_mod(modulation_encoding)
|
304
303
|
text_mod1, text_mod2 = self.text_mod(modulation_encoding)
|
@@ -367,8 +366,10 @@ class SingleStreamBlock(keras.Model):
|
|
367
366
|
Args:
|
368
367
|
hidden_size: int. The hidden dimension size for the model.
|
369
368
|
num_heads: int. The number of attention heads.
|
370
|
-
mlp_ratio: float, optional. The ratio of the MLP hidden dimension to the
|
371
|
-
|
369
|
+
mlp_ratio: float, optional. The ratio of the MLP hidden dimension to the
|
370
|
+
hidden size. Default is 4.0.
|
371
|
+
qk_scale: float, optional. Scaling factor for the query-key product.
|
372
|
+
Default is None.
|
372
373
|
"""
|
373
374
|
|
374
375
|
def __init__(
|
@@ -443,7 +444,8 @@ class SingleStreamBlock(keras.Model):
|
|
443
444
|
attn = self.attention(
|
444
445
|
q, k=k, v=v, positional_encoding=positional_encoding
|
445
446
|
)
|
446
|
-
# compute activation in mlp stream, cat again and run second linear
|
447
|
+
# compute activation in mlp stream, cat again and run second linear
|
448
|
+
# layer
|
447
449
|
output = self.linear2(
|
448
450
|
ops.concatenate(
|
449
451
|
(attn, keras.activations.gelu(mlp, approximate=True)), 2
|
@@ -3,19 +3,21 @@ from keras import ops
|
|
3
3
|
|
4
4
|
|
5
5
|
class TimestepEmbedding(keras.layers.Layer):
|
6
|
-
"""
|
7
|
-
Creates sinusoidal timestep embeddings.
|
6
|
+
"""Creates sinusoidal timestep embeddings.
|
8
7
|
|
9
8
|
Call arguments:
|
10
|
-
t:
|
9
|
+
t: Tensor of shape (N,), representing N indices, one per batch element.
|
11
10
|
These values may be fractional.
|
12
11
|
dim: int. The dimension of the output.
|
13
|
-
max_period: int, optional. Controls the minimum frequency of the
|
14
|
-
|
12
|
+
max_period: int, optional. Controls the minimum frequency of the
|
13
|
+
embeddings. Defaults to 10000.
|
14
|
+
time_factor: float, optional. A scaling factor applied to `t`. Defaults
|
15
|
+
to 1000.0.
|
15
16
|
|
16
17
|
Returns:
|
17
|
-
|
18
|
-
|
18
|
+
A tensor of shape (N, D) representing the positional embeddings,
|
19
|
+
where N is the number of batch elements and D is the specified
|
20
|
+
dimension `dim`.
|
19
21
|
"""
|
20
22
|
|
21
23
|
def call(self, t, dim, max_period=10000, time_factor=1000.0):
|
@@ -68,7 +70,8 @@ class ApplyRoPE(keras.layers.Layer):
|
|
68
70
|
Call arguments:
|
69
71
|
xq: KerasTensor. The query tensor of shape (..., L, D).
|
70
72
|
xk: KerasTensor. The key tensor of shape (..., L, D).
|
71
|
-
freqs_cis: KerasTensor. The frequency complex numbers tensor with shape
|
73
|
+
freqs_cis: KerasTensor. The frequency complex numbers tensor with shape
|
74
|
+
`(..., 2)`.
|
72
75
|
|
73
76
|
Returns:
|
74
77
|
tuple[KerasTensor, KerasTensor]: The transformed query and key tensors.
|
@@ -91,12 +94,12 @@ class ApplyRoPE(keras.layers.Layer):
|
|
91
94
|
|
92
95
|
|
93
96
|
class FluxRoPEAttention(keras.layers.Layer):
|
94
|
-
"""
|
95
|
-
Computes the attention mechanism with the RoPE transformation applied to the query and key tensors.
|
97
|
+
"""Computes the attention mechanism with RoPE.
|
96
98
|
|
97
99
|
Args:
|
98
100
|
dropout_p: float, optional. Dropout probability. Defaults to 0.0.
|
99
|
-
is_causal: bool, optional. If True, applies causal masking. Defaults to
|
101
|
+
is_causal: bool, optional. If True, applies causal masking. Defaults to
|
102
|
+
False.
|
100
103
|
|
101
104
|
Call arguments:
|
102
105
|
q: KerasTensor. Query tensor of shape (..., L, D).
|
@@ -122,12 +125,14 @@ class FluxRoPEAttention(keras.layers.Layer):
|
|
122
125
|
q, k, v, dropout_p=self.dropout_p, is_causal=self.is_causal
|
123
126
|
)
|
124
127
|
x = ops.transpose(x, (0, 2, 1, 3))
|
125
|
-
b,
|
126
|
-
return ops.reshape(x, (b,
|
128
|
+
b, s, h, d = ops.shape(x)
|
129
|
+
return ops.reshape(x, (b, s, h * d))
|
127
130
|
|
128
131
|
|
129
|
-
# TODO: This is probably already implemented in several places, but is needed to
|
130
|
-
#
|
132
|
+
# TODO: This is probably already implemented in several places, but is needed to
|
133
|
+
# ensure numeric equivalence to the original implementation. It uses
|
134
|
+
# torch.functional.scaled_dot_product_attention() - do we have an equivalent
|
135
|
+
# already in Keras?
|
131
136
|
def scaled_dot_product_attention(
|
132
137
|
query,
|
133
138
|
key,
|
@@ -144,9 +149,11 @@ def scaled_dot_product_attention(
|
|
144
149
|
query: KerasTensor. Query tensor of shape (..., L, D).
|
145
150
|
key: KerasTensor. Key tensor of shape (..., S, D).
|
146
151
|
value: KerasTensor. Value tensor of shape (..., S, D).
|
147
|
-
attn_mask: KerasTensor, optional. Attention mask tensor. Defaults to
|
152
|
+
attn_mask: KerasTensor, optional. Attention mask tensor. Defaults to
|
153
|
+
None.
|
148
154
|
dropout_p: float, optional. Dropout probability. Defaults to 0.0.
|
149
|
-
is_causal: bool, optional. If True, applies causal masking. Defaults to
|
155
|
+
is_causal: bool, optional. If True, applies causal masking. Defaults to
|
156
|
+
False.
|
150
157
|
scale: float, optional. Scale factor for attention. Defaults to None.
|
151
158
|
|
152
159
|
Returns:
|
@@ -12,41 +12,47 @@ from keras_hub.src.models.flux.flux_maths import TimestepEmbedding
|
|
12
12
|
|
13
13
|
@keras_hub_export("keras_hub.models.FluxBackbone")
|
14
14
|
class FluxBackbone(Backbone):
|
15
|
-
"""
|
16
|
-
|
15
|
+
"""Transformer model for flow matching on sequences.
|
16
|
+
|
17
|
+
The model processes image and text data with associated positional and
|
18
|
+
timestep embeddings, and optionally applies guidance embedding.
|
19
|
+
Double-stream blocks handle separate image and text streams, while
|
20
|
+
single-stream blocks combine these streams. Ported from:
|
21
|
+
https://github.com/black-forest-labs/flux
|
17
22
|
|
18
|
-
The model processes image and text data with associated positional and timestep
|
19
|
-
embeddings, and optionally applies guidance embedding. Double-stream blocks
|
20
|
-
handle separate image and text streams, while single-stream blocks combine
|
21
|
-
these streams. Ported from: https://github.com/black-forest-labs/flux
|
22
23
|
|
23
24
|
Args:
|
24
25
|
input_channels: int. The number of input channels.
|
25
|
-
hidden_size: int. The hidden size of the transformer, must be divisible
|
26
|
+
hidden_size: int. The hidden size of the transformer, must be divisible
|
27
|
+
by `num_heads`.
|
26
28
|
mlp_ratio: float. The ratio of the MLP dimension to the hidden size.
|
27
29
|
num_heads: int. The number of attention heads.
|
28
30
|
depth: int. The number of double-stream blocks.
|
29
31
|
depth_single_blocks: int. The number of single-stream blocks.
|
30
|
-
axes_dim: list[int]. A list of dimensions for the positional embedding
|
32
|
+
axes_dim: list[int]. A list of dimensions for the positional embedding
|
33
|
+
axes.
|
31
34
|
theta: int. The base frequency for positional embeddings.
|
32
|
-
use_bias: bool. Whether to apply bias to the query, key, and value
|
35
|
+
use_bias: bool. Whether to apply bias to the query, key, and value
|
36
|
+
projections.
|
33
37
|
guidance_embed: bool. If True, applies guidance embedding in the model.
|
34
38
|
|
35
39
|
Call arguments:
|
36
|
-
image: KerasTensor. Image input tensor of shape (N, L, D) where N is the
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
+
image: KerasTensor. Image input tensor of shape (N, L, D) where N is the
|
41
|
+
batch size, L is the sequence length, and D is the feature
|
42
|
+
dimension.
|
43
|
+
image_ids: KerasTensor. Image ID input tensor of shape (N, L, D)
|
44
|
+
corresponding to the image sequences.
|
40
45
|
text: KerasTensor. Text input tensor of shape (N, L, D).
|
41
|
-
text_ids: KerasTensor. Text ID input tensor of shape (N, L, D)
|
42
|
-
to the text sequences.
|
43
|
-
timesteps: KerasTensor. Timestep tensor used to compute positional
|
46
|
+
text_ids: KerasTensor. Text ID input tensor of shape (N, L, D)
|
47
|
+
corresponding to the text sequences.
|
48
|
+
timesteps: KerasTensor. Timestep tensor used to compute positional
|
49
|
+
embeddings.
|
44
50
|
y: KerasTensor. Additional vector input, such as target values.
|
45
51
|
guidance: KerasTensor, optional. Guidance input tensor used
|
46
52
|
in guidance-embedded models.
|
47
53
|
Raises:
|
48
|
-
ValueError: If `hidden_size` is not divisible by `num_heads`, or if
|
49
|
-
|
54
|
+
ValueError: If `hidden_size` is not divisible by `num_heads`, or if
|
55
|
+
`sum(axes_dim)` is not equal to the positional embedding dimension.
|
50
56
|
"""
|
51
57
|
|
52
58
|
def __init__(
|
@@ -69,7 +75,6 @@ class FluxBackbone(Backbone):
|
|
69
75
|
y_shape=(None, 128),
|
70
76
|
**kwargs,
|
71
77
|
):
|
72
|
-
|
73
78
|
# === Layers ===
|
74
79
|
self.positional_embedder = EmbedND(theta=theta, axes_dim=axes_dim)
|
75
80
|
self.image_input_embedder = keras.layers.Dense(
|
@@ -4,7 +4,8 @@ presets = {
|
|
4
4
|
"schnell": {
|
5
5
|
"metadata": {
|
6
6
|
"description": (
|
7
|
-
"A 12 billion parameter rectified flow transformer capable of
|
7
|
+
"A 12 billion parameter rectified flow transformer capable of "
|
8
|
+
"generating images from text descriptions."
|
8
9
|
),
|
9
10
|
"params": 124439808,
|
10
11
|
"path": "flux",
|
@@ -24,11 +24,15 @@ class FluxTextToImage(TextToImage):
|
|
24
24
|
|
25
25
|
Use `generate()` to do image generation.
|
26
26
|
```python
|
27
|
+
prompt = (
|
28
|
+
"Astronaut in a jungle, cold color palette, muted colors, "
|
29
|
+
"detailed, 8k"
|
30
|
+
)
|
27
31
|
text_to_image = keras_hub.models.FluxTextToImage.from_preset(
|
28
32
|
"TBA", height=512, width=512
|
29
33
|
)
|
30
34
|
text_to_image.generate(
|
31
|
-
|
35
|
+
prompt
|
32
36
|
)
|
33
37
|
|
34
38
|
# Generate with batched prompts.
|
@@ -38,7 +42,7 @@ class FluxTextToImage(TextToImage):
|
|
38
42
|
|
39
43
|
# Generate with different `num_steps` and `guidance_scale`.
|
40
44
|
text_to_image.generate(
|
41
|
-
|
45
|
+
prompt,
|
42
46
|
num_steps=50,
|
43
47
|
guidance_scale=5.0,
|
44
48
|
)
|
@@ -46,7 +50,7 @@ class FluxTextToImage(TextToImage):
|
|
46
50
|
# Generate with `negative_prompts`.
|
47
51
|
text_to_image.generate(
|
48
52
|
{
|
49
|
-
"prompts":
|
53
|
+
"prompts": prompt,
|
50
54
|
"negative_prompts": "green color",
|
51
55
|
}
|
52
56
|
)
|
@@ -44,10 +44,10 @@ class GemmaBackbone(Backbone):
|
|
44
44
|
`hidden_dim / num_query_heads`. Defaults to True.
|
45
45
|
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
|
46
46
|
block. Defaults to False.
|
47
|
-
use_post_attention_norm: boolean. Whether to normalize after the
|
48
|
-
block. Defaults to False.
|
49
|
-
attention_logit_soft_cap: None or int. Soft cap for the attention
|
50
|
-
Defaults to None.
|
47
|
+
use_post_attention_norm: boolean. Whether to normalize after the
|
48
|
+
attention block. Defaults to False.
|
49
|
+
attention_logit_soft_cap: None or int. Soft cap for the attention
|
50
|
+
logits. Defaults to None.
|
51
51
|
final_logit_soft_cap: None or int. Soft cap for the final logits.
|
52
52
|
Defaults to None.
|
53
53
|
use_sliding_window_attention boolean. Whether to use sliding local
|
@@ -205,7 +205,9 @@ class GemmaBackbone(Backbone):
|
|
205
205
|
"final_logit_soft_cap": self.final_logit_soft_cap,
|
206
206
|
"attention_logit_soft_cap": self.attention_logit_soft_cap,
|
207
207
|
"sliding_window_size": self.sliding_window_size,
|
208
|
-
"use_sliding_window_attention":
|
208
|
+
"use_sliding_window_attention": (
|
209
|
+
self.use_sliding_window_attention
|
210
|
+
),
|
209
211
|
}
|
210
212
|
)
|
211
213
|
return config
|
@@ -224,7 +226,8 @@ class GemmaBackbone(Backbone):
|
|
224
226
|
|
225
227
|
Example:
|
226
228
|
```
|
227
|
-
# Feel free to change the mesh shape to balance data and model
|
229
|
+
# Feel free to change the mesh shape to balance data and model
|
230
|
+
# parallelism
|
228
231
|
mesh = keras.distribution.DeviceMesh(
|
229
232
|
shape=(1, 8), axis_names=('batch', 'model'),
|
230
233
|
devices=keras.distribution.list_devices())
|
@@ -237,12 +240,16 @@ class GemmaBackbone(Backbone):
|
|
237
240
|
gemma_model = keras_hub.models.GemmaCausalLM.from_preset()
|
238
241
|
```
|
239
242
|
|
240
|
-
To see how the layout map was applied, load the model then run (for one
|
243
|
+
To see how the layout map was applied, load the model then run (for one
|
244
|
+
decoder block):
|
241
245
|
```
|
242
246
|
embedding_layer = gemma_model.backbone.get_layer("token_embedding")
|
243
247
|
decoder_block_1 = gemma_model.backbone.get_layer('decoder_block_1')
|
244
248
|
for variable in embedding_layer.weights + decoder_block_1.weights:
|
245
|
-
print(
|
249
|
+
print(
|
250
|
+
f'{variable.path:<58} {str(variable.shape):<16} '
|
251
|
+
f'{str(variable.value.sharding.spec)}'
|
252
|
+
)
|
246
253
|
```
|
247
254
|
|
248
255
|
Args:
|
@@ -257,22 +264,22 @@ class GemmaBackbone(Backbone):
|
|
257
264
|
for all the model weights.
|
258
265
|
"""
|
259
266
|
# The weight path and shape of the Gemma backbone is like below (for 2G)
|
260
|
-
# token_embedding/embeddings, (256128, 2048)
|
267
|
+
# token_embedding/embeddings, (256128, 2048)
|
261
268
|
# repeat block for decoder
|
262
269
|
# ...
|
263
|
-
# decoder_block_17/pre_attention_norm/scale, (2048,)
|
264
|
-
# decoder_block_17/attention/query/kernel, (8, 2048, 256)
|
265
|
-
# decoder_block_17/attention/key/kernel, (8, 2048, 256)
|
266
|
-
# decoder_block_17/attention/value/kernel, (8, 2048, 256)
|
267
|
-
# decoder_block_17/attention/attention_output/kernel, (8, 256, 2048)
|
268
|
-
# decoder_block_17/pre_ffw_norm/scale, (2048,)
|
269
|
-
# decoder_block_17/ffw_gating/kernel, (2048, 16384)
|
270
|
-
# decoder_block_17/ffw_gating_2/kernel, (2048, 16384)
|
271
|
-
# decoder_block_17/ffw_linear/kernel, (16384, 2048)
|
270
|
+
# decoder_block_17/pre_attention_norm/scale, (2048,)
|
271
|
+
# decoder_block_17/attention/query/kernel, (8, 2048, 256)
|
272
|
+
# decoder_block_17/attention/key/kernel, (8, 2048, 256)
|
273
|
+
# decoder_block_17/attention/value/kernel, (8, 2048, 256)
|
274
|
+
# decoder_block_17/attention/attention_output/kernel, (8, 256, 2048)
|
275
|
+
# decoder_block_17/pre_ffw_norm/scale, (2048,)
|
276
|
+
# decoder_block_17/ffw_gating/kernel, (2048, 16384)
|
277
|
+
# decoder_block_17/ffw_gating_2/kernel, (2048, 16384)
|
278
|
+
# decoder_block_17/ffw_linear/kernel, (16384, 2048)
|
272
279
|
if not isinstance(device_mesh, keras.distribution.DeviceMesh):
|
273
280
|
raise ValueError(
|
274
|
-
"Invalid device_mesh type. Expected
|
275
|
-
f" got {type(device_mesh)}"
|
281
|
+
"Invalid device_mesh type. Expected "
|
282
|
+
f"`keras.distribution.Device`, got {type(device_mesh)}"
|
276
283
|
)
|
277
284
|
if model_parallel_dim_name not in device_mesh.axis_names:
|
278
285
|
raise ValueError(
|
@@ -187,8 +187,8 @@ class GemmaCausalLM(CausalLM):
|
|
187
187
|
Args:
|
188
188
|
token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
|
189
189
|
cache: a dense float Tensor, the cache of key and value.
|
190
|
-
cache_update_index: int, or int Tensor. The index of current inputs
|
191
|
-
whole sequence.
|
190
|
+
cache_update_index: int, or int Tensor. The index of current inputs
|
191
|
+
in the whole sequence.
|
192
192
|
|
193
193
|
Returns:
|
194
194
|
A (logits, hidden_states, cache) tuple. Where `logits` is the
|
@@ -220,7 +220,9 @@ class GemmaDecoderBlock(keras.layers.Layer):
|
|
220
220
|
"use_post_ffw_norm": self.use_post_ffw_norm,
|
221
221
|
"use_post_attention_norm": self.use_post_attention_norm,
|
222
222
|
"logit_soft_cap": self.logit_soft_cap,
|
223
|
-
"use_sliding_window_attention":
|
223
|
+
"use_sliding_window_attention": (
|
224
|
+
self.use_sliding_window_attention
|
225
|
+
),
|
224
226
|
"sliding_window_size": self.sliding_window_size,
|
225
227
|
"query_head_dim_normalize": self.query_head_dim_normalize,
|
226
228
|
}
|
@@ -130,7 +130,9 @@ backbone_presets = {
|
|
130
130
|
},
|
131
131
|
"gemma2_instruct_2b_en": {
|
132
132
|
"metadata": {
|
133
|
-
"description":
|
133
|
+
"description": (
|
134
|
+
"2 billion parameter, 26-layer, instruction tuned Gemma model."
|
135
|
+
),
|
134
136
|
"params": 2614341888,
|
135
137
|
"path": "gemma",
|
136
138
|
},
|
@@ -146,7 +148,9 @@ backbone_presets = {
|
|
146
148
|
},
|
147
149
|
"gemma2_instruct_9b_en": {
|
148
150
|
"metadata": {
|
149
|
-
"description":
|
151
|
+
"description": (
|
152
|
+
"9 billion parameter, 42-layer, instruction tuned Gemma model."
|
153
|
+
),
|
150
154
|
"params": 9241705984,
|
151
155
|
"path": "gemma",
|
152
156
|
},
|
@@ -162,7 +166,9 @@ backbone_presets = {
|
|
162
166
|
},
|
163
167
|
"gemma2_instruct_27b_en": {
|
164
168
|
"metadata": {
|
165
|
-
"description":
|
169
|
+
"description": (
|
170
|
+
"27 billion parameter, 42-layer, instruction tuned Gemma model."
|
171
|
+
),
|
166
172
|
"params": 27227128320,
|
167
173
|
"path": "gemma",
|
168
174
|
},
|
@@ -172,8 +172,8 @@ class GPT2CausalLM(CausalLM):
|
|
172
172
|
Args:
|
173
173
|
token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
|
174
174
|
cache: a dense float Tensor, the cache of key and value.
|
175
|
-
cache_update_index: int, or int Tensor. The index of current inputs
|
176
|
-
whole sequence.
|
175
|
+
cache_update_index: int, or int Tensor. The index of current inputs
|
176
|
+
in the whole sequence.
|
177
177
|
|
178
178
|
Returns:
|
179
179
|
A (logits, hidden_states, cache) tuple. Where `logits` is the
|
@@ -202,7 +202,8 @@ class GPTNeoXAttention(keras.layers.Layer):
|
|
202
202
|
training=training,
|
203
203
|
)
|
204
204
|
|
205
|
-
# Reshape `attention_output` to
|
205
|
+
# Reshape `attention_output` to
|
206
|
+
# `(batch_size, sequence_length, hidden_dim)`.
|
206
207
|
attention_output = ops.reshape(
|
207
208
|
attention_output,
|
208
209
|
[
|
@@ -27,9 +27,9 @@ class GPTNeoXCausalLM(CausalLM):
|
|
27
27
|
|
28
28
|
Args:
|
29
29
|
backbone: A `keras_hub.models.GPTNeoXBackbone` instance.
|
30
|
-
preprocessor: A `keras_hub.models.GPTNeoXCausalLMPreprocessor` or
|
31
|
-
If `None`, this model will not apply preprocessing, and
|
32
|
-
should be preprocessed before calling the model.
|
30
|
+
preprocessor: A `keras_hub.models.GPTNeoXCausalLMPreprocessor` or
|
31
|
+
`None`. If `None`, this model will not apply preprocessing, and
|
32
|
+
inputs should be preprocessed before calling the model.
|
33
33
|
"""
|
34
34
|
|
35
35
|
backbone_cls = GPTNeoXBackbone
|
@@ -16,7 +16,8 @@ class GPTNeoXDecoder(keras.layers.Layer):
|
|
16
16
|
|
17
17
|
This class follows the architecture of the GPT-NeoX decoder layer in the
|
18
18
|
paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745).
|
19
|
-
Users can instantiate multiple instances of this class to stack up a
|
19
|
+
Users can instantiate multiple instances of this class to stack up a
|
20
|
+
decoder.
|
20
21
|
|
21
22
|
This layer will always apply a causal mask to the decoder attention layer.
|
22
23
|
|
@@ -46,7 +46,10 @@ class ImageClassifierPreprocessor(Preprocessor):
|
|
46
46
|
x, y = preprocessor(x, y)
|
47
47
|
|
48
48
|
# Resize a batch of labeled images.
|
49
|
-
x, y = [
|
49
|
+
x, y = [
|
50
|
+
np.random.randint(0, 256, (512, 512, 3)),
|
51
|
+
np.zeros((512, 512, 3))
|
52
|
+
], [1, 0]
|
50
53
|
x, y = preprocessor(x, y)
|
51
54
|
|
52
55
|
# Use a `tf.data.Dataset`.
|
@@ -31,8 +31,8 @@ class ImageObjectDetector(Task):
|
|
31
31
|
):
|
32
32
|
"""Configures the `ImageObjectDetector` task for training.
|
33
33
|
|
34
|
-
The `ImageObjectDetector` task extends the default compilation signature
|
35
|
-
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
|
34
|
+
The `ImageObjectDetector` task extends the default compilation signature
|
35
|
+
of `keras.Model.compile` with defaults for `optimizer`, `loss`, and
|
36
36
|
`metrics`. To override these defaults, pass any value
|
37
37
|
to these arguments during compilation.
|
38
38
|
|
@@ -21,10 +21,10 @@ class ImageObjectDetectorPreprocessor(Preprocessor):
|
|
21
21
|
be the a dict of `{"boxes": Tensor(batch_size, num_boxes, 4),
|
22
22
|
"classes": (batch_size, num_boxes)}.
|
23
23
|
|
24
|
-
The layer will returns either `x`, an `(x, y)` tuple if labels were
|
25
|
-
or an `(x, y, sample_weight)` tuple if labels and sample weight
|
26
|
-
provided. `x` will be the input images after all model preprocessing
|
27
|
-
been applied.
|
24
|
+
The layer will returns either `x`, an `(x, y)` tuple if labels were
|
25
|
+
provided, or an `(x, y, sample_weight)` tuple if labels and sample weight
|
26
|
+
were provided. `x` will be the input images after all model preprocessing
|
27
|
+
has been applied.
|
28
28
|
|
29
29
|
All `ImageObjectDetectorPreprocessor` tasks include a `from_preset()`
|
30
30
|
constructor which can be used to load a pre-trained config and vocabularies.
|