keras-hub-nightly 0.19.0.dev202412120352__py3-none-any.whl → 0.19.0.dev202412140350__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. keras_hub/api/layers/__init__.py +1 -0
  2. keras_hub/api/models/__init__.py +11 -6
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/converters.py +2 -2
  5. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  6. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  7. keras_hub/src/layers/modeling/rms_normalization.py +8 -6
  8. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  9. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  10. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  11. keras_hub/src/layers/modeling/transformer_encoder.py +3 -1
  12. keras_hub/src/metrics/bleu.py +1 -1
  13. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  14. keras_hub/src/models/bart/bart_backbone.py +4 -4
  15. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  16. keras_hub/src/models/bert/bert_presets.py +4 -2
  17. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  18. keras_hub/src/models/causal_lm.py +19 -15
  19. keras_hub/src/models/clip/clip_vision_embedding.py +1 -1
  20. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  21. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  22. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  23. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  24. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  25. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  26. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +17 -13
  27. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -3
  28. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +1 -1
  29. keras_hub/src/models/densenet/densenet_backbone.py +3 -1
  30. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -1
  31. keras_hub/src/models/densenet/densenet_presets.py +6 -6
  32. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  33. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  34. keras_hub/src/models/distil_bert/distil_bert_presets.py +2 -1
  35. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  36. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  37. keras_hub/src/models/efficientnet/cba.py +1 -1
  38. keras_hub/src/models/efficientnet/efficientnet_backbone.py +20 -8
  39. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +1 -1
  40. keras_hub/src/models/efficientnet/efficientnet_presets.py +12 -11
  41. keras_hub/src/models/efficientnet/fusedmbconv.py +3 -5
  42. keras_hub/src/models/efficientnet/mbconv.py +1 -1
  43. keras_hub/src/models/electra/electra_backbone.py +2 -2
  44. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  45. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  46. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  47. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  48. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  49. keras_hub/src/models/flux/flux_layers.py +46 -44
  50. keras_hub/src/models/flux/flux_maths.py +24 -17
  51. keras_hub/src/models/flux/flux_model.py +24 -19
  52. keras_hub/src/models/flux/flux_presets.py +2 -1
  53. keras_hub/src/models/flux/flux_text_to_image.py +7 -3
  54. keras_hub/src/models/gemma/gemma_backbone.py +27 -20
  55. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  56. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  57. keras_hub/src/models/gemma/gemma_presets.py +9 -3
  58. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  59. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  60. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  61. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  62. keras_hub/src/models/image_classifier_preprocessor.py +4 -1
  63. keras_hub/src/models/image_object_detector.py +2 -2
  64. keras_hub/src/models/image_object_detector_preprocessor.py +4 -4
  65. keras_hub/src/models/image_segmenter_preprocessor.py +2 -2
  66. keras_hub/src/models/llama/llama_backbone.py +34 -26
  67. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  68. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  69. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  70. keras_hub/src/models/mistral/mistral_causal_lm.py +3 -3
  71. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  72. keras_hub/src/models/mit/mit_backbone.py +4 -3
  73. keras_hub/src/models/mit/mit_layers.py +2 -1
  74. keras_hub/src/models/mobilenet/mobilenet_backbone.py +7 -7
  75. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  76. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +5 -3
  77. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +2 -2
  78. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  79. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  80. keras_hub/src/models/preprocessor.py +2 -2
  81. keras_hub/src/models/retinanet/feature_pyramid.py +3 -2
  82. keras_hub/src/models/retinanet/prediction_head.py +2 -2
  83. keras_hub/src/models/retinanet/retinanet_backbone.py +2 -2
  84. keras_hub/src/models/retinanet/retinanet_image_converter.py +1 -1
  85. keras_hub/src/models/retinanet/retinanet_object_detector.py +5 -6
  86. keras_hub/src/models/retinanet/retinanet_presets.py +2 -1
  87. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  88. keras_hub/src/models/roberta/roberta_presets.py +4 -2
  89. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  90. keras_hub/src/models/sam/sam_backbone.py +2 -2
  91. keras_hub/src/models/sam/sam_image_segmenter.py +6 -5
  92. keras_hub/src/models/sam/sam_layers.py +5 -3
  93. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  94. keras_hub/src/models/sam/sam_transformer.py +5 -4
  95. keras_hub/src/models/segformer/segformer_backbone.py +18 -14
  96. keras_hub/src/models/segformer/segformer_image_segmenter.py +51 -38
  97. keras_hub/src/models/segformer/segformer_presets.py +24 -12
  98. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  99. keras_hub/src/models/stable_diffusion_3/mmdit.py +20 -1
  100. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +1 -1
  101. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +13 -6
  102. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +2 -2
  103. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +7 -3
  104. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  105. keras_hub/src/models/task.py +4 -2
  106. keras_hub/src/models/text_classifier.py +2 -2
  107. keras_hub/src/models/text_to_image.py +5 -1
  108. keras_hub/src/models/vae/vae_layers.py +0 -1
  109. keras_hub/src/models/vit/__init__.py +5 -0
  110. keras_hub/src/models/vit/vit_backbone.py +152 -0
  111. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  112. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  113. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  114. keras_hub/src/models/vit/vit_layers.py +391 -0
  115. keras_hub/src/models/vit/vit_presets.py +49 -0
  116. keras_hub/src/models/vit_det/vit_det_backbone.py +4 -2
  117. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  118. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -3
  119. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  120. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  121. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  122. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  123. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  124. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  125. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  126. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  127. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  128. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  129. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  130. keras_hub/src/samplers/sampler.py +2 -1
  131. keras_hub/src/tests/test_case.py +2 -2
  132. keras_hub/src/tokenizers/byte_pair_tokenizer.py +2 -2
  133. keras_hub/src/tokenizers/byte_tokenizer.py +2 -8
  134. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  135. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +7 -12
  136. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +8 -5
  137. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +7 -3
  138. keras_hub/src/utils/preset_utils.py +25 -18
  139. keras_hub/src/utils/tensor_utils.py +4 -4
  140. keras_hub/src/utils/timm/convert_efficientnet.py +2 -4
  141. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  142. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  143. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  144. keras_hub/src/version_utils.py +1 -1
  145. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/METADATA +1 -1
  146. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/RECORD +148 -140
  147. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/WHEEL +0 -0
  148. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,391 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.utils.keras_utils import standardize_data_format
5
+
6
+
7
+ class MLP(keras.layers.Layer):
8
+ """Multi-Layer Perceptron (MLP) block.
9
+
10
+ Args:
11
+ hidden_dim: int. Dimensionality of the hidden representations.
12
+ mlp_dim: int. Dimensionality of the intermediate MLP layer.
13
+ use_bias: bool. Whether to use bias in the dense layers. Defaults to
14
+ `True`.
15
+ dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`.
16
+ **kwargs: Additional keyword arguments passed to `keras.layers.Layer`
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ hidden_dim,
22
+ mlp_dim,
23
+ use_bias=True,
24
+ dropout_rate=0.0,
25
+ **kwargs,
26
+ ):
27
+ super().__init__(**kwargs)
28
+
29
+ # === Config ===
30
+ self.hidden_dim = hidden_dim
31
+ self.mlp_dim = mlp_dim
32
+ self.use_bias = use_bias
33
+ self.dropout_rate = dropout_rate
34
+
35
+ def build(self, input_shape):
36
+ self.dense_1 = keras.layers.Dense(
37
+ units=self.mlp_dim,
38
+ use_bias=self.use_bias,
39
+ activation="gelu",
40
+ bias_initializer=(
41
+ keras.initializers.RandomNormal(stddev=1e-6)
42
+ if self.use_bias
43
+ else None
44
+ ),
45
+ dtype=self.dtype_policy,
46
+ name="dense_1",
47
+ )
48
+ self.dense_1.build(input_shape)
49
+ self.dense_2 = keras.layers.Dense(
50
+ units=self.hidden_dim,
51
+ use_bias=self.use_bias,
52
+ bias_initializer=(
53
+ keras.initializers.RandomNormal(stddev=1e-6)
54
+ if self.use_bias
55
+ else None
56
+ ),
57
+ dtype=self.dtype_policy,
58
+ name="dense_2",
59
+ )
60
+ self.dense_2.build((None, None, self.mlp_dim))
61
+ self.dropout = keras.layers.Dropout(
62
+ self.dropout_rate, dtype=self.dtype_policy, name="dropout"
63
+ )
64
+ self.built = True
65
+
66
+ def call(self, inputs):
67
+ x = self.dense_1(inputs)
68
+ x = self.dense_2(x)
69
+ out = self.dropout(x)
70
+ return out
71
+
72
+
73
+ class ViTPatchingAndEmbedding(keras.layers.Layer):
74
+ """Patches the image and embeds the patches.
75
+
76
+ Args:
77
+ image_size: int. Size of the input image (height or width).
78
+ Assumed to be square.
79
+ patch_size: int. Size of each image patch.
80
+ hidden_dim: int. Dimensionality of the patch embeddings.
81
+ num_channels: int. Number of channels in the input image. Defaults to
82
+ `3`.
83
+ data_format: str. `"channels_last"` or `"channels_first"`. Defaults to
84
+ `None` (which uses `"channels_last"`).
85
+ **kwargs: Additional keyword arguments passed to `keras.layers.Layer`
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ image_size,
91
+ patch_size,
92
+ hidden_dim,
93
+ num_channels=3,
94
+ data_format=None,
95
+ **kwargs,
96
+ ):
97
+ super().__init__(**kwargs)
98
+ num_patches = (image_size // patch_size) ** 2
99
+ num_positions = num_patches + 1
100
+
101
+ # === Config ===
102
+ self.image_size = image_size
103
+ self.patch_size = patch_size
104
+ self.hidden_dim = hidden_dim
105
+ self.num_channels = num_channels
106
+ self.num_patches = num_patches
107
+ self.num_positions = num_positions
108
+ self.data_format = standardize_data_format(data_format)
109
+
110
+ def build(self, input_shape):
111
+ self.class_token = self.add_weight(
112
+ shape=(
113
+ 1,
114
+ 1,
115
+ self.hidden_dim,
116
+ ),
117
+ initializer="random_normal",
118
+ dtype=self.variable_dtype,
119
+ name="class_token",
120
+ )
121
+ self.patch_embedding = keras.layers.Conv2D(
122
+ filters=self.hidden_dim,
123
+ kernel_size=self.patch_size,
124
+ strides=self.patch_size,
125
+ padding="valid",
126
+ activation=None,
127
+ dtype=self.dtype_policy,
128
+ data_format=self.data_format,
129
+ name="patch_embedding",
130
+ )
131
+ self.patch_embedding.build(input_shape)
132
+ self.position_embedding = keras.layers.Embedding(
133
+ self.num_positions,
134
+ self.hidden_dim,
135
+ dtype=self.dtype_policy,
136
+ embeddings_initializer=keras.initializers.RandomNormal(stddev=0.02),
137
+ name="position_embedding",
138
+ )
139
+ self.position_embedding.build((1, self.num_positions))
140
+ self.position_ids = keras.ops.expand_dims(
141
+ keras.ops.arange(self.num_positions), axis=0
142
+ )
143
+ self.built = True
144
+
145
+ def call(self, inputs):
146
+ patch_embeddings = self.patch_embedding(inputs)
147
+ if self.data_format == "channels_first":
148
+ patch_embeddings = ops.transpose(
149
+ patch_embeddings, axes=(0, 2, 3, 1)
150
+ )
151
+ embeddings_shape = ops.shape(patch_embeddings)
152
+ patch_embeddings = ops.reshape(
153
+ patch_embeddings, [embeddings_shape[0], -1, embeddings_shape[-1]]
154
+ )
155
+ class_token = ops.tile(self.class_token, (embeddings_shape[0], 1, 1))
156
+ position_embeddings = self.position_embedding(self.position_ids)
157
+ embeddings = ops.concatenate([class_token, patch_embeddings], axis=1)
158
+ return ops.add(embeddings, position_embeddings)
159
+
160
+ def compute_output_shape(self, input_shape):
161
+ return (
162
+ input_shape[0],
163
+ self.num_positions,
164
+ self.hidden_dim,
165
+ )
166
+
167
+ def get_config(self):
168
+ config = super().get_config()
169
+ config.update(
170
+ {
171
+ "image_size": self.image_size,
172
+ "patch_size": self.patch_size,
173
+ "hidden_dim": self.hidden_dim,
174
+ "num_channels": self.num_channels,
175
+ "num_patches": self.num_patches,
176
+ "num_positions": self.num_positions,
177
+ }
178
+ )
179
+ return config
180
+
181
+
182
+ class ViTEncoderBlock(keras.layers.Layer):
183
+ """Transformer encoder block.
184
+
185
+ Args:
186
+ num_heads: int. Number of attention heads.
187
+ hidden_dim: int. Dimensionality of the hidden representations.
188
+ mlp_dim: int. Dimensionality of the intermediate MLP layer.
189
+ use_mha_bias: bool. Whether to use bias in the multi-head attention
190
+ layer. Defaults to `True`.
191
+ use_mlp_bias: bool. Whether to use bias in the MLP layer. Defaults to
192
+ `True`.
193
+ dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`.
194
+ attention_dropout: float. Dropout rate for the attention mechanism.
195
+ Between 0 and 1. Defaults to `0.0`.
196
+ layer_norm_epsilon: float. Small float value for layer normalization
197
+ stability. Defaults to `1e-6`.
198
+ **kwargs: Additional keyword arguments passed to `keras.layers.Layer`
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ num_heads,
204
+ hidden_dim,
205
+ mlp_dim,
206
+ use_mha_bias=True,
207
+ use_mlp_bias=True,
208
+ dropout_rate=0.0,
209
+ attention_dropout=0.0,
210
+ layer_norm_epsilon=1e-6,
211
+ **kwargs,
212
+ ):
213
+ super().__init__(**kwargs)
214
+
215
+ key_dim = hidden_dim // num_heads
216
+
217
+ # === Config ===
218
+ self.num_heads = num_heads
219
+ self.hidden_dim = hidden_dim
220
+ self.key_dim = key_dim
221
+ self.mlp_dim = mlp_dim
222
+ self.use_mha_bias = use_mha_bias
223
+ self.use_mlp_bias = use_mlp_bias
224
+ self.dropout_rate = dropout_rate
225
+ self.attention_dropout = attention_dropout
226
+ self.layer_norm_epsilon = layer_norm_epsilon
227
+
228
+ def build(self, input_shape):
229
+ # Attention block
230
+ self.layer_norm_1 = keras.layers.LayerNormalization(
231
+ epsilon=self.layer_norm_epsilon,
232
+ name="ln_1",
233
+ dtype=self.dtype_policy,
234
+ )
235
+ self.layer_norm_1.build(input_shape)
236
+ self.mha = keras.layers.MultiHeadAttention(
237
+ num_heads=self.num_heads,
238
+ key_dim=self.key_dim,
239
+ use_bias=self.use_mha_bias,
240
+ dropout=self.attention_dropout,
241
+ name="mha",
242
+ dtype=self.dtype_policy,
243
+ )
244
+ self.mha.build(input_shape, input_shape)
245
+ self.dropout = keras.layers.Dropout(
246
+ self.dropout_rate, dtype=self.dtype_policy, name="dropout"
247
+ )
248
+
249
+ # MLP block
250
+ self.layer_norm_2 = keras.layers.LayerNormalization(
251
+ epsilon=self.layer_norm_epsilon,
252
+ name="ln_2",
253
+ dtype=self.dtype_policy,
254
+ )
255
+ self.layer_norm_2.build((None, None, self.hidden_dim))
256
+ self.mlp = MLP(
257
+ hidden_dim=self.hidden_dim,
258
+ mlp_dim=self.mlp_dim,
259
+ use_bias=self.use_mlp_bias,
260
+ name="mlp",
261
+ dtype=self.dtype_policy,
262
+ )
263
+ self.mlp.build((None, None, self.hidden_dim))
264
+ self.built = True
265
+
266
+ def call(self, inputs):
267
+ x = self.layer_norm_1(inputs)
268
+ x = self.mha(x, x)
269
+ x = self.dropout(x)
270
+ x = x + inputs
271
+
272
+ y = self.layer_norm_2(x)
273
+ y = self.mlp(y)
274
+
275
+ return x + y
276
+
277
+ def get_config(self):
278
+ config = super().get_config()
279
+ config.update(
280
+ {
281
+ "num_heads": self.num_heads,
282
+ "hidden_dim": self.hidden_dim,
283
+ "key_dim": self.key_dim,
284
+ "mlp_dim": self.mlp_dim,
285
+ "use_mha_bias": self.use_mha_bias,
286
+ "use_mlp_bias": self.use_mlp_bias,
287
+ "dropout_rate": self.dropout_rate,
288
+ "attention_dropout": self.attention_dropout,
289
+ "layer_norm_epsilon": self.layer_norm_epsilon,
290
+ }
291
+ )
292
+ return config
293
+
294
+
295
+ class ViTEncoder(keras.layers.Layer):
296
+ """Vision Transformer (ViT) encoder.
297
+
298
+ Args:
299
+ num_layers: int. Number of Transformer encoder blocks.
300
+ num_heads: int. Number of attention heads.
301
+ hidden_dim: int. Dimensionality of the hidden representations.
302
+ mlp_dim: int. Dimensionality of the intermediate MLP layer.
303
+ use_mha_bias: bool. Whether to use bias in the multi-head attention
304
+ layers. Defaults to `True`.
305
+ use_mlp_bias: bool. Whether to use bias in the MLP layers. Defaults to
306
+ `True`.
307
+ dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`.
308
+ attention_dropout: float. Dropout rate for the attention mechanism.
309
+ Between 0 and 1. Defaults to `0.0`.
310
+ layer_norm_epsilon: float. Small float value for layer normalization
311
+ tability. Defaults to `1e-6`.
312
+ **kwargs: Additional keyword arguments passed to `keras.layers.Layer`
313
+ """
314
+
315
+ def __init__(
316
+ self,
317
+ num_layers,
318
+ num_heads,
319
+ hidden_dim,
320
+ mlp_dim,
321
+ use_mha_bias=True,
322
+ use_mlp_bias=True,
323
+ dropout_rate=0.0,
324
+ attention_dropout=0.0,
325
+ layer_norm_epsilon=1e-6,
326
+ **kwargs,
327
+ ):
328
+ super().__init__(**kwargs)
329
+
330
+ # === config ===
331
+ self.num_layers = num_layers
332
+ self.num_heads = num_heads
333
+ self.hidden_dim = hidden_dim
334
+ self.mlp_dim = mlp_dim
335
+ self.use_mha_bias = use_mha_bias
336
+ self.use_mlp_bias = use_mlp_bias
337
+ self.dropout_rate = dropout_rate
338
+ self.attention_dropout = attention_dropout
339
+ self.layer_norm_epsilon = layer_norm_epsilon
340
+
341
+ def build(self, input_shape):
342
+ self.encoder_layers = []
343
+ for i in range(self.num_layers):
344
+ encoder_block = ViTEncoderBlock(
345
+ num_heads=self.num_heads,
346
+ hidden_dim=self.hidden_dim,
347
+ mlp_dim=self.mlp_dim,
348
+ dropout_rate=self.dropout_rate,
349
+ use_mha_bias=self.use_mha_bias,
350
+ use_mlp_bias=self.use_mlp_bias,
351
+ attention_dropout=self.attention_dropout,
352
+ layer_norm_epsilon=self.layer_norm_epsilon,
353
+ dtype=self.dtype_policy,
354
+ name=f"tranformer_block_{i+1}",
355
+ )
356
+ encoder_block.build((None, None, self.hidden_dim))
357
+ self.encoder_layers.append(encoder_block)
358
+ self.dropout = keras.layers.Dropout(
359
+ self.dropout_rate, dtype=self.dtype_policy, name="dropout"
360
+ )
361
+ self.layer_norm = keras.layers.LayerNormalization(
362
+ epsilon=self.layer_norm_epsilon,
363
+ dtype=self.dtype_policy,
364
+ name="ln",
365
+ )
366
+ self.layer_norm.build((None, None, self.hidden_dim))
367
+ self.built = True
368
+
369
+ def call(self, inputs):
370
+ x = self.dropout(inputs)
371
+ for i in range(self.num_layers):
372
+ x = self.encoder_layers[i](x)
373
+ x = self.layer_norm(x)
374
+ return x
375
+
376
+ def get_config(self):
377
+ config = super().get_config()
378
+ config.update(
379
+ {
380
+ "num_layers": self.num_layers,
381
+ "num_heads": self.num_heads,
382
+ "hidden_dim": self.hidden_dim,
383
+ "mlp_dim": self.mlp_dim,
384
+ "use_mha_bias": self.use_mha_bias,
385
+ "use_mlp_bias": self.use_mlp_bias,
386
+ "dropout_rate": self.dropout_rate,
387
+ "attention_dropout": self.attention_dropout,
388
+ "layer_norm_epsilon": self.layer_norm_epsilon,
389
+ }
390
+ )
391
+ return config
@@ -0,0 +1,49 @@
1
+ """ViT model preset configurations."""
2
+
3
+ # Metadata for loading pretrained model weights.
4
+ backbone_presets = {
5
+ "vit_base_patch16_224_imagenet": {
6
+ "metadata": {
7
+ "description": (
8
+ "ViT-B16 model pre-trained on the ImageNet 1k dataset with "
9
+ "image resolution of 224x224 "
10
+ ),
11
+ "params": 85798656,
12
+ "path": "vit",
13
+ },
14
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet/1",
15
+ },
16
+ "vit_base_patch16_384_imagenet": {
17
+ "metadata": {
18
+ "description": (
19
+ "ViT-B16 model pre-trained on the ImageNet 1k dataset with "
20
+ "image resolution of 384x384 "
21
+ ),
22
+ "params": 86090496,
23
+ "path": "vit",
24
+ },
25
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_384_imagenet/1",
26
+ },
27
+ "vit_large_patch16_224_imagenet": {
28
+ "metadata": {
29
+ "description": (
30
+ "ViT-L16 model pre-trained on the ImageNet 1k dataset with "
31
+ "image resolution of 224x224 "
32
+ ),
33
+ "params": 303301632,
34
+ "path": "vit",
35
+ },
36
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet/1",
37
+ },
38
+ "vit_large_patch16_384_imagenet": {
39
+ "metadata": {
40
+ "description": (
41
+ "ViT-L16 model pre-trained on the ImageNet 1k dataset with "
42
+ "image resolution of 384x384 "
43
+ ),
44
+ "params": 303690752,
45
+ "path": "vit",
46
+ },
47
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_384_imagenet/1",
48
+ },
49
+ }
@@ -87,7 +87,7 @@ class ViTDetBackbone(Backbone):
87
87
  use_rel_pos=True,
88
88
  window_size=14,
89
89
  layer_norm_epsilon=1e-6,
90
- **kwargs
90
+ **kwargs,
91
91
  ):
92
92
  # === Functional model ===
93
93
  img_input = keras.layers.Input(shape=image_shape, name="images")
@@ -179,7 +179,9 @@ class ViTDetBackbone(Backbone):
179
179
  "use_abs_pos": self.use_abs_pos,
180
180
  "use_rel_pos": self.use_rel_pos,
181
181
  "window_size": self.window_size,
182
- "global_attention_layer_indices": self.global_attention_layer_indices,
182
+ "global_attention_layer_indices": (
183
+ self.global_attention_layer_indices
184
+ ),
183
185
  "layer_norm_epsilon": self.layer_norm_epsilon,
184
186
  }
185
187
  )
@@ -117,7 +117,7 @@ class AddRelativePositionalEmbedding(keras.layers.Layer):
117
117
  """Calculate decomposed Relative Positional Embeddings
118
118
 
119
119
  The code has been adapted based on
120
- https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa: E501
120
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
121
121
 
122
122
  Args:
123
123
  attention_map (tensor): Attention map.
@@ -193,7 +193,7 @@ class MultiHeadAttentionWithRelativePE(keras.layers.Layer):
193
193
  use_bias=True,
194
194
  use_rel_pos=False,
195
195
  input_size=None,
196
- **kwargs
196
+ **kwargs,
197
197
  ):
198
198
  super().__init__(**kwargs)
199
199
  self.num_heads = num_heads
@@ -378,7 +378,7 @@ class WindowedTransformerEncoder(keras.layers.Layer):
378
378
  input_size=None,
379
379
  activation="gelu",
380
380
  layer_norm_epsilon=1e-6,
381
- **kwargs
381
+ **kwargs,
382
382
  ):
383
383
  super().__init__(**kwargs)
384
384
  self.project_dim = project_dim
@@ -172,9 +172,7 @@ class WhisperAudioConverter(AudioConverter):
172
172
  )
173
173
 
174
174
  def tf_log10(x):
175
- """
176
- Computes log base 10 of input tensor using TensorFlow's natural log operator.
177
- """
175
+ """Computes log base 10 of input tensor using TensorFlow."""
178
176
  numerator = tf.math.log(x)
179
177
  denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
180
178
  return numerator / denominator
@@ -30,9 +30,10 @@ class WhisperBackbone(Backbone):
30
30
  It includes the embedding lookups and transformer layers, but not the head
31
31
  for predicting the next token.
32
32
 
33
- The default constructor gives a fully customizable, randomly initialized Whisper
34
- model with any number of layers, heads, and embedding dimensions. To load
35
- preset architectures and weights, use the `from_preset()` constructor.
33
+ The default constructor gives a fully customizable, randomly initialized
34
+ Whisper model with any number of layers, heads, and embedding dimensions.
35
+ To load preset architectures and weights, use the `from_preset()`
36
+ constructor.
36
37
 
37
38
  Disclaimer: Pre-trained models are provided on an "as is" basis, without
38
39
  warranties or conditions of any kind. The underlying model is provided by a
@@ -53,8 +54,8 @@ class WhisperBackbone(Backbone):
53
54
  max_encoder_sequence_length: int. The maximum sequence length that the
54
55
  audio encoder can consume. Since the second convolutional layer in
55
56
  the encoder reduces the sequence length by half (stride of 2), we
56
- use `max_encoder_sequence_length // 2` as the sequence length for the
57
- positional embedding layer.
57
+ use `max_encoder_sequence_length // 2` as the sequence length for
58
+ the positional embedding layer.
58
59
  max_decoder_sequence_length: int. The maximum sequence length that the
59
60
  text decoder can consume.
60
61
  dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
@@ -14,11 +14,9 @@ class WhisperDecoder(TransformerDecoder):
14
14
  """Whisper decoder.
15
15
 
16
16
  Inherits from `keras_hub.layers.TransformerDecoder`, and overrides the
17
- `build` method to use the
18
- `keras_hub.models.whisper.whisper_multi_head_attention.WhisperMultiHeadAttention`
19
- layer instead of `keras.layers.MultiHeadAttention` and
20
- `keras_hub.models.whisper.whisper_cached_multi_head_attention.WhisperCachedMultiHeadAttention`
21
- instead of `keras_hub.layers.cached_multi_head_attention.CachedMultiHeadAttention`.
17
+ `build` method to use the `WhisperMultiHeadAttention`
18
+ layer instead of `MultiHeadAttention` and `WhisperCachedMultiHeadAttention`
19
+ instead of `CachedMultiHeadAttention`.
22
20
  """
23
21
 
24
22
  def build(
@@ -9,7 +9,7 @@ from keras_hub.src.models.roberta.roberta_backbone import (
9
9
  from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import (
10
10
  XLMRobertaBackbone,
11
11
  )
12
- from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import (
12
+ from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( # noqa: E501
13
13
  XLMRobertaMaskedLMPreprocessor,
14
14
  )
15
15
 
@@ -20,8 +20,8 @@ class XLMRobertaMaskedLMPreprocessor(MaskedLMPreprocessor):
20
20
 
21
21
  This preprocessing layer will prepare inputs for a masked language modeling
22
22
  task. It is primarily intended for use with the
23
- `keras_hub.models.XLMRobertaMaskedLM` task model. Preprocessing will occur in
24
- multiple steps.
23
+ `keras_hub.models.XLMRobertaMaskedLM` task model. Preprocessing will occur
24
+ in multiple steps.
25
25
 
26
26
  1. Tokenize any number of input segments using the `tokenizer`.
27
27
  2. Pack the inputs together with the appropriate `"<s>"`, `"</s>"` and
@@ -8,7 +8,7 @@ from keras_hub.src.models.text_classifier import TextClassifier
8
8
  from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import (
9
9
  XLMRobertaBackbone,
10
10
  )
11
- from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import (
11
+ from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( # noqa: E501
12
12
  XLMRobertaTextClassifierPreprocessor,
13
13
  )
14
14
 
@@ -40,9 +40,9 @@ class XLMRobertaTextClassifier(TextClassifier):
40
40
  Args:
41
41
  backbone: A `keras_hub.models.XLMRobertaBackbone` instance.
42
42
  num_classes: int. Number of classes to predict.
43
- preprocessor: A `keras_hub.models.XLMRobertaTextClassifierPreprocessor` or `None`. If
44
- `None`, this model will not apply preprocessing, and inputs should
45
- be preprocessed before calling the model.
43
+ preprocessor: A `keras_hub.models.XLMRobertaTextClassifierPreprocessor`
44
+ or `None`. If `None`, this model will not apply preprocessing, and
45
+ inputs should be preprocessed before calling the model.
46
46
  activation: Optional `str` or callable. The activation function to use
47
47
  on the model outputs. Set `activation="softmax"` to return output
48
48
  probabilities. Defaults to `None`.
@@ -177,7 +177,8 @@ class XLMRobertaTokenizer(SentencePieceTokenizer):
177
177
  # Shift the tokens IDs left by one.
178
178
  tokens = tf.subtract(tokens, 1)
179
179
 
180
- # Correct `unk_token_id`, `end_token_id`, `start_token_id`, respectively.
180
+ # Correct `unk_token_id`, `end_token_id`, `start_token_id`,
181
+ # respectively.
181
182
  # Note: The `pad_token_id` is taken as 0 (`unk_token_id`) since the
182
183
  # proto does not contain `pad_token_id`. This mapping of the pad token
183
184
  # is done automatically by the above subtraction.
@@ -64,27 +64,28 @@ def _rel_shift(x, klen=-1):
64
64
  class TwoStreamRelativeAttention(keras.layers.MultiHeadAttention):
65
65
  """Two-stream relative self-attention for XLNet.
66
66
 
67
- In XLNet, each token has two associated vectors at each self-attention layer,
68
- the content stream (h) and the query stream (g). The content stream is the
69
- self-attention stream as in Transformer XL and represents the context and
70
- content (the token itself). The query stream only has access to contextual
71
- information and the position, but not the content.
67
+ In XLNet, each token has two associated vectors at each self-attention
68
+ layer, the content stream (h) and the query stream (g). The content stream
69
+ is the self-attention stream as in Transformer XL and represents the context
70
+ and content (the token itself). The query stream only has access to
71
+ contextual information and the position, but not the content.
72
72
 
73
- This layer shares the same build signature as `keras.layers.MultiHeadAttention`
74
- but has different input/output projections.
73
+ This layer shares the same build signature as
74
+ `keras.layers.MultiHeadAttention` but has different input/output
75
+ projections.
75
76
 
76
77
  We use the notations `B`, `T`, `S`, `M`, `L`, `E`, `P`, `dim`, `num_heads`
77
- below, where
78
- `B` is the batch dimension, `T` is the target sequence length,
78
+ below, where `B` is the batch dimension, `T` is the target sequence length,
79
79
  `S` in the source sequence length, `M` is the length of the state or memory,
80
80
  `L` is the length of relative positional encoding, `E` is the last dimension
81
- of query input, `P` is the number of predictions, `dim` is the dimensionality
82
- of the encoder layers. and `num_heads` is the number of attention heads.
81
+ of query input, `P` is the number of predictions, `dim` is the
82
+ dimensionality of the encoder layers. and `num_heads` is the number of
83
+ attention heads.
83
84
 
84
85
  Args:
85
86
  content_stream: `Tensor` of shape `[B, T, dim]`.
86
- content_attention_bias: Bias `Tensor` for content based attention of shape
87
- `[num_heads, dim]`.
87
+ content_attention_bias: Bias `Tensor` for content based attention of
88
+ shape `[num_heads, dim]`.
88
89
  positional_attention_bias: Bias `Tensor` for position based attention of
89
90
  shape `[num_heads, dim]`.
90
91
  query_stream: `Tensor` of shape `[B, P, dim]`.
@@ -96,8 +97,8 @@ class TwoStreamRelativeAttention(keras.layers.MultiHeadAttention):
96
97
  segment_encoding: Optional `Tensor` representing the segmentation
97
98
  encoding as used in XLNet of shape `[2, num_heads, dim]`.
98
99
  segment_attention_bias: Optional trainable bias parameter added to the
99
- query had when calculating the segment-based attention score used
100
- in XLNet of shape `[num_heads, dim]`.
100
+ query had when calculating the segment-based attention score used in
101
+ XLNet of shape `[num_heads, dim]`.
101
102
  state: Optional `Tensor` of shape `[B, M, E]`.
102
103
  If passed, this is also attended over as in Transformer XL.
103
104
  content_attention_mask: a boolean mask of shape `[B, T, S]` that
@@ -336,11 +337,11 @@ class TwoStreamRelativeAttention(keras.layers.MultiHeadAttention):
336
337
  dimension of query input.
337
338
 
338
339
  Args:
339
- content_stream: The content representation, commonly referred to as h.
340
- This serves a similar role to the standard hidden states in
340
+ content_stream: The content representation, commonly referred to as
341
+ h. This serves a similar role to the standard hidden states in
341
342
  Transformer-XL.
342
- content_attention_bias: A trainable bias parameter added to the query
343
- head when calculating the content-based attention score.
343
+ content_attention_bias: A trainable bias parameter added to the
344
+ query head when calculating the content-based attention score.
344
345
  positional_attention_bias: A trainable bias parameter added to the
345
346
  query head when calculating the position-based attention score.
346
347
  query_stream: The query representation, commonly referred to as g.
@@ -49,8 +49,8 @@ class XLNetBackbone(Backbone):
49
49
  `[batch_size, sequence_length]`.
50
50
  segment_ids: Segment token indices to indicate first and second portions
51
51
  of the inputs of shape `[batch_size, sequence_length]`.
52
- padding_mask: Mask to avoid performing attention on padding token indices
53
- of shape `[batch_size, sequence_length]`.
52
+ padding_mask: Mask to avoid performing attention on padding token
53
+ indices of shape `[batch_size, sequence_length]`.
54
54
 
55
55
  Example:
56
56
  ```python