keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,126 @@
|
|
1
|
+
"""ViT model preset configurations."""
|
2
|
+
|
3
|
+
# Metadata for loading pretrained model weights.
|
4
|
+
backbone_presets = {
|
5
|
+
"vit_base_patch16_224_imagenet": {
|
6
|
+
"metadata": {
|
7
|
+
"description": (
|
8
|
+
"ViT-B16 model pre-trained on the ImageNet 1k dataset with "
|
9
|
+
"image resolution of 224x224 "
|
10
|
+
),
|
11
|
+
"params": 85798656,
|
12
|
+
"path": "vit",
|
13
|
+
},
|
14
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet/2",
|
15
|
+
},
|
16
|
+
"vit_base_patch16_384_imagenet": {
|
17
|
+
"metadata": {
|
18
|
+
"description": (
|
19
|
+
"ViT-B16 model pre-trained on the ImageNet 1k dataset with "
|
20
|
+
"image resolution of 384x384 "
|
21
|
+
),
|
22
|
+
"params": 86090496,
|
23
|
+
"path": "vit",
|
24
|
+
},
|
25
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_384_imagenet/2",
|
26
|
+
},
|
27
|
+
"vit_large_patch16_224_imagenet": {
|
28
|
+
"metadata": {
|
29
|
+
"description": (
|
30
|
+
"ViT-L16 model pre-trained on the ImageNet 1k dataset with "
|
31
|
+
"image resolution of 224x224 "
|
32
|
+
),
|
33
|
+
"params": 303301632,
|
34
|
+
"path": "vit",
|
35
|
+
},
|
36
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet/2",
|
37
|
+
},
|
38
|
+
"vit_large_patch16_384_imagenet": {
|
39
|
+
"metadata": {
|
40
|
+
"description": (
|
41
|
+
"ViT-L16 model pre-trained on the ImageNet 1k dataset with "
|
42
|
+
"image resolution of 384x384 "
|
43
|
+
),
|
44
|
+
"params": 303690752,
|
45
|
+
"path": "vit",
|
46
|
+
},
|
47
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_384_imagenet/2",
|
48
|
+
},
|
49
|
+
"vit_base_patch32_384_imagenet": {
|
50
|
+
"metadata": {
|
51
|
+
"description": (
|
52
|
+
"ViT-B32 model pre-trained on the ImageNet 1k dataset with "
|
53
|
+
"image resolution of 384x384 "
|
54
|
+
),
|
55
|
+
"params": 87528192,
|
56
|
+
"path": "vit",
|
57
|
+
},
|
58
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_384_imagenet/1",
|
59
|
+
},
|
60
|
+
"vit_large_patch32_384_imagenet": {
|
61
|
+
"metadata": {
|
62
|
+
"description": (
|
63
|
+
"ViT-L32 model pre-trained on the ImageNet 1k dataset with "
|
64
|
+
"image resolution of 384x384 "
|
65
|
+
),
|
66
|
+
"params": 305607680,
|
67
|
+
"path": "vit",
|
68
|
+
},
|
69
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_384_imagenet/1",
|
70
|
+
},
|
71
|
+
"vit_base_patch16_224_imagenet21k": {
|
72
|
+
"metadata": {
|
73
|
+
"description": (
|
74
|
+
"ViT-B16 backbone pre-trained on the ImageNet 21k dataset with "
|
75
|
+
"image resolution of 224x224 "
|
76
|
+
),
|
77
|
+
"params": 85798656,
|
78
|
+
"path": "vit",
|
79
|
+
},
|
80
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet21k/1",
|
81
|
+
},
|
82
|
+
"vit_base_patch32_224_imagenet21k": {
|
83
|
+
"metadata": {
|
84
|
+
"description": (
|
85
|
+
"ViT-B32 backbone pre-trained on the ImageNet 21k dataset with "
|
86
|
+
"image resolution of 224x224 "
|
87
|
+
),
|
88
|
+
"params": 87455232,
|
89
|
+
"path": "vit",
|
90
|
+
},
|
91
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_224_imagenet21k/1",
|
92
|
+
},
|
93
|
+
"vit_huge_patch14_224_imagenet21k": {
|
94
|
+
"metadata": {
|
95
|
+
"description": (
|
96
|
+
"ViT-H14 backbone pre-trained on the ImageNet 21k dataset with "
|
97
|
+
"image resolution of 224x224 "
|
98
|
+
),
|
99
|
+
"params": 630764800,
|
100
|
+
"path": "vit",
|
101
|
+
},
|
102
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_huge_patch14_224_imagenet21k/1",
|
103
|
+
},
|
104
|
+
"vit_large_patch16_224_imagenet21k": {
|
105
|
+
"metadata": {
|
106
|
+
"description": (
|
107
|
+
"ViT-L16 backbone pre-trained on the ImageNet 21k dataset with "
|
108
|
+
"image resolution of 224x224 "
|
109
|
+
),
|
110
|
+
"params": 303301632,
|
111
|
+
"path": "vit",
|
112
|
+
},
|
113
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet21k/1",
|
114
|
+
},
|
115
|
+
"vit_large_patch32_224_imagenet21k": {
|
116
|
+
"metadata": {
|
117
|
+
"description": (
|
118
|
+
"ViT-L32 backbone pre-trained on the ImageNet 21k dataset with "
|
119
|
+
"image resolution of 224x224 "
|
120
|
+
),
|
121
|
+
"params": 305510400,
|
122
|
+
"path": "vit",
|
123
|
+
},
|
124
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_224_imagenet21k/1",
|
125
|
+
},
|
126
|
+
}
|
@@ -31,7 +31,7 @@ class ViTDetBackbone(Backbone):
|
|
31
31
|
global_attention_layer_indices (list): Indexes for blocks using
|
32
32
|
global attention.
|
33
33
|
image_shape (tuple[int], optional): The size of the input image in
|
34
|
-
`(H, W, C)` format. Defaults to `(
|
34
|
+
`(H, W, C)` format. Defaults to `(None, None, 3)`.
|
35
35
|
patch_size (int, optional): the patch size to be supplied to the
|
36
36
|
Patching layer to turn input images into a flattened sequence of
|
37
37
|
patches. Defaults to `16`.
|
@@ -79,7 +79,7 @@ class ViTDetBackbone(Backbone):
|
|
79
79
|
intermediate_dim,
|
80
80
|
num_heads,
|
81
81
|
global_attention_layer_indices,
|
82
|
-
image_shape=(
|
82
|
+
image_shape=(None, None, 3),
|
83
83
|
patch_size=16,
|
84
84
|
num_output_channels=256,
|
85
85
|
use_bias=True,
|
@@ -87,7 +87,7 @@ class ViTDetBackbone(Backbone):
|
|
87
87
|
use_rel_pos=True,
|
88
88
|
window_size=14,
|
89
89
|
layer_norm_epsilon=1e-6,
|
90
|
-
**kwargs
|
90
|
+
**kwargs,
|
91
91
|
):
|
92
92
|
# === Functional model ===
|
93
93
|
img_input = keras.layers.Input(shape=image_shape, name="images")
|
@@ -179,7 +179,9 @@ class ViTDetBackbone(Backbone):
|
|
179
179
|
"use_abs_pos": self.use_abs_pos,
|
180
180
|
"use_rel_pos": self.use_rel_pos,
|
181
181
|
"window_size": self.window_size,
|
182
|
-
"global_attention_layer_indices":
|
182
|
+
"global_attention_layer_indices": (
|
183
|
+
self.global_attention_layer_indices
|
184
|
+
),
|
183
185
|
"layer_norm_epsilon": self.layer_norm_epsilon,
|
184
186
|
}
|
185
187
|
)
|
@@ -117,7 +117,7 @@ class AddRelativePositionalEmbedding(keras.layers.Layer):
|
|
117
117
|
"""Calculate decomposed Relative Positional Embeddings
|
118
118
|
|
119
119
|
The code has been adapted based on
|
120
|
-
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
|
120
|
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
|
121
121
|
|
122
122
|
Args:
|
123
123
|
attention_map (tensor): Attention map.
|
@@ -193,7 +193,7 @@ class MultiHeadAttentionWithRelativePE(keras.layers.Layer):
|
|
193
193
|
use_bias=True,
|
194
194
|
use_rel_pos=False,
|
195
195
|
input_size=None,
|
196
|
-
**kwargs
|
196
|
+
**kwargs,
|
197
197
|
):
|
198
198
|
super().__init__(**kwargs)
|
199
199
|
self.num_heads = num_heads
|
@@ -378,7 +378,7 @@ class WindowedTransformerEncoder(keras.layers.Layer):
|
|
378
378
|
input_size=None,
|
379
379
|
activation="gelu",
|
380
380
|
layer_norm_epsilon=1e-6,
|
381
|
-
**kwargs
|
381
|
+
**kwargs,
|
382
382
|
):
|
383
383
|
super().__init__(**kwargs)
|
384
384
|
self.project_dim = project_dim
|
@@ -39,7 +39,7 @@ class WhisperAudioConverter(AudioConverter):
|
|
39
39
|
audio_tensor = tf.ones((8000,), dtype="float32")
|
40
40
|
|
41
41
|
# Compute the log-mel spectrogram.
|
42
|
-
audio_converter = keras_hub.
|
42
|
+
audio_converter = keras_hub.layers.WhisperAudioConverter.from_preset(
|
43
43
|
"whisper_base_en",
|
44
44
|
)
|
45
45
|
audio_converter(audio_tensor)
|
@@ -172,9 +172,7 @@ class WhisperAudioConverter(AudioConverter):
|
|
172
172
|
)
|
173
173
|
|
174
174
|
def tf_log10(x):
|
175
|
-
"""
|
176
|
-
Computes log base 10 of input tensor using TensorFlow's natural log operator.
|
177
|
-
"""
|
175
|
+
"""Computes log base 10 of input tensor using TensorFlow."""
|
178
176
|
numerator = tf.math.log(x)
|
179
177
|
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
|
180
178
|
return numerator / denominator
|
@@ -30,9 +30,10 @@ class WhisperBackbone(Backbone):
|
|
30
30
|
It includes the embedding lookups and transformer layers, but not the head
|
31
31
|
for predicting the next token.
|
32
32
|
|
33
|
-
The default constructor gives a fully customizable, randomly initialized
|
34
|
-
model with any number of layers, heads, and embedding dimensions.
|
35
|
-
preset architectures and weights, use the `from_preset()`
|
33
|
+
The default constructor gives a fully customizable, randomly initialized
|
34
|
+
Whisper model with any number of layers, heads, and embedding dimensions.
|
35
|
+
To load preset architectures and weights, use the `from_preset()`
|
36
|
+
constructor.
|
36
37
|
|
37
38
|
Disclaimer: Pre-trained models are provided on an "as is" basis, without
|
38
39
|
warranties or conditions of any kind. The underlying model is provided by a
|
@@ -53,8 +54,8 @@ class WhisperBackbone(Backbone):
|
|
53
54
|
max_encoder_sequence_length: int. The maximum sequence length that the
|
54
55
|
audio encoder can consume. Since the second convolutional layer in
|
55
56
|
the encoder reduces the sequence length by half (stride of 2), we
|
56
|
-
use `max_encoder_sequence_length // 2` as the sequence length for
|
57
|
-
positional embedding layer.
|
57
|
+
use `max_encoder_sequence_length // 2` as the sequence length for
|
58
|
+
the positional embedding layer.
|
58
59
|
max_decoder_sequence_length: int. The maximum sequence length that the
|
59
60
|
text decoder can consume.
|
60
61
|
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
@@ -14,11 +14,9 @@ class WhisperDecoder(TransformerDecoder):
|
|
14
14
|
"""Whisper decoder.
|
15
15
|
|
16
16
|
Inherits from `keras_hub.layers.TransformerDecoder`, and overrides the
|
17
|
-
`build` method to use the
|
18
|
-
`
|
19
|
-
|
20
|
-
`keras_hub.models.whisper.whisper_cached_multi_head_attention.WhisperCachedMultiHeadAttention`
|
21
|
-
instead of `keras_hub.layers.cached_multi_head_attention.CachedMultiHeadAttention`.
|
17
|
+
`build` method to use the `WhisperMultiHeadAttention`
|
18
|
+
layer instead of `MultiHeadAttention` and `WhisperCachedMultiHeadAttention`
|
19
|
+
instead of `CachedMultiHeadAttention`.
|
22
20
|
"""
|
23
21
|
|
24
22
|
def build(
|
@@ -7,11 +7,9 @@ backbone_presets = {
|
|
7
7
|
"English speech data."
|
8
8
|
),
|
9
9
|
"params": 37184256,
|
10
|
-
"official_name": "Whisper",
|
11
10
|
"path": "whisper",
|
12
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
13
11
|
},
|
14
|
-
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_en/
|
12
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_en/4",
|
15
13
|
},
|
16
14
|
"whisper_base_en": {
|
17
15
|
"metadata": {
|
@@ -20,11 +18,9 @@ backbone_presets = {
|
|
20
18
|
"English speech data."
|
21
19
|
),
|
22
20
|
"params": 124439808,
|
23
|
-
"official_name": "Whisper",
|
24
21
|
"path": "whisper",
|
25
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
26
22
|
},
|
27
|
-
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_en/
|
23
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_en/4",
|
28
24
|
},
|
29
25
|
"whisper_small_en": {
|
30
26
|
"metadata": {
|
@@ -33,11 +29,9 @@ backbone_presets = {
|
|
33
29
|
"English speech data."
|
34
30
|
),
|
35
31
|
"params": 241734144,
|
36
|
-
"official_name": "Whisper",
|
37
32
|
"path": "whisper",
|
38
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
39
33
|
},
|
40
|
-
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_en/
|
34
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_en/4",
|
41
35
|
},
|
42
36
|
"whisper_medium_en": {
|
43
37
|
"metadata": {
|
@@ -46,11 +40,9 @@ backbone_presets = {
|
|
46
40
|
"English speech data."
|
47
41
|
),
|
48
42
|
"params": 763856896,
|
49
|
-
"official_name": "Whisper",
|
50
43
|
"path": "whisper",
|
51
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
52
44
|
},
|
53
|
-
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_en/
|
45
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_en/4",
|
54
46
|
},
|
55
47
|
"whisper_tiny_multi": {
|
56
48
|
"metadata": {
|
@@ -59,11 +51,9 @@ backbone_presets = {
|
|
59
51
|
"multilingual speech data."
|
60
52
|
),
|
61
53
|
"params": 37760640,
|
62
|
-
"official_name": "Whisper",
|
63
54
|
"path": "whisper",
|
64
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
65
55
|
},
|
66
|
-
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_multi/
|
56
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_multi/4",
|
67
57
|
},
|
68
58
|
"whisper_base_multi": {
|
69
59
|
"metadata": {
|
@@ -72,11 +62,9 @@ backbone_presets = {
|
|
72
62
|
"multilingual speech data."
|
73
63
|
),
|
74
64
|
"params": 72593920,
|
75
|
-
"official_name": "Whisper",
|
76
65
|
"path": "whisper",
|
77
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
78
66
|
},
|
79
|
-
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_multi/
|
67
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_multi/4",
|
80
68
|
},
|
81
69
|
"whisper_small_multi": {
|
82
70
|
"metadata": {
|
@@ -85,11 +73,9 @@ backbone_presets = {
|
|
85
73
|
"multilingual speech data."
|
86
74
|
),
|
87
75
|
"params": 241734912,
|
88
|
-
"official_name": "Whisper",
|
89
76
|
"path": "whisper",
|
90
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
91
77
|
},
|
92
|
-
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_multi/
|
78
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_multi/4",
|
93
79
|
},
|
94
80
|
"whisper_medium_multi": {
|
95
81
|
"metadata": {
|
@@ -98,11 +84,9 @@ backbone_presets = {
|
|
98
84
|
"multilingual speech data."
|
99
85
|
),
|
100
86
|
"params": 763857920,
|
101
|
-
"official_name": "Whisper",
|
102
87
|
"path": "whisper",
|
103
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
104
88
|
},
|
105
|
-
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_multi/
|
89
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_multi/4",
|
106
90
|
},
|
107
91
|
"whisper_large_multi": {
|
108
92
|
"metadata": {
|
@@ -111,11 +95,9 @@ backbone_presets = {
|
|
111
95
|
"multilingual speech data."
|
112
96
|
),
|
113
97
|
"params": 1543304960,
|
114
|
-
"official_name": "Whisper",
|
115
98
|
"path": "whisper",
|
116
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
117
99
|
},
|
118
|
-
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi/
|
100
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi/4",
|
119
101
|
},
|
120
102
|
"whisper_large_multi_v2": {
|
121
103
|
"metadata": {
|
@@ -125,10 +107,8 @@ backbone_presets = {
|
|
125
107
|
"of `whisper_large_multi`."
|
126
108
|
),
|
127
109
|
"params": 1543304960,
|
128
|
-
"official_name": "Whisper",
|
129
110
|
"path": "whisper",
|
130
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
131
111
|
},
|
132
|
-
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi_v2/
|
112
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi_v2/4",
|
133
113
|
},
|
134
114
|
}
|
@@ -9,7 +9,7 @@ from keras_hub.src.models.roberta.roberta_backbone import (
|
|
9
9
|
from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import (
|
10
10
|
XLMRobertaBackbone,
|
11
11
|
)
|
12
|
-
from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import (
|
12
|
+
from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( # noqa: E501
|
13
13
|
XLMRobertaMaskedLMPreprocessor,
|
14
14
|
)
|
15
15
|
|
@@ -20,8 +20,8 @@ class XLMRobertaMaskedLMPreprocessor(MaskedLMPreprocessor):
|
|
20
20
|
|
21
21
|
This preprocessing layer will prepare inputs for a masked language modeling
|
22
22
|
task. It is primarily intended for use with the
|
23
|
-
`keras_hub.models.XLMRobertaMaskedLM` task model. Preprocessing will occur
|
24
|
-
multiple steps.
|
23
|
+
`keras_hub.models.XLMRobertaMaskedLM` task model. Preprocessing will occur
|
24
|
+
in multiple steps.
|
25
25
|
|
26
26
|
1. Tokenize any number of input segments using the `tokenizer`.
|
27
27
|
2. Pack the inputs together with the appropriate `"<s>"`, `"</s>"` and
|
@@ -8,11 +8,9 @@ backbone_presets = {
|
|
8
8
|
"Trained on CommonCrawl in 100 languages."
|
9
9
|
),
|
10
10
|
"params": 277450752,
|
11
|
-
"official_name": "XLM-RoBERTa",
|
12
11
|
"path": "xlm_roberta",
|
13
|
-
"model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/xlmr/README.md",
|
14
12
|
},
|
15
|
-
"kaggle_handle": "kaggle://keras/xlm_roberta/keras/xlm_roberta_base_multi/
|
13
|
+
"kaggle_handle": "kaggle://keras/xlm_roberta/keras/xlm_roberta_base_multi/3",
|
16
14
|
},
|
17
15
|
"xlm_roberta_large_multi": {
|
18
16
|
"metadata": {
|
@@ -21,10 +19,8 @@ backbone_presets = {
|
|
21
19
|
"Trained on CommonCrawl in 100 languages."
|
22
20
|
),
|
23
21
|
"params": 558837760,
|
24
|
-
"official_name": "XLM-RoBERTa",
|
25
22
|
"path": "xlm_roberta",
|
26
|
-
"model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/xlmr/README.md",
|
27
23
|
},
|
28
|
-
"kaggle_handle": "kaggle://keras/xlm_roberta/keras/xlm_roberta_large_multi/
|
24
|
+
"kaggle_handle": "kaggle://keras/xlm_roberta/keras/xlm_roberta_large_multi/3",
|
29
25
|
},
|
30
26
|
}
|
@@ -8,7 +8,7 @@ from keras_hub.src.models.text_classifier import TextClassifier
|
|
8
8
|
from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import (
|
9
9
|
XLMRobertaBackbone,
|
10
10
|
)
|
11
|
-
from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import (
|
11
|
+
from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( # noqa: E501
|
12
12
|
XLMRobertaTextClassifierPreprocessor,
|
13
13
|
)
|
14
14
|
|
@@ -40,9 +40,9 @@ class XLMRobertaTextClassifier(TextClassifier):
|
|
40
40
|
Args:
|
41
41
|
backbone: A `keras_hub.models.XLMRobertaBackbone` instance.
|
42
42
|
num_classes: int. Number of classes to predict.
|
43
|
-
preprocessor: A `keras_hub.models.XLMRobertaTextClassifierPreprocessor`
|
44
|
-
`None`, this model will not apply preprocessing, and
|
45
|
-
be preprocessed before calling the model.
|
43
|
+
preprocessor: A `keras_hub.models.XLMRobertaTextClassifierPreprocessor`
|
44
|
+
or `None`. If `None`, this model will not apply preprocessing, and
|
45
|
+
inputs should be preprocessed before calling the model.
|
46
46
|
activation: Optional `str` or callable. The activation function to use
|
47
47
|
on the model outputs. Set `activation="softmax"` to return output
|
48
48
|
probabilities. Defaults to `None`.
|
@@ -177,7 +177,8 @@ class XLMRobertaTokenizer(SentencePieceTokenizer):
|
|
177
177
|
# Shift the tokens IDs left by one.
|
178
178
|
tokens = tf.subtract(tokens, 1)
|
179
179
|
|
180
|
-
# Correct `unk_token_id`, `end_token_id`, `start_token_id`,
|
180
|
+
# Correct `unk_token_id`, `end_token_id`, `start_token_id`,
|
181
|
+
# respectively.
|
181
182
|
# Note: The `pad_token_id` is taken as 0 (`unk_token_id`) since the
|
182
183
|
# proto does not contain `pad_token_id`. This mapping of the pad token
|
183
184
|
# is done automatically by the above subtraction.
|
@@ -64,27 +64,28 @@ def _rel_shift(x, klen=-1):
|
|
64
64
|
class TwoStreamRelativeAttention(keras.layers.MultiHeadAttention):
|
65
65
|
"""Two-stream relative self-attention for XLNet.
|
66
66
|
|
67
|
-
In XLNet, each token has two associated vectors at each self-attention
|
68
|
-
the content stream (h) and the query stream (g). The content stream
|
69
|
-
self-attention stream as in Transformer XL and represents the context
|
70
|
-
content (the token itself). The query stream only has access to
|
71
|
-
information and the position, but not the content.
|
67
|
+
In XLNet, each token has two associated vectors at each self-attention
|
68
|
+
layer, the content stream (h) and the query stream (g). The content stream
|
69
|
+
is the self-attention stream as in Transformer XL and represents the context
|
70
|
+
and content (the token itself). The query stream only has access to
|
71
|
+
contextual information and the position, but not the content.
|
72
72
|
|
73
|
-
This layer shares the same build signature as
|
74
|
-
but has different input/output
|
73
|
+
This layer shares the same build signature as
|
74
|
+
`keras.layers.MultiHeadAttention` but has different input/output
|
75
|
+
projections.
|
75
76
|
|
76
77
|
We use the notations `B`, `T`, `S`, `M`, `L`, `E`, `P`, `dim`, `num_heads`
|
77
|
-
below, where
|
78
|
-
`B` is the batch dimension, `T` is the target sequence length,
|
78
|
+
below, where `B` is the batch dimension, `T` is the target sequence length,
|
79
79
|
`S` in the source sequence length, `M` is the length of the state or memory,
|
80
80
|
`L` is the length of relative positional encoding, `E` is the last dimension
|
81
|
-
of query input, `P` is the number of predictions, `dim` is the
|
82
|
-
of the encoder layers. and `num_heads` is the number of
|
81
|
+
of query input, `P` is the number of predictions, `dim` is the
|
82
|
+
dimensionality of the encoder layers. and `num_heads` is the number of
|
83
|
+
attention heads.
|
83
84
|
|
84
85
|
Args:
|
85
86
|
content_stream: `Tensor` of shape `[B, T, dim]`.
|
86
|
-
content_attention_bias: Bias `Tensor` for content based attention of
|
87
|
-
`[num_heads, dim]`.
|
87
|
+
content_attention_bias: Bias `Tensor` for content based attention of
|
88
|
+
shape `[num_heads, dim]`.
|
88
89
|
positional_attention_bias: Bias `Tensor` for position based attention of
|
89
90
|
shape `[num_heads, dim]`.
|
90
91
|
query_stream: `Tensor` of shape `[B, P, dim]`.
|
@@ -96,8 +97,8 @@ class TwoStreamRelativeAttention(keras.layers.MultiHeadAttention):
|
|
96
97
|
segment_encoding: Optional `Tensor` representing the segmentation
|
97
98
|
encoding as used in XLNet of shape `[2, num_heads, dim]`.
|
98
99
|
segment_attention_bias: Optional trainable bias parameter added to the
|
99
|
-
query had when calculating the segment-based attention score used
|
100
|
-
|
100
|
+
query had when calculating the segment-based attention score used in
|
101
|
+
XLNet of shape `[num_heads, dim]`.
|
101
102
|
state: Optional `Tensor` of shape `[B, M, E]`.
|
102
103
|
If passed, this is also attended over as in Transformer XL.
|
103
104
|
content_attention_mask: a boolean mask of shape `[B, T, S]` that
|
@@ -336,11 +337,11 @@ class TwoStreamRelativeAttention(keras.layers.MultiHeadAttention):
|
|
336
337
|
dimension of query input.
|
337
338
|
|
338
339
|
Args:
|
339
|
-
content_stream: The content representation, commonly referred to as
|
340
|
-
This serves a similar role to the standard hidden states in
|
340
|
+
content_stream: The content representation, commonly referred to as
|
341
|
+
h. This serves a similar role to the standard hidden states in
|
341
342
|
Transformer-XL.
|
342
|
-
content_attention_bias: A trainable bias parameter added to the
|
343
|
-
head when calculating the content-based attention score.
|
343
|
+
content_attention_bias: A trainable bias parameter added to the
|
344
|
+
query head when calculating the content-based attention score.
|
344
345
|
positional_attention_bias: A trainable bias parameter added to the
|
345
346
|
query head when calculating the position-based attention score.
|
346
347
|
query_stream: The query representation, commonly referred to as g.
|
@@ -49,8 +49,8 @@ class XLNetBackbone(Backbone):
|
|
49
49
|
`[batch_size, sequence_length]`.
|
50
50
|
segment_ids: Segment token indices to indicate first and second portions
|
51
51
|
of the inputs of shape `[batch_size, sequence_length]`.
|
52
|
-
padding_mask: Mask to avoid performing attention on padding token
|
53
|
-
of shape `[batch_size, sequence_length]`.
|
52
|
+
padding_mask: Mask to avoid performing attention on padding token
|
53
|
+
indices of shape `[batch_size, sequence_length]`.
|
54
54
|
|
55
55
|
Example:
|
56
56
|
```python
|
@@ -3,8 +3,7 @@ from keras import ops
|
|
3
3
|
|
4
4
|
|
5
5
|
class ContentAndQueryEmbedding(keras.layers.Layer):
|
6
|
-
"""
|
7
|
-
Content and Query Embedding.
|
6
|
+
"""Content and Query Embedding.
|
8
7
|
|
9
8
|
This class creates Content and Query Embeddings for XLNet model
|
10
9
|
which is later used in XLNet Encoder.
|
@@ -20,9 +19,8 @@ class ContentAndQueryEmbedding(keras.layers.Layer):
|
|
20
19
|
**kwargs: other keyword arguments.
|
21
20
|
|
22
21
|
References:
|
23
|
-
- [XLNet: Generalized Autoregressive Pretraining for Language Understanding]
|
24
|
-
|
25
|
-
"""
|
22
|
+
- [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237)
|
23
|
+
""" # noqa: E501
|
26
24
|
|
27
25
|
def __init__(
|
28
26
|
self, vocabulary_size, hidden_dim, dropout, name=None, **kwargs
|
@@ -11,17 +11,16 @@ def xlnet_kernel_initializer(stddev=0.02):
|
|
11
11
|
|
12
12
|
|
13
13
|
class XLNetEncoder(keras.layers.Layer):
|
14
|
-
"""
|
15
|
-
XLNet Encoder.
|
14
|
+
"""XLNet Encoder.
|
16
15
|
|
17
16
|
This class follows the architecture of the transformer encoder layer in the
|
18
17
|
paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users
|
19
18
|
can instantiate multiple instances of this class to stack up an encoder.
|
20
19
|
|
21
20
|
Contrary to the single hidden state used in the paper mentioned above, this
|
22
|
-
Encoder uses two hidden states, Content State and Query State. Thus
|
23
|
-
Two Stream Relative Attention using both of the hidden states.
|
24
|
-
please check the reference.
|
21
|
+
Encoder uses two hidden states, Content State and Query State. Thus
|
22
|
+
calculates Two Stream Relative Attention using both of the hidden states.
|
23
|
+
To know more please check the reference.
|
25
24
|
|
26
25
|
Args:
|
27
26
|
num_heads: int, the number of heads in the
|
@@ -44,9 +43,8 @@ class XLNetEncoder(keras.layers.Layer):
|
|
44
43
|
**kwargs: other keyword arguments.
|
45
44
|
|
46
45
|
References:
|
47
|
-
- [XLNet: Generalized Autoregressive Pretraining for Language Understanding]
|
48
|
-
|
49
|
-
"""
|
46
|
+
- [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237)
|
47
|
+
""" # noqa: E501
|
50
48
|
|
51
49
|
def __init__(
|
52
50
|
self,
|
@@ -60,7 +58,7 @@ class XLNetEncoder(keras.layers.Layer):
|
|
60
58
|
kernel_initializer_range=0.02,
|
61
59
|
bias_initializer="zeros",
|
62
60
|
name=None,
|
63
|
-
**kwargs
|
61
|
+
**kwargs,
|
64
62
|
):
|
65
63
|
super().__init__(name=name, **kwargs)
|
66
64
|
self.num_heads = num_heads
|
@@ -150,9 +150,8 @@ class ContrastiveSampler(Sampler):
|
|
150
150
|
# The final score of each candidate token is weighted sum of
|
151
151
|
# probability and similarity against previous tokens.
|
152
152
|
accumulated_scores = (
|
153
|
-
|
154
|
-
|
155
|
-
)
|
153
|
+
1 - self.alpha
|
154
|
+
) * next_token_probabilities - self.alpha * max_similarity_scores
|
156
155
|
# Unflatten variables to shape [batch_size, self.k, ...] for
|
157
156
|
# gather purpose.
|
158
157
|
unflat_score = unflatten_beams(accumulated_scores)
|
@@ -95,7 +95,8 @@ class Sampler:
|
|
95
95
|
def cond(prompt, cache, index):
|
96
96
|
if stop_token_ids is None:
|
97
97
|
return True
|
98
|
-
# Stop if all sequences have produced a *new* id from
|
98
|
+
# Stop if all sequences have produced a *new* id from
|
99
|
+
# stop_token_ids.
|
99
100
|
end_tokens = any_equal(prompt, stop_token_ids, ~mask)
|
100
101
|
prompt_done = ops.any(end_tokens, axis=-1)
|
101
102
|
return ops.logical_not(ops.all(prompt_done))
|