keras-hub 0.21.1.dev0__py3-none-any.whl → 0.22.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. keras_hub/layers/__init__.py +9 -0
  2. keras_hub/models/__init__.py +47 -0
  3. keras_hub/src/layers/modeling/transformer_encoder.py +6 -3
  4. keras_hub/src/layers/preprocessing/multi_segment_packer.py +17 -3
  5. keras_hub/src/layers/preprocessing/start_end_packer.py +24 -6
  6. keras_hub/src/models/backbone.py +13 -10
  7. keras_hub/src/models/clip/clip_backbone.py +3 -102
  8. keras_hub/src/models/clip/clip_layers.py +295 -0
  9. keras_hub/src/models/clip/clip_preprocessor.py +57 -48
  10. keras_hub/src/models/clip/clip_text_encoder.py +2 -2
  11. keras_hub/src/models/clip/clip_vision_encoder.py +3 -3
  12. keras_hub/src/models/deit/__init__.py +5 -0
  13. keras_hub/src/models/deit/deit_backbone.py +154 -0
  14. keras_hub/src/models/deit/deit_image_classifier.py +171 -0
  15. keras_hub/src/models/deit/deit_image_classifier_preprocessor.py +12 -0
  16. keras_hub/src/models/deit/deit_image_converter.py +8 -0
  17. keras_hub/src/models/deit/deit_layers.py +519 -0
  18. keras_hub/src/models/deit/deit_presets.py +49 -0
  19. keras_hub/src/models/dinov2/__init__.py +5 -0
  20. keras_hub/src/models/dinov2/dinov2_backbone.py +228 -0
  21. keras_hub/src/models/dinov2/dinov2_image_converter.py +8 -0
  22. keras_hub/src/models/dinov2/dinov2_layers.py +886 -0
  23. keras_hub/src/models/dinov2/dinov2_presets.py +89 -0
  24. keras_hub/src/models/esm/__init__.py +5 -0
  25. keras_hub/src/models/esm/esm_attention.py +95 -0
  26. keras_hub/src/models/esm/esm_backbone.py +229 -0
  27. keras_hub/src/models/esm/esm_classifier.py +184 -0
  28. keras_hub/src/models/esm/esm_classifier_preprocessor.py +135 -0
  29. keras_hub/src/models/esm/esm_encoder.py +134 -0
  30. keras_hub/src/models/esm/esm_masked_plm.py +117 -0
  31. keras_hub/src/models/esm/esm_masked_plm_preprocessor.py +143 -0
  32. keras_hub/src/models/esm/esm_presets.py +53 -0
  33. keras_hub/src/models/esm/esm_tokenizer.py +82 -0
  34. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +6 -2
  35. keras_hub/src/models/gemma/gemma_attention.py +1 -1
  36. keras_hub/src/models/gemma3/gemma3_backbone.py +2 -2
  37. keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +1 -1
  38. keras_hub/src/models/hgnetv2/__init__.py +5 -0
  39. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +193 -0
  40. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +148 -0
  41. keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py +216 -0
  42. keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_preprocessor.py +14 -0
  43. keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py +8 -0
  44. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +918 -0
  45. keras_hub/src/models/hgnetv2/hgnetv2_presets.py +58 -0
  46. keras_hub/src/models/llama3/llama3_presets.py +3 -3
  47. keras_hub/src/models/mistral/mistral_presets.py +17 -1
  48. keras_hub/src/models/mixtral/mixtral_presets.py +2 -2
  49. keras_hub/src/models/mobilenet/mobilenet_presets.py +4 -4
  50. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +2 -2
  51. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +2 -2
  52. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +17 -17
  53. keras_hub/src/models/qwen3/__init__.py +5 -0
  54. keras_hub/src/models/qwen3/qwen3_attention.py +369 -0
  55. keras_hub/src/models/qwen3/qwen3_backbone.py +191 -0
  56. keras_hub/src/models/qwen3/qwen3_causal_lm.py +390 -0
  57. keras_hub/src/models/qwen3/qwen3_causal_lm_preprocessor.py +10 -0
  58. keras_hub/src/models/qwen3/qwen3_decoder.py +309 -0
  59. keras_hub/src/models/qwen3/qwen3_layernorm.py +38 -0
  60. keras_hub/src/models/qwen3/qwen3_presets.py +73 -0
  61. keras_hub/src/models/qwen3/qwen3_tokenizer.py +48 -0
  62. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +1 -0
  63. keras_hub/src/models/qwen_moe/qwen_moe_presets.py +2 -2
  64. keras_hub/src/models/roformer_v2/roformer_v2_attention.py +0 -2
  65. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +16 -7
  66. keras_hub/src/models/stable_diffusion_3/mmdit.py +61 -4
  67. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +31 -32
  68. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +1 -0
  69. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +1 -0
  70. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -0
  71. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +6 -2
  72. keras_hub/src/models/vit/vit_backbone.py +31 -11
  73. keras_hub/src/models/vit/vit_image_converter.py +0 -70
  74. keras_hub/src/models/vit/vit_layers.py +33 -18
  75. keras_hub/src/models/vit/vit_presets.py +11 -11
  76. keras_hub/src/utils/keras_utils.py +17 -0
  77. keras_hub/src/utils/preset_utils.py +19 -4
  78. keras_hub/src/utils/tensor_utils.py +14 -0
  79. keras_hub/src/utils/transformers/convert_deit.py +155 -0
  80. keras_hub/src/utils/transformers/convert_dinov2.py +180 -0
  81. keras_hub/src/utils/transformers/convert_esm.py +159 -0
  82. keras_hub/src/utils/transformers/convert_llama3.py +6 -0
  83. keras_hub/src/utils/transformers/convert_qwen3.py +145 -0
  84. keras_hub/src/utils/transformers/export/gemma.py +89 -0
  85. keras_hub/src/utils/transformers/export/hf_exporter.py +98 -0
  86. keras_hub/src/utils/transformers/preset_loader.py +14 -2
  87. keras_hub/src/version.py +1 -1
  88. keras_hub/tokenizers/__init__.py +1 -0
  89. {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dev0.dist-info}/METADATA +4 -4
  90. {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dev0.dist-info}/RECORD +92 -48
  91. keras_hub/src/models/clip/clip_encoder_block.py +0 -111
  92. keras_hub/src/models/clip/clip_vision_embedding.py +0 -101
  93. {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dev0.dist-info}/WHEEL +0 -0
  94. {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,295 @@
1
+ import math
2
+
3
+ from keras import layers
4
+ from keras import ops
5
+
6
+ from keras_hub.src.utils.keras_utils import standardize_data_format
7
+
8
+
9
+ def quick_gelu(x):
10
+ return x * ops.sigmoid(1.702 * x)
11
+
12
+
13
+ class CLIPVisionEmbedding(layers.Layer):
14
+ def __init__(
15
+ self,
16
+ hidden_dim,
17
+ patch_size,
18
+ image_size,
19
+ data_format=None,
20
+ dtype=None,
21
+ **kwargs,
22
+ ):
23
+ super().__init__(dtype=dtype, **kwargs)
24
+ self.hidden_dim = int(hidden_dim)
25
+ self.patch_size = int(patch_size)
26
+ self.image_size = int(image_size)
27
+ data_format = standardize_data_format(data_format)
28
+ self.data_format = data_format
29
+ num_patches = (image_size // patch_size) ** 2
30
+ self.num_positions = num_patches + 1
31
+
32
+ self.patch_embedding = layers.Conv2D(
33
+ hidden_dim,
34
+ kernel_size=patch_size,
35
+ strides=patch_size,
36
+ data_format=data_format,
37
+ use_bias=False,
38
+ dtype=dtype,
39
+ name="patch_embedding",
40
+ )
41
+ self.position_embedding = layers.Embedding(
42
+ num_patches + 1, hidden_dim, dtype=dtype, name="position_embedding"
43
+ )
44
+
45
+ def build(self, input_shape):
46
+ self.class_embedding = self.add_weight(
47
+ shape=(self.hidden_dim,),
48
+ initializer="random_normal",
49
+ dtype=self.variable_dtype,
50
+ name="class_embedding",
51
+ )
52
+ self.position_ids = self.add_weight(
53
+ shape=(1, self.num_positions),
54
+ initializer="zeros",
55
+ # Let the backend determine the int dtype. For example, tf
56
+ # requires int64 for correct device placement, whereas jax and torch
57
+ # don't.
58
+ dtype=int,
59
+ trainable=False,
60
+ name="position_ids",
61
+ )
62
+ self.patch_embedding.build(input_shape)
63
+ self.position_embedding.build(self.position_ids.shape)
64
+
65
+ def call(self, inputs, training=None):
66
+ x = inputs
67
+ batch_size = ops.shape(x)[0]
68
+ patch_embeddings = self.patch_embedding(x, training=training)
69
+ if self.data_format == "channels_last":
70
+ patch_embeddings = ops.reshape(
71
+ patch_embeddings, (batch_size, -1, self.hidden_dim)
72
+ )
73
+ else:
74
+ patch_embeddings = ops.reshape(
75
+ patch_embeddings, (batch_size, self.hidden_dim, -1)
76
+ )
77
+ patch_embeddings = ops.transpose(patch_embeddings, (0, 2, 1))
78
+ class_embeddings = ops.expand_dims(self.class_embedding, axis=(0, 1))
79
+ class_embeddings = ops.tile(class_embeddings, (batch_size, 1, 1))
80
+ position_embeddings = self.position_embedding(self.position_ids)
81
+ embeddings = ops.concatenate(
82
+ [class_embeddings, patch_embeddings], axis=1
83
+ )
84
+ return ops.add(embeddings, position_embeddings)
85
+
86
+ def get_config(self):
87
+ config = super().get_config()
88
+ config.update(
89
+ {
90
+ "hidden_dim": self.hidden_dim,
91
+ "patch_size": self.patch_size,
92
+ "image_size": self.image_size,
93
+ }
94
+ )
95
+ return config
96
+
97
+ def compute_output_shape(self, input_shape):
98
+ output_shape = [input_shape[0], None, self.hidden_dim]
99
+ if self.data_format == "channels_last":
100
+ if input_shape[1] is not None and input_shape[2] is not None:
101
+ patch_num = input_shape[1] // self.patch_size
102
+ output_shape[1] = patch_num**2 + 1
103
+ else:
104
+ if input_shape[2] is not None and input_shape[3] is not None:
105
+ patch_num = input_shape[2] // self.patch_size
106
+ output_shape[1] = patch_num**2 + 1
107
+ return output_shape
108
+
109
+
110
+ class CLIPEncoderLayer(layers.Layer):
111
+ def __init__(
112
+ self,
113
+ hidden_dim,
114
+ num_heads,
115
+ intermediate_dim,
116
+ intermediate_activation="quick_gelu",
117
+ use_causal_mask=True,
118
+ **kwargs,
119
+ ):
120
+ super().__init__(**kwargs)
121
+ if hidden_dim % num_heads != 0:
122
+ raise ValueError(
123
+ "`hidden_dim` must be divisible by `num_heads`. "
124
+ f"Received: hidden_dim={hidden_dim}, num_heads={num_heads}"
125
+ )
126
+ self.hidden_dim = hidden_dim
127
+ self.num_heads = num_heads
128
+ self.intermediate_dim = intermediate_dim
129
+ self.intermediate_activation = intermediate_activation
130
+ self.use_causal_mask = use_causal_mask
131
+
132
+ if intermediate_activation == "quick_gelu":
133
+ intermediate_activation = quick_gelu
134
+
135
+ self.layer_norm_1 = layers.LayerNormalization(
136
+ epsilon=1e-5, dtype=self.dtype_policy, name="layer_norm_1"
137
+ )
138
+ self.attention = layers.MultiHeadAttention(
139
+ num_heads,
140
+ hidden_dim // num_heads,
141
+ dtype=self.dtype_policy,
142
+ name="attention",
143
+ )
144
+ self.layer_norm_2 = layers.LayerNormalization(
145
+ epsilon=1e-5, dtype=self.dtype_policy, name="layer_norm_2"
146
+ )
147
+ self.dense_1 = layers.Dense(
148
+ self.intermediate_dim, dtype=self.dtype_policy, name="dense_1"
149
+ )
150
+ self.activation = layers.Activation(
151
+ intermediate_activation, dtype=self.dtype_policy, name="activation"
152
+ )
153
+ self.dense_2 = layers.Dense(
154
+ self.hidden_dim, dtype=self.dtype_policy, name="dense_2"
155
+ )
156
+
157
+ def build(self, input_shape):
158
+ self.layer_norm_1.build(input_shape)
159
+ self.attention.build(input_shape, input_shape, input_shape)
160
+ self.layer_norm_2.build(input_shape)
161
+ self.dense_1.build(input_shape)
162
+ input_shape = self.dense_1.compute_output_shape(input_shape)
163
+ self.dense_2.build(input_shape)
164
+
165
+ def compute_output_shape(self, inputs_shape):
166
+ outputs_shape = list(inputs_shape)
167
+ outputs_shape[-1] = self.hidden_dim
168
+ return outputs_shape
169
+
170
+ def call(self, x, training=None):
171
+ residual = x
172
+ x = self.layer_norm_1(x)
173
+ x = self.attention(
174
+ x, x, x, training=training, use_causal_mask=self.use_causal_mask
175
+ )
176
+ x = ops.add(residual, x)
177
+
178
+ residual = x
179
+ x = self.dense_1(self.layer_norm_2(residual))
180
+ x = self.activation(x)
181
+ x = self.dense_2(x)
182
+ x = ops.add(residual, x)
183
+ return x
184
+
185
+ def get_config(self):
186
+ config = super().get_config()
187
+ config.update(
188
+ {
189
+ "hidden_dim": self.hidden_dim,
190
+ "num_heads": self.num_heads,
191
+ "intermediate_dim": self.intermediate_dim,
192
+ "intermediate_activation": self.intermediate_activation,
193
+ "use_causal_mask": self.use_causal_mask,
194
+ }
195
+ )
196
+ return config
197
+
198
+
199
+ class CLIPVisionPooler(layers.Layer):
200
+ """The vision pooler layer of CLIP.
201
+
202
+ `CLIPVisionPooler` will extracts the first token (index `0`) from the
203
+ sequence of the vision embeddings as the pooled outputs.
204
+
205
+ Call arguments:
206
+ vision_embeddings: A tensor of shape
207
+ `(batch_size, sequence_length, hidden_dim)`.
208
+ """
209
+
210
+ def call(self, vision_embeddings):
211
+ return vision_embeddings[:, 0, :]
212
+
213
+ def compute_output_shape(self, input_shape):
214
+ return (input_shape[0], input_shape[-1])
215
+
216
+
217
+ class CLIPTextPooler(layers.Layer):
218
+ """The text pooler layer of CLIP.
219
+
220
+ `CLIPTextPooler` extracts the text embeddings at the positions of EOS tokens
221
+ as the pooled outputs.
222
+
223
+ Call arguments:
224
+ text_embeddings: A tensor of shape
225
+ `(batch_size, sequence_length, hidden_dim)`.
226
+ token_ids: A tensor of shape `(batch_size, max_tokens)`, used to
227
+ identify the positions of EOS tokens.
228
+ """
229
+
230
+ def call(self, text_embeddings, token_ids):
231
+ # `keepdims` is not supported in `keras<=3.1`.
232
+ eos_index = ops.argmax(token_ids, axis=-1)
233
+ eos_index = ops.expand_dims(eos_index, axis=-1)
234
+ eos_index = ops.expand_dims(eos_index, axis=-1)
235
+ pooled_outputs = ops.take_along_axis(text_embeddings, eos_index, axis=1)
236
+ return ops.squeeze(pooled_outputs, axis=1)
237
+
238
+ def compute_output_shape(self, input_shape):
239
+ return (input_shape[0], input_shape[-1])
240
+
241
+
242
+ class CLIPHead(layers.Layer):
243
+ """The head layer of CLIP.
244
+
245
+ `CLIPHead` takes `vision_embedding` and `text_embedding` as inputs to
246
+ compute the corresponding logits. Both embeddings are L2 normalized and used
247
+ to compute pairwise cosine similarity. The resulting logits are then scaled
248
+ by a learnable `logit_scale` parameter.
249
+
250
+ Call arguments:
251
+ vision_embedding: A tensor of shape `(batch_size, hidden_dim)`.
252
+ text_embedding: A tensor of shape `(batch_size, hidden_dim)`.
253
+ """
254
+
255
+ def build(self, input_shape):
256
+ self.logit_scale = self.add_weight(
257
+ shape=(),
258
+ initializer=lambda *a, **kw: math.log(1 / 0.07),
259
+ trainable=True,
260
+ dtype=self.variable_dtype,
261
+ name="logit_scale",
262
+ )
263
+
264
+ def call(self, vision_embedding, text_embedding):
265
+ normalized_vision_embedding = ops.sqrt(
266
+ ops.sum(ops.power(vision_embedding, 2), axis=-1, keepdims=True)
267
+ )
268
+ normalized_text_embedding = ops.sqrt(
269
+ ops.sum(ops.power(text_embedding, 2), axis=-1, keepdims=True)
270
+ )
271
+ vision_embedding = vision_embedding / normalized_vision_embedding
272
+ text_embedding = text_embedding / normalized_text_embedding
273
+ logit_scale = ops.exp(self.logit_scale)
274
+ text_logits = (
275
+ ops.matmul(
276
+ text_embedding,
277
+ ops.transpose(vision_embedding),
278
+ )
279
+ * logit_scale
280
+ )
281
+ vision_logits = ops.transpose(text_logits)
282
+ return vision_logits, text_logits
283
+
284
+ def compute_output_shape(
285
+ self, vision_embedding_shape, text_embedding_shape
286
+ ):
287
+ vision_logits_shape = (
288
+ vision_embedding_shape[0],
289
+ text_embedding_shape[0],
290
+ )
291
+ text_logits_shape = (
292
+ text_embedding_shape[0],
293
+ vision_embedding_shape[0],
294
+ )
295
+ return vision_logits_shape, text_logits_shape
@@ -2,8 +2,10 @@ import keras
2
2
 
3
3
  from keras_hub.src.api_export import keras_hub_export
4
4
  from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
5
+ from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
6
+ from keras_hub.src.models.clip.clip_backbone import CLIPBackbone
7
+ from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter
5
8
  from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer
6
- from keras_hub.src.models.preprocessor import Preprocessor
7
9
  from keras_hub.src.utils.tensor_utils import preprocessing_function
8
10
 
9
11
  try:
@@ -13,32 +15,18 @@ except ImportError:
13
15
 
14
16
 
15
17
  @keras_hub_export("keras_hub.models.CLIPPreprocessor")
16
- class CLIPPreprocessor(Preprocessor):
17
- """CLIP preprocessing layer which tokenizes and packs inputs.
18
+ class CLIPPreprocessor(CausalLMPreprocessor):
19
+ """CLIP preprocessor.
18
20
 
19
21
  This preprocessing layer will do 2 things:
20
22
 
21
- - Tokenize the inputs using the `tokenizer`.
22
- - Construct a dictionary with keys `"token_ids"`, `"padding_mask"`.
23
-
24
- This layer can be used directly with `tf.data.Dataset.map` to preprocess
25
- string data in the `(x, y, sample_weight)` format used by
26
- `keras.Model.fit`.
27
-
28
- The call method of this layer accepts three arguments, `x`, `y`, and
29
- `sample_weight`. `x` can be a python string or tensor representing a single
30
- segment, a list of python strings representing a batch of single segments,
31
- or a list of tensors representing multiple segments to be packed together.
32
- `y` and `sample_weight` are both optional, can have any format, and will be
33
- passed through unaltered.
34
-
35
- `CLIPPreprocessor` forces the input to have only one segment, as CLIP is
36
- mainly used for generation tasks. For tasks having multi-segment inputs
37
- like "glue/mnli", please use a model designed for classification purposes
38
- such as BERT or RoBERTa.
23
+ This preprocessing layer is meant for use with
24
+ `keras_hub.models.CLIPBackbone`. By default, it will take in batches of
25
+ strings and images, and return token ids and resized images.
39
26
 
40
27
  Args:
41
28
  tokenizer: A `keras_hub.models.CLIPTokenizer` instance.
29
+ image_converter: A `keras_hub.models.CLIPImageConverter` instance.
42
30
  sequence_length: The length of the packed inputs.
43
31
  add_start_token: If `True`, the preprocessor will prepend the tokenizer
44
32
  start token to each input sequence.
@@ -47,32 +35,62 @@ class CLIPPreprocessor(Preprocessor):
47
35
  to_lower: bool. Whether to lower the inputs.
48
36
 
49
37
  Call arguments:
50
- x: A string, `tf.Tensor` or list of python strings.
51
- y: Any label data. Will be passed through unaltered.
52
- sample_weight: Any label weight data. Will be passed through unaltered.
38
+ x: A dict with `"prompts"` and `"images"` keys, where `"prompts"` is
39
+ `tf.Tensor` or list of python strings and `"images"` are the image
40
+ tensors.
41
+ y: Label data. Should always be `None` since SigLIP doesn't need the
42
+ label to calculate the loss.
43
+ sample_weight: Label weights.
53
44
  sequence_length: Pass to override the configured `sequence_length` of
54
45
  the layer.
55
- """
56
46
 
57
- # TODO: Add example once we have a CLIP model.
47
+ Examples:
48
+ ```python
49
+ # Load the preprocessor from a preset.
50
+ preprocessor = keras_hub.models.CLIPPreprocessor.from_preset(
51
+ "clip_vit_base_patch16"
52
+ )
53
+
54
+ # Tokenize the sentence and preprocess the image.
55
+ preprocessor(
56
+ {
57
+ "prompts": "The quick brown fox jumped.",
58
+ "images": np.ones(shape=(123, 123, 3)),
59
+ }
60
+ )
61
+
62
+ # Tokenize a batch of sentences and preprocess a batch of images.
63
+ preprocessor(
64
+ {
65
+ "prompts": ["The quick brown fox jumped.", "The fox slept."],
66
+ "images": np.ones(shape=(2, 123, 123, 3)),
67
+ }
68
+ )
69
+ ```
70
+ """
58
71
 
72
+ backbone_cls = CLIPBackbone
59
73
  tokenizer_cls = CLIPTokenizer
74
+ image_converter_cls = CLIPImageConverter
60
75
 
61
76
  def __init__(
62
77
  self,
63
78
  tokenizer,
79
+ image_converter=None,
64
80
  sequence_length=77,
65
81
  add_start_token=True,
66
82
  add_end_token=True,
67
83
  to_lower=True,
68
84
  **kwargs,
69
85
  ):
70
- super().__init__(**kwargs)
71
- self.tokenizer = tokenizer
72
- self.packer = None
73
- self.sequence_length = sequence_length
74
- self.add_start_token = add_start_token
75
- self.add_end_token = add_end_token
86
+ super().__init__(
87
+ tokenizer=tokenizer,
88
+ sequence_length=sequence_length,
89
+ add_start_token=add_start_token,
90
+ add_end_token=add_end_token,
91
+ **kwargs,
92
+ )
93
+ self.image_converter = image_converter
76
94
  self.to_lower = to_lower
77
95
 
78
96
  def build(self, input_shape):
@@ -96,10 +114,14 @@ class CLIPPreprocessor(Preprocessor):
96
114
  sequence_length=None,
97
115
  ):
98
116
  sequence_length = sequence_length or self.sequence_length
117
+ images, prompts = x["images"], x["prompts"]
99
118
  if self.to_lower:
100
- x = tf.strings.lower(x)
119
+ prompts = tf.strings.lower(prompts)
120
+ prompts = self.tokenizer(prompts)
121
+ if images is not None and self.image_converter:
122
+ images = self.image_converter(images)
101
123
  token_ids, padding_mask = self.packer(
102
- self.tokenizer(x),
124
+ prompts,
103
125
  sequence_length=sequence_length,
104
126
  add_start_value=self.add_start_token,
105
127
  add_end_value=self.add_end_token,
@@ -107,6 +129,7 @@ class CLIPPreprocessor(Preprocessor):
107
129
  x = {
108
130
  "token_ids": token_ids,
109
131
  "padding_mask": padding_mask,
132
+ "images": images,
110
133
  }
111
134
  return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
112
135
 
@@ -114,21 +137,7 @@ class CLIPPreprocessor(Preprocessor):
114
137
  config = super().get_config()
115
138
  config.update(
116
139
  {
117
- "sequence_length": self.sequence_length,
118
- "add_start_token": self.add_start_token,
119
- "add_end_token": self.add_end_token,
120
140
  "to_lower": self.to_lower,
121
141
  }
122
142
  )
123
143
  return config
124
-
125
- @property
126
- def sequence_length(self):
127
- """The padded length of model input sequences."""
128
- return self._sequence_length
129
-
130
- @sequence_length.setter
131
- def sequence_length(self, value):
132
- self._sequence_length = value
133
- if self.packer is not None:
134
- self.packer.sequence_length = value
@@ -5,7 +5,7 @@ from keras_hub.src.layers.modeling.token_and_position_embedding import (
5
5
  TokenAndPositionEmbedding,
6
6
  )
7
7
  from keras_hub.src.models.backbone import Backbone
8
- from keras_hub.src.models.clip.clip_encoder_block import CLIPEncoderBlock
8
+ from keras_hub.src.models.clip.clip_layers import CLIPEncoderLayer
9
9
 
10
10
 
11
11
  @keras_hub_export("keras_hub.models.CLIPTextEncoder")
@@ -71,7 +71,7 @@ class CLIPTextEncoder(Backbone):
71
71
  name=f"{prefix}embedding",
72
72
  )
73
73
  self.encoder_layers = [
74
- CLIPEncoderBlock(
74
+ CLIPEncoderLayer(
75
75
  hidden_dim,
76
76
  num_heads,
77
77
  intermediate_dim,
@@ -2,8 +2,8 @@ from keras import layers
2
2
 
3
3
  from keras_hub.src.api_export import keras_hub_export
4
4
  from keras_hub.src.models.backbone import Backbone
5
- from keras_hub.src.models.clip.clip_encoder_block import CLIPEncoderBlock
6
- from keras_hub.src.models.clip.clip_vision_embedding import CLIPVisionEmbedding
5
+ from keras_hub.src.models.clip.clip_layers import CLIPEncoderLayer
6
+ from keras_hub.src.models.clip.clip_layers import CLIPVisionEmbedding
7
7
  from keras_hub.src.utils.keras_utils import standardize_data_format
8
8
 
9
9
 
@@ -91,7 +91,7 @@ class CLIPVisionEncoder(Backbone):
91
91
  epsilon=1e-5, dtype=dtype, name=f"{prefix}pre_layer_norm"
92
92
  )
93
93
  self.encoder_layers = [
94
- CLIPEncoderBlock(
94
+ CLIPEncoderLayer(
95
95
  hidden_dim,
96
96
  num_heads,
97
97
  intermediate_dim,
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.deit.deit_backbone import DeiTBackbone
2
+ from keras_hub.src.models.deit.deit_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, DeiTBackbone)
@@ -0,0 +1,154 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.backbone import Backbone
5
+ from keras_hub.src.models.deit.deit_layers import DeiTEmbeddings
6
+ from keras_hub.src.models.deit.deit_layers import DeiTEncoder
7
+ from keras_hub.src.utils.keras_utils import standardize_data_format
8
+
9
+
10
+ @keras_hub_export("keras_hub.models.DeiTBackbone")
11
+ class DeiTBackbone(Backbone):
12
+ """DeiT backbone.
13
+
14
+ This backbone implements the Data-efficient Image Transformer (DeiT)
15
+ architecture as described in [Training data-efficient image
16
+ transformers & distillation through attention]
17
+ (https://arxiv.org/abs/2012.12877).
18
+
19
+ Args:
20
+ image_shape: A tuple or list of 3 integers representing the shape of the
21
+ input image `(height, width, channels)`.
22
+ patch_size: tuple or int. The size of each image patch. If an int is
23
+ provided, it will be used for both height and width. The input image
24
+ will be split into patches of shape `(patch_size_h, patch_size_w)`.
25
+ num_layers: int. The number of transformer encoder layers.
26
+ num_heads: int. The number of attention heads in each Transformer
27
+ encoder layer.
28
+ hidden_dim: int. The dimensionality of the hidden representations.
29
+ intermediate_dim: int. The dimensionality of the intermediate MLP layer
30
+ in each Transformer encoder layer.
31
+ dropout_rate: float. The dropout rate for the Transformer encoder
32
+ layers.
33
+ attention_dropout: float. The dropout rate for the attention mechanism
34
+ in each Transformer encoder layer.
35
+ layer_norm_epsilon: float. Value used for numerical stability in layer
36
+ normalization.
37
+ use_mha_bias: bool. Whether to use bias in the multi-head attention
38
+ layers.
39
+ data_format: str. `"channels_last"` or `"channels_first"`, specifying
40
+ the data format for the input image. If `None`, defaults to
41
+ `"channels_last"`.
42
+ dtype: The dtype of the layer weights. Defaults to None.
43
+ **kwargs: Additional keyword arguments to be passed to the parent
44
+ `Backbone` class.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ image_shape,
50
+ patch_size,
51
+ num_layers,
52
+ num_heads,
53
+ hidden_dim,
54
+ intermediate_dim,
55
+ dropout_rate=0.0,
56
+ attention_dropout=0.0,
57
+ layer_norm_epsilon=1e-6,
58
+ use_mha_bias=True,
59
+ data_format=None,
60
+ dtype=None,
61
+ **kwargs,
62
+ ):
63
+ # === Laters ===
64
+ data_format = standardize_data_format(data_format)
65
+ if isinstance(patch_size, int):
66
+ patch_size = (patch_size, patch_size)
67
+ h_axis, w_axis, channels_axis = (
68
+ (-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3)
69
+ )
70
+ # Check that the input image is well specified.
71
+ if image_shape[h_axis] is None or image_shape[w_axis] is None:
72
+ raise ValueError(
73
+ f"Image shape must have defined height and width. Found `None` "
74
+ f"at index {h_axis} (height) or {w_axis} (width). "
75
+ f"Image shape: {image_shape}"
76
+ )
77
+ # Check that image dimensions be divisible by patch size
78
+ if image_shape[h_axis] % patch_size[0] != 0:
79
+ raise ValueError(
80
+ f"Input height {image_shape[h_axis]} should be divisible by "
81
+ f"patch size {patch_size}."
82
+ )
83
+ if image_shape[w_axis] % patch_size[1] != 0:
84
+ raise ValueError(
85
+ f"Input height {image_shape[w_axis]} should be divisible by "
86
+ f"patch size {patch_size}."
87
+ )
88
+
89
+ num_channels = image_shape[channels_axis]
90
+
91
+ # === Functional Model ===
92
+ inputs = keras.layers.Input(shape=image_shape)
93
+
94
+ x = DeiTEmbeddings(
95
+ image_size=(image_shape[h_axis], image_shape[w_axis]),
96
+ patch_size=patch_size,
97
+ hidden_dim=hidden_dim,
98
+ num_channels=num_channels,
99
+ data_format=data_format,
100
+ dropout_rate=dropout_rate,
101
+ dtype=dtype,
102
+ name="deit_patching_and_embedding",
103
+ )(inputs)
104
+
105
+ output, _, _ = DeiTEncoder(
106
+ num_layers=num_layers,
107
+ num_heads=num_heads,
108
+ hidden_dim=hidden_dim,
109
+ intermediate_dim=intermediate_dim,
110
+ use_mha_bias=use_mha_bias,
111
+ dropout_rate=dropout_rate,
112
+ attention_dropout=attention_dropout,
113
+ layer_norm_epsilon=layer_norm_epsilon,
114
+ dtype=dtype,
115
+ name="deit_encoder",
116
+ )(x)
117
+
118
+ super().__init__(
119
+ inputs=inputs,
120
+ outputs=output,
121
+ dtype=dtype,
122
+ **kwargs,
123
+ )
124
+
125
+ # === Config ===
126
+ self.image_shape = image_shape
127
+ self.patch_size = patch_size
128
+ self.num_layers = num_layers
129
+ self.num_heads = num_heads
130
+ self.hidden_dim = hidden_dim
131
+ self.intermediate_dim = intermediate_dim
132
+ self.dropout_rate = dropout_rate
133
+ self.attention_dropout = attention_dropout
134
+ self.layer_norm_epsilon = layer_norm_epsilon
135
+ self.use_mha_bias = use_mha_bias
136
+ self.data_format = data_format
137
+
138
+ def get_config(self):
139
+ config = super().get_config()
140
+ config.update(
141
+ {
142
+ "image_shape": self.image_shape,
143
+ "patch_size": self.patch_size,
144
+ "num_layers": self.num_layers,
145
+ "num_heads": self.num_heads,
146
+ "hidden_dim": self.hidden_dim,
147
+ "intermediate_dim": self.intermediate_dim,
148
+ "dropout_rate": self.dropout_rate,
149
+ "attention_dropout": self.attention_dropout,
150
+ "layer_norm_epsilon": self.layer_norm_epsilon,
151
+ "use_mha_bias": self.use_mha_bias,
152
+ }
153
+ )
154
+ return config