keras-hub-nightly 0.19.0.dev202503060350__py3-none-any.whl → 0.20.0.dev202503140353__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. keras_hub/api/layers/__init__.py +3 -0
  2. keras_hub/api/models/__init__.py +5 -4
  3. keras_hub/src/models/cspnet/__init__.py +5 -0
  4. keras_hub/src/models/cspnet/cspnet_backbone.py +1279 -0
  5. keras_hub/src/models/cspnet/cspnet_image_classifier.py +12 -0
  6. keras_hub/src/models/cspnet/cspnet_image_classifier_preprocessor.py +14 -0
  7. keras_hub/src/models/cspnet/cspnet_image_converter.py +8 -0
  8. keras_hub/src/models/cspnet/cspnet_presets.py +16 -0
  9. keras_hub/src/models/gemma/gemma_attention.py +23 -12
  10. keras_hub/src/models/mobilenet/mobilenet_backbone.py +18 -1
  11. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +4 -1
  12. keras_hub/src/models/mobilenet/mobilenet_presets.py +38 -2
  13. keras_hub/src/models/siglip/siglip_presets.py +206 -10
  14. keras_hub/src/models/siglip/siglip_text_encoder.py +7 -1
  15. keras_hub/src/utils/keras_utils.py +32 -0
  16. keras_hub/src/utils/preset_utils.py +1 -0
  17. keras_hub/src/utils/timm/convert_cspnet.py +165 -0
  18. keras_hub/src/utils/timm/convert_mobilenet.py +120 -44
  19. keras_hub/src/utils/timm/preset_loader.py +9 -0
  20. keras_hub/src/version_utils.py +1 -1
  21. {keras_hub_nightly-0.19.0.dev202503060350.dist-info → keras_hub_nightly-0.20.0.dev202503140353.dist-info}/METADATA +1 -1
  22. {keras_hub_nightly-0.19.0.dev202503060350.dist-info → keras_hub_nightly-0.20.0.dev202503140353.dist-info}/RECORD +24 -20
  23. {keras_hub_nightly-0.19.0.dev202503060350.dist-info → keras_hub_nightly-0.20.0.dev202503140353.dist-info}/WHEEL +1 -1
  24. keras_hub/src/models/csp_darknet/__init__.py +0 -0
  25. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +0 -427
  26. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -10
  27. {keras_hub_nightly-0.19.0.dev202503060350.dist-info → keras_hub_nightly-0.20.0.dev202503140353.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,12 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone
3
+ from keras_hub.src.models.cspnet.cspnet_image_classifier_preprocessor import (
4
+ CSPNetImageClassifierPreprocessor,
5
+ )
6
+ from keras_hub.src.models.image_classifier import ImageClassifier
7
+
8
+
9
+ @keras_hub_export("keras_hub.models.CSPNetImageClassifier")
10
+ class CSPNetImageClassifier(ImageClassifier):
11
+ backbone_cls = CSPNetBackbone
12
+ preprocessor_cls = CSPNetImageClassifierPreprocessor
@@ -0,0 +1,14 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone
3
+ from keras_hub.src.models.cspnet.cspnet_image_converter import (
4
+ CSPNetImageConverter,
5
+ )
6
+ from keras_hub.src.models.image_classifier_preprocessor import (
7
+ ImageClassifierPreprocessor,
8
+ )
9
+
10
+
11
+ @keras_hub_export("keras_hub.models.CSPNetImageClassifierPreprocessor")
12
+ class CSPNetImageClassifierPreprocessor(ImageClassifierPreprocessor):
13
+ backbone_cls = CSPNetBackbone
14
+ image_converter_cls = CSPNetImageConverter
@@ -0,0 +1,8 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
+ from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone
4
+
5
+
6
+ @keras_hub_export("keras_hub.layers.CSPNetImageConverter")
7
+ class CSPNetImageConverter(ImageConverter):
8
+ backbone_cls = CSPNetBackbone
@@ -0,0 +1,16 @@
1
+ """CSPNet preset configurations."""
2
+
3
+ backbone_presets = {
4
+ "csp_darknet_53_ra_imagenet": {
5
+ "metadata": {
6
+ "description": (
7
+ "A CSP-DarkNet (Cross-Stage-Partial) image classification model"
8
+ " pre-trained on the Randomly Augmented ImageNet 1k dataset at "
9
+ "a 224x224 resolution."
10
+ ),
11
+ "params": 26652512,
12
+ "path": "cspnet",
13
+ },
14
+ "kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_darknet_53_ra_imagenet/1",
15
+ },
16
+ }
@@ -1,3 +1,5 @@
1
+ import inspect
2
+
1
3
  import keras
2
4
  import numpy as np
3
5
  from keras import ops
@@ -5,6 +7,7 @@ from keras import ops
5
7
  from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
6
8
  from keras_hub.src.utils.keras_utils import clone_initializer
7
9
  from keras_hub.src.utils.keras_utils import has_flash_attention_support
10
+ from keras_hub.src.utils.keras_utils import running_on_tpu
8
11
 
9
12
 
10
13
  class CachedGemmaAttention(keras.layers.Layer):
@@ -103,6 +106,18 @@ class CachedGemmaAttention(keras.layers.Layer):
103
106
  )
104
107
  return x
105
108
 
109
+ def _can_use_flash_attention(self):
110
+ if not has_flash_attention_support():
111
+ return False
112
+ if self.dropout > 0.0:
113
+ return False
114
+ if self.logit_soft_cap is None:
115
+ return True
116
+ sig = inspect.signature(ops.dot_product_attention)
117
+ # We can currently only run soft capped attention for keras >= 3.10
118
+ # and only on TPU.
119
+ return running_on_tpu() and "attn_logits_soft_cap" in sig.parameters
120
+
106
121
  def _compute_attention(
107
122
  self,
108
123
  q,
@@ -118,27 +133,23 @@ class CachedGemmaAttention(keras.layers.Layer):
118
133
  query_normalization = 1 / np.sqrt(
119
134
  self.hidden_dim // self.num_query_heads
120
135
  )
121
- use_dot_product_attention = not (
122
- self.dropout > 0.0 or (len(q.shape) != 4)
123
- )
124
- if has_flash_attention_support() and use_dot_product_attention:
125
- if self.dropout > 0.0:
126
- raise ValueError(
127
- "Flash attention does not support dropout. "
128
- "Please set `dropout` to 0.0."
129
- )
136
+ if self._can_use_flash_attention():
130
137
  if attention_mask is not None:
131
138
  attention_mask = ops.expand_dims(attention_mask, axis=1)
132
139
  attention_mask = ops.cast(attention_mask, dtype="bool")
133
-
134
- attention_output = ops.dot_product_attention(
140
+ # Only pass soft cap if needed as not all keras versions support.
141
+ if self.logit_soft_cap:
142
+ kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
143
+ else:
144
+ kwargs = {}
145
+ return ops.dot_product_attention(
135
146
  query=q,
136
147
  key=k,
137
148
  value=v,
138
149
  mask=attention_mask,
139
150
  scale=query_normalization,
151
+ **kwargs,
140
152
  )
141
- return attention_output
142
153
 
143
154
  q *= ops.cast(query_normalization, dtype=q.dtype)
144
155
  q_shape = ops.shape(q)
@@ -142,6 +142,8 @@ class DepthwiseConvBlock(keras.layers.Layer):
142
142
  signal into before reexciting back out. If (>1) technically, it's an
143
143
  excite & squeeze layer. If this doesn't exist there is no
144
144
  SqueezeExcite layer.
145
+ residual: bool, default False. True if we want a residual connection. If
146
+ False, there is no residual connection.
145
147
  name: str, name of the layer
146
148
  dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
147
149
  to use for the model's computations and weights.
@@ -161,6 +163,7 @@ class DepthwiseConvBlock(keras.layers.Layer):
161
163
  kernel_size=3,
162
164
  stride=2,
163
165
  squeeze_excite_ratio=None,
166
+ residual=False,
164
167
  name=None,
165
168
  dtype=None,
166
169
  **kwargs,
@@ -171,6 +174,7 @@ class DepthwiseConvBlock(keras.layers.Layer):
171
174
  self.kernel_size = kernel_size
172
175
  self.stride = stride
173
176
  self.squeeze_excite_ratio = squeeze_excite_ratio
177
+ self.residual = residual
174
178
  self.name = name
175
179
 
176
180
  channel_axis = (
@@ -256,11 +260,15 @@ class DepthwiseConvBlock(keras.layers.Layer):
256
260
  x = self.batch_normalization1(x)
257
261
  x = self.activation1(x)
258
262
 
259
- if self.se_layer:
263
+ if self.squeeze_excite_ratio:
260
264
  x = self.se_layer(x)
261
265
 
262
266
  x = self.conv2(x)
263
267
  x = self.batch_normalization2(x)
268
+
269
+ if self.residual:
270
+ x = x + inputs
271
+
264
272
  return x
265
273
 
266
274
  def get_config(self):
@@ -272,6 +280,7 @@ class DepthwiseConvBlock(keras.layers.Layer):
272
280
  "kernel_size": self.kernel_size,
273
281
  "stride": self.stride,
274
282
  "squeeze_excite_ratio": self.squeeze_excite_ratio,
283
+ "residual": self.residual,
275
284
  "name": self.name,
276
285
  }
277
286
  )
@@ -675,6 +684,8 @@ class MobileNetBackbone(Backbone):
675
684
  stackwise_padding,
676
685
  output_num_filters,
677
686
  depthwise_filters,
687
+ depthwise_stride,
688
+ depthwise_residual,
678
689
  last_layer_filter,
679
690
  squeeze_and_excite=None,
680
691
  image_shape=(None, None, 3),
@@ -722,7 +733,9 @@ class MobileNetBackbone(Backbone):
722
733
  x = DepthwiseConvBlock(
723
734
  input_num_filters,
724
735
  depthwise_filters,
736
+ stride=depthwise_stride,
725
737
  squeeze_excite_ratio=squeeze_and_excite,
738
+ residual=depthwise_residual,
726
739
  name="block_0",
727
740
  dtype=dtype,
728
741
  )(x)
@@ -768,6 +781,8 @@ class MobileNetBackbone(Backbone):
768
781
  self.input_num_filters = input_num_filters
769
782
  self.output_num_filters = output_num_filters
770
783
  self.depthwise_filters = depthwise_filters
784
+ self.depthwise_stride = depthwise_stride
785
+ self.depthwise_residual = depthwise_residual
771
786
  self.last_layer_filter = last_layer_filter
772
787
  self.squeeze_and_excite = squeeze_and_excite
773
788
  self.input_activation = input_activation
@@ -790,6 +805,8 @@ class MobileNetBackbone(Backbone):
790
805
  "input_num_filters": self.input_num_filters,
791
806
  "output_num_filters": self.output_num_filters,
792
807
  "depthwise_filters": self.depthwise_filters,
808
+ "depthwise_stride": self.depthwise_stride,
809
+ "depthwise_residual": self.depthwise_residual,
793
810
  "last_layer_filter": self.last_layer_filter,
794
811
  "squeeze_and_excite": self.squeeze_and_excite,
795
812
  "input_activation": self.input_activation,
@@ -18,6 +18,7 @@ class MobileNetImageClassifier(ImageClassifier):
18
18
  self,
19
19
  backbone,
20
20
  num_classes,
21
+ num_features=1024,
21
22
  preprocessor=None,
22
23
  head_dtype=None,
23
24
  **kwargs,
@@ -33,7 +34,7 @@ class MobileNetImageClassifier(ImageClassifier):
33
34
  )
34
35
 
35
36
  self.output_conv = keras.layers.Conv2D(
36
- filters=1024,
37
+ filters=num_features,
37
38
  kernel_size=(1, 1),
38
39
  strides=(1, 1),
39
40
  use_bias=True,
@@ -69,6 +70,7 @@ class MobileNetImageClassifier(ImageClassifier):
69
70
 
70
71
  # === Config ===
71
72
  self.num_classes = num_classes
73
+ self.num_features = num_features
72
74
 
73
75
  def get_config(self):
74
76
  # Skip ImageClassifier
@@ -76,6 +78,7 @@ class MobileNetImageClassifier(ImageClassifier):
76
78
  config.update(
77
79
  {
78
80
  "num_classes": self.num_classes,
81
+ "num_features": self.num_features,
79
82
  }
80
83
  )
81
84
  return config
@@ -4,12 +4,48 @@ backbone_presets = {
4
4
  "mobilenet_v3_small_050_imagenet": {
5
5
  "metadata": {
6
6
  "description": (
7
- "Small MobileNet V3 model pre-trained on the ImageNet 1k "
8
- "dataset at a 224x224 resolution."
7
+ "Small Mobilenet V3 model pre-trained on the ImageNet 1k "
8
+ "dataset at a 224x224 resolution. Has half channel multiplier."
9
9
  ),
10
10
  "params": 278784,
11
11
  "path": "mobilenetv3",
12
12
  },
13
13
  "kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_small_050_imagenet/1",
14
14
  },
15
+ "mobilenet_v3_small_100_imagenet": {
16
+ "metadata": {
17
+ "description": (
18
+ "Small Mobilenet V3 model pre-trained on the ImageNet 1k "
19
+ "dataset at a 224x224 resolution. Has baseline channel "
20
+ "multiplier."
21
+ ),
22
+ "params": 939120,
23
+ "path": "mobilenetv3",
24
+ },
25
+ "kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_small_100_imagenet/1",
26
+ },
27
+ "mobilenet_v3_large_100_imagenet": {
28
+ "metadata": {
29
+ "description": (
30
+ "Large Mobilenet V3 model pre-trained on the ImageNet 1k "
31
+ "dataset at a 224x224 resolution. Has baseline channel "
32
+ "multiplier."
33
+ ),
34
+ "params": 2996352,
35
+ "path": "mobilenetv3",
36
+ },
37
+ "kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_large_100_imagenet/1",
38
+ },
39
+ "mobilenet_v3_large_100_imagenet_21k": {
40
+ "metadata": {
41
+ "description": (
42
+ "Large Mobilenet V3 model pre-trained on the ImageNet 21k "
43
+ "dataset at a 224x224 resolution. Has baseline channel "
44
+ "multiplier."
45
+ ),
46
+ "params": 2996352,
47
+ "path": "mobilenetv3",
48
+ },
49
+ "kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_large_100_imagenet_21k/1",
50
+ },
15
51
  }
@@ -10,7 +10,7 @@ backbone_presets = {
10
10
  "params": 203156230,
11
11
  "official_name": "SigLIP",
12
12
  "path": "siglip",
13
- "model_card": "https://www.kaggle.com/models/kerashub/siglip",
13
+ "model_card": "https://huggingface.co/collections/google/siglip-659d5e62f0ae1a57ae0e83ba",
14
14
  },
15
15
  "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip_base_patch16_224/2",
16
16
  },
@@ -22,7 +22,7 @@ backbone_presets = {
22
22
  "params": 203202370,
23
23
  "official_name": "SigLIP",
24
24
  "path": "siglip",
25
- "model_card": "https://www.kaggle.com/models/kerashub/siglip",
25
+ "model_card": "https://huggingface.co/collections/google/siglip-659d5e62f0ae1a57ae0e83ba",
26
26
  },
27
27
  "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip_base_patch16_256/1",
28
28
  },
@@ -34,7 +34,7 @@ backbone_presets = {
34
34
  "params": 203448450,
35
35
  "official_name": "SigLIP",
36
36
  "path": "siglip",
37
- "model_card": "https://www.kaggle.com/models/kerashub/siglip",
37
+ "model_card": "https://huggingface.co/collections/google/siglip-659d5e62f0ae1a57ae0e83ba",
38
38
  },
39
39
  "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip_base_patch16_384/1",
40
40
  },
@@ -46,7 +46,7 @@ backbone_presets = {
46
46
  "params": 203792962,
47
47
  "official_name": "SigLIP",
48
48
  "path": "siglip",
49
- "model_card": "https://www.kaggle.com/models/kerashub/siglip",
49
+ "model_card": "https://huggingface.co/collections/google/siglip-659d5e62f0ae1a57ae0e83ba",
50
50
  },
51
51
  "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip_base_patch16_512/1",
52
52
  },
@@ -58,7 +58,7 @@ backbone_presets = {
58
58
  "params": 652151106,
59
59
  "official_name": "SigLIP",
60
60
  "path": "siglip",
61
- "model_card": "https://www.kaggle.com/models/kerashub/siglip",
61
+ "model_card": "https://huggingface.co/collections/google/siglip-659d5e62f0ae1a57ae0e83ba",
62
62
  },
63
63
  "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip_large_patch16_256/1",
64
64
  },
@@ -70,7 +70,7 @@ backbone_presets = {
70
70
  "params": 652479106,
71
71
  "official_name": "SigLIP",
72
72
  "path": "siglip",
73
- "model_card": "https://www.kaggle.com/models/kerashub/siglip",
73
+ "model_card": "https://huggingface.co/collections/google/siglip-659d5e62f0ae1a57ae0e83ba",
74
74
  },
75
75
  "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip_large_patch16_384/1",
76
76
  },
@@ -83,7 +83,7 @@ backbone_presets = {
83
83
  "params": 877360578,
84
84
  "official_name": "SigLIP",
85
85
  "path": "siglip",
86
- "model_card": "https://www.kaggle.com/models/kerashub/siglip",
86
+ "model_card": "https://huggingface.co/collections/google/siglip-659d5e62f0ae1a57ae0e83ba",
87
87
  },
88
88
  "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip_so400m_patch14_224/2",
89
89
  },
@@ -96,7 +96,7 @@ backbone_presets = {
96
96
  "params": 877961291,
97
97
  "official_name": "SigLIP",
98
98
  "path": "siglip",
99
- "model_card": "https://www.kaggle.com/models/kerashub/siglip",
99
+ "model_card": "https://huggingface.co/collections/google/siglip-659d5e62f0ae1a57ae0e83ba",
100
100
  },
101
101
  "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip_so400m_patch14_384/1",
102
102
  },
@@ -109,7 +109,7 @@ backbone_presets = {
109
109
  "params": 1128759282,
110
110
  "official_name": "SigLIP",
111
111
  "path": "siglip",
112
- "model_card": "https://www.kaggle.com/models/kerashub/siglip",
112
+ "model_card": "https://huggingface.co/collections/google/siglip-659d5e62f0ae1a57ae0e83ba",
113
113
  },
114
114
  "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip_so400m_patch16_256_i18n/1",
115
115
  },
@@ -121,8 +121,204 @@ backbone_presets = {
121
121
  "params": 370626370,
122
122
  "official_name": "SigLIP",
123
123
  "path": "siglip",
124
- "model_card": "https://www.kaggle.com/models/kerashub/siglip",
124
+ "model_card": "https://huggingface.co/collections/google/siglip-659d5e62f0ae1a57ae0e83ba",
125
125
  },
126
126
  "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip_base_patch16_256_multilingual/1",
127
127
  },
128
+ # SigLIP2.
129
+ "siglip2_base_patch16_224": {
130
+ "metadata": {
131
+ "description": (
132
+ "375 million parameter, patch size 16, image size 224, "
133
+ "pre-trained on WebLi."
134
+ ),
135
+ "params": 375188230,
136
+ "official_name": "SigLIP2",
137
+ "path": "siglip",
138
+ "model_card": "https://huggingface.co/collections/google/siglip2-67b5dcef38c175486e240107",
139
+ },
140
+ "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip2_base_patch16_224/1",
141
+ },
142
+ "siglip2_base_patch16_256": {
143
+ "metadata": {
144
+ "description": (
145
+ "375 million parameter, patch size 16, image size 256, "
146
+ "pre-trained on WebLi."
147
+ ),
148
+ "params": 375234370,
149
+ "official_name": "SigLIP2",
150
+ "path": "siglip",
151
+ "model_card": "https://huggingface.co/collections/google/siglip2-67b5dcef38c175486e240107",
152
+ },
153
+ "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip2_base_patch16_256/1",
154
+ },
155
+ "siglip2_base_patch32_256": {
156
+ "metadata": {
157
+ "description": (
158
+ "376 million parameter, patch size 32, image size 256, "
159
+ "pre-trained on WebLi."
160
+ ),
161
+ "params": 376856194,
162
+ "official_name": "SigLIP2",
163
+ "path": "siglip",
164
+ "model_card": "https://huggingface.co/collections/google/siglip2-67b5dcef38c175486e240107",
165
+ },
166
+ "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip2_base_patch32_256/1",
167
+ },
168
+ "siglip2_base_patch16_384": {
169
+ "metadata": {
170
+ "description": (
171
+ "376 million parameter, patch size 16, image size 384, "
172
+ "pre-trained on WebLi."
173
+ ),
174
+ "params": 376856194,
175
+ "official_name": "SigLIP2",
176
+ "path": "siglip",
177
+ "model_card": "https://huggingface.co/collections/google/siglip2-67b5dcef38c175486e240107",
178
+ },
179
+ "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip2_base_patch16_384/1",
180
+ },
181
+ "siglip2_base_patch16_512": {
182
+ "metadata": {
183
+ "description": (
184
+ "375 million parameter, patch size 16, image size 512, "
185
+ "pre-trained on WebLi."
186
+ ),
187
+ "params": 375824962,
188
+ "official_name": "SigLIP2",
189
+ "path": "siglip",
190
+ "model_card": "https://huggingface.co/collections/google/siglip2-67b5dcef38c175486e240107",
191
+ },
192
+ "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip2_base_patch16_512/1",
193
+ },
194
+ "siglip2_large_patch16_256": {
195
+ "metadata": {
196
+ "description": (
197
+ "881 million parameter, patch size 16, image size 256, "
198
+ "pre-trained on WebLi."
199
+ ),
200
+ "params": 881527106,
201
+ "official_name": "SigLIP2",
202
+ "path": "siglip",
203
+ "model_card": "https://huggingface.co/collections/google/siglip2-67b5dcef38c175486e240107",
204
+ },
205
+ "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip2_large_patch16_256/1",
206
+ },
207
+ "siglip2_large_patch16_384": {
208
+ "metadata": {
209
+ "description": (
210
+ "881 million parameter, patch size 16, image size 384, "
211
+ "pre-trained on WebLi."
212
+ ),
213
+ "params": 881855106,
214
+ "official_name": "SigLIP2",
215
+ "path": "siglip",
216
+ "model_card": "https://huggingface.co/collections/google/siglip2-67b5dcef38c175486e240107",
217
+ },
218
+ "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip2_large_patch16_384/1",
219
+ },
220
+ "siglip2_large_patch16_512": {
221
+ "metadata": {
222
+ "description": (
223
+ "882 million parameter, patch size 16, image size 512, "
224
+ "pre-trained on WebLi."
225
+ ),
226
+ "params": 882314306,
227
+ "official_name": "SigLIP2",
228
+ "path": "siglip",
229
+ "model_card": "https://huggingface.co/collections/google/siglip2-67b5dcef38c175486e240107",
230
+ },
231
+ "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip2_large_patch16_512/1",
232
+ },
233
+ "siglip2_giant_opt_patch16_256": {
234
+ "metadata": {
235
+ "description": (
236
+ "1.8 billion parameter, patch size 16, image size 256, "
237
+ "pre-trained on WebLi."
238
+ ),
239
+ "params": 1871394226,
240
+ "official_name": "SigLIP2",
241
+ "path": "siglip",
242
+ "model_card": "https://huggingface.co/collections/google/siglip2-67b5dcef38c175486e240107",
243
+ },
244
+ "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip2_giant_opt_patch16_256/1",
245
+ },
246
+ "siglip2_giant_opt_patch16_384": {
247
+ "metadata": {
248
+ "description": (
249
+ "1.8 billion parameter, patch size 16, image size 384, "
250
+ "pre-trained on WebLi."
251
+ ),
252
+ "params": 1871886066,
253
+ "official_name": "SigLIP2",
254
+ "path": "siglip",
255
+ "model_card": "https://huggingface.co/collections/google/siglip2-67b5dcef38c175486e240107",
256
+ },
257
+ "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip2_giant_opt_patch16_384/1",
258
+ },
259
+ "siglip2_so400m_patch14_224": {
260
+ "metadata": {
261
+ "description": (
262
+ "1.1 billion parameter, patch size 14, image size 224, "
263
+ "shape-optimized version, pre-trained on WebLi."
264
+ ),
265
+ "params": 1135463922,
266
+ "official_name": "SigLIP2",
267
+ "path": "siglip",
268
+ "model_card": "https://huggingface.co/collections/google/siglip2-67b5dcef38c175486e240107",
269
+ },
270
+ "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip2_so400m_patch14_224/1",
271
+ },
272
+ "siglip2_so400m_patch14_384": {
273
+ "metadata": {
274
+ "description": (
275
+ "1.1 billion parameter, patch size 14, image size 224, "
276
+ "shape-optimized version, pre-trained on WebLi."
277
+ ),
278
+ "params": 1136009291,
279
+ "official_name": "SigLIP2",
280
+ "path": "siglip",
281
+ "model_card": "https://huggingface.co/collections/google/siglip2-67b5dcef38c175486e240107",
282
+ },
283
+ "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip2_so400m_patch14_384/1",
284
+ },
285
+ "siglip2_so400m_patch16_256": {
286
+ "metadata": {
287
+ "description": (
288
+ "1.1 billion parameter, patch size 16, image size 256, "
289
+ "shape-optimized version, pre-trained on WebLi."
290
+ ),
291
+ "params": 1135671282,
292
+ "official_name": "SigLIP2",
293
+ "path": "siglip",
294
+ "model_card": "https://huggingface.co/collections/google/siglip2-67b5dcef38c175486e240107",
295
+ },
296
+ "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip2_so400m_patch16_256/1",
297
+ },
298
+ "siglip2_so400m_patch16_384": {
299
+ "metadata": {
300
+ "description": (
301
+ "1.1 billion parameter, patch size 16, image size 384, "
302
+ "shape-optimized version, pre-trained on WebLi."
303
+ ),
304
+ "params": 1136040242,
305
+ "official_name": "SigLIP2",
306
+ "path": "siglip",
307
+ "model_card": "https://huggingface.co/collections/google/siglip2-67b5dcef38c175486e240107",
308
+ },
309
+ "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip2_so400m_patch16_384/1",
310
+ },
311
+ "siglip2_so400m_patch16_512": {
312
+ "metadata": {
313
+ "description": (
314
+ "1.1 billion parameter, patch size 16, image size 512, "
315
+ "shape-optimized version, pre-trained on WebLi."
316
+ ),
317
+ "params": 1136555698,
318
+ "official_name": "SigLIP2",
319
+ "path": "siglip",
320
+ "model_card": "https://huggingface.co/collections/google/siglip2-67b5dcef38c175486e240107",
321
+ },
322
+ "kaggle_handle": "kaggle://kerashub/siglip/keras/siglip2_so400m_patch16_512/1",
323
+ },
128
324
  }
@@ -27,6 +27,8 @@ class SigLIPTextEncoder(Backbone):
27
27
  Defaults to `1e-6`.
28
28
  max_sequence_length: int. The maximum sequence length that this encoder
29
29
  can consume. Defaults to `64`.
30
+ projection_dim: int. The size of the projection in the head. If not
31
+ specified, set to `hidden_dim`. Defaults to `None`.
30
32
  dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
31
33
  for the models computations and weights. Note that some
32
34
  computations, such as softmax and layer normalization will always
@@ -44,10 +46,12 @@ class SigLIPTextEncoder(Backbone):
44
46
  intermediate_activation="gelu_approximate",
45
47
  layer_norm_epsilon=1e-6,
46
48
  max_sequence_length=64,
49
+ projection_dim=None,
47
50
  dtype=None,
48
51
  name=None,
49
52
  **kwargs,
50
53
  ):
54
+ projection_dim = projection_dim or hidden_dim
51
55
  # `prefix` is used to prevent duplicate name when utilizing multiple
52
56
  # SigLIP encoders within a single model.
53
57
  prefix = str(name) + "_" if name is not None else ""
@@ -78,7 +82,7 @@ class SigLIPTextEncoder(Backbone):
78
82
  name=f"{prefix}post_layer_norm",
79
83
  )
80
84
  self.head = layers.Dense(
81
- hidden_dim,
85
+ projection_dim,
82
86
  kernel_initializer=initializers.LecunNormal(),
83
87
  dtype=dtype,
84
88
  name=f"{prefix}head",
@@ -115,6 +119,7 @@ class SigLIPTextEncoder(Backbone):
115
119
  self.intermediate_activation = intermediate_activation
116
120
  self.layer_norm_epsilon = layer_norm_epsilon
117
121
  self.max_sequence_length = max_sequence_length
122
+ self.projection_dim = projection_dim
118
123
 
119
124
  def get_config(self):
120
125
  config = super().get_config()
@@ -129,6 +134,7 @@ class SigLIPTextEncoder(Backbone):
129
134
  "intermediate_activation": self.intermediate_activation,
130
135
  "layer_norm_epsilon": self.layer_norm_epsilon,
131
136
  "max_sequence_length": self.max_sequence_length,
137
+ "projection_dim": self.projection_dim,
132
138
  }
133
139
  )
134
140
  return config
@@ -72,3 +72,35 @@ def has_flash_attention_support():
72
72
  return True
73
73
  else:
74
74
  return False
75
+
76
+
77
+ def running_on_tpu():
78
+ backend = keras.config.backend()
79
+ if backend == "jax":
80
+ import jax
81
+
82
+ devices = jax.devices()
83
+ return any(d.platform == "tpu" for d in devices)
84
+ elif backend == "tensorflow":
85
+ import tensorflow as tf
86
+
87
+ return bool(tf.config.list_logical_devices("TPU"))
88
+ elif backend == "torch":
89
+ return False
90
+
91
+
92
+ def running_on_gpu():
93
+ backend = keras.config.backend()
94
+ if backend == "jax":
95
+ import jax
96
+
97
+ devices = jax.devices()
98
+ return any(d.platform == "gpu" for d in devices)
99
+ elif backend == "tensorflow":
100
+ import tensorflow as tf
101
+
102
+ return bool(tf.config.list_logical_devices("GPU"))
103
+ elif backend == "torch":
104
+ import torch
105
+
106
+ return torch.cuda.is_available()
@@ -622,6 +622,7 @@ class PresetLoader:
622
622
  kwargs["preprocessor"] = self.load_preprocessor(
623
623
  cls.preprocessor_cls,
624
624
  )
625
+
625
626
  return cls(**kwargs)
626
627
 
627
628
  def load_preprocessor(