keras-hub-nightly 0.19.0.dev202412120352__py3-none-any.whl → 0.19.0.dev202412140350__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. keras_hub/api/layers/__init__.py +1 -0
  2. keras_hub/api/models/__init__.py +11 -6
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/converters.py +2 -2
  5. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  6. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  7. keras_hub/src/layers/modeling/rms_normalization.py +8 -6
  8. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  9. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  10. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  11. keras_hub/src/layers/modeling/transformer_encoder.py +3 -1
  12. keras_hub/src/metrics/bleu.py +1 -1
  13. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  14. keras_hub/src/models/bart/bart_backbone.py +4 -4
  15. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  16. keras_hub/src/models/bert/bert_presets.py +4 -2
  17. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  18. keras_hub/src/models/causal_lm.py +19 -15
  19. keras_hub/src/models/clip/clip_vision_embedding.py +1 -1
  20. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  21. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  22. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  23. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  24. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  25. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  26. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +17 -13
  27. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -3
  28. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +1 -1
  29. keras_hub/src/models/densenet/densenet_backbone.py +3 -1
  30. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -1
  31. keras_hub/src/models/densenet/densenet_presets.py +6 -6
  32. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  33. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  34. keras_hub/src/models/distil_bert/distil_bert_presets.py +2 -1
  35. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  36. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  37. keras_hub/src/models/efficientnet/cba.py +1 -1
  38. keras_hub/src/models/efficientnet/efficientnet_backbone.py +20 -8
  39. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +1 -1
  40. keras_hub/src/models/efficientnet/efficientnet_presets.py +12 -11
  41. keras_hub/src/models/efficientnet/fusedmbconv.py +3 -5
  42. keras_hub/src/models/efficientnet/mbconv.py +1 -1
  43. keras_hub/src/models/electra/electra_backbone.py +2 -2
  44. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  45. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  46. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  47. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  48. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  49. keras_hub/src/models/flux/flux_layers.py +46 -44
  50. keras_hub/src/models/flux/flux_maths.py +24 -17
  51. keras_hub/src/models/flux/flux_model.py +24 -19
  52. keras_hub/src/models/flux/flux_presets.py +2 -1
  53. keras_hub/src/models/flux/flux_text_to_image.py +7 -3
  54. keras_hub/src/models/gemma/gemma_backbone.py +27 -20
  55. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  56. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  57. keras_hub/src/models/gemma/gemma_presets.py +9 -3
  58. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  59. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  60. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  61. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  62. keras_hub/src/models/image_classifier_preprocessor.py +4 -1
  63. keras_hub/src/models/image_object_detector.py +2 -2
  64. keras_hub/src/models/image_object_detector_preprocessor.py +4 -4
  65. keras_hub/src/models/image_segmenter_preprocessor.py +2 -2
  66. keras_hub/src/models/llama/llama_backbone.py +34 -26
  67. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  68. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  69. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  70. keras_hub/src/models/mistral/mistral_causal_lm.py +3 -3
  71. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  72. keras_hub/src/models/mit/mit_backbone.py +4 -3
  73. keras_hub/src/models/mit/mit_layers.py +2 -1
  74. keras_hub/src/models/mobilenet/mobilenet_backbone.py +7 -7
  75. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  76. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +5 -3
  77. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +2 -2
  78. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  79. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  80. keras_hub/src/models/preprocessor.py +2 -2
  81. keras_hub/src/models/retinanet/feature_pyramid.py +3 -2
  82. keras_hub/src/models/retinanet/prediction_head.py +2 -2
  83. keras_hub/src/models/retinanet/retinanet_backbone.py +2 -2
  84. keras_hub/src/models/retinanet/retinanet_image_converter.py +1 -1
  85. keras_hub/src/models/retinanet/retinanet_object_detector.py +5 -6
  86. keras_hub/src/models/retinanet/retinanet_presets.py +2 -1
  87. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  88. keras_hub/src/models/roberta/roberta_presets.py +4 -2
  89. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  90. keras_hub/src/models/sam/sam_backbone.py +2 -2
  91. keras_hub/src/models/sam/sam_image_segmenter.py +6 -5
  92. keras_hub/src/models/sam/sam_layers.py +5 -3
  93. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  94. keras_hub/src/models/sam/sam_transformer.py +5 -4
  95. keras_hub/src/models/segformer/segformer_backbone.py +18 -14
  96. keras_hub/src/models/segformer/segformer_image_segmenter.py +51 -38
  97. keras_hub/src/models/segformer/segformer_presets.py +24 -12
  98. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  99. keras_hub/src/models/stable_diffusion_3/mmdit.py +20 -1
  100. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +1 -1
  101. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +13 -6
  102. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +2 -2
  103. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +7 -3
  104. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  105. keras_hub/src/models/task.py +4 -2
  106. keras_hub/src/models/text_classifier.py +2 -2
  107. keras_hub/src/models/text_to_image.py +5 -1
  108. keras_hub/src/models/vae/vae_layers.py +0 -1
  109. keras_hub/src/models/vit/__init__.py +5 -0
  110. keras_hub/src/models/vit/vit_backbone.py +152 -0
  111. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  112. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  113. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  114. keras_hub/src/models/vit/vit_layers.py +391 -0
  115. keras_hub/src/models/vit/vit_presets.py +49 -0
  116. keras_hub/src/models/vit_det/vit_det_backbone.py +4 -2
  117. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  118. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -3
  119. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  120. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  121. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  122. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  123. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  124. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  125. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  126. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  127. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  128. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  129. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  130. keras_hub/src/samplers/sampler.py +2 -1
  131. keras_hub/src/tests/test_case.py +2 -2
  132. keras_hub/src/tokenizers/byte_pair_tokenizer.py +2 -2
  133. keras_hub/src/tokenizers/byte_tokenizer.py +2 -8
  134. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  135. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +7 -12
  136. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +8 -5
  137. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +7 -3
  138. keras_hub/src/utils/preset_utils.py +25 -18
  139. keras_hub/src/utils/tensor_utils.py +4 -4
  140. keras_hub/src/utils/timm/convert_efficientnet.py +2 -4
  141. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  142. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  143. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  144. keras_hub/src/version_utils.py +1 -1
  145. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/METADATA +1 -1
  146. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/RECORD +148 -140
  147. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/WHEEL +0 -0
  148. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/top_level.txt +0 -0
@@ -9,11 +9,10 @@ from keras_hub.src.models.flux.flux_maths import rearrange_symbolic_tensors
9
9
 
10
10
 
11
11
  class EmbedND(keras.Model):
12
- """
13
- Embedding layer for N-dimensional inputs using Rotary Positional Embedding (RoPE).
12
+ """Embedding layer for N-dimensional inputs using RoPE.
14
13
 
15
- This layer applies RoPE embeddings across multiple axes of the input tensor and
16
- concatenates the embeddings along a specified axis.
14
+ This layer applies RoPE embeddings across multiple axes of the input tensor
15
+ and concatenates the embeddings along a specified axis.
17
16
 
18
17
  Args:
19
18
  theta. Rotational angle parameter for RoPE.
@@ -32,14 +31,14 @@ class EmbedND(keras.Model):
32
31
  self.rope.build((input_shape[:-1] + (self.axes_dim[i],)))
33
32
 
34
33
  def call(self, ids):
35
- """
36
- Computes the positional embeddings for each axis and concatenates them.
34
+ """Computes the positional embeddings for each axis and concatenates.
37
35
 
38
36
  Args:
39
37
  ids: KerasTensor. Input tensor of shape (..., num_axes).
40
38
 
41
39
  Returns:
42
- KerasTensor: Positional embeddings of shape (..., concatenated_dim, 1, ...).
40
+ KerasTensor: Positional embeddings of shape
41
+ (..., concatenated_dim, 1, ...).
43
42
  """
44
43
  n_axes = ids.shape[-1]
45
44
  emb = ops.concatenate(
@@ -54,8 +53,7 @@ class EmbedND(keras.Model):
54
53
 
55
54
 
56
55
  class MLPEmbedder(keras.Model):
57
- """
58
- A simple multi-layer perceptron (MLP) embedder model.
56
+ """A simple multi-layer perceptron (MLP) embedder model.
59
57
 
60
58
  This model applies a linear transformation followed by the SiLU activation
61
59
  function and another linear transformation to the input tensor.
@@ -76,15 +74,14 @@ class MLPEmbedder(keras.Model):
76
74
  self.output_layer.build((input_shape[0], self.input_layer.units))
77
75
 
78
76
  def call(self, x):
79
- """
80
- Applies the MLP embedding to the input tensor.
77
+ """Applies the MLP embedding to the input tensor.
81
78
 
82
79
  Args:
83
- x: KerasTensor. Input tensor of shape (batch_size, in_dim).
80
+ x: Input tensor of shape (batch_size, in_dim).
84
81
 
85
82
  Returns:
86
- KerasTensor: Output tensor of shape (batch_size, hidden_dim) after applying
87
- the MLP transformations.
83
+ Output tensor of shape (batch_size, hidden_dim) after applying the
84
+ MLP transformations.
88
85
  """
89
86
  x = self.input_layer(x)
90
87
  x = self.silu(x)
@@ -92,11 +89,10 @@ class MLPEmbedder(keras.Model):
92
89
 
93
90
 
94
91
  class QKNorm(keras.layers.Layer):
95
- """
96
- A layer that applies RMS normalization to query and key tensors.
92
+ """A layer that applies RMS normalization to query and key tensors.
97
93
 
98
- This layer normalizes the input query and key tensors using separate RMSNormalization
99
- layers for each.
94
+ This layer normalizes the input query and key tensors using separate
95
+ RMSNormalization layers for each.
100
96
 
101
97
  Args:
102
98
  input_dim. The dimensionality of the input query and key tensors.
@@ -120,7 +116,8 @@ class QKNorm(keras.layers.Layer):
120
116
  k: KerasTensor. The key tensor of shape (batch_size, input_dim).
121
117
 
122
118
  Returns:
123
- tuple[KerasTensor, KerasTensor]: A tuple containing the normalized query and key tensors.
119
+ tuple[KerasTensor, KerasTensor]: A tuple containing the normalized
120
+ query and key tensors.
124
121
  """
125
122
  q = self.query_norm(q)
126
123
  k = self.key_norm(k)
@@ -128,17 +125,17 @@ class QKNorm(keras.layers.Layer):
128
125
 
129
126
 
130
127
  class SelfAttention(keras.Model):
131
- """
132
- Multi-head self-attention layer with RoPE embeddings and RMS normalization.
128
+ """Multi-head self-attention layer with RoPE and RMS normalization.
133
129
 
134
130
  This layer performs self-attention over the input sequence and applies RMS
135
- normalization to the query and key tensors before computing the attention scores.
131
+ normalization to the query and key tensors before computing the attention
132
+ scores.
136
133
 
137
134
  Args:
138
135
  dim: int. Dimensionality of the input tensor.
139
136
  num_heads: int. Number of attention heads. Default is 8.
140
- use_bias: bool. Whether to use bias in the query, key, value projection layers.
141
- Default is False.
137
+ use_bias: bool. Whether to use bias in the query, key, value projection
138
+ layers. Default is False.
142
139
  """
143
140
 
144
141
  def __init__(self, dim, num_heads=8, use_bias=False):
@@ -159,12 +156,12 @@ class SelfAttention(keras.Model):
159
156
  self.proj.build((None, input_shape[1], input_shape[-1]))
160
157
 
161
158
  def call(self, x, positional_encoding):
162
- """
163
- Applies self-attention with RoPE embeddings.
159
+ """Applies self-attention with RoPE embeddings.
164
160
 
165
161
  Args:
166
162
  x: KerasTensor. Input tensor of shape (batch_size, seq_len, dim).
167
- positional_encoding: KerasTensor. Positional encoding tensor for RoPE.
163
+ positional_encoding: KerasTensor. Positional encoding tensor for
164
+ RoPE.
168
165
 
169
166
  Returns:
170
167
  KerasTensor: Output tensor after self-attention and projection.
@@ -180,12 +177,11 @@ class SelfAttention(keras.Model):
180
177
 
181
178
 
182
179
  class Modulation(keras.Model):
183
- """
184
- Modulation layer that produces shift, scale, and gate tensors.
180
+ """Modulation layer that produces shift, scale, and gate tensors.
185
181
 
186
- This layer applies a SiLU activation to the input tensor followed by a linear
187
- transformation to generate modulation parameters. It can optionally generate two
188
- sets of modulation parameters.
182
+ This layer applies a SiLU activation to the input tensor followed by a
183
+ linear transformation to generate modulation parameters. It can optionally
184
+ generate two sets of modulation parameters.
189
185
 
190
186
  Args:
191
187
  dim: int. Dimensionality of the modulation output.
@@ -212,8 +208,9 @@ class Modulation(keras.Model):
212
208
  x: KerasTensor. Input tensor.
213
209
 
214
210
  Returns:
215
- tuple[ModulationOut, ModulationOut | None]: A tuple containing the shift,
216
- scale, and gate tensors. If `double` is True, returns two sets of modulation parameters.
211
+ tuple[ModulationOut, ModulationOut | None]: A tuple containing th
212
+ shift, scale, and gate tensors. If `double` is True, returns two
213
+ sets of modulation parameters.
217
214
  """
218
215
  x = keras.layers.Activation("silu")(x)
219
216
  out = self.linear_projection(x)
@@ -239,8 +236,10 @@ class DoubleStreamBlock(keras.Model):
239
236
  Args:
240
237
  hidden_size: int. The hidden dimension size for the model.
241
238
  num_heads: int. The number of attention heads.
242
- mlp_ratio: float. The ratio of the MLP hidden dimension to the hidden size.
243
- use_bias: bool, optional. Whether to include bias in QKV projection. Default is False.
239
+ mlp_ratio: float. The ratio of the MLP hidden dimension to the hidde
240
+ size.
241
+ use_bias: bool, optional. Whether to include bias in QKV projection.
242
+ Default is False.
244
243
  """
245
244
 
246
245
  def __init__(
@@ -292,13 +291,13 @@ class DoubleStreamBlock(keras.Model):
292
291
  Forward pass for the DoubleStreamBlock.
293
292
 
294
293
  Args:
295
- image: KerasTensor. Input image tensor.
296
- text: KerasTensor. Input text tensor.
297
- modulation_encoding: KerasTensor. Modulation vector.
298
- positional_encoding: KerasTensor. Positional encoding tensor.
294
+ image: Input image tensor.
295
+ text: Input text tensor.
296
+ modulation_encoding: Modulation vector.
297
+ positional_encoding: Positional encoding tensor.
299
298
 
300
299
  Returns:
301
- Tuple[KerasTensor, KerasTensor]: The modified image and text tensors.
300
+ A `(image, text)` tuple of modified image and text tensors.
302
301
  """
303
302
  image_mod1, image_mod2 = self.image_mod(modulation_encoding)
304
303
  text_mod1, text_mod2 = self.text_mod(modulation_encoding)
@@ -367,8 +366,10 @@ class SingleStreamBlock(keras.Model):
367
366
  Args:
368
367
  hidden_size: int. The hidden dimension size for the model.
369
368
  num_heads: int. The number of attention heads.
370
- mlp_ratio: float, optional. The ratio of the MLP hidden dimension to the hidden size. Default is 4.0.
371
- qk_scale: float, optional. Scaling factor for the query-key product. Default is None.
369
+ mlp_ratio: float, optional. The ratio of the MLP hidden dimension to the
370
+ hidden size. Default is 4.0.
371
+ qk_scale: float, optional. Scaling factor for the query-key product.
372
+ Default is None.
372
373
  """
373
374
 
374
375
  def __init__(
@@ -443,7 +444,8 @@ class SingleStreamBlock(keras.Model):
443
444
  attn = self.attention(
444
445
  q, k=k, v=v, positional_encoding=positional_encoding
445
446
  )
446
- # compute activation in mlp stream, cat again and run second linear layer
447
+ # compute activation in mlp stream, cat again and run second linear
448
+ # layer
447
449
  output = self.linear2(
448
450
  ops.concatenate(
449
451
  (attn, keras.activations.gelu(mlp, approximate=True)), 2
@@ -3,19 +3,21 @@ from keras import ops
3
3
 
4
4
 
5
5
  class TimestepEmbedding(keras.layers.Layer):
6
- """
7
- Creates sinusoidal timestep embeddings.
6
+ """Creates sinusoidal timestep embeddings.
8
7
 
9
8
  Call arguments:
10
- t: KerasTensor of shape (N,), representing N indices, one per batch element.
9
+ t: Tensor of shape (N,), representing N indices, one per batch element.
11
10
  These values may be fractional.
12
11
  dim: int. The dimension of the output.
13
- max_period: int, optional. Controls the minimum frequency of the embeddings. Defaults to 10000.
14
- time_factor: float, optional. A scaling factor applied to `t`. Defaults to 1000.0.
12
+ max_period: int, optional. Controls the minimum frequency of the
13
+ embeddings. Defaults to 10000.
14
+ time_factor: float, optional. A scaling factor applied to `t`. Defaults
15
+ to 1000.0.
15
16
 
16
17
  Returns:
17
- KerasTensor: A tensor of shape (N, D) representing the positional embeddings,
18
- where N is the number of batch elements and D is the specified dimension `dim`.
18
+ A tensor of shape (N, D) representing the positional embeddings,
19
+ where N is the number of batch elements and D is the specified
20
+ dimension `dim`.
19
21
  """
20
22
 
21
23
  def call(self, t, dim, max_period=10000, time_factor=1000.0):
@@ -68,7 +70,8 @@ class ApplyRoPE(keras.layers.Layer):
68
70
  Call arguments:
69
71
  xq: KerasTensor. The query tensor of shape (..., L, D).
70
72
  xk: KerasTensor. The key tensor of shape (..., L, D).
71
- freqs_cis: KerasTensor. The frequency complex numbers tensor with shape (..., 2).
73
+ freqs_cis: KerasTensor. The frequency complex numbers tensor with shape
74
+ `(..., 2)`.
72
75
 
73
76
  Returns:
74
77
  tuple[KerasTensor, KerasTensor]: The transformed query and key tensors.
@@ -91,12 +94,12 @@ class ApplyRoPE(keras.layers.Layer):
91
94
 
92
95
 
93
96
  class FluxRoPEAttention(keras.layers.Layer):
94
- """
95
- Computes the attention mechanism with the RoPE transformation applied to the query and key tensors.
97
+ """Computes the attention mechanism with RoPE.
96
98
 
97
99
  Args:
98
100
  dropout_p: float, optional. Dropout probability. Defaults to 0.0.
99
- is_causal: bool, optional. If True, applies causal masking. Defaults to False.
101
+ is_causal: bool, optional. If True, applies causal masking. Defaults to
102
+ False.
100
103
 
101
104
  Call arguments:
102
105
  q: KerasTensor. Query tensor of shape (..., L, D).
@@ -122,12 +125,14 @@ class FluxRoPEAttention(keras.layers.Layer):
122
125
  q, k, v, dropout_p=self.dropout_p, is_causal=self.is_causal
123
126
  )
124
127
  x = ops.transpose(x, (0, 2, 1, 3))
125
- b, l, h, d = ops.shape(x)
126
- return ops.reshape(x, (b, l, h * d))
128
+ b, s, h, d = ops.shape(x)
129
+ return ops.reshape(x, (b, s, h * d))
127
130
 
128
131
 
129
- # TODO: This is probably already implemented in several places, but is needed to ensure numeric equivalence to the original
130
- # implementation. It uses torch.functional.scaled_dot_product_attention() - do we have an equivalent already in Keras?
132
+ # TODO: This is probably already implemented in several places, but is needed to
133
+ # ensure numeric equivalence to the original implementation. It uses
134
+ # torch.functional.scaled_dot_product_attention() - do we have an equivalent
135
+ # already in Keras?
131
136
  def scaled_dot_product_attention(
132
137
  query,
133
138
  key,
@@ -144,9 +149,11 @@ def scaled_dot_product_attention(
144
149
  query: KerasTensor. Query tensor of shape (..., L, D).
145
150
  key: KerasTensor. Key tensor of shape (..., S, D).
146
151
  value: KerasTensor. Value tensor of shape (..., S, D).
147
- attn_mask: KerasTensor, optional. Attention mask tensor. Defaults to None.
152
+ attn_mask: KerasTensor, optional. Attention mask tensor. Defaults to
153
+ None.
148
154
  dropout_p: float, optional. Dropout probability. Defaults to 0.0.
149
- is_causal: bool, optional. If True, applies causal masking. Defaults to False.
155
+ is_causal: bool, optional. If True, applies causal masking. Defaults to
156
+ False.
150
157
  scale: float, optional. Scale factor for attention. Defaults to None.
151
158
 
152
159
  Returns:
@@ -12,41 +12,47 @@ from keras_hub.src.models.flux.flux_maths import TimestepEmbedding
12
12
 
13
13
  @keras_hub_export("keras_hub.models.FluxBackbone")
14
14
  class FluxBackbone(Backbone):
15
- """
16
- Transformer model for flow matching on sequences.
15
+ """Transformer model for flow matching on sequences.
16
+
17
+ The model processes image and text data with associated positional and
18
+ timestep embeddings, and optionally applies guidance embedding.
19
+ Double-stream blocks handle separate image and text streams, while
20
+ single-stream blocks combine these streams. Ported from:
21
+ https://github.com/black-forest-labs/flux
17
22
 
18
- The model processes image and text data with associated positional and timestep
19
- embeddings, and optionally applies guidance embedding. Double-stream blocks
20
- handle separate image and text streams, while single-stream blocks combine
21
- these streams. Ported from: https://github.com/black-forest-labs/flux
22
23
 
23
24
  Args:
24
25
  input_channels: int. The number of input channels.
25
- hidden_size: int. The hidden size of the transformer, must be divisible by `num_heads`.
26
+ hidden_size: int. The hidden size of the transformer, must be divisible
27
+ by `num_heads`.
26
28
  mlp_ratio: float. The ratio of the MLP dimension to the hidden size.
27
29
  num_heads: int. The number of attention heads.
28
30
  depth: int. The number of double-stream blocks.
29
31
  depth_single_blocks: int. The number of single-stream blocks.
30
- axes_dim: list[int]. A list of dimensions for the positional embedding axes.
32
+ axes_dim: list[int]. A list of dimensions for the positional embedding
33
+ axes.
31
34
  theta: int. The base frequency for positional embeddings.
32
- use_bias: bool. Whether to apply bias to the query, key, and value projections.
35
+ use_bias: bool. Whether to apply bias to the query, key, and value
36
+ projections.
33
37
  guidance_embed: bool. If True, applies guidance embedding in the model.
34
38
 
35
39
  Call arguments:
36
- image: KerasTensor. Image input tensor of shape (N, L, D) where N is the batch size,
37
- L is the sequence length, and D is the feature dimension.
38
- image_ids: KerasTensor. Image ID input tensor of shape (N, L, D) corresponding
39
- to the image sequences.
40
+ image: KerasTensor. Image input tensor of shape (N, L, D) where N is the
41
+ batch size, L is the sequence length, and D is the feature
42
+ dimension.
43
+ image_ids: KerasTensor. Image ID input tensor of shape (N, L, D)
44
+ corresponding to the image sequences.
40
45
  text: KerasTensor. Text input tensor of shape (N, L, D).
41
- text_ids: KerasTensor. Text ID input tensor of shape (N, L, D) corresponding
42
- to the text sequences.
43
- timesteps: KerasTensor. Timestep tensor used to compute positional embeddings.
46
+ text_ids: KerasTensor. Text ID input tensor of shape (N, L, D)
47
+ corresponding to the text sequences.
48
+ timesteps: KerasTensor. Timestep tensor used to compute positional
49
+ embeddings.
44
50
  y: KerasTensor. Additional vector input, such as target values.
45
51
  guidance: KerasTensor, optional. Guidance input tensor used
46
52
  in guidance-embedded models.
47
53
  Raises:
48
- ValueError: If `hidden_size` is not divisible by `num_heads`, or if `sum(axes_dim)` is not equal to the
49
- positional embedding dimension.
54
+ ValueError: If `hidden_size` is not divisible by `num_heads`, or if
55
+ `sum(axes_dim)` is not equal to the positional embedding dimension.
50
56
  """
51
57
 
52
58
  def __init__(
@@ -69,7 +75,6 @@ class FluxBackbone(Backbone):
69
75
  y_shape=(None, 128),
70
76
  **kwargs,
71
77
  ):
72
-
73
78
  # === Layers ===
74
79
  self.positional_embedder = EmbedND(theta=theta, axes_dim=axes_dim)
75
80
  self.image_input_embedder = keras.layers.Dense(
@@ -4,7 +4,8 @@ presets = {
4
4
  "schnell": {
5
5
  "metadata": {
6
6
  "description": (
7
- "A 12 billion parameter rectified flow transformer capable of generating images from text descriptions."
7
+ "A 12 billion parameter rectified flow transformer capable of "
8
+ "generating images from text descriptions."
8
9
  ),
9
10
  "params": 124439808,
10
11
  "path": "flux",
@@ -24,11 +24,15 @@ class FluxTextToImage(TextToImage):
24
24
 
25
25
  Use `generate()` to do image generation.
26
26
  ```python
27
+ prompt = (
28
+ "Astronaut in a jungle, cold color palette, muted colors, "
29
+ "detailed, 8k"
30
+ )
27
31
  text_to_image = keras_hub.models.FluxTextToImage.from_preset(
28
32
  "TBA", height=512, width=512
29
33
  )
30
34
  text_to_image.generate(
31
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
35
+ prompt
32
36
  )
33
37
 
34
38
  # Generate with batched prompts.
@@ -38,7 +42,7 @@ class FluxTextToImage(TextToImage):
38
42
 
39
43
  # Generate with different `num_steps` and `guidance_scale`.
40
44
  text_to_image.generate(
41
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
45
+ prompt,
42
46
  num_steps=50,
43
47
  guidance_scale=5.0,
44
48
  )
@@ -46,7 +50,7 @@ class FluxTextToImage(TextToImage):
46
50
  # Generate with `negative_prompts`.
47
51
  text_to_image.generate(
48
52
  {
49
- "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
53
+ "prompts": prompt,
50
54
  "negative_prompts": "green color",
51
55
  }
52
56
  )
@@ -44,10 +44,10 @@ class GemmaBackbone(Backbone):
44
44
  `hidden_dim / num_query_heads`. Defaults to True.
45
45
  use_post_ffw_norm: boolean. Whether to normalize after the feedforward
46
46
  block. Defaults to False.
47
- use_post_attention_norm: boolean. Whether to normalize after the attention
48
- block. Defaults to False.
49
- attention_logit_soft_cap: None or int. Soft cap for the attention logits.
50
- Defaults to None.
47
+ use_post_attention_norm: boolean. Whether to normalize after the
48
+ attention block. Defaults to False.
49
+ attention_logit_soft_cap: None or int. Soft cap for the attention
50
+ logits. Defaults to None.
51
51
  final_logit_soft_cap: None or int. Soft cap for the final logits.
52
52
  Defaults to None.
53
53
  use_sliding_window_attention boolean. Whether to use sliding local
@@ -205,7 +205,9 @@ class GemmaBackbone(Backbone):
205
205
  "final_logit_soft_cap": self.final_logit_soft_cap,
206
206
  "attention_logit_soft_cap": self.attention_logit_soft_cap,
207
207
  "sliding_window_size": self.sliding_window_size,
208
- "use_sliding_window_attention": self.use_sliding_window_attention,
208
+ "use_sliding_window_attention": (
209
+ self.use_sliding_window_attention
210
+ ),
209
211
  }
210
212
  )
211
213
  return config
@@ -224,7 +226,8 @@ class GemmaBackbone(Backbone):
224
226
 
225
227
  Example:
226
228
  ```
227
- # Feel free to change the mesh shape to balance data and model parallelism
229
+ # Feel free to change the mesh shape to balance data and model
230
+ # parallelism
228
231
  mesh = keras.distribution.DeviceMesh(
229
232
  shape=(1, 8), axis_names=('batch', 'model'),
230
233
  devices=keras.distribution.list_devices())
@@ -237,12 +240,16 @@ class GemmaBackbone(Backbone):
237
240
  gemma_model = keras_hub.models.GemmaCausalLM.from_preset()
238
241
  ```
239
242
 
240
- To see how the layout map was applied, load the model then run (for one decoder block):
243
+ To see how the layout map was applied, load the model then run (for one
244
+ decoder block):
241
245
  ```
242
246
  embedding_layer = gemma_model.backbone.get_layer("token_embedding")
243
247
  decoder_block_1 = gemma_model.backbone.get_layer('decoder_block_1')
244
248
  for variable in embedding_layer.weights + decoder_block_1.weights:
245
- print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}')
249
+ print(
250
+ f'{variable.path:<58} {str(variable.shape):<16} '
251
+ f'{str(variable.value.sharding.spec)}'
252
+ )
246
253
  ```
247
254
 
248
255
  Args:
@@ -257,22 +264,22 @@ class GemmaBackbone(Backbone):
257
264
  for all the model weights.
258
265
  """
259
266
  # The weight path and shape of the Gemma backbone is like below (for 2G)
260
- # token_embedding/embeddings, (256128, 2048), 524550144
267
+ # token_embedding/embeddings, (256128, 2048)
261
268
  # repeat block for decoder
262
269
  # ...
263
- # decoder_block_17/pre_attention_norm/scale, (2048,), 2048
264
- # decoder_block_17/attention/query/kernel, (8, 2048, 256), 4194304
265
- # decoder_block_17/attention/key/kernel, (8, 2048, 256), 4194304
266
- # decoder_block_17/attention/value/kernel, (8, 2048, 256), 4194304
267
- # decoder_block_17/attention/attention_output/kernel, (8, 256, 2048), 4194304
268
- # decoder_block_17/pre_ffw_norm/scale, (2048,), 2048
269
- # decoder_block_17/ffw_gating/kernel, (2048, 16384), 33554432
270
- # decoder_block_17/ffw_gating_2/kernel, (2048, 16384), 33554432
271
- # decoder_block_17/ffw_linear/kernel, (16384, 2048), 33554432
270
+ # decoder_block_17/pre_attention_norm/scale, (2048,)
271
+ # decoder_block_17/attention/query/kernel, (8, 2048, 256)
272
+ # decoder_block_17/attention/key/kernel, (8, 2048, 256)
273
+ # decoder_block_17/attention/value/kernel, (8, 2048, 256)
274
+ # decoder_block_17/attention/attention_output/kernel, (8, 256, 2048)
275
+ # decoder_block_17/pre_ffw_norm/scale, (2048,)
276
+ # decoder_block_17/ffw_gating/kernel, (2048, 16384)
277
+ # decoder_block_17/ffw_gating_2/kernel, (2048, 16384)
278
+ # decoder_block_17/ffw_linear/kernel, (16384, 2048)
272
279
  if not isinstance(device_mesh, keras.distribution.DeviceMesh):
273
280
  raise ValueError(
274
- "Invalid device_mesh type. Expected `keras.distribution.Device`,"
275
- f" got {type(device_mesh)}"
281
+ "Invalid device_mesh type. Expected "
282
+ f"`keras.distribution.Device`, got {type(device_mesh)}"
276
283
  )
277
284
  if model_parallel_dim_name not in device_mesh.axis_names:
278
285
  raise ValueError(
@@ -187,8 +187,8 @@ class GemmaCausalLM(CausalLM):
187
187
  Args:
188
188
  token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
189
189
  cache: a dense float Tensor, the cache of key and value.
190
- cache_update_index: int, or int Tensor. The index of current inputs in the
191
- whole sequence.
190
+ cache_update_index: int, or int Tensor. The index of current inputs
191
+ in the whole sequence.
192
192
 
193
193
  Returns:
194
194
  A (logits, hidden_states, cache) tuple. Where `logits` is the
@@ -220,7 +220,9 @@ class GemmaDecoderBlock(keras.layers.Layer):
220
220
  "use_post_ffw_norm": self.use_post_ffw_norm,
221
221
  "use_post_attention_norm": self.use_post_attention_norm,
222
222
  "logit_soft_cap": self.logit_soft_cap,
223
- "use_sliding_window_attention": self.use_sliding_window_attention,
223
+ "use_sliding_window_attention": (
224
+ self.use_sliding_window_attention
225
+ ),
224
226
  "sliding_window_size": self.sliding_window_size,
225
227
  "query_head_dim_normalize": self.query_head_dim_normalize,
226
228
  }
@@ -130,7 +130,9 @@ backbone_presets = {
130
130
  },
131
131
  "gemma2_instruct_2b_en": {
132
132
  "metadata": {
133
- "description": "2 billion parameter, 26-layer, instruction tuned Gemma model.",
133
+ "description": (
134
+ "2 billion parameter, 26-layer, instruction tuned Gemma model."
135
+ ),
134
136
  "params": 2614341888,
135
137
  "path": "gemma",
136
138
  },
@@ -146,7 +148,9 @@ backbone_presets = {
146
148
  },
147
149
  "gemma2_instruct_9b_en": {
148
150
  "metadata": {
149
- "description": "9 billion parameter, 42-layer, instruction tuned Gemma model.",
151
+ "description": (
152
+ "9 billion parameter, 42-layer, instruction tuned Gemma model."
153
+ ),
150
154
  "params": 9241705984,
151
155
  "path": "gemma",
152
156
  },
@@ -162,7 +166,9 @@ backbone_presets = {
162
166
  },
163
167
  "gemma2_instruct_27b_en": {
164
168
  "metadata": {
165
- "description": "27 billion parameter, 42-layer, instruction tuned Gemma model.",
169
+ "description": (
170
+ "27 billion parameter, 42-layer, instruction tuned Gemma model."
171
+ ),
166
172
  "params": 27227128320,
167
173
  "path": "gemma",
168
174
  },
@@ -172,8 +172,8 @@ class GPT2CausalLM(CausalLM):
172
172
  Args:
173
173
  token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
174
174
  cache: a dense float Tensor, the cache of key and value.
175
- cache_update_index: int, or int Tensor. The index of current inputs in the
176
- whole sequence.
175
+ cache_update_index: int, or int Tensor. The index of current inputs
176
+ in the whole sequence.
177
177
 
178
178
  Returns:
179
179
  A (logits, hidden_states, cache) tuple. Where `logits` is the
@@ -202,7 +202,8 @@ class GPTNeoXAttention(keras.layers.Layer):
202
202
  training=training,
203
203
  )
204
204
 
205
- # Reshape `attention_output` to `(batch_size, sequence_length, hidden_dim)`.
205
+ # Reshape `attention_output` to
206
+ # `(batch_size, sequence_length, hidden_dim)`.
206
207
  attention_output = ops.reshape(
207
208
  attention_output,
208
209
  [
@@ -27,9 +27,9 @@ class GPTNeoXCausalLM(CausalLM):
27
27
 
28
28
  Args:
29
29
  backbone: A `keras_hub.models.GPTNeoXBackbone` instance.
30
- preprocessor: A `keras_hub.models.GPTNeoXCausalLMPreprocessor` or `None`.
31
- If `None`, this model will not apply preprocessing, and inputs
32
- should be preprocessed before calling the model.
30
+ preprocessor: A `keras_hub.models.GPTNeoXCausalLMPreprocessor` or
31
+ `None`. If `None`, this model will not apply preprocessing, and
32
+ inputs should be preprocessed before calling the model.
33
33
  """
34
34
 
35
35
  backbone_cls = GPTNeoXBackbone
@@ -16,7 +16,8 @@ class GPTNeoXDecoder(keras.layers.Layer):
16
16
 
17
17
  This class follows the architecture of the GPT-NeoX decoder layer in the
18
18
  paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745).
19
- Users can instantiate multiple instances of this class to stack up a decoder.
19
+ Users can instantiate multiple instances of this class to stack up a
20
+ decoder.
20
21
 
21
22
  This layer will always apply a causal mask to the decoder attention layer.
22
23
 
@@ -46,7 +46,10 @@ class ImageClassifierPreprocessor(Preprocessor):
46
46
  x, y = preprocessor(x, y)
47
47
 
48
48
  # Resize a batch of labeled images.
49
- x, y = [np.random.randint(0, 256, (512, 512, 3)), np.zeros((512, 512, 3))], [1, 0]
49
+ x, y = [
50
+ np.random.randint(0, 256, (512, 512, 3)),
51
+ np.zeros((512, 512, 3))
52
+ ], [1, 0]
50
53
  x, y = preprocessor(x, y)
51
54
 
52
55
  # Use a `tf.data.Dataset`.
@@ -31,8 +31,8 @@ class ImageObjectDetector(Task):
31
31
  ):
32
32
  """Configures the `ImageObjectDetector` task for training.
33
33
 
34
- The `ImageObjectDetector` task extends the default compilation signature of
35
- `keras.Model.compile` with defaults for `optimizer`, `loss`, and
34
+ The `ImageObjectDetector` task extends the default compilation signature
35
+ of `keras.Model.compile` with defaults for `optimizer`, `loss`, and
36
36
  `metrics`. To override these defaults, pass any value
37
37
  to these arguments during compilation.
38
38
 
@@ -21,10 +21,10 @@ class ImageObjectDetectorPreprocessor(Preprocessor):
21
21
  be the a dict of `{"boxes": Tensor(batch_size, num_boxes, 4),
22
22
  "classes": (batch_size, num_boxes)}.
23
23
 
24
- The layer will returns either `x`, an `(x, y)` tuple if labels were provided,
25
- or an `(x, y, sample_weight)` tuple if labels and sample weight were
26
- provided. `x` will be the input images after all model preprocessing has
27
- been applied.
24
+ The layer will returns either `x`, an `(x, y)` tuple if labels were
25
+ provided, or an `(x, y, sample_weight)` tuple if labels and sample weight
26
+ were provided. `x` will be the input images after all model preprocessing
27
+ has been applied.
28
28
 
29
29
  All `ImageObjectDetectorPreprocessor` tasks include a `from_preset()`
30
30
  constructor which can be used to load a pre-trained config and vocabularies.