keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -274,6 +274,7 @@ class CausalLM(Task):
|
|
274
274
|
inputs,
|
275
275
|
max_length=None,
|
276
276
|
stop_token_ids="auto",
|
277
|
+
strip_prompt=False,
|
277
278
|
):
|
278
279
|
"""Generate text given prompt `inputs`.
|
279
280
|
|
@@ -302,13 +303,18 @@ class CausalLM(Task):
|
|
302
303
|
`preprocessor`. If `preprocessor` is `None`, `inputs` should be
|
303
304
|
should be padded to the desired maximum length and this argument
|
304
305
|
will be ignored.
|
305
|
-
stop_token_ids: Optional. `None`, "auto", or tuple of token ids.
|
306
|
-
to "auto" which uses the
|
307
|
-
Not specifying a
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
306
|
+
stop_token_ids: Optional. `None`, "auto", or tuple of token ids.
|
307
|
+
Defaults to "auto" which uses the
|
308
|
+
`preprocessor.tokenizer.end_token_id`. Not specifying a
|
309
|
+
processor will produce an error. None stops generation after
|
310
|
+
generating `max_length` tokens. You may also specify a list of
|
311
|
+
token id's the model should stop on. Note that sequences of
|
312
|
+
tokens will each be interpreted as a stop token, multi-token
|
313
|
+
stop sequences are not supported.
|
314
|
+
strip_prompt: Optional. By default, generate() returns the full
|
315
|
+
prompt followed by its completion generated by the model. If
|
316
|
+
this option is set to True, only the newly generated text is
|
317
|
+
returned.
|
312
318
|
"""
|
313
319
|
# Setup our three main passes.
|
314
320
|
# 1. Optionally preprocessing strings to dense integer tensors.
|
@@ -318,14 +324,19 @@ class CausalLM(Task):
|
|
318
324
|
|
319
325
|
if self.preprocessor is None and stop_token_ids == "auto":
|
320
326
|
raise ValueError(
|
321
|
-
|
322
|
-
"Currently `preprocessor=None`. To
|
323
|
-
"
|
324
|
-
"`
|
327
|
+
"A `preprocessor` must be attached to the model if "
|
328
|
+
'`stop_token_ids="auto"`. Currently `preprocessor=None`. To '
|
329
|
+
"call `generate()` with preprocessing detached, either pass "
|
330
|
+
"`stop_token_ids=None` to always generate until `max_length` "
|
331
|
+
"or pass a tuple of token ids that should terminate generation "
|
325
332
|
"as `stop_token_ids`."
|
326
333
|
)
|
327
334
|
elif stop_token_ids == "auto":
|
328
335
|
stop_token_ids = [self.preprocessor.tokenizer.end_token_id]
|
336
|
+
# Some models like Llama3 use two end tokens: <|eot_id|> in
|
337
|
+
# "instruct" versions and <|end_of_text|> in others.
|
338
|
+
if hasattr(self.preprocessor.tokenizer, "end_token2_id"):
|
339
|
+
stop_token_ids.append(self.preprocessor.tokenizer.end_token2_id)
|
329
340
|
|
330
341
|
def preprocess(x):
|
331
342
|
return self.preprocessor.generate_preprocess(
|
@@ -335,6 +346,34 @@ class CausalLM(Task):
|
|
335
346
|
def generate(x):
|
336
347
|
return generate_function(x, stop_token_ids=stop_token_ids)
|
337
348
|
|
349
|
+
def strip_prompt_function(x, prompt):
|
350
|
+
# This function removes the prompt from the generated
|
351
|
+
# response, in a batch-friendly fashion.
|
352
|
+
y = {}
|
353
|
+
prompt_mask = prompt["padding_mask"]
|
354
|
+
seq_len = prompt_mask.shape[1]
|
355
|
+
|
356
|
+
# We need to shift every output sequence by the size of the prompt.
|
357
|
+
shifts = -ops.sum(ops.cast(prompt_mask, "int"), axis=1) % seq_len
|
358
|
+
ix = ops.arange(seq_len, dtype="int")
|
359
|
+
ix = ops.expand_dims(ix, axis=0) - ops.expand_dims(shifts, axis=1)
|
360
|
+
|
361
|
+
# This produces the desired shift (in fact a rollover).
|
362
|
+
def roll_sequence(seq):
|
363
|
+
return ops.take_along_axis(seq, ix, axis=1)
|
364
|
+
|
365
|
+
# The shifting rolls the content over so the prompt is at the end of
|
366
|
+
# the sequence and the generated text is at the beginning. We mask
|
367
|
+
# it to retain the generated text only.
|
368
|
+
y["padding_mask"] = ops.logical_xor(
|
369
|
+
roll_sequence(prompt_mask), roll_sequence(x["padding_mask"])
|
370
|
+
)
|
371
|
+
# we assume the mask is enough and there is no need to zero-out the
|
372
|
+
# values
|
373
|
+
y["token_ids"] = roll_sequence(x["token_ids"])
|
374
|
+
|
375
|
+
return y
|
376
|
+
|
338
377
|
def postprocess(x):
|
339
378
|
return self.preprocessor.generate_postprocess(x)
|
340
379
|
|
@@ -343,7 +382,12 @@ class CausalLM(Task):
|
|
343
382
|
|
344
383
|
if self.preprocessor is not None:
|
345
384
|
inputs = [preprocess(x) for x in inputs]
|
346
|
-
|
385
|
+
|
386
|
+
if strip_prompt:
|
387
|
+
outputs = [strip_prompt_function(generate(x), x) for x in inputs]
|
388
|
+
else:
|
389
|
+
outputs = [generate(x) for x in inputs]
|
390
|
+
|
347
391
|
if self.preprocessor is not None:
|
348
392
|
outputs = [postprocess(x) for x in outputs]
|
349
393
|
|
@@ -0,0 +1,286 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
from keras import layers
|
4
|
+
from keras import ops
|
5
|
+
|
6
|
+
from keras_hub.src.api_export import keras_hub_export
|
7
|
+
from keras_hub.src.models.backbone import Backbone
|
8
|
+
|
9
|
+
|
10
|
+
class CLIPVisionPooler(layers.Layer):
|
11
|
+
"""The vision pooler layer of CLIP.
|
12
|
+
|
13
|
+
`CLIPVisionPooler` will extracts the first token (index `0`) from the
|
14
|
+
sequence of the vision embeddings as the pooled outputs.
|
15
|
+
|
16
|
+
Call arguments:
|
17
|
+
vision_embeddings: A tensor of shape
|
18
|
+
`(batch_size, sequence_length, hidden_dim)`.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def call(self, vision_embeddings):
|
22
|
+
return vision_embeddings[:, 0, :]
|
23
|
+
|
24
|
+
def compute_output_shape(self, input_shape):
|
25
|
+
return (input_shape[0], input_shape[-1])
|
26
|
+
|
27
|
+
|
28
|
+
class CLIPTextPooler(layers.Layer):
|
29
|
+
"""The text pooler layer of CLIP.
|
30
|
+
|
31
|
+
`CLIPTextPooler` extracts the text embeddings at the positions of EOS tokens
|
32
|
+
as the pooled outputs.
|
33
|
+
|
34
|
+
Call arguments:
|
35
|
+
text_embeddings: A tensor of shape
|
36
|
+
`(batch_size, sequence_length, hidden_dim)`.
|
37
|
+
token_ids: A tensor of shape `(batch_size, max_tokens)`, used to
|
38
|
+
identify the positions of EOS tokens.
|
39
|
+
"""
|
40
|
+
|
41
|
+
def call(self, text_embeddings, token_ids):
|
42
|
+
# `keepdims` is not supported in `keras<=3.1`.
|
43
|
+
eos_index = ops.argmax(token_ids, axis=-1)
|
44
|
+
eos_index = ops.expand_dims(eos_index, axis=-1)
|
45
|
+
eos_index = ops.expand_dims(eos_index, axis=-1)
|
46
|
+
pooled_outputs = ops.take_along_axis(text_embeddings, eos_index, axis=1)
|
47
|
+
return ops.squeeze(pooled_outputs, axis=1)
|
48
|
+
|
49
|
+
def compute_output_shape(self, input_shape):
|
50
|
+
return (input_shape[0], input_shape[-1])
|
51
|
+
|
52
|
+
|
53
|
+
class CLIPHead(layers.Layer):
|
54
|
+
"""The head layer of CLIP.
|
55
|
+
|
56
|
+
`CLIPHead` takes `vision_embedding` and `text_embedding` as inputs to
|
57
|
+
compute the corresponding logits. Both embeddings are L2 normalized and used
|
58
|
+
to compute pairwise cosine similarity. The resulting logits are then scaled
|
59
|
+
by a learnable `logit_scale` parameter.
|
60
|
+
|
61
|
+
Call arguments:
|
62
|
+
vision_embedding: A tensor of shape `(batch_size, hidden_dim)`.
|
63
|
+
text_embedding: A tensor of shape `(batch_size, hidden_dim)`.
|
64
|
+
"""
|
65
|
+
|
66
|
+
def build(self, input_shape):
|
67
|
+
self.logit_scale = self.add_weight(
|
68
|
+
shape=(),
|
69
|
+
initializer=lambda *a, **kw: math.log(1 / 0.07),
|
70
|
+
trainable=True,
|
71
|
+
dtype=self.variable_dtype,
|
72
|
+
name="logit_scale",
|
73
|
+
)
|
74
|
+
|
75
|
+
def call(self, vision_embedding, text_embedding):
|
76
|
+
normalized_vision_embedding = ops.sqrt(
|
77
|
+
ops.sum(ops.power(vision_embedding, 2), axis=-1, keepdims=True)
|
78
|
+
)
|
79
|
+
normalized_text_embedding = ops.sqrt(
|
80
|
+
ops.sum(ops.power(text_embedding, 2), axis=-1, keepdims=True)
|
81
|
+
)
|
82
|
+
vision_embedding = vision_embedding / normalized_vision_embedding
|
83
|
+
text_embedding = text_embedding / normalized_text_embedding
|
84
|
+
logit_scale = ops.exp(self.logit_scale)
|
85
|
+
text_logits = (
|
86
|
+
ops.matmul(
|
87
|
+
text_embedding,
|
88
|
+
ops.transpose(vision_embedding),
|
89
|
+
)
|
90
|
+
* logit_scale
|
91
|
+
)
|
92
|
+
vision_logits = ops.transpose(text_logits)
|
93
|
+
return vision_logits, text_logits
|
94
|
+
|
95
|
+
def compute_output_shape(
|
96
|
+
self, vision_embedding_shape, text_embedding_shape
|
97
|
+
):
|
98
|
+
vision_logits_shape = (
|
99
|
+
vision_embedding_shape[0],
|
100
|
+
text_embedding_shape[0],
|
101
|
+
)
|
102
|
+
text_logits_shape = (
|
103
|
+
text_embedding_shape[0],
|
104
|
+
vision_embedding_shape[0],
|
105
|
+
)
|
106
|
+
return vision_logits_shape, text_logits_shape
|
107
|
+
|
108
|
+
|
109
|
+
@keras_hub_export("keras_hub.models.CLIPBackbone")
|
110
|
+
class CLIPBackbone(Backbone):
|
111
|
+
"""CLIP core network with hyperparameters.
|
112
|
+
|
113
|
+
This backbone implements the base architecture for Contrastive
|
114
|
+
Language-Image Pretraining (CLIP) model. It includes a vision and text
|
115
|
+
encoders and the corresponding projection layers. This backbone will output
|
116
|
+
the final logit scores corresponding to each image and token input. These
|
117
|
+
values are cosine similarities between the corresponding image and text
|
118
|
+
features.
|
119
|
+
|
120
|
+
The default constructor gives a fully customizable, randomly initialized
|
121
|
+
CLIP model with any number of layers, heads, and embedding dimensions. To
|
122
|
+
load preset architectures and weights, use the `from_preset` constructor.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
vision_encoder: The CLIP vision encoder for encoding the input images.
|
126
|
+
text_encoder: The CLIP text encoder for encoding the input tokens.
|
127
|
+
projection_dim: int. The size of the projection layer.
|
128
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
129
|
+
for the models computations and weights. Note that some
|
130
|
+
computations, such as softmax and layer normalization will always
|
131
|
+
be done a float32 precision regardless of dtype.
|
132
|
+
|
133
|
+
Example:
|
134
|
+
```python
|
135
|
+
input_data = {
|
136
|
+
"images": np.ones(shape=(1, 224, 224, 3), dtype="float32"),
|
137
|
+
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
|
138
|
+
}
|
139
|
+
|
140
|
+
# Pretrained CLIP model.
|
141
|
+
model = keras_hub.models.CLIPBackbone.from_preset("clip_vit_base_patch32")
|
142
|
+
model(input_data)
|
143
|
+
|
144
|
+
# Randomly initialized CLIP model with custom config.
|
145
|
+
vision_encoder = keras_hub.models.CLIPVisionEncoder(
|
146
|
+
patch_size=32,
|
147
|
+
hidden_dim=768,
|
148
|
+
num_layers=8,
|
149
|
+
num_heads=8,
|
150
|
+
intermediate_dim=2048,
|
151
|
+
image_shape=(384, 384, 3),
|
152
|
+
)
|
153
|
+
text_encoder = keras_hub.models.CLIPTextEncoder(
|
154
|
+
vocabulary_size=49408,
|
155
|
+
embedding_dim=768,
|
156
|
+
hidden_dim=768,
|
157
|
+
num_layers=8,
|
158
|
+
num_heads=8,
|
159
|
+
intermediate_dim=2048,
|
160
|
+
)
|
161
|
+
model = keras_hub.models.CLIPBackbone(
|
162
|
+
vision_encoder=vision_encoder,
|
163
|
+
text_encoder=text_encoder,
|
164
|
+
projection_dim=256,
|
165
|
+
)
|
166
|
+
model(input_data)
|
167
|
+
```
|
168
|
+
"""
|
169
|
+
|
170
|
+
def __init__(
|
171
|
+
self,
|
172
|
+
vision_encoder,
|
173
|
+
text_encoder,
|
174
|
+
projection_dim,
|
175
|
+
dtype=None,
|
176
|
+
name=None,
|
177
|
+
**kwargs,
|
178
|
+
):
|
179
|
+
# === Layers ===
|
180
|
+
self.vision_encoder = vision_encoder
|
181
|
+
self.text_encoder = text_encoder
|
182
|
+
self.vision_pooler = CLIPVisionPooler(dtype=dtype, name="vision_pooler")
|
183
|
+
self.text_pooler = CLIPTextPooler(dtype=dtype, name="text_pooler")
|
184
|
+
self.vision_projection = layers.Dense(
|
185
|
+
projection_dim,
|
186
|
+
use_bias=False,
|
187
|
+
dtype=dtype,
|
188
|
+
name="vision_projection",
|
189
|
+
)
|
190
|
+
self.text_projection = layers.Dense(
|
191
|
+
projection_dim,
|
192
|
+
use_bias=False,
|
193
|
+
dtype=dtype,
|
194
|
+
name="text_projection",
|
195
|
+
)
|
196
|
+
self.clip_head = CLIPHead(dtype=dtype, name="clip_head")
|
197
|
+
|
198
|
+
# === Functional Model ===
|
199
|
+
image_input = layers.Input(
|
200
|
+
shape=self.vision_encoder.image_shape, name="images"
|
201
|
+
)
|
202
|
+
token_id_input = layers.Input(
|
203
|
+
shape=(None,), dtype="int32", name="token_ids"
|
204
|
+
)
|
205
|
+
vision_embeddings = self.get_vision_embeddings(image_input)
|
206
|
+
text_embeddings = self.get_text_embeddings(token_id_input)
|
207
|
+
vision_logits, text_logits = self.clip_head(
|
208
|
+
vision_embeddings, text_embeddings
|
209
|
+
)
|
210
|
+
|
211
|
+
super().__init__(
|
212
|
+
inputs={
|
213
|
+
"images": image_input,
|
214
|
+
"token_ids": token_id_input,
|
215
|
+
},
|
216
|
+
outputs={
|
217
|
+
"vision_logits": vision_logits,
|
218
|
+
"text_logits": text_logits,
|
219
|
+
},
|
220
|
+
dtype=dtype,
|
221
|
+
name=name,
|
222
|
+
**kwargs,
|
223
|
+
)
|
224
|
+
|
225
|
+
# === Config ===
|
226
|
+
self.projection_dim = projection_dim
|
227
|
+
|
228
|
+
def get_vision_embeddings(self, images):
|
229
|
+
"""Get the embeddings from the vision encoder.
|
230
|
+
|
231
|
+
Args:
|
232
|
+
images: The input tensor for the vision encoder.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
The output embeddings obtained by applying projection layer to the
|
236
|
+
pooled output of the vision encoder.
|
237
|
+
"""
|
238
|
+
vision_outputs = self.vision_encoder({"images": images})
|
239
|
+
vision_outputs = self.vision_pooler(vision_outputs)
|
240
|
+
return self.vision_projection(vision_outputs)
|
241
|
+
|
242
|
+
def get_text_embeddings(self, token_ids):
|
243
|
+
"""Get the embeddings from the text encoder.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
token_ids: The input int tensor for the text encoder.
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
The output embeddings obtained by applying projection layer to the
|
250
|
+
pooled output of the text encoder.
|
251
|
+
"""
|
252
|
+
text_outputs = self.text_encoder({"token_ids": token_ids})
|
253
|
+
text_outputs = self.text_pooler(text_outputs, token_ids)
|
254
|
+
return self.text_projection(text_outputs)
|
255
|
+
|
256
|
+
def get_config(self):
|
257
|
+
config = super().get_config()
|
258
|
+
config.update(
|
259
|
+
{
|
260
|
+
"vision_encoder": layers.serialize(self.vision_encoder),
|
261
|
+
"text_encoder": layers.serialize(self.text_encoder),
|
262
|
+
"projection_dim": self.projection_dim,
|
263
|
+
}
|
264
|
+
)
|
265
|
+
return config
|
266
|
+
|
267
|
+
@classmethod
|
268
|
+
def from_config(cls, config, custom_objects=None):
|
269
|
+
config = config.copy()
|
270
|
+
|
271
|
+
# Propagate `dtype` to submodels if needed.
|
272
|
+
if "dtype" in config and config["dtype"] is not None:
|
273
|
+
dtype_config = config["dtype"]
|
274
|
+
if "dtype" not in config["vision_encoder"]["config"]:
|
275
|
+
config["vision_encoder"]["config"]["dtype"] = dtype_config
|
276
|
+
if "dtype" not in config["text_encoder"]["config"]:
|
277
|
+
config["text_encoder"]["config"]["dtype"] = dtype_config
|
278
|
+
|
279
|
+
# We expect submodels to be instantiated.
|
280
|
+
config["vision_encoder"] = layers.deserialize(
|
281
|
+
config["vision_encoder"], custom_objects=custom_objects
|
282
|
+
)
|
283
|
+
config["text_encoder"] = layers.deserialize(
|
284
|
+
config["text_encoder"], custom_objects=custom_objects
|
285
|
+
)
|
286
|
+
return cls(**config)
|
@@ -7,6 +7,16 @@ def quick_gelu(x):
|
|
7
7
|
return x * ops.sigmoid(1.702 * x)
|
8
8
|
|
9
9
|
|
10
|
+
# TODO: Deprecate this in favor of `keras.layers.MultiHeadAttention` once the
|
11
|
+
# dtype compatibility issue is resolved.
|
12
|
+
class CLIPMultiHeadAttention(layers.MultiHeadAttention):
|
13
|
+
def _masked_softmax(self, attention_scores, attention_mask=None):
|
14
|
+
attention_scores = super()._masked_softmax(
|
15
|
+
attention_scores, attention_mask
|
16
|
+
)
|
17
|
+
return ops.cast(attention_scores, self._value_dense.compute_dtype)
|
18
|
+
|
19
|
+
|
10
20
|
class CLIPEncoderBlock(layers.Layer):
|
11
21
|
def __init__(
|
12
22
|
self,
|
@@ -14,6 +24,7 @@ class CLIPEncoderBlock(layers.Layer):
|
|
14
24
|
num_heads,
|
15
25
|
intermediate_dim,
|
16
26
|
intermediate_activation="quick_gelu",
|
27
|
+
use_causal_mask=True,
|
17
28
|
**kwargs,
|
18
29
|
):
|
19
30
|
super().__init__(**kwargs)
|
@@ -26,21 +37,22 @@ class CLIPEncoderBlock(layers.Layer):
|
|
26
37
|
self.num_heads = num_heads
|
27
38
|
self.intermediate_dim = intermediate_dim
|
28
39
|
self.intermediate_activation = intermediate_activation
|
40
|
+
self.use_causal_mask = use_causal_mask
|
29
41
|
|
30
42
|
if intermediate_activation == "quick_gelu":
|
31
43
|
intermediate_activation = quick_gelu
|
32
44
|
|
33
45
|
self.layer_norm_1 = layers.LayerNormalization(
|
34
|
-
epsilon=1e-5, dtype=
|
46
|
+
epsilon=1e-5, dtype=self.dtype_policy, name="layer_norm_1"
|
35
47
|
)
|
36
|
-
self.attention =
|
48
|
+
self.attention = CLIPMultiHeadAttention(
|
37
49
|
num_heads,
|
38
50
|
hidden_dim // num_heads,
|
39
51
|
dtype=self.dtype_policy,
|
40
52
|
name="attention",
|
41
53
|
)
|
42
54
|
self.layer_norm_2 = layers.LayerNormalization(
|
43
|
-
epsilon=1e-5, dtype=
|
55
|
+
epsilon=1e-5, dtype=self.dtype_policy, name="layer_norm_2"
|
44
56
|
)
|
45
57
|
self.dense_1 = layers.Dense(
|
46
58
|
self.intermediate_dim, dtype=self.dtype_policy, name="dense_1"
|
@@ -73,7 +85,9 @@ class CLIPEncoderBlock(layers.Layer):
|
|
73
85
|
def call(self, x, training=None):
|
74
86
|
residual = x
|
75
87
|
x = self.layer_norm_1(x)
|
76
|
-
x = self.attention(
|
88
|
+
x = self.attention(
|
89
|
+
x, x, x, training=training, use_causal_mask=self.use_causal_mask
|
90
|
+
)
|
77
91
|
x = ops.add(residual, x)
|
78
92
|
|
79
93
|
residual = x
|
@@ -91,6 +105,7 @@ class CLIPEncoderBlock(layers.Layer):
|
|
91
105
|
"num_heads": self.num_heads,
|
92
106
|
"intermediate_dim": self.intermediate_dim,
|
93
107
|
"intermediate_activation": self.intermediate_activation,
|
108
|
+
"use_causal_mask": self.use_causal_mask,
|
94
109
|
}
|
95
110
|
)
|
96
111
|
return config
|
@@ -0,0 +1,8 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
3
|
+
from keras_hub.src.models.clip.clip_backbone import CLIPBackbone
|
4
|
+
|
5
|
+
|
6
|
+
@keras_hub_export("keras_hub.layers.CLIPImageConverter")
|
7
|
+
class CLIPImageConverter(ImageConverter):
|
8
|
+
backbone_cls = CLIPBackbone
|
@@ -0,0 +1,93 @@
|
|
1
|
+
"""CLIP model preset configurations."""
|
2
|
+
|
3
|
+
# Metadata for loading pretrained model weights.
|
4
|
+
backbone_presets = {
|
5
|
+
"clip_vit_base_patch16": {
|
6
|
+
"metadata": {
|
7
|
+
"description": (
|
8
|
+
"150 million parameter, 12-layer for vision and 12-layer for "
|
9
|
+
"text, patch size of 16, CLIP model."
|
10
|
+
),
|
11
|
+
"params": 149620934,
|
12
|
+
"path": "clip",
|
13
|
+
},
|
14
|
+
"kaggle_handle": "kaggle://keras/clip/keras/clip_vit_base_patch16/2",
|
15
|
+
},
|
16
|
+
"clip_vit_base_patch32": {
|
17
|
+
"metadata": {
|
18
|
+
"description": (
|
19
|
+
"151 million parameter, 12-layer for vision and 12-layer for "
|
20
|
+
"text, patch size of 32, CLIP model."
|
21
|
+
),
|
22
|
+
"params": 151277363,
|
23
|
+
"path": "clip",
|
24
|
+
},
|
25
|
+
"kaggle_handle": "kaggle://keras/clip/keras/clip_vit_base_patch32/2",
|
26
|
+
},
|
27
|
+
"clip_vit_large_patch14": {
|
28
|
+
"metadata": {
|
29
|
+
"description": (
|
30
|
+
"428 million parameter, 24-layer for vision and 12-layer for "
|
31
|
+
"text, patch size of 14, CLIP model."
|
32
|
+
),
|
33
|
+
"params": 427616770,
|
34
|
+
"path": "clip",
|
35
|
+
},
|
36
|
+
"kaggle_handle": "kaggle://keras/clip/keras/clip_vit_large_patch14/2",
|
37
|
+
},
|
38
|
+
"clip_vit_large_patch14_336": {
|
39
|
+
"metadata": {
|
40
|
+
"description": (
|
41
|
+
"428 million parameter, 24-layer for vision and 12-layer for "
|
42
|
+
"text, patch size of 14, image size of 336, CLIP model."
|
43
|
+
),
|
44
|
+
"params": 427944770,
|
45
|
+
"path": "clip",
|
46
|
+
},
|
47
|
+
"kaggle_handle": "kaggle://keras/clip/keras/clip_vit_large_patch14_336/2",
|
48
|
+
},
|
49
|
+
"clip_vit_b_32_laion2b_s34b_b79k": {
|
50
|
+
"metadata": {
|
51
|
+
"description": (
|
52
|
+
"151 million parameter, 12-layer for vision and 12-layer for "
|
53
|
+
"text, patch size of 32, Open CLIP model."
|
54
|
+
),
|
55
|
+
"params": 151277363,
|
56
|
+
"path": "clip",
|
57
|
+
},
|
58
|
+
"kaggle_handle": "kaggle://keras/clip/keras/clip_vit_b_32_laion2b_s34b_b79k/2",
|
59
|
+
},
|
60
|
+
"clip_vit_h_14_laion2b_s32b_b79k": {
|
61
|
+
"metadata": {
|
62
|
+
"description": (
|
63
|
+
"986 million parameter, 32-layer for vision and 24-layer for "
|
64
|
+
"text, patch size of 14, Open CLIP model."
|
65
|
+
),
|
66
|
+
"params": 986109698,
|
67
|
+
"path": "clip",
|
68
|
+
},
|
69
|
+
"kaggle_handle": "kaggle://keras/clip/keras/clip_vit_h_14_laion2b_s32b_b79k/2",
|
70
|
+
},
|
71
|
+
"clip_vit_g_14_laion2b_s12b_b42k": {
|
72
|
+
"metadata": {
|
73
|
+
"description": (
|
74
|
+
"1.4 billion parameter, 40-layer for vision and 24-layer for "
|
75
|
+
"text, patch size of 14, Open CLIP model."
|
76
|
+
),
|
77
|
+
"params": 1366678530,
|
78
|
+
"path": "clip",
|
79
|
+
},
|
80
|
+
"kaggle_handle": "kaggle://keras/clip/keras/clip_vit_g_14_laion2b_s12b_b42k/2",
|
81
|
+
},
|
82
|
+
"clip_vit_bigg_14_laion2b_39b_b160k": {
|
83
|
+
"metadata": {
|
84
|
+
"description": (
|
85
|
+
"2.5 billion parameter, 48-layer for vision and 32-layer for "
|
86
|
+
"text, patch size of 14, Open CLIP model."
|
87
|
+
),
|
88
|
+
"params": 2539567362,
|
89
|
+
"path": "clip",
|
90
|
+
},
|
91
|
+
"kaggle_handle": "kaggle://keras/clip/keras/clip_vit_bigg_14_laion2b_39b_b160k/2",
|
92
|
+
},
|
93
|
+
}
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from keras import layers
|
2
2
|
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
3
4
|
from keras_hub.src.layers.modeling.token_and_position_embedding import (
|
4
5
|
TokenAndPositionEmbedding,
|
5
6
|
)
|
@@ -7,6 +8,7 @@ from keras_hub.src.models.backbone import Backbone
|
|
7
8
|
from keras_hub.src.models.clip.clip_encoder_block import CLIPEncoderBlock
|
8
9
|
|
9
10
|
|
11
|
+
@keras_hub_export("keras_hub.models.CLIPTextEncoder")
|
10
12
|
class CLIPTextEncoder(Backbone):
|
11
13
|
"""CLIP text core network with hyperparameters.
|
12
14
|
|
@@ -80,7 +82,7 @@ class CLIPTextEncoder(Backbone):
|
|
80
82
|
for i in range(num_layers)
|
81
83
|
]
|
82
84
|
self.layer_norm = layers.LayerNormalization(
|
83
|
-
epsilon=1e-6, dtype=
|
85
|
+
epsilon=1e-6, dtype=dtype, name=f"{prefix}layer_norm"
|
84
86
|
)
|
85
87
|
|
86
88
|
# === Functional Model ===
|
@@ -106,6 +108,7 @@ class CLIPTextEncoder(Backbone):
|
|
106
108
|
super().__init__(
|
107
109
|
inputs={"token_ids": token_id_input},
|
108
110
|
outputs=outputs,
|
111
|
+
dtype=dtype,
|
109
112
|
name=name,
|
110
113
|
**kwargs,
|
111
114
|
)
|
@@ -1,4 +1,5 @@
|
|
1
1
|
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.models.clip.clip_backbone import CLIPBackbone
|
2
3
|
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
|
3
4
|
from keras_hub.src.tokenizers.byte_pair_tokenizer import convert_to_ragged_batch
|
4
5
|
from keras_hub.src.tokenizers.byte_pair_tokenizer import split_strings_for_bpe
|
@@ -39,11 +40,25 @@ class CLIPTokenizer(BytePairTokenizer):
|
|
39
40
|
should have one merge rule per line. Every merge rule contains
|
40
41
|
merge entities separated by a space.
|
41
42
|
pad_with_end_token: bool. Whether to pad the output with `end_token`.
|
42
|
-
"""
|
43
43
|
|
44
|
-
|
44
|
+
Examples:
|
45
|
+
|
46
|
+
```python
|
47
|
+
# Unbatched input.
|
48
|
+
tokenizer = keras_hub.models.CLIPTokenizer.from_preset(
|
49
|
+
"clip_vit_base_patch32"
|
50
|
+
)
|
51
|
+
tokenizer("The quick brown fox jumped.")
|
52
|
+
|
53
|
+
# Batched input.
|
54
|
+
tokenizer(["The quick brown fox jumped.", "The fox slept."])
|
55
|
+
|
56
|
+
# Detokenization.
|
57
|
+
tokenizer.detokenize(tokenizer("The quick brown fox jumped."))
|
58
|
+
```
|
59
|
+
"""
|
45
60
|
|
46
|
-
backbone_cls =
|
61
|
+
backbone_cls = CLIPBackbone
|
47
62
|
|
48
63
|
def __init__(
|
49
64
|
self,
|