keras-hub-nightly 0.19.0.dev202503010353__py3-none-any.whl → 0.19.0.dev202503030351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,555 @@
1
+ import math
2
+
3
+ from keras import initializers
4
+ from keras import layers
5
+ from keras import ops
6
+
7
+ from keras_hub.src.layers.modeling.reversible_embedding import (
8
+ ReversibleEmbedding,
9
+ )
10
+ from keras_hub.src.utils.keras_utils import clone_initializer
11
+ from keras_hub.src.utils.keras_utils import gelu_approximate
12
+ from keras_hub.src.utils.keras_utils import standardize_data_format
13
+
14
+
15
+ class SigLIPVisionEmbedding(layers.Layer):
16
+ """A layer that converts images into patches.
17
+
18
+ Args:
19
+ hidden_dim: int. The number of units in the hidden layers.
20
+ patch_size: int. The size of one side of each patch.
21
+ image_size: int. The size of the input images.
22
+ data_format: `None` or str. If specified, either `"channels_last"` or
23
+ `"channels_first"`. The ordering of the dimensions in the
24
+ inputs. `"channels_last"` corresponds to inputs with shape
25
+ `(batch_size, height, width, channels)`
26
+ while `"channels_first"` corresponds to inputs with shape
27
+ `(batch_size, channels, height, width)`. It defaults to the
28
+ `image_data_format` value found in your Keras config file at
29
+ `~/.keras/keras.json`. If you never set it, then it will be
30
+ `"channels_last"`.
31
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
32
+ including `name`, `dtype` etc.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ hidden_dim,
38
+ patch_size,
39
+ image_size,
40
+ data_format=None,
41
+ **kwargs,
42
+ ):
43
+ super().__init__(**kwargs)
44
+ self.hidden_dim = int(hidden_dim)
45
+ self.patch_size = int(patch_size)
46
+ self.image_size = int(image_size)
47
+ self.data_format = standardize_data_format(data_format)
48
+ self.num_positions = (image_size // patch_size) ** 2
49
+
50
+ self.patch_embedding = layers.Conv2D(
51
+ hidden_dim,
52
+ kernel_size=patch_size,
53
+ strides=patch_size,
54
+ kernel_initializer=initializers.LecunNormal(),
55
+ data_format=data_format,
56
+ dtype=self.dtype_policy,
57
+ name="patch_embedding",
58
+ )
59
+ self.position_embedding = layers.Embedding(
60
+ self.num_positions,
61
+ hidden_dim,
62
+ embeddings_initializer=initializers.RandomNormal(
63
+ stddev=1.0 / math.sqrt(hidden_dim)
64
+ ),
65
+ dtype=self.dtype_policy,
66
+ name="position_embedding",
67
+ )
68
+
69
+ def build(self, input_shape):
70
+ self.position_ids = self.add_weight(
71
+ shape=(1, self.num_positions),
72
+ initializer="zeros",
73
+ # Let the backend determine the int dtype. For example, tf
74
+ # requires int64 for correct device placement, whereas jax and torch
75
+ # don't.
76
+ dtype=int,
77
+ trainable=False,
78
+ name="position_ids",
79
+ )
80
+ self.position_ids.assign(
81
+ ops.expand_dims(ops.arange(0, self.num_positions), axis=0)
82
+ )
83
+ self.patch_embedding.build(input_shape)
84
+ self.position_embedding.build(self.position_ids.shape)
85
+
86
+ def call(self, inputs, training=None):
87
+ x = inputs
88
+ batch_size = ops.shape(x)[0]
89
+ patch_embeddings = self.patch_embedding(x, training=training)
90
+ if self.data_format == "channels_last":
91
+ patch_embeddings = ops.reshape(
92
+ patch_embeddings, (batch_size, -1, self.hidden_dim)
93
+ )
94
+ else:
95
+ patch_embeddings = ops.reshape(
96
+ patch_embeddings, (batch_size, self.hidden_dim, -1)
97
+ )
98
+ patch_embeddings = ops.transpose(patch_embeddings, (0, 2, 1))
99
+ position_embeddings = self.position_embedding(self.position_ids)
100
+ return ops.add(patch_embeddings, position_embeddings)
101
+
102
+ def get_config(self):
103
+ config = super().get_config()
104
+ config.update(
105
+ {
106
+ "hidden_dim": self.hidden_dim,
107
+ "patch_size": self.patch_size,
108
+ "image_size": self.image_size,
109
+ }
110
+ )
111
+ return config
112
+
113
+ def compute_output_shape(self, input_shape):
114
+ output_shape = [input_shape[0], None, self.hidden_dim]
115
+ if self.data_format == "channels_last":
116
+ if input_shape[1] is not None and input_shape[2] is not None:
117
+ patch_num = input_shape[1] // self.patch_size
118
+ output_shape[1] = patch_num**2 + 1
119
+ else:
120
+ if input_shape[2] is not None and input_shape[3] is not None:
121
+ patch_num = input_shape[2] // self.patch_size
122
+ output_shape[1] = patch_num**2 + 1
123
+ return output_shape
124
+
125
+
126
+ class SigLIPTextEmbedding(layers.Layer):
127
+ """A layer which sums a token and position embedding.
128
+
129
+ Args:
130
+ vocabulary_size: The size of the vocabulary.
131
+ sequence_length: The maximum length of input sequence.
132
+ embedding_dim: The output dimension of the embedding layer
133
+ tie_weights: Boolean, whether or not the matrix for embedding and
134
+ the matrix for the `reverse` projection should share the same
135
+ weights. Defaults to `True`.
136
+ embeddings_initializer: The initializer to use for the Embedding
137
+ Layers. Defaults to `"normal"`.
138
+ mask_zero: Boolean, whether or not the input value 0 is a special
139
+ "padding" value that should be masked out.
140
+ This is useful when using recurrent layers which may take variable
141
+ length input. If this is True, then all subsequent layers in the
142
+ model need to support masking or an exception will be raised.
143
+ If mask_zero` is set to True, as a consequence, index 0 cannot be
144
+ used in the vocabulary
145
+ (input_dim should equal size of vocabulary + 1). Defaults to
146
+ `False`.
147
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
148
+ including `name`, `trainable`, `dtype` etc.
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ vocabulary_size,
154
+ sequence_length,
155
+ embedding_dim,
156
+ tie_weights=True,
157
+ embeddings_initializer="normal",
158
+ mask_zero=False,
159
+ **kwargs,
160
+ ):
161
+ super().__init__(**kwargs)
162
+ self.vocabulary_size = int(vocabulary_size)
163
+ self.sequence_length = int(sequence_length)
164
+ self.embedding_dim = int(embedding_dim)
165
+ self.embeddings_initializer = initializers.get(embeddings_initializer)
166
+ self.token_embedding = ReversibleEmbedding(
167
+ vocabulary_size,
168
+ embedding_dim,
169
+ tie_weights=tie_weights,
170
+ embeddings_initializer=clone_initializer(
171
+ self.embeddings_initializer
172
+ ),
173
+ mask_zero=mask_zero,
174
+ dtype=self.dtype_policy,
175
+ name="token_embedding",
176
+ )
177
+ self.position_embedding = ReversibleEmbedding(
178
+ sequence_length,
179
+ embedding_dim,
180
+ tie_weights=tie_weights,
181
+ embeddings_initializer=clone_initializer(
182
+ self.embeddings_initializer
183
+ ),
184
+ mask_zero=mask_zero,
185
+ dtype=self.dtype_policy,
186
+ name="position_embedding",
187
+ )
188
+ self.supports_masking = self.token_embedding.supports_masking
189
+
190
+ def build(self, input_shape):
191
+ input_shape = tuple(input_shape)
192
+ self.token_embedding.build(input_shape)
193
+ self.position_embedding.build((1, self.sequence_length))
194
+ self.position_ids = self.add_weight(
195
+ shape=(1, self.sequence_length),
196
+ initializer="zeros",
197
+ # Let the backend determine the int dtype. For example, tf
198
+ # requires int64 for correct device placement, whereas jax and torch
199
+ # don't.
200
+ dtype=int,
201
+ trainable=False,
202
+ name="position_ids",
203
+ )
204
+ self.position_ids.assign(
205
+ ops.expand_dims(ops.arange(0, self.sequence_length), axis=0)
206
+ )
207
+
208
+ def get_config(self):
209
+ config = super().get_config()
210
+ config.update(
211
+ {
212
+ "vocabulary_size": self.vocabulary_size,
213
+ "sequence_length": self.sequence_length,
214
+ "embedding_dim": self.embedding_dim,
215
+ "embeddings_initializer": initializers.serialize(
216
+ self.embeddings_initializer
217
+ ),
218
+ "tie_weights": self.token_embedding.tie_weights,
219
+ "mask_zero": self.token_embedding.mask_zero,
220
+ }
221
+ )
222
+ return config
223
+
224
+ def call(self, inputs):
225
+ embedded_tokens = self.token_embedding(inputs)
226
+ embedded_positions = self.position_embedding(self.position_ids)
227
+ outputs = ops.add(embedded_tokens, embedded_positions)
228
+ return outputs
229
+
230
+ def compute_mask(self, inputs, mask=None):
231
+ return self.token_embedding.compute_mask(inputs, mask=mask)
232
+
233
+ def compute_output_shape(self, input_shape):
234
+ return tuple(input_shape) + (self.embedding_dim,)
235
+
236
+
237
+ class SigLIPMLP(layers.Layer):
238
+ """A SigLIP MLP block.
239
+
240
+ Args:
241
+ hidden_dim: int. The number of units in the output layer.
242
+ intermediate_dim: int. The number of units in the intermediate layer.
243
+ activation: str of callable. Activation to use in the intermediate
244
+ layer.
245
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
246
+ including `name`, `dtype` etc.
247
+ """
248
+
249
+ def __init__(self, hidden_dim, intermediate_dim, activation, **kwargs):
250
+ super().__init__(**kwargs)
251
+ self.hidden_dim = hidden_dim
252
+ self.intermediate_dim = intermediate_dim
253
+ self.activation = activation
254
+
255
+ if activation == "gelu_approximate":
256
+ activation = gelu_approximate
257
+
258
+ self.fc1 = layers.Dense(
259
+ self.intermediate_dim,
260
+ activation=activation,
261
+ bias_initializer=initializers.RandomNormal(stddev=1e-6),
262
+ dtype=self.dtype_policy,
263
+ name="fc1",
264
+ )
265
+ self.fc2 = layers.Dense(
266
+ self.hidden_dim,
267
+ bias_initializer=initializers.RandomNormal(stddev=1e-6),
268
+ dtype=self.dtype_policy,
269
+ name="fc2",
270
+ )
271
+
272
+ def build(self, inputs_shape):
273
+ self.fc1.build(inputs_shape)
274
+ inputs_shape = self.fc1.compute_output_shape(inputs_shape)
275
+ self.fc2.build(inputs_shape)
276
+
277
+ def call(self, inputs):
278
+ hidden_states = self.fc1(inputs)
279
+ return self.fc2(hidden_states)
280
+
281
+ def get_config(self):
282
+ config = super().get_config()
283
+ config.update(
284
+ {
285
+ "hidden_dim": self.hidden_dim,
286
+ "intermediate_dim": self.intermediate_dim,
287
+ "activation": self.activation,
288
+ }
289
+ )
290
+ return config
291
+
292
+ def compute_output_shape(self, inputs_shape):
293
+ outputs_shape = list(inputs_shape)
294
+ outputs_shape[-1] = self.hidden_dim
295
+ return outputs_shape
296
+
297
+
298
+ class SigLIPEncoderLayer(layers.Layer):
299
+ """A SigLIP encoder layer.
300
+
301
+ Args:
302
+ hidden_dim: int. The number of units in the hidden layers.
303
+ num_heads: int. Number of attention heads.
304
+ intermediate_dim: int. The number of units in the intermediate layers.
305
+ intermediate_activation: str or callable. Activation to use in the
306
+ hidden layers.
307
+ layer_norm_epsilon: float. The epsilon for the layer normalization.
308
+ Defaults to `1e-6`.
309
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
310
+ including `name`, `dtype` etc.
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ hidden_dim,
316
+ num_heads,
317
+ intermediate_dim,
318
+ intermediate_activation="gelu_approximate",
319
+ use_causal_mask=False,
320
+ layer_norm_epsilon=1e-6,
321
+ **kwargs,
322
+ ):
323
+ super().__init__(**kwargs)
324
+ if hidden_dim % num_heads != 0:
325
+ raise ValueError(
326
+ "`hidden_dim` must be divisible by `num_heads`. "
327
+ f"Received: hidden_dim={hidden_dim}, num_heads={num_heads}"
328
+ )
329
+ self.hidden_dim = int(hidden_dim)
330
+ self.num_heads = int(num_heads)
331
+ self.intermediate_dim = int(intermediate_dim)
332
+ self.intermediate_activation = intermediate_activation
333
+ self.use_causal_mask = bool(use_causal_mask)
334
+ self.layer_norm_epsilon = layer_norm_epsilon
335
+
336
+ self.self_attn = layers.MultiHeadAttention(
337
+ num_heads,
338
+ hidden_dim // num_heads,
339
+ dtype=self.dtype_policy,
340
+ name="self_attn",
341
+ )
342
+ self.layer_norm1 = layers.LayerNormalization(
343
+ epsilon=layer_norm_epsilon,
344
+ dtype=self.dtype_policy,
345
+ name="layer_norm1",
346
+ )
347
+ self.mlp = SigLIPMLP(
348
+ hidden_dim,
349
+ intermediate_dim,
350
+ intermediate_activation,
351
+ dtype=self.dtype_policy,
352
+ name="mlp",
353
+ )
354
+ self.layer_norm2 = layers.LayerNormalization(
355
+ epsilon=layer_norm_epsilon,
356
+ dtype=self.dtype_policy,
357
+ name="layer_norm2",
358
+ )
359
+
360
+ def build(self, inputs_shape):
361
+ self.layer_norm1.build(inputs_shape)
362
+ self.self_attn.build(inputs_shape, inputs_shape, inputs_shape)
363
+ self.layer_norm2.build(inputs_shape)
364
+ self.mlp.build(inputs_shape)
365
+
366
+ def compute_output_shape(self, inputs_shape):
367
+ outputs_shape = list(inputs_shape)
368
+ outputs_shape[-1] = self.hidden_dim
369
+ return outputs_shape
370
+
371
+ def call(self, inputs, training=None):
372
+ residual = inputs
373
+ x = self.layer_norm1(inputs)
374
+ x = self.self_attn(
375
+ x, x, x, training=training, use_causal_mask=self.use_causal_mask
376
+ )
377
+ x = ops.add(residual, x)
378
+
379
+ residual = x
380
+ x = self.layer_norm2(x)
381
+ x = self.mlp(x, training=training)
382
+ x = ops.add(residual, x)
383
+ return x
384
+
385
+ def get_config(self):
386
+ config = super().get_config()
387
+ config.update(
388
+ {
389
+ "hidden_dim": self.hidden_dim,
390
+ "num_heads": self.num_heads,
391
+ "intermediate_dim": self.intermediate_dim,
392
+ "intermediate_activation": self.intermediate_activation,
393
+ "use_causal_mask": self.use_causal_mask,
394
+ "layer_norm_epsilon": self.layer_norm_epsilon,
395
+ }
396
+ )
397
+ return config
398
+
399
+
400
+ class SigLIPMultiHeadAttentionPooling(layers.Layer):
401
+ """A SigLIP multi-headed attention pooling layer.
402
+
403
+ Args:
404
+ hidden_dim: int. The number of units in the hidden layers.
405
+ intermediate_dim: int. The number of units in the intermediate layers.
406
+ num_heads: int. Number of attention heads.
407
+ activation: str or callable. Activation to use in the MLP.
408
+ layer_norm_epsilon: float. The epsilon for the layer normalization.
409
+ Defaults to `1e-6`.
410
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
411
+ including `name`, `dtype` etc.
412
+ """
413
+
414
+ def __init__(
415
+ self,
416
+ hidden_dim,
417
+ intermediate_dim,
418
+ num_heads,
419
+ activation,
420
+ layer_norm_epsilon=1e-6,
421
+ **kwargs,
422
+ ):
423
+ super().__init__(**kwargs)
424
+ self.hidden_dim = int(hidden_dim)
425
+ self.intermediate_dim = int(intermediate_dim)
426
+ self.num_heads = int(num_heads)
427
+ self.activation = activation
428
+ self.layer_norm_epsilon = layer_norm_epsilon
429
+
430
+ self.attention = layers.MultiHeadAttention(
431
+ num_heads,
432
+ hidden_dim // num_heads,
433
+ dtype=self.dtype_policy,
434
+ name="attention",
435
+ )
436
+ self.layer_norm = layers.LayerNormalization(
437
+ epsilon=self.layer_norm_epsilon,
438
+ dtype=self.dtype_policy,
439
+ name="layernorm",
440
+ )
441
+ self.mlp = SigLIPMLP(
442
+ hidden_dim=self.hidden_dim,
443
+ intermediate_dim=self.intermediate_dim,
444
+ activation=self.activation,
445
+ dtype=self.dtype_policy,
446
+ name="mlp",
447
+ )
448
+
449
+ def build(self, inputs_shape):
450
+ self.probe = self.add_weight(
451
+ (1, 1, self.hidden_dim),
452
+ initializer=initializers.GlorotUniform(),
453
+ dtype=self.dtype_policy.variable_dtype,
454
+ )
455
+ self.attention.build(
456
+ query_shape=(inputs_shape[0], 1, self.hidden_dim),
457
+ value_shape=inputs_shape,
458
+ )
459
+ inputs_shape = self.attention.compute_output_shape(
460
+ query_shape=(inputs_shape[0], 1, self.hidden_dim),
461
+ value_shape=inputs_shape,
462
+ )
463
+ self.layer_norm.build(inputs_shape)
464
+ self.mlp.build(inputs_shape)
465
+
466
+ def call(self, inputs, training=None):
467
+ batch_size = ops.shape(inputs)[0]
468
+ probes = ops.repeat(self.probe, repeats=batch_size, axis=0)
469
+ hidden_states = self.attention(
470
+ probes, inputs, inputs, training=training
471
+ )
472
+ residuals = hidden_states
473
+ hidden_states = self.layer_norm(hidden_states)
474
+ hidden_states = ops.add(residuals, self.mlp(hidden_states))
475
+ return hidden_states[:, 0]
476
+
477
+ def get_config(self):
478
+ config = super().get_config()
479
+ config.update(
480
+ {
481
+ "hidden_dim": self.hidden_dim,
482
+ "intermediate_dim": self.intermediate_dim,
483
+ "num_heads": self.num_heads,
484
+ "activation": self.activation,
485
+ "layer_norm_epsilon": self.layer_norm_epsilon,
486
+ }
487
+ )
488
+ return config
489
+
490
+ def compute_output_shape(self, inputs_shape):
491
+ return (inputs_shape[0], self.hidden_dim)
492
+
493
+
494
+ class SigLIPHead(layers.Layer):
495
+ """The head layer of SigLIP.
496
+
497
+ `SigLIP` takes `vision_embedding` and `text_embedding` as inputs to
498
+ compute the corresponding logits. Both embeddings are L2 normalized and used
499
+ to compute pairwise cosine similarity. The resulting logits are then scaled
500
+ and added by learnable `logit_scale` and `logit_bias` parameters.
501
+
502
+ Call arguments:
503
+ vision_embedding: A tensor of shape `(batch_size, hidden_dim)`.
504
+ text_embedding: A tensor of shape `(batch_size, hidden_dim)`.
505
+ """
506
+
507
+ def build(self, input_shape):
508
+ self.logit_scale = self.add_weight(
509
+ shape=(),
510
+ initializer=initializers.Constant(math.log(1.0)),
511
+ trainable=True,
512
+ dtype=self.variable_dtype,
513
+ name="logit_scale",
514
+ )
515
+ self.logit_bias = self.add_weight(
516
+ shape=(),
517
+ initializer=initializers.Zeros(),
518
+ trainable=True,
519
+ dtype=self.variable_dtype,
520
+ name="logit_bias",
521
+ )
522
+
523
+ def call(self, vision_embedding, text_embedding):
524
+ normalized_vision_embedding = ops.sqrt(
525
+ ops.sum(ops.power(vision_embedding, 2), axis=-1, keepdims=True)
526
+ )
527
+ normalized_text_embedding = ops.sqrt(
528
+ ops.sum(ops.power(text_embedding, 2), axis=-1, keepdims=True)
529
+ )
530
+ vision_embedding = ops.divide(
531
+ vision_embedding, normalized_vision_embedding
532
+ )
533
+ text_embedding = ops.divide(text_embedding, normalized_text_embedding)
534
+ text_logits = ops.add(
535
+ ops.multiply(
536
+ ops.matmul(text_embedding, ops.transpose(vision_embedding)),
537
+ ops.exp(self.logit_scale),
538
+ ),
539
+ self.logit_bias,
540
+ )
541
+ vision_logits = ops.transpose(text_logits)
542
+ return vision_logits, text_logits
543
+
544
+ def compute_output_shape(
545
+ self, vision_embedding_shape, text_embedding_shape
546
+ ):
547
+ vision_logits_shape = (
548
+ vision_embedding_shape[0],
549
+ text_embedding_shape[0],
550
+ )
551
+ text_logits_shape = (
552
+ text_embedding_shape[0],
553
+ vision_embedding_shape[0],
554
+ )
555
+ return vision_logits_shape, text_logits_shape
@@ -0,0 +1,35 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+
5
+ class SigLIPLoss(keras.losses.Loss):
6
+ """SigLIP Loss.
7
+
8
+ SigLIP loss replaces the loss function used in CLIP by a simple pairwise
9
+ sigmoid loss. Unlike standard contrastive learning with softmax
10
+ normalization, the sigmoid loss operates solely on image-text pairs and does
11
+ not require a global view of the pairwise similarities for normalization.
12
+ The sigmoid loss simultaneously allows further scaling up the batch size,
13
+ while also performing better at smaller batch sizes.
14
+
15
+ References:
16
+ - [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343)
17
+ """
18
+
19
+ def call(self, y_true, y_pred):
20
+ y_pred = ops.convert_to_tensor(y_pred)
21
+ y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
22
+
23
+ # Ref: https://github.com/google-research/big_vision/blob/main/big_vision/trainers/proj/image_text/siglip.py
24
+ logits = y_pred
25
+ m1_diag1 = y_true
26
+
27
+ # Standard sigmoid computes everything twice, once assuming positive
28
+ # labels and once assuming negative ones. But here we know exactly where
29
+ # to find positives (on "me" diagonal) and negatives (everywhere else),
30
+ # so compute each one's loss only once:
31
+ loglike = ops.nn.log_sigmoid(m1_diag1 * logits)
32
+
33
+ # Normalize by npos per column, but that's one, so just sum.
34
+ nll = -ops.sum(loglike, axis=-1)
35
+ return nll