keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +12 -0
- keras_hub/api/models/__init__.py +32 -0
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/rms_normalization.py +34 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
- keras_hub/src/layers/preprocessing/image_converter.py +5 -0
- keras_hub/src/models/albert/albert_presets.py +0 -8
- keras_hub/src/models/bart/bart_presets.py +0 -6
- keras_hub/src/models/bert/bert_presets.py +0 -20
- keras_hub/src/models/bloom/bloom_presets.py +0 -16
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
- keras_hub/src/models/densenet/densenet_backbone.py +1 -1
- keras_hub/src/models/densenet/densenet_presets.py +0 -6
- keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
- keras_hub/src/models/efficientnet/mbconv.py +52 -21
- keras_hub/src/models/electra/electra_presets.py +0 -12
- keras_hub/src/models/f_net/f_net_presets.py +0 -4
- keras_hub/src/models/falcon/falcon_presets.py +0 -2
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +494 -0
- keras_hub/src/models/flux/flux_maths.py +218 -0
- keras_hub/src/models/flux/flux_model.py +231 -0
- keras_hub/src/models/flux/flux_presets.py +14 -0
- keras_hub/src/models/flux/flux_text_to_image.py +142 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_presets.py +0 -40
- keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_to_image.py +16 -10
- keras_hub/src/models/inpaint.py +20 -13
- keras_hub/src/models/llama/llama_backbone.py +1 -1
- keras_hub/src/models/llama/llama_presets.py +5 -15
- keras_hub/src/models/llama3/llama3_presets.py +0 -8
- keras_hub/src/models/mistral/mistral_presets.py +0 -6
- keras_hub/src/models/mit/mit_backbone.py +41 -27
- keras_hub/src/models/mit/mit_layers.py +9 -7
- keras_hub/src/models/mit/mit_presets.py +12 -24
- keras_hub/src/models/opt/opt_presets.py +0 -8
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
- keras_hub/src/models/phi3/phi3_presets.py +0 -4
- keras_hub/src/models/resnet/resnet_presets.py +10 -42
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
- keras_hub/src/models/roberta/roberta_presets.py +0 -4
- keras_hub/src/models/sam/sam_backbone.py +0 -1
- keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
- keras_hub/src/models/sam/sam_presets.py +0 -6
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +163 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +124 -0
- keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +41 -13
- keras_hub/src/models/text_to_image.py +13 -5
- keras_hub/src/models/vgg/vgg_backbone.py +1 -1
- keras_hub/src/models/vgg/vgg_presets.py +0 -8
- keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
- keras_hub/src/models/whisper/whisper_presets.py +0 -20
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
- keras_hub/src/tests/test_case.py +25 -0
- keras_hub/src/utils/preset_utils.py +17 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
- keras_hub/src/utils/timm/preset_loader.py +3 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -6,9 +6,7 @@ backbone_presets = {
|
|
6
6
|
"metadata": {
|
7
7
|
"description": "Mistral 7B base model",
|
8
8
|
"params": 7241732096,
|
9
|
-
"official_name": "Mistral",
|
10
9
|
"path": "mistral",
|
11
|
-
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
|
12
10
|
},
|
13
11
|
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/6",
|
14
12
|
},
|
@@ -16,9 +14,7 @@ backbone_presets = {
|
|
16
14
|
"metadata": {
|
17
15
|
"description": "Mistral 7B instruct model",
|
18
16
|
"params": 7241732096,
|
19
|
-
"official_name": "Mistral",
|
20
17
|
"path": "mistral",
|
21
|
-
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
|
22
18
|
},
|
23
19
|
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/6",
|
24
20
|
},
|
@@ -26,9 +22,7 @@ backbone_presets = {
|
|
26
22
|
"metadata": {
|
27
23
|
"description": "Mistral 7B instruct Version 0.2 model",
|
28
24
|
"params": 7241732096,
|
29
|
-
"official_name": "Mistral",
|
30
25
|
"path": "mistral",
|
31
|
-
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
|
32
26
|
},
|
33
27
|
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.2_instruct_7b_en/1",
|
34
28
|
},
|
@@ -1,3 +1,14 @@
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
2
|
+
# you may not use this file except in compliance with the License.
|
3
|
+
# You may obtain a copy of the License at
|
4
|
+
#
|
5
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
#
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10
|
+
# See the License for the specific language governing permissions and
|
11
|
+
# limitations under the License.
|
1
12
|
import keras
|
2
13
|
import numpy as np
|
3
14
|
from keras import ops
|
@@ -12,13 +23,13 @@ from keras_hub.src.models.mit.mit_layers import OverlappingPatchingAndEmbedding
|
|
12
23
|
class MiTBackbone(FeaturePyramidBackbone):
|
13
24
|
def __init__(
|
14
25
|
self,
|
15
|
-
|
26
|
+
layerwise_depths,
|
16
27
|
num_layers,
|
17
|
-
|
18
|
-
|
28
|
+
layerwise_num_heads,
|
29
|
+
layerwise_sr_ratios,
|
19
30
|
max_drop_path_rate,
|
20
|
-
|
21
|
-
|
31
|
+
layerwise_patch_sizes,
|
32
|
+
layerwise_strides,
|
22
33
|
image_shape=(None, None, 3),
|
23
34
|
hidden_dims=None,
|
24
35
|
**kwargs,
|
@@ -32,12 +43,12 @@ class MiTBackbone(FeaturePyramidBackbone):
|
|
32
43
|
https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer)
|
33
44
|
|
34
45
|
Args:
|
35
|
-
|
46
|
+
layerwise_depths: The number of transformer encoders to be used per layer in the
|
36
47
|
network.
|
37
48
|
num_layers: int. The number of Transformer layers.
|
38
|
-
|
49
|
+
layerwise_num_heads: list of integers, the number of heads to use
|
39
50
|
in the attention computation for each layer.
|
40
|
-
|
51
|
+
layerwise_sr_ratios: list of integers, the sequence reduction
|
41
52
|
ratio to perform for each layer on the sequence before key and
|
42
53
|
value projections. If set to > 1, a `Conv2D` layer is used to
|
43
54
|
reduce the length of the sequence.
|
@@ -71,7 +82,10 @@ class MiTBackbone(FeaturePyramidBackbone):
|
|
71
82
|
model.fit(images, labels, epochs=3)
|
72
83
|
```
|
73
84
|
"""
|
74
|
-
dpr = [
|
85
|
+
dpr = [
|
86
|
+
x
|
87
|
+
for x in np.linspace(0.0, max_drop_path_rate, sum(layerwise_depths))
|
88
|
+
]
|
75
89
|
|
76
90
|
# === Layers ===
|
77
91
|
cur = 0
|
@@ -82,8 +96,8 @@ class MiTBackbone(FeaturePyramidBackbone):
|
|
82
96
|
for i in range(num_layers):
|
83
97
|
patch_embed_layer = OverlappingPatchingAndEmbedding(
|
84
98
|
project_dim=hidden_dims[i],
|
85
|
-
patch_size=
|
86
|
-
stride=
|
99
|
+
patch_size=layerwise_patch_sizes[i],
|
100
|
+
stride=layerwise_strides[i],
|
87
101
|
name=f"patch_and_embed_{i}",
|
88
102
|
)
|
89
103
|
patch_embedding_layers.append(patch_embed_layer)
|
@@ -91,16 +105,16 @@ class MiTBackbone(FeaturePyramidBackbone):
|
|
91
105
|
transformer_block = [
|
92
106
|
HierarchicalTransformerEncoder(
|
93
107
|
project_dim=hidden_dims[i],
|
94
|
-
num_heads=
|
95
|
-
sr_ratio=
|
108
|
+
num_heads=layerwise_num_heads[i],
|
109
|
+
sr_ratio=layerwise_sr_ratios[i],
|
96
110
|
drop_prob=dpr[cur + k],
|
97
111
|
name=f"hierarchical_encoder_{i}_{k}",
|
98
112
|
)
|
99
|
-
for k in range(
|
113
|
+
for k in range(layerwise_depths[i])
|
100
114
|
]
|
101
115
|
transformer_blocks.append(transformer_block)
|
102
|
-
cur +=
|
103
|
-
layer_norms.append(keras.layers.LayerNormalization())
|
116
|
+
cur += layerwise_depths[i]
|
117
|
+
layer_norms.append(keras.layers.LayerNormalization(epsilon=1e-5))
|
104
118
|
|
105
119
|
# === Functional Model ===
|
106
120
|
image_input = keras.layers.Input(shape=image_shape)
|
@@ -109,7 +123,7 @@ class MiTBackbone(FeaturePyramidBackbone):
|
|
109
123
|
for i in range(num_layers):
|
110
124
|
# Compute new height/width after the `proj`
|
111
125
|
# call in `OverlappingPatchingAndEmbedding`
|
112
|
-
stride =
|
126
|
+
stride = layerwise_strides[i]
|
113
127
|
new_height, new_width = (
|
114
128
|
int(ops.shape(x)[1] / stride),
|
115
129
|
int(ops.shape(x)[2] / stride),
|
@@ -127,30 +141,30 @@ class MiTBackbone(FeaturePyramidBackbone):
|
|
127
141
|
super().__init__(inputs=image_input, outputs=x, **kwargs)
|
128
142
|
|
129
143
|
# === Config ===
|
130
|
-
self.
|
144
|
+
self.layerwise_depths = layerwise_depths
|
131
145
|
self.image_shape = image_shape
|
132
146
|
self.hidden_dims = hidden_dims
|
133
147
|
self.pyramid_outputs = pyramid_outputs
|
134
148
|
self.num_layers = num_layers
|
135
|
-
self.
|
136
|
-
self.
|
149
|
+
self.layerwise_num_heads = layerwise_num_heads
|
150
|
+
self.layerwise_sr_ratios = layerwise_sr_ratios
|
137
151
|
self.max_drop_path_rate = max_drop_path_rate
|
138
|
-
self.
|
139
|
-
self.
|
152
|
+
self.layerwise_patch_sizes = layerwise_patch_sizes
|
153
|
+
self.layerwise_strides = layerwise_strides
|
140
154
|
|
141
155
|
def get_config(self):
|
142
156
|
config = super().get_config()
|
143
157
|
config.update(
|
144
158
|
{
|
145
|
-
"
|
159
|
+
"layerwise_depths": self.layerwise_depths,
|
146
160
|
"hidden_dims": self.hidden_dims,
|
147
161
|
"image_shape": self.image_shape,
|
148
162
|
"num_layers": self.num_layers,
|
149
|
-
"
|
150
|
-
"
|
163
|
+
"layerwise_num_heads": self.layerwise_num_heads,
|
164
|
+
"layerwise_sr_ratios": self.layerwise_sr_ratios,
|
151
165
|
"max_drop_path_rate": self.max_drop_path_rate,
|
152
|
-
"
|
153
|
-
"
|
166
|
+
"layerwise_patch_sizes": self.layerwise_patch_sizes,
|
167
|
+
"layerwise_strides": self.layerwise_strides,
|
154
168
|
}
|
155
169
|
)
|
156
170
|
return config
|
@@ -183,20 +183,21 @@ class SegFormerMultiheadAttention(keras.layers.Layer):
|
|
183
183
|
self.k = keras.layers.Dense(project_dim)
|
184
184
|
self.v = keras.layers.Dense(project_dim)
|
185
185
|
self.proj = keras.layers.Dense(project_dim)
|
186
|
+
self.dropout = keras.layers.Dropout(0.1)
|
187
|
+
self.proj_drop = keras.layers.Dropout(0.1)
|
186
188
|
|
187
189
|
if sr_ratio > 1:
|
188
190
|
self.sr = keras.layers.Conv2D(
|
189
191
|
filters=project_dim,
|
190
192
|
kernel_size=sr_ratio,
|
191
193
|
strides=sr_ratio,
|
192
|
-
padding="same",
|
193
194
|
)
|
194
|
-
self.norm = keras.layers.LayerNormalization()
|
195
|
+
self.norm = keras.layers.LayerNormalization(epsilon=1e-5)
|
195
196
|
|
196
197
|
def call(self, x):
|
197
198
|
input_shape = ops.shape(x)
|
198
199
|
H, W = int(math.sqrt(input_shape[1])), int(math.sqrt(input_shape[1]))
|
199
|
-
B, C = input_shape[0], input_shape[2]
|
200
|
+
B, N, C = input_shape[0], input_shape[1], input_shape[2]
|
200
201
|
|
201
202
|
q = self.q(x)
|
202
203
|
q = ops.reshape(
|
@@ -212,12 +213,11 @@ class SegFormerMultiheadAttention(keras.layers.Layer):
|
|
212
213
|
|
213
214
|
if self.sr_ratio > 1:
|
214
215
|
x = ops.reshape(
|
215
|
-
|
216
|
+
x,
|
216
217
|
(B, H, W, C),
|
217
218
|
)
|
218
219
|
x = self.sr(x)
|
219
|
-
x = ops.reshape(x, [
|
220
|
-
x = ops.transpose(x, [0, 2, 1])
|
220
|
+
x = ops.reshape(x, [B, -1, C])
|
221
221
|
x = self.norm(x)
|
222
222
|
|
223
223
|
k = self.k(x)
|
@@ -241,14 +241,16 @@ class SegFormerMultiheadAttention(keras.layers.Layer):
|
|
241
241
|
|
242
242
|
attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale
|
243
243
|
attn = ops.nn.softmax(attn, axis=-1)
|
244
|
+
attn = self.dropout(attn)
|
244
245
|
|
245
246
|
attn = attn @ v
|
246
247
|
attn = ops.reshape(
|
247
248
|
ops.transpose(attn, [0, 2, 1, 3]),
|
248
|
-
[
|
249
|
+
[B, N, C],
|
249
250
|
)
|
250
251
|
|
251
252
|
x = self.proj(attn)
|
253
|
+
x = self.proj_drop(x)
|
252
254
|
return x
|
253
255
|
|
254
256
|
|
@@ -18,10 +18,9 @@ backbone_presets_with_weights = {
|
|
18
18
|
"MiT (MixTransformer) model with 8 transformer blocks."
|
19
19
|
),
|
20
20
|
"params": 3321962,
|
21
|
-
"official_name": "MiT",
|
22
21
|
"path": "mit",
|
23
22
|
},
|
24
|
-
"kaggle_handle": "kaggle://keras/mit/keras/mit_b0_ade20k_512/
|
23
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b0_ade20k_512/2",
|
25
24
|
},
|
26
25
|
"mit_b1_ade20k_512": {
|
27
26
|
"metadata": {
|
@@ -29,10 +28,9 @@ backbone_presets_with_weights = {
|
|
29
28
|
"MiT (MixTransformer) model with 8 transformer blocks."
|
30
29
|
),
|
31
30
|
"params": 13156554,
|
32
|
-
"official_name": "MiT",
|
33
31
|
"path": "mit",
|
34
32
|
},
|
35
|
-
"kaggle_handle": "kaggle://keras/mit/keras/mit_b1_ade20k_512/
|
33
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b1_ade20k_512/2",
|
36
34
|
},
|
37
35
|
"mit_b2_ade20k_512": {
|
38
36
|
"metadata": {
|
@@ -40,10 +38,9 @@ backbone_presets_with_weights = {
|
|
40
38
|
"MiT (MixTransformer) model with 16 transformer blocks."
|
41
39
|
),
|
42
40
|
"params": 24201418,
|
43
|
-
"official_name": "MiT",
|
44
41
|
"path": "mit",
|
45
42
|
},
|
46
|
-
"kaggle_handle": "kaggle://keras/mit/keras/mit_b2_ade20k_512/
|
43
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b2_ade20k_512/2",
|
47
44
|
},
|
48
45
|
"mit_b3_ade20k_512": {
|
49
46
|
"metadata": {
|
@@ -51,10 +48,9 @@ backbone_presets_with_weights = {
|
|
51
48
|
"MiT (MixTransformer) model with 28 transformer blocks."
|
52
49
|
),
|
53
50
|
"params": 44077258,
|
54
|
-
"official_name": "MiT",
|
55
51
|
"path": "mit",
|
56
52
|
},
|
57
|
-
"kaggle_handle": "kaggle://keras/mit/keras/mit_b3_ade20k_512/
|
53
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b3_ade20k_512/2",
|
58
54
|
},
|
59
55
|
"mit_b4_ade20k_512": {
|
60
56
|
"metadata": {
|
@@ -62,10 +58,9 @@ backbone_presets_with_weights = {
|
|
62
58
|
"MiT (MixTransformer) model with 41 transformer blocks."
|
63
59
|
),
|
64
60
|
"params": 60847818,
|
65
|
-
"official_name": "MiT",
|
66
61
|
"path": "mit",
|
67
62
|
},
|
68
|
-
"kaggle_handle": "kaggle://keras/mit/keras/mit_b4_ade20k_512/
|
63
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b4_ade20k_512/2",
|
69
64
|
},
|
70
65
|
"mit_b5_ade20k_640": {
|
71
66
|
"metadata": {
|
@@ -73,10 +68,9 @@ backbone_presets_with_weights = {
|
|
73
68
|
"MiT (MixTransformer) model with 52 transformer blocks."
|
74
69
|
),
|
75
70
|
"params": 81448138,
|
76
|
-
"official_name": "MiT",
|
77
71
|
"path": "mit",
|
78
72
|
},
|
79
|
-
"kaggle_handle": "kaggle://keras/mit/keras/
|
73
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b5_ade20k_640/2",
|
80
74
|
},
|
81
75
|
"mit_b0_cityscapes_1024": {
|
82
76
|
"metadata": {
|
@@ -84,10 +78,9 @@ backbone_presets_with_weights = {
|
|
84
78
|
"MiT (MixTransformer) model with 8 transformer blocks."
|
85
79
|
),
|
86
80
|
"params": 3321962,
|
87
|
-
"official_name": "MiT",
|
88
81
|
"path": "mit",
|
89
82
|
},
|
90
|
-
"kaggle_handle": "kaggle://keras/mit/keras/mit_b0_cityscapes_1024/
|
83
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b0_cityscapes_1024/2",
|
91
84
|
},
|
92
85
|
"mit_b1_cityscapes_1024": {
|
93
86
|
"metadata": {
|
@@ -95,10 +88,9 @@ backbone_presets_with_weights = {
|
|
95
88
|
"MiT (MixTransformer) model with 8 transformer blocks."
|
96
89
|
),
|
97
90
|
"params": 13156554,
|
98
|
-
"official_name": "MiT",
|
99
91
|
"path": "mit",
|
100
92
|
},
|
101
|
-
"kaggle_handle": "kaggle://keras/mit/keras/mit_b1_cityscapes_1024/
|
93
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b1_cityscapes_1024/2",
|
102
94
|
},
|
103
95
|
"mit_b2_cityscapes_1024": {
|
104
96
|
"metadata": {
|
@@ -106,10 +98,9 @@ backbone_presets_with_weights = {
|
|
106
98
|
"MiT (MixTransformer) model with 16 transformer blocks."
|
107
99
|
),
|
108
100
|
"params": 24201418,
|
109
|
-
"official_name": "MiT",
|
110
101
|
"path": "mit",
|
111
102
|
},
|
112
|
-
"kaggle_handle": "kaggle://keras/mit/keras/mit_b2_cityscapes_1024/
|
103
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b2_cityscapes_1024/2",
|
113
104
|
},
|
114
105
|
"mit_b3_cityscapes_1024": {
|
115
106
|
"metadata": {
|
@@ -117,10 +108,9 @@ backbone_presets_with_weights = {
|
|
117
108
|
"MiT (MixTransformer) model with 28 transformer blocks."
|
118
109
|
),
|
119
110
|
"params": 44077258,
|
120
|
-
"official_name": "MiT",
|
121
111
|
"path": "mit",
|
122
112
|
},
|
123
|
-
"kaggle_handle": "kaggle://keras/mit/keras/mit_b3_cityscapes_1024/
|
113
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b3_cityscapes_1024/2",
|
124
114
|
},
|
125
115
|
"mit_b4_cityscapes_1024": {
|
126
116
|
"metadata": {
|
@@ -128,10 +118,9 @@ backbone_presets_with_weights = {
|
|
128
118
|
"MiT (MixTransformer) model with 41 transformer blocks."
|
129
119
|
),
|
130
120
|
"params": 60847818,
|
131
|
-
"official_name": "MiT",
|
132
121
|
"path": "mit",
|
133
122
|
},
|
134
|
-
"kaggle_handle": "kaggle://keras/mit/keras/mit_b4_cityscapes_1024/
|
123
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b4_cityscapes_1024/2",
|
135
124
|
},
|
136
125
|
"mit_b5_cityscapes_1024": {
|
137
126
|
"metadata": {
|
@@ -139,10 +128,9 @@ backbone_presets_with_weights = {
|
|
139
128
|
"MiT (MixTransformer) model with 52 transformer blocks."
|
140
129
|
),
|
141
130
|
"params": 81448138,
|
142
|
-
"official_name": "MiT",
|
143
131
|
"path": "mit",
|
144
132
|
},
|
145
|
-
"kaggle_handle": "kaggle://keras/mit/keras/mit_b5_cityscapes_1024/
|
133
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b5_cityscapes_1024/2",
|
146
134
|
},
|
147
135
|
}
|
148
136
|
|
@@ -9,9 +9,7 @@ backbone_presets = {
|
|
9
9
|
"BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
|
10
10
|
),
|
11
11
|
"params": 125237760,
|
12
|
-
"official_name": "OPT",
|
13
12
|
"path": "opt",
|
14
|
-
"model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
|
15
13
|
},
|
16
14
|
"kaggle_handle": "kaggle://keras/opt/keras/opt_125m_en/2",
|
17
15
|
},
|
@@ -24,9 +22,7 @@ backbone_presets = {
|
|
24
22
|
"BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
|
25
23
|
),
|
26
24
|
"params": 1315753984,
|
27
|
-
"official_name": "OPT",
|
28
25
|
"path": "opt",
|
29
|
-
"model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
|
30
26
|
},
|
31
27
|
"kaggle_handle": "kaggle://keras/opt/keras/opt_1.3b_en/2",
|
32
28
|
},
|
@@ -37,9 +33,7 @@ backbone_presets = {
|
|
37
33
|
"BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
|
38
34
|
),
|
39
35
|
"params": 2700000000,
|
40
|
-
"official_name": "OPT",
|
41
36
|
"path": "opt",
|
42
|
-
"model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
|
43
37
|
},
|
44
38
|
"kaggle_handle": "kaggle://keras/opt/keras/opt_2.7b_en/2",
|
45
39
|
},
|
@@ -50,9 +44,7 @@ backbone_presets = {
|
|
50
44
|
"BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
|
51
45
|
),
|
52
46
|
"params": 6700000000,
|
53
|
-
"official_name": "OPT",
|
54
47
|
"path": "opt",
|
55
|
-
"model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
|
56
48
|
},
|
57
49
|
"kaggle_handle": "kaggle://keras/opt/keras/opt_6.7b_en/2",
|
58
50
|
},
|
@@ -48,22 +48,40 @@ class PaliGemmaBackbone(Backbone):
|
|
48
48
|
a two-layer feedforward network for each transformer decoder block.
|
49
49
|
head_dim: int. The size of each attention head in the mixed decoder.
|
50
50
|
vit_patch_size: int. The size of each square patch in the input image.
|
51
|
-
vit_num_heads: int. The number of attention heads for the vision(image)
|
51
|
+
vit_num_heads: int. The number of attention heads for the vision (image)
|
52
52
|
transformer encoder.
|
53
53
|
vit_hidden_dim: int. The size of the transformer hidden state at the end
|
54
54
|
of each vision transformer layer.
|
55
55
|
vit_num_layers: int. The number of vision transformer layers.
|
56
56
|
vit_intermediate_dim: int. The output dimension of the first Dense layer
|
57
|
-
in a two-layer feedforward network for vision transformer.
|
58
|
-
|
59
|
-
|
60
|
-
|
57
|
+
in a two-layer feedforward network for vision transformer. Defaults
|
58
|
+
to `4304`.
|
59
|
+
vit_pooling: `None` or string. The encoded vision embeddings are pooled
|
60
|
+
using the specified polling setting. The accepted values are
|
61
|
+
`"map"`, `"gap"`, `"0"` or `None`. Defaults to `None`.
|
61
62
|
vit_classifier_activation: activation function. The activation that
|
62
63
|
is used for final output classification in the vision transformer.
|
64
|
+
Defaults to `None`.
|
63
65
|
vit_name: string. The name used for vision transformer layers.
|
66
|
+
query_head_dim_normalize: boolean. If `True` normalize the query before
|
67
|
+
attention with `head_dim`. If `False`, normalize the query with
|
68
|
+
`hidden_dim / num_query_heads`. Defaults to `True`.
|
69
|
+
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
|
70
|
+
block. Defaults to `False`.
|
71
|
+
use_post_attention_norm: boolean. Whether to normalize after the attention
|
72
|
+
block. Defaults to `False`.
|
73
|
+
attention_logit_soft_cap: `None` or int. Soft cap for the attention
|
74
|
+
logits. Defaults to `None`.
|
75
|
+
final_logit_soft_cap: `None` or int. Soft cap for the final logits.
|
76
|
+
Defaults to `None`.
|
77
|
+
use_sliding_window_attention: boolean. Whether to use sliding local
|
78
|
+
window attention. Defaults to `False`.
|
79
|
+
sliding_window_size: int. Size of the sliding local window. Defaults to
|
80
|
+
`4096`.
|
64
81
|
layer_norm_epsilon: float. The epsilon value user for every layer norm
|
65
|
-
in all transformer blocks.
|
82
|
+
in all transformer blocks. Defaults to `1e-6`.
|
66
83
|
dropout: float. Dropout probability for the Transformer decoder blocks.
|
84
|
+
Defaults to `0`.
|
67
85
|
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
68
86
|
for the models computations and weights. Note that some
|
69
87
|
computations, such as softmax and layer normalization will always
|
@@ -119,6 +137,13 @@ class PaliGemmaBackbone(Backbone):
|
|
119
137
|
vit_pooling=None,
|
120
138
|
vit_classifier_activation=None,
|
121
139
|
vit_name=None,
|
140
|
+
query_head_dim_normalize=True,
|
141
|
+
use_post_ffw_norm=False,
|
142
|
+
use_post_attention_norm=False,
|
143
|
+
attention_logit_soft_cap=None,
|
144
|
+
final_logit_soft_cap=None,
|
145
|
+
use_sliding_window_attention=False,
|
146
|
+
sliding_window_size=4096,
|
122
147
|
layer_norm_epsilon=1e-6,
|
123
148
|
dropout=0,
|
124
149
|
dtype=None,
|
@@ -136,6 +161,7 @@ class PaliGemmaBackbone(Backbone):
|
|
136
161
|
seed=None,
|
137
162
|
),
|
138
163
|
dtype=dtype,
|
164
|
+
logit_soft_cap=final_logit_soft_cap,
|
139
165
|
name="token_embedding",
|
140
166
|
)
|
141
167
|
# TODO Remove this. Work around for previous serialization bug.
|
@@ -155,12 +181,19 @@ class PaliGemmaBackbone(Backbone):
|
|
155
181
|
)
|
156
182
|
self.transformer_layers = []
|
157
183
|
for i in range(num_layers):
|
184
|
+
sliding_window = use_sliding_window_attention and (i % 2 == 0)
|
158
185
|
layer = PaliGemmaDecoderBlock(
|
159
186
|
hidden_dim=hidden_dim,
|
160
187
|
intermediate_dim=intermediate_dim,
|
161
|
-
num_query_heads=num_query_heads,
|
162
188
|
head_dim=head_dim,
|
189
|
+
num_query_heads=num_query_heads,
|
163
190
|
num_key_value_heads=num_key_value_heads,
|
191
|
+
query_head_dim_normalize=query_head_dim_normalize,
|
192
|
+
use_post_ffw_norm=use_post_ffw_norm,
|
193
|
+
use_post_attention_norm=use_post_attention_norm,
|
194
|
+
logit_soft_cap=attention_logit_soft_cap,
|
195
|
+
use_sliding_window_attention=sliding_window,
|
196
|
+
sliding_window_size=sliding_window_size,
|
164
197
|
dropout=dropout,
|
165
198
|
dtype=dtype,
|
166
199
|
name=f"decoder_block_{i}",
|
@@ -173,7 +206,9 @@ class PaliGemmaBackbone(Backbone):
|
|
173
206
|
)
|
174
207
|
|
175
208
|
# === Functional Model ===
|
176
|
-
image_input =
|
209
|
+
image_input = keras.Input(
|
210
|
+
shape=(image_size, image_size, 3), name="images"
|
211
|
+
)
|
177
212
|
token_id_input = keras.Input(
|
178
213
|
shape=(None,), dtype="int32", name="token_ids"
|
179
214
|
)
|
@@ -219,7 +254,15 @@ class PaliGemmaBackbone(Backbone):
|
|
219
254
|
self.head_dim = head_dim
|
220
255
|
self.layer_norm_epsilon = layer_norm_epsilon
|
221
256
|
self.dropout = dropout
|
222
|
-
#
|
257
|
+
# Gemma2 params
|
258
|
+
self.query_head_dim_normalize = query_head_dim_normalize
|
259
|
+
self.use_post_ffw_norm = use_post_ffw_norm
|
260
|
+
self.use_post_attention_norm = use_post_attention_norm
|
261
|
+
self.attention_logit_soft_cap = attention_logit_soft_cap
|
262
|
+
self.final_logit_soft_cap = final_logit_soft_cap
|
263
|
+
self.sliding_window_size = sliding_window_size
|
264
|
+
self.use_sliding_window_attention = use_sliding_window_attention
|
265
|
+
# ViT params
|
223
266
|
self.vit_patch_size = vit_patch_size
|
224
267
|
self.vit_num_heads = vit_num_heads
|
225
268
|
self.vit_hidden_dim = vit_hidden_dim
|
@@ -243,8 +286,6 @@ class PaliGemmaBackbone(Backbone):
|
|
243
286
|
"hidden_dim": self.hidden_dim,
|
244
287
|
"intermediate_dim": self.intermediate_dim,
|
245
288
|
"head_dim": self.head_dim,
|
246
|
-
"layer_norm_epsilon": self.layer_norm_epsilon,
|
247
|
-
"dropout": self.dropout,
|
248
289
|
"vit_patch_size": self.vit_patch_size,
|
249
290
|
"vit_num_heads": self.vit_num_heads,
|
250
291
|
"vit_hidden_dim": self.vit_hidden_dim,
|
@@ -253,6 +294,15 @@ class PaliGemmaBackbone(Backbone):
|
|
253
294
|
"vit_pooling": self.vit_pooling,
|
254
295
|
"vit_classifier_activation": self.vit_classifier_activation,
|
255
296
|
"vit_name": self.vit_name,
|
297
|
+
"query_head_dim_normalize": self.query_head_dim_normalize,
|
298
|
+
"use_post_ffw_norm": self.use_post_ffw_norm,
|
299
|
+
"use_post_attention_norm": self.use_post_attention_norm,
|
300
|
+
"final_logit_soft_cap": self.final_logit_soft_cap,
|
301
|
+
"attention_logit_soft_cap": self.attention_logit_soft_cap,
|
302
|
+
"sliding_window_size": self.sliding_window_size,
|
303
|
+
"use_sliding_window_attention": self.use_sliding_window_attention,
|
304
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
305
|
+
"dropout": self.dropout,
|
256
306
|
}
|
257
307
|
)
|
258
308
|
return config
|
@@ -31,33 +31,25 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
|
|
31
31
|
the attention layer.
|
32
32
|
num_key_value_heads: int. The number of heads for the key and value
|
33
33
|
projections in the attention layer.
|
34
|
+
query_head_dim_normalize: boolean. If `True` normalize the query before
|
35
|
+
attention with `head_dim`. If `False`, normalize the query with
|
36
|
+
`hidden_dim / num_query_heads`. Defaults to `True`.
|
37
|
+
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
|
38
|
+
block. Defaults to `False`.
|
39
|
+
use_post_attention_norm: boolean. Whether to normalize after the
|
40
|
+
attention block. Defaults to `False`.
|
41
|
+
logit_soft_cap: `None` or int. Soft cap for the attention logits.
|
42
|
+
Defaults to `None`.
|
43
|
+
use_sliding_window_attention: boolean. Whether to use sliding local
|
44
|
+
window attention. Defaults to `False`.
|
45
|
+
sliding_window_size: int. Size of the sliding local window. Defaults to
|
46
|
+
`4096`.
|
34
47
|
layer_norm_epsilon: float. The epsilon hyperparameter used for layer
|
35
|
-
normalization.
|
48
|
+
normalization. Defaults to `1e-6`.
|
36
49
|
dropout: float. The dropout rate for the transformer attention layer.
|
50
|
+
Defaults to `0`.
|
37
51
|
"""
|
38
52
|
|
39
|
-
def __init__(
|
40
|
-
self,
|
41
|
-
hidden_dim,
|
42
|
-
intermediate_dim,
|
43
|
-
head_dim,
|
44
|
-
num_query_heads,
|
45
|
-
num_key_value_heads,
|
46
|
-
layer_norm_epsilon=1e-6,
|
47
|
-
dropout=0,
|
48
|
-
**kwargs,
|
49
|
-
):
|
50
|
-
super().__init__(
|
51
|
-
hidden_dim=hidden_dim,
|
52
|
-
intermediate_dim=intermediate_dim,
|
53
|
-
head_dim=head_dim,
|
54
|
-
num_query_heads=num_query_heads,
|
55
|
-
num_key_value_heads=num_key_value_heads,
|
56
|
-
layer_norm_epsilon=layer_norm_epsilon,
|
57
|
-
dropout=dropout,
|
58
|
-
**kwargs,
|
59
|
-
)
|
60
|
-
|
61
53
|
def call(
|
62
54
|
self,
|
63
55
|
x,
|
@@ -83,6 +75,9 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
|
|
83
75
|
attention_mask=attention_mask,
|
84
76
|
)
|
85
77
|
|
78
|
+
if self.use_post_attention_norm:
|
79
|
+
attention = self.post_attention_norm(attention)
|
80
|
+
|
86
81
|
if self.dropout:
|
87
82
|
attention = self.attention_dropout(attention)
|
88
83
|
|
@@ -94,6 +89,9 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
|
|
94
89
|
x = keras.activations.gelu(x1, approximate=True) * x2
|
95
90
|
x = self.ffw_linear(x)
|
96
91
|
|
92
|
+
if self.use_post_ffw_norm:
|
93
|
+
x = self.post_ffw_norm(x)
|
94
|
+
|
97
95
|
x = x + attention_x
|
98
96
|
|
99
97
|
if cache is not None:
|