keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/api/layers/__init__.py +12 -0
  2. keras_hub/api/models/__init__.py +32 -0
  3. keras_hub/src/bounding_box/__init__.py +2 -0
  4. keras_hub/src/bounding_box/converters.py +102 -12
  5. keras_hub/src/layers/modeling/rms_normalization.py +34 -0
  6. keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
  7. keras_hub/src/layers/preprocessing/image_converter.py +5 -0
  8. keras_hub/src/models/albert/albert_presets.py +0 -8
  9. keras_hub/src/models/bart/bart_presets.py +0 -6
  10. keras_hub/src/models/bert/bert_presets.py +0 -20
  11. keras_hub/src/models/bloom/bloom_presets.py +0 -16
  12. keras_hub/src/models/clip/__init__.py +5 -0
  13. keras_hub/src/models/clip/clip_backbone.py +286 -0
  14. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  15. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  16. keras_hub/src/models/clip/clip_presets.py +93 -0
  17. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  18. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  19. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  20. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  21. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
  22. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
  23. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
  24. keras_hub/src/models/densenet/densenet_backbone.py +1 -1
  25. keras_hub/src/models/densenet/densenet_presets.py +0 -6
  26. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
  27. keras_hub/src/models/efficientnet/__init__.py +9 -0
  28. keras_hub/src/models/efficientnet/cba.py +141 -0
  29. keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
  30. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  31. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  32. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  33. keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
  34. keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
  35. keras_hub/src/models/efficientnet/mbconv.py +52 -21
  36. keras_hub/src/models/electra/electra_presets.py +0 -12
  37. keras_hub/src/models/f_net/f_net_presets.py +0 -4
  38. keras_hub/src/models/falcon/falcon_presets.py +0 -2
  39. keras_hub/src/models/flux/__init__.py +5 -0
  40. keras_hub/src/models/flux/flux_layers.py +494 -0
  41. keras_hub/src/models/flux/flux_maths.py +218 -0
  42. keras_hub/src/models/flux/flux_model.py +231 -0
  43. keras_hub/src/models/flux/flux_presets.py +14 -0
  44. keras_hub/src/models/flux/flux_text_to_image.py +142 -0
  45. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  46. keras_hub/src/models/gemma/gemma_presets.py +0 -40
  47. keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
  48. keras_hub/src/models/image_object_detector.py +87 -0
  49. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  50. keras_hub/src/models/image_to_image.py +16 -10
  51. keras_hub/src/models/inpaint.py +20 -13
  52. keras_hub/src/models/llama/llama_backbone.py +1 -1
  53. keras_hub/src/models/llama/llama_presets.py +5 -15
  54. keras_hub/src/models/llama3/llama3_presets.py +0 -8
  55. keras_hub/src/models/mistral/mistral_presets.py +0 -6
  56. keras_hub/src/models/mit/mit_backbone.py +41 -27
  57. keras_hub/src/models/mit/mit_layers.py +9 -7
  58. keras_hub/src/models/mit/mit_presets.py +12 -24
  59. keras_hub/src/models/opt/opt_presets.py +0 -8
  60. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
  61. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  62. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
  63. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
  64. keras_hub/src/models/phi3/phi3_presets.py +0 -4
  65. keras_hub/src/models/resnet/resnet_presets.py +10 -42
  66. keras_hub/src/models/retinanet/__init__.py +5 -0
  67. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  68. keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
  69. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  70. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  71. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  72. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  73. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  74. keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
  75. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  76. keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
  77. keras_hub/src/models/roberta/roberta_presets.py +0 -4
  78. keras_hub/src/models/sam/sam_backbone.py +0 -1
  79. keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
  80. keras_hub/src/models/sam/sam_presets.py +0 -6
  81. keras_hub/src/models/segformer/__init__.py +8 -0
  82. keras_hub/src/models/segformer/segformer_backbone.py +163 -0
  83. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  84. keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
  85. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  86. keras_hub/src/models/segformer/segformer_presets.py +124 -0
  87. keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
  88. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
  89. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
  90. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
  92. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
  93. keras_hub/src/models/t5/t5_backbone.py +5 -4
  94. keras_hub/src/models/t5/t5_presets.py +41 -13
  95. keras_hub/src/models/text_to_image.py +13 -5
  96. keras_hub/src/models/vgg/vgg_backbone.py +1 -1
  97. keras_hub/src/models/vgg/vgg_presets.py +0 -8
  98. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
  99. keras_hub/src/models/whisper/whisper_presets.py +0 -20
  100. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
  101. keras_hub/src/tests/test_case.py +25 -0
  102. keras_hub/src/utils/preset_utils.py +17 -4
  103. keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
  104. keras_hub/src/utils/timm/preset_loader.py +3 -0
  105. keras_hub/src/version_utils.py +1 -1
  106. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
  107. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
  108. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
  109. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -6,9 +6,7 @@ backbone_presets = {
6
6
  "metadata": {
7
7
  "description": "Mistral 7B base model",
8
8
  "params": 7241732096,
9
- "official_name": "Mistral",
10
9
  "path": "mistral",
11
- "model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
12
10
  },
13
11
  "kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/6",
14
12
  },
@@ -16,9 +14,7 @@ backbone_presets = {
16
14
  "metadata": {
17
15
  "description": "Mistral 7B instruct model",
18
16
  "params": 7241732096,
19
- "official_name": "Mistral",
20
17
  "path": "mistral",
21
- "model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
22
18
  },
23
19
  "kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/6",
24
20
  },
@@ -26,9 +22,7 @@ backbone_presets = {
26
22
  "metadata": {
27
23
  "description": "Mistral 7B instruct Version 0.2 model",
28
24
  "params": 7241732096,
29
- "official_name": "Mistral",
30
25
  "path": "mistral",
31
- "model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
32
26
  },
33
27
  "kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.2_instruct_7b_en/1",
34
28
  },
@@ -1,3 +1,14 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # https://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
1
12
  import keras
2
13
  import numpy as np
3
14
  from keras import ops
@@ -12,13 +23,13 @@ from keras_hub.src.models.mit.mit_layers import OverlappingPatchingAndEmbedding
12
23
  class MiTBackbone(FeaturePyramidBackbone):
13
24
  def __init__(
14
25
  self,
15
- depths,
26
+ layerwise_depths,
16
27
  num_layers,
17
- blockwise_num_heads,
18
- blockwise_sr_ratios,
28
+ layerwise_num_heads,
29
+ layerwise_sr_ratios,
19
30
  max_drop_path_rate,
20
- patch_sizes,
21
- strides,
31
+ layerwise_patch_sizes,
32
+ layerwise_strides,
22
33
  image_shape=(None, None, 3),
23
34
  hidden_dims=None,
24
35
  **kwargs,
@@ -32,12 +43,12 @@ class MiTBackbone(FeaturePyramidBackbone):
32
43
  https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer)
33
44
 
34
45
  Args:
35
- depths: The number of transformer encoders to be used per layer in the
46
+ layerwise_depths: The number of transformer encoders to be used per layer in the
36
47
  network.
37
48
  num_layers: int. The number of Transformer layers.
38
- blockwise_num_heads: list of integers, the number of heads to use
49
+ layerwise_num_heads: list of integers, the number of heads to use
39
50
  in the attention computation for each layer.
40
- blockwise_sr_ratios: list of integers, the sequence reduction
51
+ layerwise_sr_ratios: list of integers, the sequence reduction
41
52
  ratio to perform for each layer on the sequence before key and
42
53
  value projections. If set to > 1, a `Conv2D` layer is used to
43
54
  reduce the length of the sequence.
@@ -71,7 +82,10 @@ class MiTBackbone(FeaturePyramidBackbone):
71
82
  model.fit(images, labels, epochs=3)
72
83
  ```
73
84
  """
74
- dpr = [x for x in np.linspace(0.0, max_drop_path_rate, sum(depths))]
85
+ dpr = [
86
+ x
87
+ for x in np.linspace(0.0, max_drop_path_rate, sum(layerwise_depths))
88
+ ]
75
89
 
76
90
  # === Layers ===
77
91
  cur = 0
@@ -82,8 +96,8 @@ class MiTBackbone(FeaturePyramidBackbone):
82
96
  for i in range(num_layers):
83
97
  patch_embed_layer = OverlappingPatchingAndEmbedding(
84
98
  project_dim=hidden_dims[i],
85
- patch_size=patch_sizes[i],
86
- stride=strides[i],
99
+ patch_size=layerwise_patch_sizes[i],
100
+ stride=layerwise_strides[i],
87
101
  name=f"patch_and_embed_{i}",
88
102
  )
89
103
  patch_embedding_layers.append(patch_embed_layer)
@@ -91,16 +105,16 @@ class MiTBackbone(FeaturePyramidBackbone):
91
105
  transformer_block = [
92
106
  HierarchicalTransformerEncoder(
93
107
  project_dim=hidden_dims[i],
94
- num_heads=blockwise_num_heads[i],
95
- sr_ratio=blockwise_sr_ratios[i],
108
+ num_heads=layerwise_num_heads[i],
109
+ sr_ratio=layerwise_sr_ratios[i],
96
110
  drop_prob=dpr[cur + k],
97
111
  name=f"hierarchical_encoder_{i}_{k}",
98
112
  )
99
- for k in range(depths[i])
113
+ for k in range(layerwise_depths[i])
100
114
  ]
101
115
  transformer_blocks.append(transformer_block)
102
- cur += depths[i]
103
- layer_norms.append(keras.layers.LayerNormalization())
116
+ cur += layerwise_depths[i]
117
+ layer_norms.append(keras.layers.LayerNormalization(epsilon=1e-5))
104
118
 
105
119
  # === Functional Model ===
106
120
  image_input = keras.layers.Input(shape=image_shape)
@@ -109,7 +123,7 @@ class MiTBackbone(FeaturePyramidBackbone):
109
123
  for i in range(num_layers):
110
124
  # Compute new height/width after the `proj`
111
125
  # call in `OverlappingPatchingAndEmbedding`
112
- stride = strides[i]
126
+ stride = layerwise_strides[i]
113
127
  new_height, new_width = (
114
128
  int(ops.shape(x)[1] / stride),
115
129
  int(ops.shape(x)[2] / stride),
@@ -127,30 +141,30 @@ class MiTBackbone(FeaturePyramidBackbone):
127
141
  super().__init__(inputs=image_input, outputs=x, **kwargs)
128
142
 
129
143
  # === Config ===
130
- self.depths = depths
144
+ self.layerwise_depths = layerwise_depths
131
145
  self.image_shape = image_shape
132
146
  self.hidden_dims = hidden_dims
133
147
  self.pyramid_outputs = pyramid_outputs
134
148
  self.num_layers = num_layers
135
- self.blockwise_num_heads = blockwise_num_heads
136
- self.blockwise_sr_ratios = blockwise_sr_ratios
149
+ self.layerwise_num_heads = layerwise_num_heads
150
+ self.layerwise_sr_ratios = layerwise_sr_ratios
137
151
  self.max_drop_path_rate = max_drop_path_rate
138
- self.patch_sizes = patch_sizes
139
- self.strides = strides
152
+ self.layerwise_patch_sizes = layerwise_patch_sizes
153
+ self.layerwise_strides = layerwise_strides
140
154
 
141
155
  def get_config(self):
142
156
  config = super().get_config()
143
157
  config.update(
144
158
  {
145
- "depths": self.depths,
159
+ "layerwise_depths": self.layerwise_depths,
146
160
  "hidden_dims": self.hidden_dims,
147
161
  "image_shape": self.image_shape,
148
162
  "num_layers": self.num_layers,
149
- "blockwise_num_heads": self.blockwise_num_heads,
150
- "blockwise_sr_ratios": self.blockwise_sr_ratios,
163
+ "layerwise_num_heads": self.layerwise_num_heads,
164
+ "layerwise_sr_ratios": self.layerwise_sr_ratios,
151
165
  "max_drop_path_rate": self.max_drop_path_rate,
152
- "patch_sizes": self.patch_sizes,
153
- "strides": self.strides,
166
+ "layerwise_patch_sizes": self.layerwise_patch_sizes,
167
+ "layerwise_strides": self.layerwise_strides,
154
168
  }
155
169
  )
156
170
  return config
@@ -183,20 +183,21 @@ class SegFormerMultiheadAttention(keras.layers.Layer):
183
183
  self.k = keras.layers.Dense(project_dim)
184
184
  self.v = keras.layers.Dense(project_dim)
185
185
  self.proj = keras.layers.Dense(project_dim)
186
+ self.dropout = keras.layers.Dropout(0.1)
187
+ self.proj_drop = keras.layers.Dropout(0.1)
186
188
 
187
189
  if sr_ratio > 1:
188
190
  self.sr = keras.layers.Conv2D(
189
191
  filters=project_dim,
190
192
  kernel_size=sr_ratio,
191
193
  strides=sr_ratio,
192
- padding="same",
193
194
  )
194
- self.norm = keras.layers.LayerNormalization()
195
+ self.norm = keras.layers.LayerNormalization(epsilon=1e-5)
195
196
 
196
197
  def call(self, x):
197
198
  input_shape = ops.shape(x)
198
199
  H, W = int(math.sqrt(input_shape[1])), int(math.sqrt(input_shape[1]))
199
- B, C = input_shape[0], input_shape[2]
200
+ B, N, C = input_shape[0], input_shape[1], input_shape[2]
200
201
 
201
202
  q = self.q(x)
202
203
  q = ops.reshape(
@@ -212,12 +213,11 @@ class SegFormerMultiheadAttention(keras.layers.Layer):
212
213
 
213
214
  if self.sr_ratio > 1:
214
215
  x = ops.reshape(
215
- ops.transpose(x, [0, 2, 1]),
216
+ x,
216
217
  (B, H, W, C),
217
218
  )
218
219
  x = self.sr(x)
219
- x = ops.reshape(x, [input_shape[0], input_shape[2], -1])
220
- x = ops.transpose(x, [0, 2, 1])
220
+ x = ops.reshape(x, [B, -1, C])
221
221
  x = self.norm(x)
222
222
 
223
223
  k = self.k(x)
@@ -241,14 +241,16 @@ class SegFormerMultiheadAttention(keras.layers.Layer):
241
241
 
242
242
  attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale
243
243
  attn = ops.nn.softmax(attn, axis=-1)
244
+ attn = self.dropout(attn)
244
245
 
245
246
  attn = attn @ v
246
247
  attn = ops.reshape(
247
248
  ops.transpose(attn, [0, 2, 1, 3]),
248
- [input_shape[0], input_shape[1], input_shape[2]],
249
+ [B, N, C],
249
250
  )
250
251
 
251
252
  x = self.proj(attn)
253
+ x = self.proj_drop(x)
252
254
  return x
253
255
 
254
256
 
@@ -18,10 +18,9 @@ backbone_presets_with_weights = {
18
18
  "MiT (MixTransformer) model with 8 transformer blocks."
19
19
  ),
20
20
  "params": 3321962,
21
- "official_name": "MiT",
22
21
  "path": "mit",
23
22
  },
24
- "kaggle_handle": "kaggle://keras/mit/keras/mit_b0_ade20k_512/1",
23
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b0_ade20k_512/2",
25
24
  },
26
25
  "mit_b1_ade20k_512": {
27
26
  "metadata": {
@@ -29,10 +28,9 @@ backbone_presets_with_weights = {
29
28
  "MiT (MixTransformer) model with 8 transformer blocks."
30
29
  ),
31
30
  "params": 13156554,
32
- "official_name": "MiT",
33
31
  "path": "mit",
34
32
  },
35
- "kaggle_handle": "kaggle://keras/mit/keras/mit_b1_ade20k_512/1",
33
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b1_ade20k_512/2",
36
34
  },
37
35
  "mit_b2_ade20k_512": {
38
36
  "metadata": {
@@ -40,10 +38,9 @@ backbone_presets_with_weights = {
40
38
  "MiT (MixTransformer) model with 16 transformer blocks."
41
39
  ),
42
40
  "params": 24201418,
43
- "official_name": "MiT",
44
41
  "path": "mit",
45
42
  },
46
- "kaggle_handle": "kaggle://keras/mit/keras/mit_b2_ade20k_512/1",
43
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b2_ade20k_512/2",
47
44
  },
48
45
  "mit_b3_ade20k_512": {
49
46
  "metadata": {
@@ -51,10 +48,9 @@ backbone_presets_with_weights = {
51
48
  "MiT (MixTransformer) model with 28 transformer blocks."
52
49
  ),
53
50
  "params": 44077258,
54
- "official_name": "MiT",
55
51
  "path": "mit",
56
52
  },
57
- "kaggle_handle": "kaggle://keras/mit/keras/mit_b3_ade20k_512/1",
53
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b3_ade20k_512/2",
58
54
  },
59
55
  "mit_b4_ade20k_512": {
60
56
  "metadata": {
@@ -62,10 +58,9 @@ backbone_presets_with_weights = {
62
58
  "MiT (MixTransformer) model with 41 transformer blocks."
63
59
  ),
64
60
  "params": 60847818,
65
- "official_name": "MiT",
66
61
  "path": "mit",
67
62
  },
68
- "kaggle_handle": "kaggle://keras/mit/keras/mit_b4_ade20k_512/1",
63
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b4_ade20k_512/2",
69
64
  },
70
65
  "mit_b5_ade20k_640": {
71
66
  "metadata": {
@@ -73,10 +68,9 @@ backbone_presets_with_weights = {
73
68
  "MiT (MixTransformer) model with 52 transformer blocks."
74
69
  ),
75
70
  "params": 81448138,
76
- "official_name": "MiT",
77
71
  "path": "mit",
78
72
  },
79
- "kaggle_handle": "kaggle://keras/mit/keras/mit_b5_ade20k_512/1",
73
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b5_ade20k_640/2",
80
74
  },
81
75
  "mit_b0_cityscapes_1024": {
82
76
  "metadata": {
@@ -84,10 +78,9 @@ backbone_presets_with_weights = {
84
78
  "MiT (MixTransformer) model with 8 transformer blocks."
85
79
  ),
86
80
  "params": 3321962,
87
- "official_name": "MiT",
88
81
  "path": "mit",
89
82
  },
90
- "kaggle_handle": "kaggle://keras/mit/keras/mit_b0_cityscapes_1024/1",
83
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b0_cityscapes_1024/2",
91
84
  },
92
85
  "mit_b1_cityscapes_1024": {
93
86
  "metadata": {
@@ -95,10 +88,9 @@ backbone_presets_with_weights = {
95
88
  "MiT (MixTransformer) model with 8 transformer blocks."
96
89
  ),
97
90
  "params": 13156554,
98
- "official_name": "MiT",
99
91
  "path": "mit",
100
92
  },
101
- "kaggle_handle": "kaggle://keras/mit/keras/mit_b1_cityscapes_1024/1",
93
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b1_cityscapes_1024/2",
102
94
  },
103
95
  "mit_b2_cityscapes_1024": {
104
96
  "metadata": {
@@ -106,10 +98,9 @@ backbone_presets_with_weights = {
106
98
  "MiT (MixTransformer) model with 16 transformer blocks."
107
99
  ),
108
100
  "params": 24201418,
109
- "official_name": "MiT",
110
101
  "path": "mit",
111
102
  },
112
- "kaggle_handle": "kaggle://keras/mit/keras/mit_b2_cityscapes_1024/1",
103
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b2_cityscapes_1024/2",
113
104
  },
114
105
  "mit_b3_cityscapes_1024": {
115
106
  "metadata": {
@@ -117,10 +108,9 @@ backbone_presets_with_weights = {
117
108
  "MiT (MixTransformer) model with 28 transformer blocks."
118
109
  ),
119
110
  "params": 44077258,
120
- "official_name": "MiT",
121
111
  "path": "mit",
122
112
  },
123
- "kaggle_handle": "kaggle://keras/mit/keras/mit_b3_cityscapes_1024/1",
113
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b3_cityscapes_1024/2",
124
114
  },
125
115
  "mit_b4_cityscapes_1024": {
126
116
  "metadata": {
@@ -128,10 +118,9 @@ backbone_presets_with_weights = {
128
118
  "MiT (MixTransformer) model with 41 transformer blocks."
129
119
  ),
130
120
  "params": 60847818,
131
- "official_name": "MiT",
132
121
  "path": "mit",
133
122
  },
134
- "kaggle_handle": "kaggle://keras/mit/keras/mit_b4_cityscapes_1024/1",
123
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b4_cityscapes_1024/2",
135
124
  },
136
125
  "mit_b5_cityscapes_1024": {
137
126
  "metadata": {
@@ -139,10 +128,9 @@ backbone_presets_with_weights = {
139
128
  "MiT (MixTransformer) model with 52 transformer blocks."
140
129
  ),
141
130
  "params": 81448138,
142
- "official_name": "MiT",
143
131
  "path": "mit",
144
132
  },
145
- "kaggle_handle": "kaggle://keras/mit/keras/mit_b5_cityscapes_1024/1",
133
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b5_cityscapes_1024/2",
146
134
  },
147
135
  }
148
136
 
@@ -9,9 +9,7 @@ backbone_presets = {
9
9
  "BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
10
10
  ),
11
11
  "params": 125237760,
12
- "official_name": "OPT",
13
12
  "path": "opt",
14
- "model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
15
13
  },
16
14
  "kaggle_handle": "kaggle://keras/opt/keras/opt_125m_en/2",
17
15
  },
@@ -24,9 +22,7 @@ backbone_presets = {
24
22
  "BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
25
23
  ),
26
24
  "params": 1315753984,
27
- "official_name": "OPT",
28
25
  "path": "opt",
29
- "model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
30
26
  },
31
27
  "kaggle_handle": "kaggle://keras/opt/keras/opt_1.3b_en/2",
32
28
  },
@@ -37,9 +33,7 @@ backbone_presets = {
37
33
  "BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
38
34
  ),
39
35
  "params": 2700000000,
40
- "official_name": "OPT",
41
36
  "path": "opt",
42
- "model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
43
37
  },
44
38
  "kaggle_handle": "kaggle://keras/opt/keras/opt_2.7b_en/2",
45
39
  },
@@ -50,9 +44,7 @@ backbone_presets = {
50
44
  "BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
51
45
  ),
52
46
  "params": 6700000000,
53
- "official_name": "OPT",
54
47
  "path": "opt",
55
- "model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
56
48
  },
57
49
  "kaggle_handle": "kaggle://keras/opt/keras/opt_6.7b_en/2",
58
50
  },
@@ -48,22 +48,40 @@ class PaliGemmaBackbone(Backbone):
48
48
  a two-layer feedforward network for each transformer decoder block.
49
49
  head_dim: int. The size of each attention head in the mixed decoder.
50
50
  vit_patch_size: int. The size of each square patch in the input image.
51
- vit_num_heads: int. The number of attention heads for the vision(image)
51
+ vit_num_heads: int. The number of attention heads for the vision (image)
52
52
  transformer encoder.
53
53
  vit_hidden_dim: int. The size of the transformer hidden state at the end
54
54
  of each vision transformer layer.
55
55
  vit_num_layers: int. The number of vision transformer layers.
56
56
  vit_intermediate_dim: int. The output dimension of the first Dense layer
57
- in a two-layer feedforward network for vision transformer.
58
- vit_pooling: string. The encoded vision embeddings are pooled using the
59
- specified polling setting. The accepted values are `"map"`, `"gap"`,
60
- `"0"` or `"none"`. Defaults to `"none"`.
57
+ in a two-layer feedforward network for vision transformer. Defaults
58
+ to `4304`.
59
+ vit_pooling: `None` or string. The encoded vision embeddings are pooled
60
+ using the specified polling setting. The accepted values are
61
+ `"map"`, `"gap"`, `"0"` or `None`. Defaults to `None`.
61
62
  vit_classifier_activation: activation function. The activation that
62
63
  is used for final output classification in the vision transformer.
64
+ Defaults to `None`.
63
65
  vit_name: string. The name used for vision transformer layers.
66
+ query_head_dim_normalize: boolean. If `True` normalize the query before
67
+ attention with `head_dim`. If `False`, normalize the query with
68
+ `hidden_dim / num_query_heads`. Defaults to `True`.
69
+ use_post_ffw_norm: boolean. Whether to normalize after the feedforward
70
+ block. Defaults to `False`.
71
+ use_post_attention_norm: boolean. Whether to normalize after the attention
72
+ block. Defaults to `False`.
73
+ attention_logit_soft_cap: `None` or int. Soft cap for the attention
74
+ logits. Defaults to `None`.
75
+ final_logit_soft_cap: `None` or int. Soft cap for the final logits.
76
+ Defaults to `None`.
77
+ use_sliding_window_attention: boolean. Whether to use sliding local
78
+ window attention. Defaults to `False`.
79
+ sliding_window_size: int. Size of the sliding local window. Defaults to
80
+ `4096`.
64
81
  layer_norm_epsilon: float. The epsilon value user for every layer norm
65
- in all transformer blocks.
82
+ in all transformer blocks. Defaults to `1e-6`.
66
83
  dropout: float. Dropout probability for the Transformer decoder blocks.
84
+ Defaults to `0`.
67
85
  dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
68
86
  for the models computations and weights. Note that some
69
87
  computations, such as softmax and layer normalization will always
@@ -119,6 +137,13 @@ class PaliGemmaBackbone(Backbone):
119
137
  vit_pooling=None,
120
138
  vit_classifier_activation=None,
121
139
  vit_name=None,
140
+ query_head_dim_normalize=True,
141
+ use_post_ffw_norm=False,
142
+ use_post_attention_norm=False,
143
+ attention_logit_soft_cap=None,
144
+ final_logit_soft_cap=None,
145
+ use_sliding_window_attention=False,
146
+ sliding_window_size=4096,
122
147
  layer_norm_epsilon=1e-6,
123
148
  dropout=0,
124
149
  dtype=None,
@@ -136,6 +161,7 @@ class PaliGemmaBackbone(Backbone):
136
161
  seed=None,
137
162
  ),
138
163
  dtype=dtype,
164
+ logit_soft_cap=final_logit_soft_cap,
139
165
  name="token_embedding",
140
166
  )
141
167
  # TODO Remove this. Work around for previous serialization bug.
@@ -155,12 +181,19 @@ class PaliGemmaBackbone(Backbone):
155
181
  )
156
182
  self.transformer_layers = []
157
183
  for i in range(num_layers):
184
+ sliding_window = use_sliding_window_attention and (i % 2 == 0)
158
185
  layer = PaliGemmaDecoderBlock(
159
186
  hidden_dim=hidden_dim,
160
187
  intermediate_dim=intermediate_dim,
161
- num_query_heads=num_query_heads,
162
188
  head_dim=head_dim,
189
+ num_query_heads=num_query_heads,
163
190
  num_key_value_heads=num_key_value_heads,
191
+ query_head_dim_normalize=query_head_dim_normalize,
192
+ use_post_ffw_norm=use_post_ffw_norm,
193
+ use_post_attention_norm=use_post_attention_norm,
194
+ logit_soft_cap=attention_logit_soft_cap,
195
+ use_sliding_window_attention=sliding_window,
196
+ sliding_window_size=sliding_window_size,
164
197
  dropout=dropout,
165
198
  dtype=dtype,
166
199
  name=f"decoder_block_{i}",
@@ -173,7 +206,9 @@ class PaliGemmaBackbone(Backbone):
173
206
  )
174
207
 
175
208
  # === Functional Model ===
176
- image_input = self.vit_encoder.inputs[0]
209
+ image_input = keras.Input(
210
+ shape=(image_size, image_size, 3), name="images"
211
+ )
177
212
  token_id_input = keras.Input(
178
213
  shape=(None,), dtype="int32", name="token_ids"
179
214
  )
@@ -219,7 +254,15 @@ class PaliGemmaBackbone(Backbone):
219
254
  self.head_dim = head_dim
220
255
  self.layer_norm_epsilon = layer_norm_epsilon
221
256
  self.dropout = dropout
222
- # VIT Params
257
+ # Gemma2 params
258
+ self.query_head_dim_normalize = query_head_dim_normalize
259
+ self.use_post_ffw_norm = use_post_ffw_norm
260
+ self.use_post_attention_norm = use_post_attention_norm
261
+ self.attention_logit_soft_cap = attention_logit_soft_cap
262
+ self.final_logit_soft_cap = final_logit_soft_cap
263
+ self.sliding_window_size = sliding_window_size
264
+ self.use_sliding_window_attention = use_sliding_window_attention
265
+ # ViT params
223
266
  self.vit_patch_size = vit_patch_size
224
267
  self.vit_num_heads = vit_num_heads
225
268
  self.vit_hidden_dim = vit_hidden_dim
@@ -243,8 +286,6 @@ class PaliGemmaBackbone(Backbone):
243
286
  "hidden_dim": self.hidden_dim,
244
287
  "intermediate_dim": self.intermediate_dim,
245
288
  "head_dim": self.head_dim,
246
- "layer_norm_epsilon": self.layer_norm_epsilon,
247
- "dropout": self.dropout,
248
289
  "vit_patch_size": self.vit_patch_size,
249
290
  "vit_num_heads": self.vit_num_heads,
250
291
  "vit_hidden_dim": self.vit_hidden_dim,
@@ -253,6 +294,15 @@ class PaliGemmaBackbone(Backbone):
253
294
  "vit_pooling": self.vit_pooling,
254
295
  "vit_classifier_activation": self.vit_classifier_activation,
255
296
  "vit_name": self.vit_name,
297
+ "query_head_dim_normalize": self.query_head_dim_normalize,
298
+ "use_post_ffw_norm": self.use_post_ffw_norm,
299
+ "use_post_attention_norm": self.use_post_attention_norm,
300
+ "final_logit_soft_cap": self.final_logit_soft_cap,
301
+ "attention_logit_soft_cap": self.attention_logit_soft_cap,
302
+ "sliding_window_size": self.sliding_window_size,
303
+ "use_sliding_window_attention": self.use_sliding_window_attention,
304
+ "layer_norm_epsilon": self.layer_norm_epsilon,
305
+ "dropout": self.dropout,
256
306
  }
257
307
  )
258
308
  return config
@@ -31,33 +31,25 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
31
31
  the attention layer.
32
32
  num_key_value_heads: int. The number of heads for the key and value
33
33
  projections in the attention layer.
34
+ query_head_dim_normalize: boolean. If `True` normalize the query before
35
+ attention with `head_dim`. If `False`, normalize the query with
36
+ `hidden_dim / num_query_heads`. Defaults to `True`.
37
+ use_post_ffw_norm: boolean. Whether to normalize after the feedforward
38
+ block. Defaults to `False`.
39
+ use_post_attention_norm: boolean. Whether to normalize after the
40
+ attention block. Defaults to `False`.
41
+ logit_soft_cap: `None` or int. Soft cap for the attention logits.
42
+ Defaults to `None`.
43
+ use_sliding_window_attention: boolean. Whether to use sliding local
44
+ window attention. Defaults to `False`.
45
+ sliding_window_size: int. Size of the sliding local window. Defaults to
46
+ `4096`.
34
47
  layer_norm_epsilon: float. The epsilon hyperparameter used for layer
35
- normalization.
48
+ normalization. Defaults to `1e-6`.
36
49
  dropout: float. The dropout rate for the transformer attention layer.
50
+ Defaults to `0`.
37
51
  """
38
52
 
39
- def __init__(
40
- self,
41
- hidden_dim,
42
- intermediate_dim,
43
- head_dim,
44
- num_query_heads,
45
- num_key_value_heads,
46
- layer_norm_epsilon=1e-6,
47
- dropout=0,
48
- **kwargs,
49
- ):
50
- super().__init__(
51
- hidden_dim=hidden_dim,
52
- intermediate_dim=intermediate_dim,
53
- head_dim=head_dim,
54
- num_query_heads=num_query_heads,
55
- num_key_value_heads=num_key_value_heads,
56
- layer_norm_epsilon=layer_norm_epsilon,
57
- dropout=dropout,
58
- **kwargs,
59
- )
60
-
61
53
  def call(
62
54
  self,
63
55
  x,
@@ -83,6 +75,9 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
83
75
  attention_mask=attention_mask,
84
76
  )
85
77
 
78
+ if self.use_post_attention_norm:
79
+ attention = self.post_attention_norm(attention)
80
+
86
81
  if self.dropout:
87
82
  attention = self.attention_dropout(attention)
88
83
 
@@ -94,6 +89,9 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
94
89
  x = keras.activations.gelu(x1, approximate=True) * x2
95
90
  x = self.ffw_linear(x)
96
91
 
92
+ if self.use_post_ffw_norm:
93
+ x = self.post_ffw_norm(x)
94
+
97
95
  x = x + attention_x
98
96
 
99
97
  if cache is not None: