keras-hub 0.21.1__py3-none-any.whl → 0.22.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +9 -0
- keras_hub/models/__init__.py +47 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +6 -3
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +17 -3
- keras_hub/src/layers/preprocessing/start_end_packer.py +24 -6
- keras_hub/src/models/backbone.py +13 -10
- keras_hub/src/models/clip/clip_backbone.py +3 -102
- keras_hub/src/models/clip/clip_layers.py +295 -0
- keras_hub/src/models/clip/clip_preprocessor.py +57 -48
- keras_hub/src/models/clip/clip_text_encoder.py +2 -2
- keras_hub/src/models/clip/clip_vision_encoder.py +3 -3
- keras_hub/src/models/deit/__init__.py +5 -0
- keras_hub/src/models/deit/deit_backbone.py +154 -0
- keras_hub/src/models/deit/deit_image_classifier.py +171 -0
- keras_hub/src/models/deit/deit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/deit/deit_image_converter.py +8 -0
- keras_hub/src/models/deit/deit_layers.py +519 -0
- keras_hub/src/models/deit/deit_presets.py +49 -0
- keras_hub/src/models/dinov2/__init__.py +5 -0
- keras_hub/src/models/dinov2/dinov2_backbone.py +228 -0
- keras_hub/src/models/dinov2/dinov2_image_converter.py +8 -0
- keras_hub/src/models/dinov2/dinov2_layers.py +886 -0
- keras_hub/src/models/dinov2/dinov2_presets.py +89 -0
- keras_hub/src/models/esm/__init__.py +5 -0
- keras_hub/src/models/esm/esm_attention.py +95 -0
- keras_hub/src/models/esm/esm_backbone.py +229 -0
- keras_hub/src/models/esm/esm_classifier.py +184 -0
- keras_hub/src/models/esm/esm_classifier_preprocessor.py +135 -0
- keras_hub/src/models/esm/esm_encoder.py +134 -0
- keras_hub/src/models/esm/esm_masked_plm.py +117 -0
- keras_hub/src/models/esm/esm_masked_plm_preprocessor.py +143 -0
- keras_hub/src/models/esm/esm_presets.py +53 -0
- keras_hub/src/models/esm/esm_tokenizer.py +82 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +6 -2
- keras_hub/src/models/gemma/gemma_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +2 -2
- keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +1 -1
- keras_hub/src/models/hgnetv2/__init__.py +5 -0
- keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +193 -0
- keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +148 -0
- keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py +216 -0
- keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_preprocessor.py +14 -0
- keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py +8 -0
- keras_hub/src/models/hgnetv2/hgnetv2_layers.py +918 -0
- keras_hub/src/models/hgnetv2/hgnetv2_presets.py +58 -0
- keras_hub/src/models/llama3/llama3_presets.py +3 -3
- keras_hub/src/models/mistral/mistral_presets.py +17 -1
- keras_hub/src/models/mixtral/mixtral_presets.py +2 -2
- keras_hub/src/models/mobilenet/mobilenet_presets.py +4 -4
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +2 -2
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +2 -2
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +17 -17
- keras_hub/src/models/qwen3/__init__.py +5 -0
- keras_hub/src/models/qwen3/qwen3_attention.py +369 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +191 -0
- keras_hub/src/models/qwen3/qwen3_causal_lm.py +390 -0
- keras_hub/src/models/qwen3/qwen3_causal_lm_preprocessor.py +10 -0
- keras_hub/src/models/qwen3/qwen3_decoder.py +309 -0
- keras_hub/src/models/qwen3/qwen3_layernorm.py +38 -0
- keras_hub/src/models/qwen3/qwen3_presets.py +73 -0
- keras_hub/src/models/qwen3/qwen3_tokenizer.py +48 -0
- keras_hub/src/models/qwen_moe/qwen_moe_attention.py +1 -0
- keras_hub/src/models/qwen_moe/qwen_moe_presets.py +2 -2
- keras_hub/src/models/roformer_v2/roformer_v2_attention.py +0 -2
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +16 -7
- keras_hub/src/models/stable_diffusion_3/mmdit.py +61 -4
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +31 -32
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +1 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +1 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +6 -2
- keras_hub/src/models/vit/vit_backbone.py +31 -11
- keras_hub/src/models/vit/vit_image_converter.py +0 -70
- keras_hub/src/models/vit/vit_layers.py +33 -18
- keras_hub/src/models/vit/vit_presets.py +11 -11
- keras_hub/src/utils/keras_utils.py +17 -0
- keras_hub/src/utils/preset_utils.py +19 -4
- keras_hub/src/utils/tensor_utils.py +14 -0
- keras_hub/src/utils/transformers/convert_deit.py +155 -0
- keras_hub/src/utils/transformers/convert_dinov2.py +180 -0
- keras_hub/src/utils/transformers/convert_esm.py +159 -0
- keras_hub/src/utils/transformers/convert_llama3.py +6 -0
- keras_hub/src/utils/transformers/convert_qwen3.py +145 -0
- keras_hub/src/utils/transformers/export/gemma.py +89 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +98 -0
- keras_hub/src/utils/transformers/preset_loader.py +14 -2
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +1 -0
- {keras_hub-0.21.1.dist-info → keras_hub-0.22.0.dev0.dist-info}/METADATA +4 -4
- {keras_hub-0.21.1.dist-info → keras_hub-0.22.0.dev0.dist-info}/RECORD +92 -48
- keras_hub/src/models/clip/clip_encoder_block.py +0 -111
- keras_hub/src/models/clip/clip_vision_embedding.py +0 -101
- {keras_hub-0.21.1.dist-info → keras_hub-0.22.0.dev0.dist-info}/WHEEL +0 -0
- {keras_hub-0.21.1.dist-info → keras_hub-0.22.0.dev0.dist-info}/top_level.txt +0 -0
keras_hub/layers/__init__.py
CHANGED
|
@@ -78,15 +78,24 @@ from keras_hub.src.models.cspnet.cspnet_image_converter import (
|
|
|
78
78
|
from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import (
|
|
79
79
|
DeepLabV3ImageConverter as DeepLabV3ImageConverter,
|
|
80
80
|
)
|
|
81
|
+
from keras_hub.src.models.deit.deit_image_converter import (
|
|
82
|
+
DeiTImageConverter as DeiTImageConverter,
|
|
83
|
+
)
|
|
81
84
|
from keras_hub.src.models.densenet.densenet_image_converter import (
|
|
82
85
|
DenseNetImageConverter as DenseNetImageConverter,
|
|
83
86
|
)
|
|
87
|
+
from keras_hub.src.models.dinov2.dinov2_image_converter import (
|
|
88
|
+
DINOV2ImageConverter as DINOV2ImageConverter,
|
|
89
|
+
)
|
|
84
90
|
from keras_hub.src.models.efficientnet.efficientnet_image_converter import (
|
|
85
91
|
EfficientNetImageConverter as EfficientNetImageConverter,
|
|
86
92
|
)
|
|
87
93
|
from keras_hub.src.models.gemma3.gemma3_image_converter import (
|
|
88
94
|
Gemma3ImageConverter as Gemma3ImageConverter,
|
|
89
95
|
)
|
|
96
|
+
from keras_hub.src.models.hgnetv2.hgnetv2_image_converter import (
|
|
97
|
+
HGNetV2ImageConverter as HGNetV2ImageConverter,
|
|
98
|
+
)
|
|
90
99
|
from keras_hub.src.models.mit.mit_image_converter import (
|
|
91
100
|
MiTImageConverter as MiTImageConverter,
|
|
92
101
|
)
|
keras_hub/models/__init__.py
CHANGED
|
@@ -141,6 +141,13 @@ from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor impo
|
|
|
141
141
|
from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import (
|
|
142
142
|
DeepLabV3ImageSegmenter as DeepLabV3ImageSegmenter,
|
|
143
143
|
)
|
|
144
|
+
from keras_hub.src.models.deit.deit_backbone import DeiTBackbone as DeiTBackbone
|
|
145
|
+
from keras_hub.src.models.deit.deit_image_classifier import (
|
|
146
|
+
DeiTImageClassifier as DeiTImageClassifier,
|
|
147
|
+
)
|
|
148
|
+
from keras_hub.src.models.deit.deit_image_classifier_preprocessor import (
|
|
149
|
+
DeiTImageClassifierPreprocessor as DeiTImageClassifierPreprocessor,
|
|
150
|
+
)
|
|
144
151
|
from keras_hub.src.models.densenet.densenet_backbone import (
|
|
145
152
|
DenseNetBackbone as DenseNetBackbone,
|
|
146
153
|
)
|
|
@@ -150,6 +157,9 @@ from keras_hub.src.models.densenet.densenet_image_classifier import (
|
|
|
150
157
|
from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import (
|
|
151
158
|
DenseNetImageClassifierPreprocessor as DenseNetImageClassifierPreprocessor,
|
|
152
159
|
)
|
|
160
|
+
from keras_hub.src.models.dinov2.dinov2_backbone import (
|
|
161
|
+
DINOV2Backbone as DINOV2Backbone,
|
|
162
|
+
)
|
|
153
163
|
from keras_hub.src.models.distil_bert.distil_bert_backbone import (
|
|
154
164
|
DistilBertBackbone as DistilBertBackbone,
|
|
155
165
|
)
|
|
@@ -189,6 +199,22 @@ from keras_hub.src.models.electra.electra_backbone import (
|
|
|
189
199
|
from keras_hub.src.models.electra.electra_tokenizer import (
|
|
190
200
|
ElectraTokenizer as ElectraTokenizer,
|
|
191
201
|
)
|
|
202
|
+
from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESM2Backbone
|
|
203
|
+
from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESMBackbone
|
|
204
|
+
from keras_hub.src.models.esm.esm_classifier import (
|
|
205
|
+
ESMProteinClassifier as ESMProteinClassifier,
|
|
206
|
+
)
|
|
207
|
+
from keras_hub.src.models.esm.esm_classifier_preprocessor import (
|
|
208
|
+
ESMProteinClassifierPreprocessor as ESMProteinClassifierPreprocessor,
|
|
209
|
+
)
|
|
210
|
+
from keras_hub.src.models.esm.esm_masked_plm import (
|
|
211
|
+
ESMMaskedPLM as ESM2MaskedPLM,
|
|
212
|
+
)
|
|
213
|
+
from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM as ESMMaskedPLM
|
|
214
|
+
from keras_hub.src.models.esm.esm_masked_plm_preprocessor import (
|
|
215
|
+
ESMMaskedPLMPreprocessor as ESMMaskedPLMPreprocessor,
|
|
216
|
+
)
|
|
217
|
+
from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer as ESMTokenizer
|
|
192
218
|
from keras_hub.src.models.f_net.f_net_backbone import (
|
|
193
219
|
FNetBackbone as FNetBackbone,
|
|
194
220
|
)
|
|
@@ -287,6 +313,15 @@ from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import (
|
|
|
287
313
|
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import (
|
|
288
314
|
GPTNeoXTokenizer as GPTNeoXTokenizer,
|
|
289
315
|
)
|
|
316
|
+
from keras_hub.src.models.hgnetv2.hgnetv2_backbone import (
|
|
317
|
+
HGNetV2Backbone as HGNetV2Backbone,
|
|
318
|
+
)
|
|
319
|
+
from keras_hub.src.models.hgnetv2.hgnetv2_image_classifier import (
|
|
320
|
+
HGNetV2ImageClassifier as HGNetV2ImageClassifier,
|
|
321
|
+
)
|
|
322
|
+
from keras_hub.src.models.hgnetv2.hgnetv2_image_classifier_preprocessor import (
|
|
323
|
+
HGNetV2ImageClassifierPreprocessor as HGNetV2ImageClassifierPreprocessor,
|
|
324
|
+
)
|
|
290
325
|
from keras_hub.src.models.image_classifier import (
|
|
291
326
|
ImageClassifier as ImageClassifier,
|
|
292
327
|
)
|
|
@@ -444,6 +479,18 @@ from keras_hub.src.models.qwen.qwen_tokenizer import (
|
|
|
444
479
|
from keras_hub.src.models.qwen.qwen_tokenizer import (
|
|
445
480
|
QwenTokenizer as QwenTokenizer,
|
|
446
481
|
)
|
|
482
|
+
from keras_hub.src.models.qwen3.qwen3_backbone import (
|
|
483
|
+
Qwen3Backbone as Qwen3Backbone,
|
|
484
|
+
)
|
|
485
|
+
from keras_hub.src.models.qwen3.qwen3_causal_lm import (
|
|
486
|
+
Qwen3CausalLM as Qwen3CausalLM,
|
|
487
|
+
)
|
|
488
|
+
from keras_hub.src.models.qwen3.qwen3_causal_lm_preprocessor import (
|
|
489
|
+
Qwen3CausalLMPreprocessor as Qwen3CausalLMPreprocessor,
|
|
490
|
+
)
|
|
491
|
+
from keras_hub.src.models.qwen3.qwen3_tokenizer import (
|
|
492
|
+
Qwen3Tokenizer as Qwen3Tokenizer,
|
|
493
|
+
)
|
|
447
494
|
from keras_hub.src.models.qwen_moe.qwen_moe_backbone import (
|
|
448
495
|
QwenMoeBackbone as QwenMoeBackbone,
|
|
449
496
|
)
|
|
@@ -16,9 +16,12 @@ class TransformerEncoder(keras.layers.Layer):
|
|
|
16
16
|
paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users
|
|
17
17
|
can instantiate multiple instances of this class to stack up an encoder.
|
|
18
18
|
|
|
19
|
-
This layer will
|
|
20
|
-
|
|
21
|
-
|
|
19
|
+
This layer will compute an attention mask, prioritizing explicitly provided
|
|
20
|
+
masks (a `padding_mask` or a custom `attention_mask`) over an implicit Keras
|
|
21
|
+
padding mask (for example, by passing `mask_zero=True` to a
|
|
22
|
+
`keras.layers.Embedding` layer). If both a `padding_mask` and a
|
|
23
|
+
`attention_mask` are provided, they will be combined to determine the final
|
|
24
|
+
mask. See the Masking and Padding
|
|
22
25
|
[guide](https://keras.io/guides/understanding_masking_and_padding/)
|
|
23
26
|
for more details.
|
|
24
27
|
|
|
@@ -3,6 +3,7 @@ from keras_hub.src.layers.preprocessing.preprocessing_layer import (
|
|
|
3
3
|
PreprocessingLayer,
|
|
4
4
|
)
|
|
5
5
|
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
|
|
6
|
+
from keras_hub.src.utils.tensor_utils import pad
|
|
6
7
|
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
|
7
8
|
|
|
8
9
|
try:
|
|
@@ -66,6 +67,8 @@ class MultiSegmentPacker(PreprocessingLayer):
|
|
|
66
67
|
"waterfall" algorithm that allocates quota in a
|
|
67
68
|
left-to-right manner and fills up the buckets until we run
|
|
68
69
|
out of budget. It support arbitrary number of segments.
|
|
70
|
+
padding_side: str. Whether to pad the input on the "left" or "right".
|
|
71
|
+
Defaults to "right".
|
|
69
72
|
|
|
70
73
|
Returns:
|
|
71
74
|
A tuple with two elements. The first is the dense, packed token
|
|
@@ -124,6 +127,7 @@ class MultiSegmentPacker(PreprocessingLayer):
|
|
|
124
127
|
sep_value=None,
|
|
125
128
|
pad_value=None,
|
|
126
129
|
truncate="round_robin",
|
|
130
|
+
padding_side="right",
|
|
127
131
|
**kwargs,
|
|
128
132
|
):
|
|
129
133
|
super().__init__(**kwargs)
|
|
@@ -162,6 +166,7 @@ class MultiSegmentPacker(PreprocessingLayer):
|
|
|
162
166
|
self.end_value = end_value
|
|
163
167
|
|
|
164
168
|
self.pad_value = pad_value
|
|
169
|
+
self.padding_side = padding_side
|
|
165
170
|
|
|
166
171
|
def get_config(self):
|
|
167
172
|
config = super().get_config()
|
|
@@ -173,6 +178,7 @@ class MultiSegmentPacker(PreprocessingLayer):
|
|
|
173
178
|
"sep_value": self._sep_value,
|
|
174
179
|
"pad_value": self.pad_value,
|
|
175
180
|
"truncate": self.truncate,
|
|
181
|
+
"padding_side": self.padding_side,
|
|
176
182
|
}
|
|
177
183
|
)
|
|
178
184
|
return config
|
|
@@ -287,10 +293,18 @@ class MultiSegmentPacker(PreprocessingLayer):
|
|
|
287
293
|
# Pad to dense tensor output.
|
|
288
294
|
sequence_length = sequence_length or self.sequence_length
|
|
289
295
|
shape = tf.cast([-1, sequence_length], "int64")
|
|
290
|
-
token_ids =
|
|
291
|
-
|
|
296
|
+
token_ids = pad(
|
|
297
|
+
token_ids,
|
|
298
|
+
shape=shape,
|
|
299
|
+
padding_side=self.padding_side,
|
|
300
|
+
pad_value=self.pad_value,
|
|
301
|
+
)
|
|
302
|
+
segment_ids = pad(
|
|
303
|
+
segment_ids,
|
|
304
|
+
shape=shape,
|
|
305
|
+
padding_side=self.padding_side,
|
|
306
|
+
pad_value=0,
|
|
292
307
|
)
|
|
293
|
-
segment_ids = segment_ids.to_tensor(shape=shape)
|
|
294
308
|
# Remove the batch dim if added.
|
|
295
309
|
if unbatched:
|
|
296
310
|
token_ids = tf.squeeze(token_ids, 0)
|
|
@@ -3,6 +3,7 @@ from keras_hub.src.layers.preprocessing.preprocessing_layer import (
|
|
|
3
3
|
PreprocessingLayer,
|
|
4
4
|
)
|
|
5
5
|
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
|
|
6
|
+
from keras_hub.src.utils.tensor_utils import pad
|
|
6
7
|
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
|
7
8
|
|
|
8
9
|
try:
|
|
@@ -39,6 +40,8 @@ class StartEndPacker(PreprocessingLayer):
|
|
|
39
40
|
0 or "" will be added depending on the dtype of the input tensor.
|
|
40
41
|
return_padding_mask: bool. Whether to return a boolean padding mask of
|
|
41
42
|
all locations that are filled in with the `pad_value`.
|
|
43
|
+
padding_side: str. Whether to pad the input on the "left" or "right".
|
|
44
|
+
Defaults to "right".
|
|
42
45
|
|
|
43
46
|
Call arguments:
|
|
44
47
|
inputs: A `tf.Tensor`, `tf.RaggedTensor`, or list of python strings.
|
|
@@ -111,6 +114,7 @@ class StartEndPacker(PreprocessingLayer):
|
|
|
111
114
|
pad_value=None,
|
|
112
115
|
return_padding_mask=False,
|
|
113
116
|
name=None,
|
|
117
|
+
padding_side="right",
|
|
114
118
|
**kwargs,
|
|
115
119
|
):
|
|
116
120
|
super().__init__(name=name, **kwargs)
|
|
@@ -139,6 +143,7 @@ class StartEndPacker(PreprocessingLayer):
|
|
|
139
143
|
|
|
140
144
|
self.pad_value = pad_value
|
|
141
145
|
self.return_padding_mask = return_padding_mask
|
|
146
|
+
self.padding_side = padding_side
|
|
142
147
|
|
|
143
148
|
@preprocessing_function
|
|
144
149
|
def call(
|
|
@@ -154,6 +159,13 @@ class StartEndPacker(PreprocessingLayer):
|
|
|
154
159
|
batch_size = tf.shape(x)[0]
|
|
155
160
|
sequence_length = sequence_length or self.sequence_length
|
|
156
161
|
dtype = inputs.dtype
|
|
162
|
+
# Truncate.
|
|
163
|
+
truncation_length = sequence_length
|
|
164
|
+
if add_start_value and self.start_value is not None:
|
|
165
|
+
truncation_length -= len(self.start_value)
|
|
166
|
+
if add_end_value and self.end_value is not None:
|
|
167
|
+
truncation_length -= len(self.end_value)
|
|
168
|
+
x = x[..., :truncation_length]
|
|
157
169
|
|
|
158
170
|
# Concatenate start and end tokens.
|
|
159
171
|
if add_start_value and self.start_value is not None:
|
|
@@ -167,23 +179,28 @@ class StartEndPacker(PreprocessingLayer):
|
|
|
167
179
|
end_token_id_tensor = tf.repeat(
|
|
168
180
|
end_value[tf.newaxis, :], repeats=batch_size, axis=0
|
|
169
181
|
)
|
|
170
|
-
# Trim to leave room for end token.
|
|
171
|
-
x = x[..., : sequence_length - len(self.end_value)]
|
|
172
182
|
x = tf.concat([x, end_token_id_tensor], axis=-1)
|
|
173
183
|
|
|
174
184
|
# Pad to desired length.
|
|
175
|
-
outputs =
|
|
176
|
-
|
|
185
|
+
outputs = pad(
|
|
186
|
+
x,
|
|
187
|
+
pad_value=self.pad_value,
|
|
188
|
+
padding_side=self.padding_side,
|
|
177
189
|
shape=(batch_size, sequence_length),
|
|
178
190
|
)
|
|
179
191
|
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs
|
|
180
192
|
|
|
181
193
|
if self.return_padding_mask:
|
|
182
194
|
mask = tf.ones_like(x, dtype="bool")
|
|
183
|
-
|
|
195
|
+
|
|
196
|
+
mask = pad(
|
|
197
|
+
mask,
|
|
198
|
+
pad_value=False,
|
|
199
|
+
padding_side=self.padding_side,
|
|
200
|
+
shape=(batch_size, sequence_length),
|
|
201
|
+
)
|
|
184
202
|
mask = tf.squeeze(mask, axis=0) if unbatched else mask
|
|
185
203
|
return outputs, mask
|
|
186
|
-
|
|
187
204
|
return outputs
|
|
188
205
|
|
|
189
206
|
def get_config(self):
|
|
@@ -195,6 +212,7 @@ class StartEndPacker(PreprocessingLayer):
|
|
|
195
212
|
"end_value": self._end_value,
|
|
196
213
|
"pad_value": self.pad_value,
|
|
197
214
|
"return_padding_mask": self.return_padding_mask,
|
|
215
|
+
"padding_side": self.padding_side,
|
|
198
216
|
}
|
|
199
217
|
)
|
|
200
218
|
return config
|
keras_hub/src/models/backbone.py
CHANGED
|
@@ -189,23 +189,26 @@ class Backbone(keras.Model):
|
|
|
189
189
|
saver = get_preset_saver(preset_dir)
|
|
190
190
|
saver.save_backbone(self, max_shard_size=max_shard_size)
|
|
191
191
|
|
|
192
|
-
def
|
|
193
|
-
"""Returns list of layer names which are to be LoRA-fied.
|
|
194
|
-
|
|
195
|
-
Subclasses can override this method if the names of layers to be
|
|
196
|
-
LoRa-fied are different.
|
|
197
|
-
"""
|
|
192
|
+
def default_lora_layer_names(self):
|
|
193
|
+
"""Returns list of layer names which are to be LoRA-fied."""
|
|
198
194
|
return ["query_dense", "value_dense", "query", "value"]
|
|
199
195
|
|
|
200
|
-
def enable_lora(self, rank,
|
|
196
|
+
def enable_lora(self, rank, target_layer_names=None):
|
|
201
197
|
"""Enable Lora on the backbone.
|
|
202
198
|
|
|
203
199
|
Calling this method will freeze all weights on the backbone,
|
|
204
200
|
while enabling Lora on the query & value `EinsumDense` layers
|
|
205
201
|
of the attention layers.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
rank: The rank of the LoRA factorization.
|
|
205
|
+
target_layer_names: A list of strings, the names of the layers to
|
|
206
|
+
apply LoRA to. If `None`, this will be populated with the
|
|
207
|
+
default LoRA layer names as returned by
|
|
208
|
+
`backbone.default_lora_layer_names()`.
|
|
206
209
|
"""
|
|
207
|
-
if
|
|
208
|
-
|
|
210
|
+
if target_layer_names is None:
|
|
211
|
+
target_layer_names = self.default_lora_layer_names()
|
|
209
212
|
self.trainable = True
|
|
210
213
|
self._lora_enabled_layers = []
|
|
211
214
|
self._lora_rank = rank
|
|
@@ -214,7 +217,7 @@ class Backbone(keras.Model):
|
|
|
214
217
|
all_layers = self._flatten_layers(include_self=False)
|
|
215
218
|
all_layers = [lyr for lyr in all_layers if lyr.weights]
|
|
216
219
|
for i, layer in enumerate(all_layers):
|
|
217
|
-
for name in
|
|
220
|
+
for name in target_layer_names:
|
|
218
221
|
if layer.name == name:
|
|
219
222
|
if hasattr(layer, "enable_lora"):
|
|
220
223
|
layer.trainable = True
|
|
@@ -1,109 +1,10 @@
|
|
|
1
|
-
import math
|
|
2
|
-
|
|
3
1
|
from keras import layers
|
|
4
|
-
from keras import ops
|
|
5
2
|
|
|
6
3
|
from keras_hub.src.api_export import keras_hub_export
|
|
7
4
|
from keras_hub.src.models.backbone import Backbone
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
"""The vision pooler layer of CLIP.
|
|
12
|
-
|
|
13
|
-
`CLIPVisionPooler` will extracts the first token (index `0`) from the
|
|
14
|
-
sequence of the vision embeddings as the pooled outputs.
|
|
15
|
-
|
|
16
|
-
Call arguments:
|
|
17
|
-
vision_embeddings: A tensor of shape
|
|
18
|
-
`(batch_size, sequence_length, hidden_dim)`.
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
def call(self, vision_embeddings):
|
|
22
|
-
return vision_embeddings[:, 0, :]
|
|
23
|
-
|
|
24
|
-
def compute_output_shape(self, input_shape):
|
|
25
|
-
return (input_shape[0], input_shape[-1])
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class CLIPTextPooler(layers.Layer):
|
|
29
|
-
"""The text pooler layer of CLIP.
|
|
30
|
-
|
|
31
|
-
`CLIPTextPooler` extracts the text embeddings at the positions of EOS tokens
|
|
32
|
-
as the pooled outputs.
|
|
33
|
-
|
|
34
|
-
Call arguments:
|
|
35
|
-
text_embeddings: A tensor of shape
|
|
36
|
-
`(batch_size, sequence_length, hidden_dim)`.
|
|
37
|
-
token_ids: A tensor of shape `(batch_size, max_tokens)`, used to
|
|
38
|
-
identify the positions of EOS tokens.
|
|
39
|
-
"""
|
|
40
|
-
|
|
41
|
-
def call(self, text_embeddings, token_ids):
|
|
42
|
-
# `keepdims` is not supported in `keras<=3.1`.
|
|
43
|
-
eos_index = ops.argmax(token_ids, axis=-1)
|
|
44
|
-
eos_index = ops.expand_dims(eos_index, axis=-1)
|
|
45
|
-
eos_index = ops.expand_dims(eos_index, axis=-1)
|
|
46
|
-
pooled_outputs = ops.take_along_axis(text_embeddings, eos_index, axis=1)
|
|
47
|
-
return ops.squeeze(pooled_outputs, axis=1)
|
|
48
|
-
|
|
49
|
-
def compute_output_shape(self, input_shape):
|
|
50
|
-
return (input_shape[0], input_shape[-1])
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class CLIPHead(layers.Layer):
|
|
54
|
-
"""The head layer of CLIP.
|
|
55
|
-
|
|
56
|
-
`CLIPHead` takes `vision_embedding` and `text_embedding` as inputs to
|
|
57
|
-
compute the corresponding logits. Both embeddings are L2 normalized and used
|
|
58
|
-
to compute pairwise cosine similarity. The resulting logits are then scaled
|
|
59
|
-
by a learnable `logit_scale` parameter.
|
|
60
|
-
|
|
61
|
-
Call arguments:
|
|
62
|
-
vision_embedding: A tensor of shape `(batch_size, hidden_dim)`.
|
|
63
|
-
text_embedding: A tensor of shape `(batch_size, hidden_dim)`.
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
def build(self, input_shape):
|
|
67
|
-
self.logit_scale = self.add_weight(
|
|
68
|
-
shape=(),
|
|
69
|
-
initializer=lambda *a, **kw: math.log(1 / 0.07),
|
|
70
|
-
trainable=True,
|
|
71
|
-
dtype=self.variable_dtype,
|
|
72
|
-
name="logit_scale",
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
def call(self, vision_embedding, text_embedding):
|
|
76
|
-
normalized_vision_embedding = ops.sqrt(
|
|
77
|
-
ops.sum(ops.power(vision_embedding, 2), axis=-1, keepdims=True)
|
|
78
|
-
)
|
|
79
|
-
normalized_text_embedding = ops.sqrt(
|
|
80
|
-
ops.sum(ops.power(text_embedding, 2), axis=-1, keepdims=True)
|
|
81
|
-
)
|
|
82
|
-
vision_embedding = vision_embedding / normalized_vision_embedding
|
|
83
|
-
text_embedding = text_embedding / normalized_text_embedding
|
|
84
|
-
logit_scale = ops.exp(self.logit_scale)
|
|
85
|
-
text_logits = (
|
|
86
|
-
ops.matmul(
|
|
87
|
-
text_embedding,
|
|
88
|
-
ops.transpose(vision_embedding),
|
|
89
|
-
)
|
|
90
|
-
* logit_scale
|
|
91
|
-
)
|
|
92
|
-
vision_logits = ops.transpose(text_logits)
|
|
93
|
-
return vision_logits, text_logits
|
|
94
|
-
|
|
95
|
-
def compute_output_shape(
|
|
96
|
-
self, vision_embedding_shape, text_embedding_shape
|
|
97
|
-
):
|
|
98
|
-
vision_logits_shape = (
|
|
99
|
-
vision_embedding_shape[0],
|
|
100
|
-
text_embedding_shape[0],
|
|
101
|
-
)
|
|
102
|
-
text_logits_shape = (
|
|
103
|
-
text_embedding_shape[0],
|
|
104
|
-
vision_embedding_shape[0],
|
|
105
|
-
)
|
|
106
|
-
return vision_logits_shape, text_logits_shape
|
|
5
|
+
from keras_hub.src.models.clip.clip_layers import CLIPHead
|
|
6
|
+
from keras_hub.src.models.clip.clip_layers import CLIPTextPooler
|
|
7
|
+
from keras_hub.src.models.clip.clip_layers import CLIPVisionPooler
|
|
107
8
|
|
|
108
9
|
|
|
109
10
|
@keras_hub_export("keras_hub.models.CLIPBackbone")
|