keras-hub-nightly 0.19.0.dev202503010353__py3-none-any.whl → 0.19.0.dev202503030351__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -0
- keras_hub/api/models/__init__.py +7 -0
- keras_hub/api/tokenizers/__init__.py +1 -0
- keras_hub/src/layers/preprocessing/image_converter.py +97 -1
- keras_hub/src/models/siglip/__init__.py +5 -0
- keras_hub/src/models/siglip/siglip_backbone.py +230 -0
- keras_hub/src/models/siglip/siglip_image_converter.py +8 -0
- keras_hub/src/models/siglip/siglip_layers.py +555 -0
- keras_hub/src/models/siglip/siglip_loss.py +35 -0
- keras_hub/src/models/siglip/siglip_preprocessor.py +162 -0
- keras_hub/src/models/siglip/siglip_presets.py +128 -0
- keras_hub/src/models/siglip/siglip_text_encoder.py +134 -0
- keras_hub/src/models/siglip/siglip_tokenizer.py +77 -0
- keras_hub/src/models/siglip/siglip_vision_encoder.py +151 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.19.0.dev202503010353.dist-info → keras_hub_nightly-0.19.0.dev202503030351.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.19.0.dev202503010353.dist-info → keras_hub_nightly-0.19.0.dev202503030351.dist-info}/RECORD +19 -9
- {keras_hub_nightly-0.19.0.dev202503010353.dist-info → keras_hub_nightly-0.19.0.dev202503030351.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.19.0.dev202503010353.dist-info → keras_hub_nightly-0.19.0.dev202503030351.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,555 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
from keras import initializers
|
4
|
+
from keras import layers
|
5
|
+
from keras import ops
|
6
|
+
|
7
|
+
from keras_hub.src.layers.modeling.reversible_embedding import (
|
8
|
+
ReversibleEmbedding,
|
9
|
+
)
|
10
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
11
|
+
from keras_hub.src.utils.keras_utils import gelu_approximate
|
12
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
13
|
+
|
14
|
+
|
15
|
+
class SigLIPVisionEmbedding(layers.Layer):
|
16
|
+
"""A layer that converts images into patches.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
hidden_dim: int. The number of units in the hidden layers.
|
20
|
+
patch_size: int. The size of one side of each patch.
|
21
|
+
image_size: int. The size of the input images.
|
22
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
23
|
+
`"channels_first"`. The ordering of the dimensions in the
|
24
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
25
|
+
`(batch_size, height, width, channels)`
|
26
|
+
while `"channels_first"` corresponds to inputs with shape
|
27
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
28
|
+
`image_data_format` value found in your Keras config file at
|
29
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
30
|
+
`"channels_last"`.
|
31
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
32
|
+
including `name`, `dtype` etc.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(
|
36
|
+
self,
|
37
|
+
hidden_dim,
|
38
|
+
patch_size,
|
39
|
+
image_size,
|
40
|
+
data_format=None,
|
41
|
+
**kwargs,
|
42
|
+
):
|
43
|
+
super().__init__(**kwargs)
|
44
|
+
self.hidden_dim = int(hidden_dim)
|
45
|
+
self.patch_size = int(patch_size)
|
46
|
+
self.image_size = int(image_size)
|
47
|
+
self.data_format = standardize_data_format(data_format)
|
48
|
+
self.num_positions = (image_size // patch_size) ** 2
|
49
|
+
|
50
|
+
self.patch_embedding = layers.Conv2D(
|
51
|
+
hidden_dim,
|
52
|
+
kernel_size=patch_size,
|
53
|
+
strides=patch_size,
|
54
|
+
kernel_initializer=initializers.LecunNormal(),
|
55
|
+
data_format=data_format,
|
56
|
+
dtype=self.dtype_policy,
|
57
|
+
name="patch_embedding",
|
58
|
+
)
|
59
|
+
self.position_embedding = layers.Embedding(
|
60
|
+
self.num_positions,
|
61
|
+
hidden_dim,
|
62
|
+
embeddings_initializer=initializers.RandomNormal(
|
63
|
+
stddev=1.0 / math.sqrt(hidden_dim)
|
64
|
+
),
|
65
|
+
dtype=self.dtype_policy,
|
66
|
+
name="position_embedding",
|
67
|
+
)
|
68
|
+
|
69
|
+
def build(self, input_shape):
|
70
|
+
self.position_ids = self.add_weight(
|
71
|
+
shape=(1, self.num_positions),
|
72
|
+
initializer="zeros",
|
73
|
+
# Let the backend determine the int dtype. For example, tf
|
74
|
+
# requires int64 for correct device placement, whereas jax and torch
|
75
|
+
# don't.
|
76
|
+
dtype=int,
|
77
|
+
trainable=False,
|
78
|
+
name="position_ids",
|
79
|
+
)
|
80
|
+
self.position_ids.assign(
|
81
|
+
ops.expand_dims(ops.arange(0, self.num_positions), axis=0)
|
82
|
+
)
|
83
|
+
self.patch_embedding.build(input_shape)
|
84
|
+
self.position_embedding.build(self.position_ids.shape)
|
85
|
+
|
86
|
+
def call(self, inputs, training=None):
|
87
|
+
x = inputs
|
88
|
+
batch_size = ops.shape(x)[0]
|
89
|
+
patch_embeddings = self.patch_embedding(x, training=training)
|
90
|
+
if self.data_format == "channels_last":
|
91
|
+
patch_embeddings = ops.reshape(
|
92
|
+
patch_embeddings, (batch_size, -1, self.hidden_dim)
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
patch_embeddings = ops.reshape(
|
96
|
+
patch_embeddings, (batch_size, self.hidden_dim, -1)
|
97
|
+
)
|
98
|
+
patch_embeddings = ops.transpose(patch_embeddings, (0, 2, 1))
|
99
|
+
position_embeddings = self.position_embedding(self.position_ids)
|
100
|
+
return ops.add(patch_embeddings, position_embeddings)
|
101
|
+
|
102
|
+
def get_config(self):
|
103
|
+
config = super().get_config()
|
104
|
+
config.update(
|
105
|
+
{
|
106
|
+
"hidden_dim": self.hidden_dim,
|
107
|
+
"patch_size": self.patch_size,
|
108
|
+
"image_size": self.image_size,
|
109
|
+
}
|
110
|
+
)
|
111
|
+
return config
|
112
|
+
|
113
|
+
def compute_output_shape(self, input_shape):
|
114
|
+
output_shape = [input_shape[0], None, self.hidden_dim]
|
115
|
+
if self.data_format == "channels_last":
|
116
|
+
if input_shape[1] is not None and input_shape[2] is not None:
|
117
|
+
patch_num = input_shape[1] // self.patch_size
|
118
|
+
output_shape[1] = patch_num**2 + 1
|
119
|
+
else:
|
120
|
+
if input_shape[2] is not None and input_shape[3] is not None:
|
121
|
+
patch_num = input_shape[2] // self.patch_size
|
122
|
+
output_shape[1] = patch_num**2 + 1
|
123
|
+
return output_shape
|
124
|
+
|
125
|
+
|
126
|
+
class SigLIPTextEmbedding(layers.Layer):
|
127
|
+
"""A layer which sums a token and position embedding.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
vocabulary_size: The size of the vocabulary.
|
131
|
+
sequence_length: The maximum length of input sequence.
|
132
|
+
embedding_dim: The output dimension of the embedding layer
|
133
|
+
tie_weights: Boolean, whether or not the matrix for embedding and
|
134
|
+
the matrix for the `reverse` projection should share the same
|
135
|
+
weights. Defaults to `True`.
|
136
|
+
embeddings_initializer: The initializer to use for the Embedding
|
137
|
+
Layers. Defaults to `"normal"`.
|
138
|
+
mask_zero: Boolean, whether or not the input value 0 is a special
|
139
|
+
"padding" value that should be masked out.
|
140
|
+
This is useful when using recurrent layers which may take variable
|
141
|
+
length input. If this is True, then all subsequent layers in the
|
142
|
+
model need to support masking or an exception will be raised.
|
143
|
+
If mask_zero` is set to True, as a consequence, index 0 cannot be
|
144
|
+
used in the vocabulary
|
145
|
+
(input_dim should equal size of vocabulary + 1). Defaults to
|
146
|
+
`False`.
|
147
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
148
|
+
including `name`, `trainable`, `dtype` etc.
|
149
|
+
"""
|
150
|
+
|
151
|
+
def __init__(
|
152
|
+
self,
|
153
|
+
vocabulary_size,
|
154
|
+
sequence_length,
|
155
|
+
embedding_dim,
|
156
|
+
tie_weights=True,
|
157
|
+
embeddings_initializer="normal",
|
158
|
+
mask_zero=False,
|
159
|
+
**kwargs,
|
160
|
+
):
|
161
|
+
super().__init__(**kwargs)
|
162
|
+
self.vocabulary_size = int(vocabulary_size)
|
163
|
+
self.sequence_length = int(sequence_length)
|
164
|
+
self.embedding_dim = int(embedding_dim)
|
165
|
+
self.embeddings_initializer = initializers.get(embeddings_initializer)
|
166
|
+
self.token_embedding = ReversibleEmbedding(
|
167
|
+
vocabulary_size,
|
168
|
+
embedding_dim,
|
169
|
+
tie_weights=tie_weights,
|
170
|
+
embeddings_initializer=clone_initializer(
|
171
|
+
self.embeddings_initializer
|
172
|
+
),
|
173
|
+
mask_zero=mask_zero,
|
174
|
+
dtype=self.dtype_policy,
|
175
|
+
name="token_embedding",
|
176
|
+
)
|
177
|
+
self.position_embedding = ReversibleEmbedding(
|
178
|
+
sequence_length,
|
179
|
+
embedding_dim,
|
180
|
+
tie_weights=tie_weights,
|
181
|
+
embeddings_initializer=clone_initializer(
|
182
|
+
self.embeddings_initializer
|
183
|
+
),
|
184
|
+
mask_zero=mask_zero,
|
185
|
+
dtype=self.dtype_policy,
|
186
|
+
name="position_embedding",
|
187
|
+
)
|
188
|
+
self.supports_masking = self.token_embedding.supports_masking
|
189
|
+
|
190
|
+
def build(self, input_shape):
|
191
|
+
input_shape = tuple(input_shape)
|
192
|
+
self.token_embedding.build(input_shape)
|
193
|
+
self.position_embedding.build((1, self.sequence_length))
|
194
|
+
self.position_ids = self.add_weight(
|
195
|
+
shape=(1, self.sequence_length),
|
196
|
+
initializer="zeros",
|
197
|
+
# Let the backend determine the int dtype. For example, tf
|
198
|
+
# requires int64 for correct device placement, whereas jax and torch
|
199
|
+
# don't.
|
200
|
+
dtype=int,
|
201
|
+
trainable=False,
|
202
|
+
name="position_ids",
|
203
|
+
)
|
204
|
+
self.position_ids.assign(
|
205
|
+
ops.expand_dims(ops.arange(0, self.sequence_length), axis=0)
|
206
|
+
)
|
207
|
+
|
208
|
+
def get_config(self):
|
209
|
+
config = super().get_config()
|
210
|
+
config.update(
|
211
|
+
{
|
212
|
+
"vocabulary_size": self.vocabulary_size,
|
213
|
+
"sequence_length": self.sequence_length,
|
214
|
+
"embedding_dim": self.embedding_dim,
|
215
|
+
"embeddings_initializer": initializers.serialize(
|
216
|
+
self.embeddings_initializer
|
217
|
+
),
|
218
|
+
"tie_weights": self.token_embedding.tie_weights,
|
219
|
+
"mask_zero": self.token_embedding.mask_zero,
|
220
|
+
}
|
221
|
+
)
|
222
|
+
return config
|
223
|
+
|
224
|
+
def call(self, inputs):
|
225
|
+
embedded_tokens = self.token_embedding(inputs)
|
226
|
+
embedded_positions = self.position_embedding(self.position_ids)
|
227
|
+
outputs = ops.add(embedded_tokens, embedded_positions)
|
228
|
+
return outputs
|
229
|
+
|
230
|
+
def compute_mask(self, inputs, mask=None):
|
231
|
+
return self.token_embedding.compute_mask(inputs, mask=mask)
|
232
|
+
|
233
|
+
def compute_output_shape(self, input_shape):
|
234
|
+
return tuple(input_shape) + (self.embedding_dim,)
|
235
|
+
|
236
|
+
|
237
|
+
class SigLIPMLP(layers.Layer):
|
238
|
+
"""A SigLIP MLP block.
|
239
|
+
|
240
|
+
Args:
|
241
|
+
hidden_dim: int. The number of units in the output layer.
|
242
|
+
intermediate_dim: int. The number of units in the intermediate layer.
|
243
|
+
activation: str of callable. Activation to use in the intermediate
|
244
|
+
layer.
|
245
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
246
|
+
including `name`, `dtype` etc.
|
247
|
+
"""
|
248
|
+
|
249
|
+
def __init__(self, hidden_dim, intermediate_dim, activation, **kwargs):
|
250
|
+
super().__init__(**kwargs)
|
251
|
+
self.hidden_dim = hidden_dim
|
252
|
+
self.intermediate_dim = intermediate_dim
|
253
|
+
self.activation = activation
|
254
|
+
|
255
|
+
if activation == "gelu_approximate":
|
256
|
+
activation = gelu_approximate
|
257
|
+
|
258
|
+
self.fc1 = layers.Dense(
|
259
|
+
self.intermediate_dim,
|
260
|
+
activation=activation,
|
261
|
+
bias_initializer=initializers.RandomNormal(stddev=1e-6),
|
262
|
+
dtype=self.dtype_policy,
|
263
|
+
name="fc1",
|
264
|
+
)
|
265
|
+
self.fc2 = layers.Dense(
|
266
|
+
self.hidden_dim,
|
267
|
+
bias_initializer=initializers.RandomNormal(stddev=1e-6),
|
268
|
+
dtype=self.dtype_policy,
|
269
|
+
name="fc2",
|
270
|
+
)
|
271
|
+
|
272
|
+
def build(self, inputs_shape):
|
273
|
+
self.fc1.build(inputs_shape)
|
274
|
+
inputs_shape = self.fc1.compute_output_shape(inputs_shape)
|
275
|
+
self.fc2.build(inputs_shape)
|
276
|
+
|
277
|
+
def call(self, inputs):
|
278
|
+
hidden_states = self.fc1(inputs)
|
279
|
+
return self.fc2(hidden_states)
|
280
|
+
|
281
|
+
def get_config(self):
|
282
|
+
config = super().get_config()
|
283
|
+
config.update(
|
284
|
+
{
|
285
|
+
"hidden_dim": self.hidden_dim,
|
286
|
+
"intermediate_dim": self.intermediate_dim,
|
287
|
+
"activation": self.activation,
|
288
|
+
}
|
289
|
+
)
|
290
|
+
return config
|
291
|
+
|
292
|
+
def compute_output_shape(self, inputs_shape):
|
293
|
+
outputs_shape = list(inputs_shape)
|
294
|
+
outputs_shape[-1] = self.hidden_dim
|
295
|
+
return outputs_shape
|
296
|
+
|
297
|
+
|
298
|
+
class SigLIPEncoderLayer(layers.Layer):
|
299
|
+
"""A SigLIP encoder layer.
|
300
|
+
|
301
|
+
Args:
|
302
|
+
hidden_dim: int. The number of units in the hidden layers.
|
303
|
+
num_heads: int. Number of attention heads.
|
304
|
+
intermediate_dim: int. The number of units in the intermediate layers.
|
305
|
+
intermediate_activation: str or callable. Activation to use in the
|
306
|
+
hidden layers.
|
307
|
+
layer_norm_epsilon: float. The epsilon for the layer normalization.
|
308
|
+
Defaults to `1e-6`.
|
309
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
310
|
+
including `name`, `dtype` etc.
|
311
|
+
"""
|
312
|
+
|
313
|
+
def __init__(
|
314
|
+
self,
|
315
|
+
hidden_dim,
|
316
|
+
num_heads,
|
317
|
+
intermediate_dim,
|
318
|
+
intermediate_activation="gelu_approximate",
|
319
|
+
use_causal_mask=False,
|
320
|
+
layer_norm_epsilon=1e-6,
|
321
|
+
**kwargs,
|
322
|
+
):
|
323
|
+
super().__init__(**kwargs)
|
324
|
+
if hidden_dim % num_heads != 0:
|
325
|
+
raise ValueError(
|
326
|
+
"`hidden_dim` must be divisible by `num_heads`. "
|
327
|
+
f"Received: hidden_dim={hidden_dim}, num_heads={num_heads}"
|
328
|
+
)
|
329
|
+
self.hidden_dim = int(hidden_dim)
|
330
|
+
self.num_heads = int(num_heads)
|
331
|
+
self.intermediate_dim = int(intermediate_dim)
|
332
|
+
self.intermediate_activation = intermediate_activation
|
333
|
+
self.use_causal_mask = bool(use_causal_mask)
|
334
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
335
|
+
|
336
|
+
self.self_attn = layers.MultiHeadAttention(
|
337
|
+
num_heads,
|
338
|
+
hidden_dim // num_heads,
|
339
|
+
dtype=self.dtype_policy,
|
340
|
+
name="self_attn",
|
341
|
+
)
|
342
|
+
self.layer_norm1 = layers.LayerNormalization(
|
343
|
+
epsilon=layer_norm_epsilon,
|
344
|
+
dtype=self.dtype_policy,
|
345
|
+
name="layer_norm1",
|
346
|
+
)
|
347
|
+
self.mlp = SigLIPMLP(
|
348
|
+
hidden_dim,
|
349
|
+
intermediate_dim,
|
350
|
+
intermediate_activation,
|
351
|
+
dtype=self.dtype_policy,
|
352
|
+
name="mlp",
|
353
|
+
)
|
354
|
+
self.layer_norm2 = layers.LayerNormalization(
|
355
|
+
epsilon=layer_norm_epsilon,
|
356
|
+
dtype=self.dtype_policy,
|
357
|
+
name="layer_norm2",
|
358
|
+
)
|
359
|
+
|
360
|
+
def build(self, inputs_shape):
|
361
|
+
self.layer_norm1.build(inputs_shape)
|
362
|
+
self.self_attn.build(inputs_shape, inputs_shape, inputs_shape)
|
363
|
+
self.layer_norm2.build(inputs_shape)
|
364
|
+
self.mlp.build(inputs_shape)
|
365
|
+
|
366
|
+
def compute_output_shape(self, inputs_shape):
|
367
|
+
outputs_shape = list(inputs_shape)
|
368
|
+
outputs_shape[-1] = self.hidden_dim
|
369
|
+
return outputs_shape
|
370
|
+
|
371
|
+
def call(self, inputs, training=None):
|
372
|
+
residual = inputs
|
373
|
+
x = self.layer_norm1(inputs)
|
374
|
+
x = self.self_attn(
|
375
|
+
x, x, x, training=training, use_causal_mask=self.use_causal_mask
|
376
|
+
)
|
377
|
+
x = ops.add(residual, x)
|
378
|
+
|
379
|
+
residual = x
|
380
|
+
x = self.layer_norm2(x)
|
381
|
+
x = self.mlp(x, training=training)
|
382
|
+
x = ops.add(residual, x)
|
383
|
+
return x
|
384
|
+
|
385
|
+
def get_config(self):
|
386
|
+
config = super().get_config()
|
387
|
+
config.update(
|
388
|
+
{
|
389
|
+
"hidden_dim": self.hidden_dim,
|
390
|
+
"num_heads": self.num_heads,
|
391
|
+
"intermediate_dim": self.intermediate_dim,
|
392
|
+
"intermediate_activation": self.intermediate_activation,
|
393
|
+
"use_causal_mask": self.use_causal_mask,
|
394
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
395
|
+
}
|
396
|
+
)
|
397
|
+
return config
|
398
|
+
|
399
|
+
|
400
|
+
class SigLIPMultiHeadAttentionPooling(layers.Layer):
|
401
|
+
"""A SigLIP multi-headed attention pooling layer.
|
402
|
+
|
403
|
+
Args:
|
404
|
+
hidden_dim: int. The number of units in the hidden layers.
|
405
|
+
intermediate_dim: int. The number of units in the intermediate layers.
|
406
|
+
num_heads: int. Number of attention heads.
|
407
|
+
activation: str or callable. Activation to use in the MLP.
|
408
|
+
layer_norm_epsilon: float. The epsilon for the layer normalization.
|
409
|
+
Defaults to `1e-6`.
|
410
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
411
|
+
including `name`, `dtype` etc.
|
412
|
+
"""
|
413
|
+
|
414
|
+
def __init__(
|
415
|
+
self,
|
416
|
+
hidden_dim,
|
417
|
+
intermediate_dim,
|
418
|
+
num_heads,
|
419
|
+
activation,
|
420
|
+
layer_norm_epsilon=1e-6,
|
421
|
+
**kwargs,
|
422
|
+
):
|
423
|
+
super().__init__(**kwargs)
|
424
|
+
self.hidden_dim = int(hidden_dim)
|
425
|
+
self.intermediate_dim = int(intermediate_dim)
|
426
|
+
self.num_heads = int(num_heads)
|
427
|
+
self.activation = activation
|
428
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
429
|
+
|
430
|
+
self.attention = layers.MultiHeadAttention(
|
431
|
+
num_heads,
|
432
|
+
hidden_dim // num_heads,
|
433
|
+
dtype=self.dtype_policy,
|
434
|
+
name="attention",
|
435
|
+
)
|
436
|
+
self.layer_norm = layers.LayerNormalization(
|
437
|
+
epsilon=self.layer_norm_epsilon,
|
438
|
+
dtype=self.dtype_policy,
|
439
|
+
name="layernorm",
|
440
|
+
)
|
441
|
+
self.mlp = SigLIPMLP(
|
442
|
+
hidden_dim=self.hidden_dim,
|
443
|
+
intermediate_dim=self.intermediate_dim,
|
444
|
+
activation=self.activation,
|
445
|
+
dtype=self.dtype_policy,
|
446
|
+
name="mlp",
|
447
|
+
)
|
448
|
+
|
449
|
+
def build(self, inputs_shape):
|
450
|
+
self.probe = self.add_weight(
|
451
|
+
(1, 1, self.hidden_dim),
|
452
|
+
initializer=initializers.GlorotUniform(),
|
453
|
+
dtype=self.dtype_policy.variable_dtype,
|
454
|
+
)
|
455
|
+
self.attention.build(
|
456
|
+
query_shape=(inputs_shape[0], 1, self.hidden_dim),
|
457
|
+
value_shape=inputs_shape,
|
458
|
+
)
|
459
|
+
inputs_shape = self.attention.compute_output_shape(
|
460
|
+
query_shape=(inputs_shape[0], 1, self.hidden_dim),
|
461
|
+
value_shape=inputs_shape,
|
462
|
+
)
|
463
|
+
self.layer_norm.build(inputs_shape)
|
464
|
+
self.mlp.build(inputs_shape)
|
465
|
+
|
466
|
+
def call(self, inputs, training=None):
|
467
|
+
batch_size = ops.shape(inputs)[0]
|
468
|
+
probes = ops.repeat(self.probe, repeats=batch_size, axis=0)
|
469
|
+
hidden_states = self.attention(
|
470
|
+
probes, inputs, inputs, training=training
|
471
|
+
)
|
472
|
+
residuals = hidden_states
|
473
|
+
hidden_states = self.layer_norm(hidden_states)
|
474
|
+
hidden_states = ops.add(residuals, self.mlp(hidden_states))
|
475
|
+
return hidden_states[:, 0]
|
476
|
+
|
477
|
+
def get_config(self):
|
478
|
+
config = super().get_config()
|
479
|
+
config.update(
|
480
|
+
{
|
481
|
+
"hidden_dim": self.hidden_dim,
|
482
|
+
"intermediate_dim": self.intermediate_dim,
|
483
|
+
"num_heads": self.num_heads,
|
484
|
+
"activation": self.activation,
|
485
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
486
|
+
}
|
487
|
+
)
|
488
|
+
return config
|
489
|
+
|
490
|
+
def compute_output_shape(self, inputs_shape):
|
491
|
+
return (inputs_shape[0], self.hidden_dim)
|
492
|
+
|
493
|
+
|
494
|
+
class SigLIPHead(layers.Layer):
|
495
|
+
"""The head layer of SigLIP.
|
496
|
+
|
497
|
+
`SigLIP` takes `vision_embedding` and `text_embedding` as inputs to
|
498
|
+
compute the corresponding logits. Both embeddings are L2 normalized and used
|
499
|
+
to compute pairwise cosine similarity. The resulting logits are then scaled
|
500
|
+
and added by learnable `logit_scale` and `logit_bias` parameters.
|
501
|
+
|
502
|
+
Call arguments:
|
503
|
+
vision_embedding: A tensor of shape `(batch_size, hidden_dim)`.
|
504
|
+
text_embedding: A tensor of shape `(batch_size, hidden_dim)`.
|
505
|
+
"""
|
506
|
+
|
507
|
+
def build(self, input_shape):
|
508
|
+
self.logit_scale = self.add_weight(
|
509
|
+
shape=(),
|
510
|
+
initializer=initializers.Constant(math.log(1.0)),
|
511
|
+
trainable=True,
|
512
|
+
dtype=self.variable_dtype,
|
513
|
+
name="logit_scale",
|
514
|
+
)
|
515
|
+
self.logit_bias = self.add_weight(
|
516
|
+
shape=(),
|
517
|
+
initializer=initializers.Zeros(),
|
518
|
+
trainable=True,
|
519
|
+
dtype=self.variable_dtype,
|
520
|
+
name="logit_bias",
|
521
|
+
)
|
522
|
+
|
523
|
+
def call(self, vision_embedding, text_embedding):
|
524
|
+
normalized_vision_embedding = ops.sqrt(
|
525
|
+
ops.sum(ops.power(vision_embedding, 2), axis=-1, keepdims=True)
|
526
|
+
)
|
527
|
+
normalized_text_embedding = ops.sqrt(
|
528
|
+
ops.sum(ops.power(text_embedding, 2), axis=-1, keepdims=True)
|
529
|
+
)
|
530
|
+
vision_embedding = ops.divide(
|
531
|
+
vision_embedding, normalized_vision_embedding
|
532
|
+
)
|
533
|
+
text_embedding = ops.divide(text_embedding, normalized_text_embedding)
|
534
|
+
text_logits = ops.add(
|
535
|
+
ops.multiply(
|
536
|
+
ops.matmul(text_embedding, ops.transpose(vision_embedding)),
|
537
|
+
ops.exp(self.logit_scale),
|
538
|
+
),
|
539
|
+
self.logit_bias,
|
540
|
+
)
|
541
|
+
vision_logits = ops.transpose(text_logits)
|
542
|
+
return vision_logits, text_logits
|
543
|
+
|
544
|
+
def compute_output_shape(
|
545
|
+
self, vision_embedding_shape, text_embedding_shape
|
546
|
+
):
|
547
|
+
vision_logits_shape = (
|
548
|
+
vision_embedding_shape[0],
|
549
|
+
text_embedding_shape[0],
|
550
|
+
)
|
551
|
+
text_logits_shape = (
|
552
|
+
text_embedding_shape[0],
|
553
|
+
vision_embedding_shape[0],
|
554
|
+
)
|
555
|
+
return vision_logits_shape, text_logits_shape
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
|
5
|
+
class SigLIPLoss(keras.losses.Loss):
|
6
|
+
"""SigLIP Loss.
|
7
|
+
|
8
|
+
SigLIP loss replaces the loss function used in CLIP by a simple pairwise
|
9
|
+
sigmoid loss. Unlike standard contrastive learning with softmax
|
10
|
+
normalization, the sigmoid loss operates solely on image-text pairs and does
|
11
|
+
not require a global view of the pairwise similarities for normalization.
|
12
|
+
The sigmoid loss simultaneously allows further scaling up the batch size,
|
13
|
+
while also performing better at smaller batch sizes.
|
14
|
+
|
15
|
+
References:
|
16
|
+
- [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343)
|
17
|
+
"""
|
18
|
+
|
19
|
+
def call(self, y_true, y_pred):
|
20
|
+
y_pred = ops.convert_to_tensor(y_pred)
|
21
|
+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
22
|
+
|
23
|
+
# Ref: https://github.com/google-research/big_vision/blob/main/big_vision/trainers/proj/image_text/siglip.py
|
24
|
+
logits = y_pred
|
25
|
+
m1_diag1 = y_true
|
26
|
+
|
27
|
+
# Standard sigmoid computes everything twice, once assuming positive
|
28
|
+
# labels and once assuming negative ones. But here we know exactly where
|
29
|
+
# to find positives (on "me" diagonal) and negatives (everywhere else),
|
30
|
+
# so compute each one's loss only once:
|
31
|
+
loglike = ops.nn.log_sigmoid(m1_diag1 * logits)
|
32
|
+
|
33
|
+
# Normalize by npos per column, but that's one, so just sum.
|
34
|
+
nll = -ops.sum(loglike, axis=-1)
|
35
|
+
return nll
|