keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/api/layers/__init__.py +12 -0
  2. keras_hub/api/models/__init__.py +32 -0
  3. keras_hub/src/bounding_box/__init__.py +2 -0
  4. keras_hub/src/bounding_box/converters.py +102 -12
  5. keras_hub/src/layers/modeling/rms_normalization.py +34 -0
  6. keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
  7. keras_hub/src/layers/preprocessing/image_converter.py +5 -0
  8. keras_hub/src/models/albert/albert_presets.py +0 -8
  9. keras_hub/src/models/bart/bart_presets.py +0 -6
  10. keras_hub/src/models/bert/bert_presets.py +0 -20
  11. keras_hub/src/models/bloom/bloom_presets.py +0 -16
  12. keras_hub/src/models/clip/__init__.py +5 -0
  13. keras_hub/src/models/clip/clip_backbone.py +286 -0
  14. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  15. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  16. keras_hub/src/models/clip/clip_presets.py +93 -0
  17. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  18. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  19. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  20. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  21. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
  22. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
  23. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
  24. keras_hub/src/models/densenet/densenet_backbone.py +1 -1
  25. keras_hub/src/models/densenet/densenet_presets.py +0 -6
  26. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
  27. keras_hub/src/models/efficientnet/__init__.py +9 -0
  28. keras_hub/src/models/efficientnet/cba.py +141 -0
  29. keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
  30. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  31. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  32. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  33. keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
  34. keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
  35. keras_hub/src/models/efficientnet/mbconv.py +52 -21
  36. keras_hub/src/models/electra/electra_presets.py +0 -12
  37. keras_hub/src/models/f_net/f_net_presets.py +0 -4
  38. keras_hub/src/models/falcon/falcon_presets.py +0 -2
  39. keras_hub/src/models/flux/__init__.py +5 -0
  40. keras_hub/src/models/flux/flux_layers.py +494 -0
  41. keras_hub/src/models/flux/flux_maths.py +218 -0
  42. keras_hub/src/models/flux/flux_model.py +231 -0
  43. keras_hub/src/models/flux/flux_presets.py +14 -0
  44. keras_hub/src/models/flux/flux_text_to_image.py +142 -0
  45. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  46. keras_hub/src/models/gemma/gemma_presets.py +0 -40
  47. keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
  48. keras_hub/src/models/image_object_detector.py +87 -0
  49. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  50. keras_hub/src/models/image_to_image.py +16 -10
  51. keras_hub/src/models/inpaint.py +20 -13
  52. keras_hub/src/models/llama/llama_backbone.py +1 -1
  53. keras_hub/src/models/llama/llama_presets.py +5 -15
  54. keras_hub/src/models/llama3/llama3_presets.py +0 -8
  55. keras_hub/src/models/mistral/mistral_presets.py +0 -6
  56. keras_hub/src/models/mit/mit_backbone.py +41 -27
  57. keras_hub/src/models/mit/mit_layers.py +9 -7
  58. keras_hub/src/models/mit/mit_presets.py +12 -24
  59. keras_hub/src/models/opt/opt_presets.py +0 -8
  60. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
  61. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  62. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
  63. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
  64. keras_hub/src/models/phi3/phi3_presets.py +0 -4
  65. keras_hub/src/models/resnet/resnet_presets.py +10 -42
  66. keras_hub/src/models/retinanet/__init__.py +5 -0
  67. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  68. keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
  69. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  70. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  71. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  72. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  73. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  74. keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
  75. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  76. keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
  77. keras_hub/src/models/roberta/roberta_presets.py +0 -4
  78. keras_hub/src/models/sam/sam_backbone.py +0 -1
  79. keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
  80. keras_hub/src/models/sam/sam_presets.py +0 -6
  81. keras_hub/src/models/segformer/__init__.py +8 -0
  82. keras_hub/src/models/segformer/segformer_backbone.py +163 -0
  83. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  84. keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
  85. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  86. keras_hub/src/models/segformer/segformer_presets.py +124 -0
  87. keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
  88. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
  89. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
  90. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
  92. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
  93. keras_hub/src/models/t5/t5_backbone.py +5 -4
  94. keras_hub/src/models/t5/t5_presets.py +41 -13
  95. keras_hub/src/models/text_to_image.py +13 -5
  96. keras_hub/src/models/vgg/vgg_backbone.py +1 -1
  97. keras_hub/src/models/vgg/vgg_presets.py +0 -8
  98. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
  99. keras_hub/src/models/whisper/whisper_presets.py +0 -20
  100. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
  101. keras_hub/src/tests/test_case.py +25 -0
  102. keras_hub/src/utils/preset_utils.py +17 -4
  103. keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
  104. keras_hub/src/utils/timm/preset_loader.py +3 -0
  105. keras_hub/src/version_utils.py +1 -1
  106. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
  107. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
  108. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
  109. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,286 @@
1
+ import math
2
+
3
+ from keras import layers
4
+ from keras import ops
5
+
6
+ from keras_hub.src.api_export import keras_hub_export
7
+ from keras_hub.src.models.backbone import Backbone
8
+
9
+
10
+ class CLIPVisionPooler(layers.Layer):
11
+ """The vision pooler layer of CLIP.
12
+
13
+ `CLIPVisionPooler` will extracts the first token (index `0`) from the
14
+ sequence of the vision embeddings as the pooled outputs.
15
+
16
+ Call arguments:
17
+ vision_embeddings: A tensor of shape
18
+ `(batch_size, sequence_length, hidden_dim)`.
19
+ """
20
+
21
+ def call(self, vision_embeddings):
22
+ return vision_embeddings[:, 0, :]
23
+
24
+ def compute_output_shape(self, input_shape):
25
+ return (input_shape[0], input_shape[-1])
26
+
27
+
28
+ class CLIPTextPooler(layers.Layer):
29
+ """The text pooler layer of CLIP.
30
+
31
+ `CLIPTextPooler` extracts the text embeddings at the positions of EOS tokens
32
+ as the pooled outputs.
33
+
34
+ Call arguments:
35
+ text_embeddings: A tensor of shape
36
+ `(batch_size, sequence_length, hidden_dim)`.
37
+ token_ids: A tensor of shape `(batch_size, max_tokens)`, used to
38
+ identify the positions of EOS tokens.
39
+ """
40
+
41
+ def call(self, text_embeddings, token_ids):
42
+ # `keepdims` is not supported in `keras<=3.1`.
43
+ eos_index = ops.argmax(token_ids, axis=-1)
44
+ eos_index = ops.expand_dims(eos_index, axis=-1)
45
+ eos_index = ops.expand_dims(eos_index, axis=-1)
46
+ pooled_outputs = ops.take_along_axis(text_embeddings, eos_index, axis=1)
47
+ return ops.squeeze(pooled_outputs, axis=1)
48
+
49
+ def compute_output_shape(self, input_shape):
50
+ return (input_shape[0], input_shape[-1])
51
+
52
+
53
+ class CLIPHead(layers.Layer):
54
+ """The head layer of CLIP.
55
+
56
+ `CLIPHead` takes `vision_embedding` and `text_embedding` as inputs to
57
+ compute the corresponding logits. Both embeddings are L2 normalized and used
58
+ to compute pairwise cosine similarity. The resulting logits are then scaled
59
+ by a learnable `logit_scale` parameter.
60
+
61
+ Call arguments:
62
+ vision_embedding: A tensor of shape `(batch_size, hidden_dim)`.
63
+ text_embedding: A tensor of shape `(batch_size, hidden_dim)`.
64
+ """
65
+
66
+ def build(self, input_shape):
67
+ self.logit_scale = self.add_weight(
68
+ shape=(),
69
+ initializer=lambda *a, **kw: math.log(1 / 0.07),
70
+ trainable=True,
71
+ dtype=self.variable_dtype,
72
+ name="logit_scale",
73
+ )
74
+
75
+ def call(self, vision_embedding, text_embedding):
76
+ normalized_vision_embedding = ops.sqrt(
77
+ ops.sum(ops.power(vision_embedding, 2), axis=-1, keepdims=True)
78
+ )
79
+ normalized_text_embedding = ops.sqrt(
80
+ ops.sum(ops.power(text_embedding, 2), axis=-1, keepdims=True)
81
+ )
82
+ vision_embedding = vision_embedding / normalized_vision_embedding
83
+ text_embedding = text_embedding / normalized_text_embedding
84
+ logit_scale = ops.exp(self.logit_scale)
85
+ text_logits = (
86
+ ops.matmul(
87
+ text_embedding,
88
+ ops.transpose(vision_embedding),
89
+ )
90
+ * logit_scale
91
+ )
92
+ vision_logits = ops.transpose(text_logits)
93
+ return vision_logits, text_logits
94
+
95
+ def compute_output_shape(
96
+ self, vision_embedding_shape, text_embedding_shape
97
+ ):
98
+ vision_logits_shape = (
99
+ vision_embedding_shape[0],
100
+ text_embedding_shape[0],
101
+ )
102
+ text_logits_shape = (
103
+ text_embedding_shape[0],
104
+ vision_embedding_shape[0],
105
+ )
106
+ return vision_logits_shape, text_logits_shape
107
+
108
+
109
+ @keras_hub_export("keras_hub.models.CLIPBackbone")
110
+ class CLIPBackbone(Backbone):
111
+ """CLIP core network with hyperparameters.
112
+
113
+ This backbone implements the base architecture for Contrastive
114
+ Language-Image Pretraining (CLIP) model. It includes a vision and text
115
+ encoders and the corresponding projection layers. This backbone will output
116
+ the final logit scores corresponding to each image and token input. These
117
+ values are cosine similarities between the corresponding image and text
118
+ features.
119
+
120
+ The default constructor gives a fully customizable, randomly initialized
121
+ CLIP model with any number of layers, heads, and embedding dimensions. To
122
+ load preset architectures and weights, use the `from_preset` constructor.
123
+
124
+ Args:
125
+ vision_encoder: The CLIP vision encoder for encoding the input images.
126
+ text_encoder: The CLIP text encoder for encoding the input tokens.
127
+ projection_dim: int. The size of the projection layer.
128
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
129
+ for the models computations and weights. Note that some
130
+ computations, such as softmax and layer normalization will always
131
+ be done a float32 precision regardless of dtype.
132
+
133
+ Example:
134
+ ```python
135
+ input_data = {
136
+ "images": np.ones(shape=(1, 224, 224, 3), dtype="float32"),
137
+ "token_ids": np.ones(shape=(1, 12), dtype="int32"),
138
+ }
139
+
140
+ # Pretrained CLIP model.
141
+ model = keras_hub.models.CLIPBackbone.from_preset("clip_vit_base_patch32")
142
+ model(input_data)
143
+
144
+ # Randomly initialized CLIP model with custom config.
145
+ vision_encoder = keras_hub.models.CLIPVisionEncoder(
146
+ patch_size=32,
147
+ hidden_dim=768,
148
+ num_layers=8,
149
+ num_heads=8,
150
+ intermediate_dim=2048,
151
+ image_shape=(384, 384, 3),
152
+ )
153
+ text_encoder = keras_hub.models.CLIPTextEncoder(
154
+ vocabulary_size=49408,
155
+ embedding_dim=768,
156
+ hidden_dim=768,
157
+ num_layers=8,
158
+ num_heads=8,
159
+ intermediate_dim=2048,
160
+ )
161
+ model = keras_hub.models.CLIPBackbone(
162
+ vision_encoder=vision_encoder,
163
+ text_encoder=text_encoder,
164
+ projection_dim=256,
165
+ )
166
+ model(input_data)
167
+ ```
168
+ """
169
+
170
+ def __init__(
171
+ self,
172
+ vision_encoder,
173
+ text_encoder,
174
+ projection_dim,
175
+ dtype=None,
176
+ name=None,
177
+ **kwargs,
178
+ ):
179
+ # === Layers ===
180
+ self.vision_encoder = vision_encoder
181
+ self.text_encoder = text_encoder
182
+ self.vision_pooler = CLIPVisionPooler(dtype=dtype, name="vision_pooler")
183
+ self.text_pooler = CLIPTextPooler(dtype=dtype, name="text_pooler")
184
+ self.vision_projection = layers.Dense(
185
+ projection_dim,
186
+ use_bias=False,
187
+ dtype=dtype,
188
+ name="vision_projection",
189
+ )
190
+ self.text_projection = layers.Dense(
191
+ projection_dim,
192
+ use_bias=False,
193
+ dtype=dtype,
194
+ name="text_projection",
195
+ )
196
+ self.clip_head = CLIPHead(dtype=dtype, name="clip_head")
197
+
198
+ # === Functional Model ===
199
+ image_input = layers.Input(
200
+ shape=self.vision_encoder.image_shape, name="images"
201
+ )
202
+ token_id_input = layers.Input(
203
+ shape=(None,), dtype="int32", name="token_ids"
204
+ )
205
+ vision_embeddings = self.get_vision_embeddings(image_input)
206
+ text_embeddings = self.get_text_embeddings(token_id_input)
207
+ vision_logits, text_logits = self.clip_head(
208
+ vision_embeddings, text_embeddings
209
+ )
210
+
211
+ super().__init__(
212
+ inputs={
213
+ "images": image_input,
214
+ "token_ids": token_id_input,
215
+ },
216
+ outputs={
217
+ "vision_logits": vision_logits,
218
+ "text_logits": text_logits,
219
+ },
220
+ dtype=dtype,
221
+ name=name,
222
+ **kwargs,
223
+ )
224
+
225
+ # === Config ===
226
+ self.projection_dim = projection_dim
227
+
228
+ def get_vision_embeddings(self, images):
229
+ """Get the embeddings from the vision encoder.
230
+
231
+ Args:
232
+ images: The input tensor for the vision encoder.
233
+
234
+ Returns:
235
+ The output embeddings obtained by applying projection layer to the
236
+ pooled output of the vision encoder.
237
+ """
238
+ vision_outputs = self.vision_encoder({"images": images})
239
+ vision_outputs = self.vision_pooler(vision_outputs)
240
+ return self.vision_projection(vision_outputs)
241
+
242
+ def get_text_embeddings(self, token_ids):
243
+ """Get the embeddings from the text encoder.
244
+
245
+ Args:
246
+ token_ids: The input int tensor for the text encoder.
247
+
248
+ Returns:
249
+ The output embeddings obtained by applying projection layer to the
250
+ pooled output of the text encoder.
251
+ """
252
+ text_outputs = self.text_encoder({"token_ids": token_ids})
253
+ text_outputs = self.text_pooler(text_outputs, token_ids)
254
+ return self.text_projection(text_outputs)
255
+
256
+ def get_config(self):
257
+ config = super().get_config()
258
+ config.update(
259
+ {
260
+ "vision_encoder": layers.serialize(self.vision_encoder),
261
+ "text_encoder": layers.serialize(self.text_encoder),
262
+ "projection_dim": self.projection_dim,
263
+ }
264
+ )
265
+ return config
266
+
267
+ @classmethod
268
+ def from_config(cls, config, custom_objects=None):
269
+ config = config.copy()
270
+
271
+ # Propagate `dtype` to submodels if needed.
272
+ if "dtype" in config and config["dtype"] is not None:
273
+ dtype_config = config["dtype"]
274
+ if "dtype" not in config["vision_encoder"]["config"]:
275
+ config["vision_encoder"]["config"]["dtype"] = dtype_config
276
+ if "dtype" not in config["text_encoder"]["config"]:
277
+ config["text_encoder"]["config"]["dtype"] = dtype_config
278
+
279
+ # We expect submodels to be instantiated.
280
+ config["vision_encoder"] = layers.deserialize(
281
+ config["vision_encoder"], custom_objects=custom_objects
282
+ )
283
+ config["text_encoder"] = layers.deserialize(
284
+ config["text_encoder"], custom_objects=custom_objects
285
+ )
286
+ return cls(**config)
@@ -7,6 +7,16 @@ def quick_gelu(x):
7
7
  return x * ops.sigmoid(1.702 * x)
8
8
 
9
9
 
10
+ # TODO: Deprecate this in favor of `keras.layers.MultiHeadAttention` once the
11
+ # dtype compatibility issue is resolved.
12
+ class CLIPMultiHeadAttention(layers.MultiHeadAttention):
13
+ def _masked_softmax(self, attention_scores, attention_mask=None):
14
+ attention_scores = super()._masked_softmax(
15
+ attention_scores, attention_mask
16
+ )
17
+ return ops.cast(attention_scores, self._value_dense.compute_dtype)
18
+
19
+
10
20
  class CLIPEncoderBlock(layers.Layer):
11
21
  def __init__(
12
22
  self,
@@ -14,6 +24,7 @@ class CLIPEncoderBlock(layers.Layer):
14
24
  num_heads,
15
25
  intermediate_dim,
16
26
  intermediate_activation="quick_gelu",
27
+ use_causal_mask=True,
17
28
  **kwargs,
18
29
  ):
19
30
  super().__init__(**kwargs)
@@ -26,21 +37,22 @@ class CLIPEncoderBlock(layers.Layer):
26
37
  self.num_heads = num_heads
27
38
  self.intermediate_dim = intermediate_dim
28
39
  self.intermediate_activation = intermediate_activation
40
+ self.use_causal_mask = use_causal_mask
29
41
 
30
42
  if intermediate_activation == "quick_gelu":
31
43
  intermediate_activation = quick_gelu
32
44
 
33
45
  self.layer_norm_1 = layers.LayerNormalization(
34
- epsilon=1e-5, dtype="float32", name="layer_norm_1"
46
+ epsilon=1e-5, dtype=self.dtype_policy, name="layer_norm_1"
35
47
  )
36
- self.attention = layers.MultiHeadAttention(
48
+ self.attention = CLIPMultiHeadAttention(
37
49
  num_heads,
38
50
  hidden_dim // num_heads,
39
51
  dtype=self.dtype_policy,
40
52
  name="attention",
41
53
  )
42
54
  self.layer_norm_2 = layers.LayerNormalization(
43
- epsilon=1e-5, dtype="float32", name="layer_norm_2"
55
+ epsilon=1e-5, dtype=self.dtype_policy, name="layer_norm_2"
44
56
  )
45
57
  self.dense_1 = layers.Dense(
46
58
  self.intermediate_dim, dtype=self.dtype_policy, name="dense_1"
@@ -73,7 +85,9 @@ class CLIPEncoderBlock(layers.Layer):
73
85
  def call(self, x, training=None):
74
86
  residual = x
75
87
  x = self.layer_norm_1(x)
76
- x = self.attention(x, x, x, training=training, use_causal_mask=True)
88
+ x = self.attention(
89
+ x, x, x, training=training, use_causal_mask=self.use_causal_mask
90
+ )
77
91
  x = ops.add(residual, x)
78
92
 
79
93
  residual = x
@@ -91,6 +105,7 @@ class CLIPEncoderBlock(layers.Layer):
91
105
  "num_heads": self.num_heads,
92
106
  "intermediate_dim": self.intermediate_dim,
93
107
  "intermediate_activation": self.intermediate_activation,
108
+ "use_causal_mask": self.use_causal_mask,
94
109
  }
95
110
  )
96
111
  return config
@@ -0,0 +1,8 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
+ from keras_hub.src.models.clip.clip_backbone import CLIPBackbone
4
+
5
+
6
+ @keras_hub_export("keras_hub.layers.CLIPImageConverter")
7
+ class CLIPImageConverter(ImageConverter):
8
+ backbone_cls = CLIPBackbone
@@ -0,0 +1,93 @@
1
+ """CLIP model preset configurations."""
2
+
3
+ # Metadata for loading pretrained model weights.
4
+ backbone_presets = {
5
+ "clip_vit_base_patch16": {
6
+ "metadata": {
7
+ "description": (
8
+ "150 million parameter, 12-layer for vision and 12-layer for "
9
+ "text, patch size of 16, CLIP model."
10
+ ),
11
+ "params": 149620934,
12
+ "path": "clip",
13
+ },
14
+ "kaggle_handle": "kaggle://keras/clip/keras/clip_vit_base_patch16/1",
15
+ },
16
+ "clip_vit_base_patch32": {
17
+ "metadata": {
18
+ "description": (
19
+ "151 million parameter, 12-layer for vision and 12-layer for "
20
+ "text, patch size of 32, CLIP model."
21
+ ),
22
+ "params": 151277363,
23
+ "path": "clip",
24
+ },
25
+ "kaggle_handle": "kaggle://keras/clip/keras/clip_vit_base_patch32/1",
26
+ },
27
+ "clip_vit_large_patch14": {
28
+ "metadata": {
29
+ "description": (
30
+ "428 million parameter, 24-layer for vision and 12-layer for "
31
+ "text, patch size of 14, CLIP model."
32
+ ),
33
+ "params": 427616770,
34
+ "path": "clip",
35
+ },
36
+ "kaggle_handle": "kaggle://keras/clip/keras/clip_vit_large_patch14/1",
37
+ },
38
+ "clip_vit_large_patch14_336": {
39
+ "metadata": {
40
+ "description": (
41
+ "428 million parameter, 24-layer for vision and 12-layer for "
42
+ "text, patch size of 14, image size of 336, CLIP model."
43
+ ),
44
+ "params": 427944770,
45
+ "path": "clip",
46
+ },
47
+ "kaggle_handle": "kaggle://keras/clip/keras/clip_vit_large_patch14_336/1",
48
+ },
49
+ "clip_vit_b_32_laion2b_s34b_b79k": {
50
+ "metadata": {
51
+ "description": (
52
+ "151 million parameter, 12-layer for vision and 12-layer for "
53
+ "text, patch size of 32, Open CLIP model."
54
+ ),
55
+ "params": 151277363,
56
+ "path": "clip",
57
+ },
58
+ "kaggle_handle": "kaggle://keras/clip/keras/clip_vit_b_32_laion2b_s34b_b79k/1",
59
+ },
60
+ "clip_vit_h_14_laion2b_s32b_b79k": {
61
+ "metadata": {
62
+ "description": (
63
+ "986 million parameter, 32-layer for vision and 24-layer for "
64
+ "text, patch size of 14, Open CLIP model."
65
+ ),
66
+ "params": 986109698,
67
+ "path": "clip",
68
+ },
69
+ "kaggle_handle": "kaggle://keras/clip/keras/clip_vit_h_14_laion2b_s32b_b79k/1",
70
+ },
71
+ "clip_vit_g_14_laion2b_s12b_b42k": {
72
+ "metadata": {
73
+ "description": (
74
+ "1.4 billion parameter, 40-layer for vision and 24-layer for "
75
+ "text, patch size of 14, Open CLIP model."
76
+ ),
77
+ "params": 1366678530,
78
+ "path": "clip",
79
+ },
80
+ "kaggle_handle": "kaggle://keras/clip/keras/clip_vit_g_14_laion2b_s12b_b42k/1",
81
+ },
82
+ "clip_vit_bigg_14_laion2b_39b_b160k": {
83
+ "metadata": {
84
+ "description": (
85
+ "2.5 billion parameter, 48-layer for vision and 32-layer for "
86
+ "text, patch size of 14, Open CLIP model."
87
+ ),
88
+ "params": 2539567362,
89
+ "path": "clip",
90
+ },
91
+ "kaggle_handle": "kaggle://keras/clip/keras/clip_vit_bigg_14_laion2b_39b_b160k/1",
92
+ },
93
+ }
@@ -1,5 +1,6 @@
1
1
  from keras import layers
2
2
 
3
+ from keras_hub.src.api_export import keras_hub_export
3
4
  from keras_hub.src.layers.modeling.token_and_position_embedding import (
4
5
  TokenAndPositionEmbedding,
5
6
  )
@@ -7,6 +8,7 @@ from keras_hub.src.models.backbone import Backbone
7
8
  from keras_hub.src.models.clip.clip_encoder_block import CLIPEncoderBlock
8
9
 
9
10
 
11
+ @keras_hub_export("keras_hub.models.CLIPTextEncoder")
10
12
  class CLIPTextEncoder(Backbone):
11
13
  """CLIP text core network with hyperparameters.
12
14
 
@@ -80,7 +82,7 @@ class CLIPTextEncoder(Backbone):
80
82
  for i in range(num_layers)
81
83
  ]
82
84
  self.layer_norm = layers.LayerNormalization(
83
- epsilon=1e-6, dtype="float32", name=f"{prefix}layer_norm"
85
+ epsilon=1e-6, dtype=dtype, name=f"{prefix}layer_norm"
84
86
  )
85
87
 
86
88
  # === Functional Model ===
@@ -106,6 +108,7 @@ class CLIPTextEncoder(Backbone):
106
108
  super().__init__(
107
109
  inputs={"token_ids": token_id_input},
108
110
  outputs=outputs,
111
+ dtype=dtype,
109
112
  name=name,
110
113
  **kwargs,
111
114
  )
@@ -1,4 +1,5 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.clip.clip_backbone import CLIPBackbone
2
3
  from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
3
4
  from keras_hub.src.tokenizers.byte_pair_tokenizer import convert_to_ragged_batch
4
5
  from keras_hub.src.tokenizers.byte_pair_tokenizer import split_strings_for_bpe
@@ -39,11 +40,25 @@ class CLIPTokenizer(BytePairTokenizer):
39
40
  should have one merge rule per line. Every merge rule contains
40
41
  merge entities separated by a space.
41
42
  pad_with_end_token: bool. Whether to pad the output with `end_token`.
42
- """
43
43
 
44
- # TODO: Add example and `backbone_cls` once we have a CLIP model.
44
+ Examples:
45
+
46
+ ```python
47
+ # Unbatched input.
48
+ tokenizer = keras_hub.models.CLIPTokenizer.from_preset(
49
+ "clip_vit_base_patch32"
50
+ )
51
+ tokenizer("The quick brown fox jumped.")
52
+
53
+ # Batched input.
54
+ tokenizer(["The quick brown fox jumped.", "The fox slept."])
55
+
56
+ # Detokenization.
57
+ tokenizer.detokenize(tokenizer("The quick brown fox jumped."))
58
+ ```
59
+ """
45
60
 
46
- backbone_cls = None
61
+ backbone_cls = CLIPBackbone
47
62
 
48
63
  def __init__(
49
64
  self,
@@ -0,0 +1,101 @@
1
+ from keras import layers
2
+ from keras import ops
3
+
4
+ from keras_hub.src.utils.keras_utils import standardize_data_format
5
+
6
+
7
+ class CLIPVisionEmbedding(layers.Layer):
8
+ def __init__(
9
+ self,
10
+ hidden_dim,
11
+ patch_size,
12
+ image_size,
13
+ data_format=None,
14
+ dtype=None,
15
+ **kwargs
16
+ ):
17
+ super().__init__(dtype=dtype, **kwargs)
18
+ self.hidden_dim = int(hidden_dim)
19
+ self.patch_size = int(patch_size)
20
+ self.image_size = int(image_size)
21
+ data_format = standardize_data_format(data_format)
22
+ self.data_format = data_format
23
+ num_patches = (image_size // patch_size) ** 2
24
+ self.num_positions = num_patches + 1
25
+
26
+ self.patch_embedding = layers.Conv2D(
27
+ hidden_dim,
28
+ kernel_size=patch_size,
29
+ strides=patch_size,
30
+ data_format=data_format,
31
+ use_bias=False,
32
+ dtype=dtype,
33
+ name="patch_embedding",
34
+ )
35
+ self.position_embedding = layers.Embedding(
36
+ num_patches + 1, hidden_dim, dtype=dtype, name="position_embedding"
37
+ )
38
+
39
+ def build(self, input_shape):
40
+ self.class_embedding = self.add_weight(
41
+ shape=(self.hidden_dim,),
42
+ initializer="random_normal",
43
+ dtype=self.variable_dtype,
44
+ name="class_embedding",
45
+ )
46
+ self.position_ids = self.add_weight(
47
+ shape=(1, self.num_positions),
48
+ initializer="zeros",
49
+ # Let the backend determine the int dtype. For example, tf
50
+ # requires int64 for correct device placement, whereas jax and torch
51
+ # don't.
52
+ dtype=int,
53
+ trainable=False,
54
+ name="position_ids",
55
+ )
56
+ self.patch_embedding.build(input_shape)
57
+ self.position_embedding.build(self.position_ids.shape)
58
+
59
+ def call(self, inputs, training=None):
60
+ x = inputs
61
+ batch_size = ops.shape(x)[0]
62
+ patch_embeddings = self.patch_embedding(x, training=training)
63
+ if self.data_format == "channels_last":
64
+ patch_embeddings = ops.reshape(
65
+ patch_embeddings, (batch_size, -1, self.hidden_dim)
66
+ )
67
+ else:
68
+ patch_embeddings = ops.reshape(
69
+ patch_embeddings, (batch_size, self.hidden_dim, -1)
70
+ )
71
+ patch_embeddings = ops.transpose(patch_embeddings, (0, 2, 1))
72
+ class_embeddings = ops.expand_dims(self.class_embedding, axis=(0, 1))
73
+ class_embeddings = ops.tile(class_embeddings, (batch_size, 1, 1))
74
+ position_embeddings = self.position_embedding(self.position_ids)
75
+ embeddings = ops.concatenate(
76
+ [class_embeddings, patch_embeddings], axis=1
77
+ )
78
+ return ops.add(embeddings, position_embeddings)
79
+
80
+ def get_config(self):
81
+ config = super().get_config()
82
+ config.update(
83
+ {
84
+ "hidden_dim": self.hidden_dim,
85
+ "patch_size": self.patch_size,
86
+ "image_size": self.image_size,
87
+ }
88
+ )
89
+ return config
90
+
91
+ def compute_output_shape(self, input_shape):
92
+ output_shape = [input_shape[0], None, self.hidden_dim]
93
+ if self.data_format == "channels_last":
94
+ if input_shape[1] is not None and input_shape[2] is not None:
95
+ patch_num = input_shape[1] // self.patch_size
96
+ output_shape[1] = patch_num**2 + 1
97
+ else:
98
+ if input_shape[2] is not None and input_shape[3] is not None:
99
+ patch_num = input_shape[2] // self.patch_size
100
+ output_shape[1] = patch_num**2 + 1
101
+ return output_shape