keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
keras_hub/api/layers/__init__.py
CHANGED
@@ -14,6 +14,7 @@ from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
|
14
14
|
from keras_hub.src.layers.modeling.reversible_embedding import (
|
15
15
|
ReversibleEmbedding,
|
16
16
|
)
|
17
|
+
from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization
|
17
18
|
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
18
19
|
from keras_hub.src.layers.modeling.sine_position_encoding import (
|
19
20
|
SinePositionEncoding,
|
@@ -33,22 +34,39 @@ from keras_hub.src.layers.preprocessing.multi_segment_packer import (
|
|
33
34
|
)
|
34
35
|
from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion
|
35
36
|
from keras_hub.src.layers.preprocessing.random_swap import RandomSwap
|
36
|
-
from keras_hub.src.layers.preprocessing.resizing_image_converter import (
|
37
|
-
ResizingImageConverter,
|
38
|
-
)
|
39
37
|
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
|
38
|
+
from keras_hub.src.models.basnet.basnet_image_converter import (
|
39
|
+
BASNetImageConverter,
|
40
|
+
)
|
41
|
+
from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter
|
42
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import (
|
43
|
+
DeepLabV3ImageConverter,
|
44
|
+
)
|
40
45
|
from keras_hub.src.models.densenet.densenet_image_converter import (
|
41
46
|
DenseNetImageConverter,
|
42
47
|
)
|
48
|
+
from keras_hub.src.models.efficientnet.efficientnet_image_converter import (
|
49
|
+
EfficientNetImageConverter,
|
50
|
+
)
|
51
|
+
from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter
|
43
52
|
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
|
44
53
|
PaliGemmaImageConverter,
|
45
54
|
)
|
46
55
|
from keras_hub.src.models.resnet.resnet_image_converter import (
|
47
56
|
ResNetImageConverter,
|
48
57
|
)
|
58
|
+
from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator
|
59
|
+
from keras_hub.src.models.retinanet.retinanet_image_converter import (
|
60
|
+
RetinaNetImageConverter,
|
61
|
+
)
|
49
62
|
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
|
50
63
|
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
|
51
64
|
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
|
65
|
+
from keras_hub.src.models.segformer.segformer_image_converter import (
|
66
|
+
SegFormerImageConverter,
|
67
|
+
)
|
68
|
+
from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter
|
69
|
+
from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter
|
52
70
|
from keras_hub.src.models.whisper.whisper_audio_converter import (
|
53
71
|
WhisperAudioConverter,
|
54
72
|
)
|
keras_hub/api/models/__init__.py
CHANGED
@@ -29,6 +29,9 @@ from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import (
|
|
29
29
|
BartSeq2SeqLMPreprocessor,
|
30
30
|
)
|
31
31
|
from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer
|
32
|
+
from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter
|
33
|
+
from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
|
34
|
+
from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor
|
32
35
|
from keras_hub.src.models.bert.bert_backbone import BertBackbone
|
33
36
|
from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM
|
34
37
|
from keras_hub.src.models.bert.bert_masked_lm_preprocessor import (
|
@@ -53,8 +56,11 @@ from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import (
|
|
53
56
|
from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer
|
54
57
|
from keras_hub.src.models.causal_lm import CausalLM
|
55
58
|
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
|
59
|
+
from keras_hub.src.models.clip.clip_backbone import CLIPBackbone
|
56
60
|
from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor
|
61
|
+
from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder
|
57
62
|
from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer
|
63
|
+
from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder
|
58
64
|
from keras_hub.src.models.csp_darknet.csp_darknet_backbone import (
|
59
65
|
CSPDarkNetBackbone,
|
60
66
|
)
|
@@ -85,6 +91,15 @@ from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor imp
|
|
85
91
|
from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import (
|
86
92
|
DebertaV3Tokenizer,
|
87
93
|
)
|
94
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
|
95
|
+
DeepLabV3Backbone,
|
96
|
+
)
|
97
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import (
|
98
|
+
DeepLabV3ImageSegmenterPreprocessor,
|
99
|
+
)
|
100
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import (
|
101
|
+
DeepLabV3ImageSegmenter,
|
102
|
+
)
|
88
103
|
from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone
|
89
104
|
from keras_hub.src.models.densenet.densenet_image_classifier import (
|
90
105
|
DenseNetImageClassifier,
|
@@ -119,6 +134,12 @@ from keras_hub.src.models.distil_bert.distil_bert_tokenizer import (
|
|
119
134
|
from keras_hub.src.models.efficientnet.efficientnet_backbone import (
|
120
135
|
EfficientNetBackbone,
|
121
136
|
)
|
137
|
+
from keras_hub.src.models.efficientnet.efficientnet_image_classifier import (
|
138
|
+
EfficientNetImageClassifier,
|
139
|
+
)
|
140
|
+
from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import (
|
141
|
+
EfficientNetImageClassifierPreprocessor,
|
142
|
+
)
|
122
143
|
from keras_hub.src.models.electra.electra_backbone import ElectraBackbone
|
123
144
|
from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer
|
124
145
|
from keras_hub.src.models.f_net.f_net_backbone import FNetBackbone
|
@@ -144,6 +165,11 @@ from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import (
|
|
144
165
|
)
|
145
166
|
from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer
|
146
167
|
from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
|
168
|
+
from keras_hub.src.models.flux.flux_model import FluxBackbone
|
169
|
+
from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage
|
170
|
+
from keras_hub.src.models.flux.flux_text_to_image_preprocessor import (
|
171
|
+
FluxTextToImagePreprocessor,
|
172
|
+
)
|
147
173
|
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
|
148
174
|
from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM
|
149
175
|
from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
|
@@ -167,22 +193,28 @@ from keras_hub.src.models.image_classifier import ImageClassifier
|
|
167
193
|
from keras_hub.src.models.image_classifier_preprocessor import (
|
168
194
|
ImageClassifierPreprocessor,
|
169
195
|
)
|
196
|
+
from keras_hub.src.models.image_object_detector import ImageObjectDetector
|
197
|
+
from keras_hub.src.models.image_object_detector_preprocessor import (
|
198
|
+
ImageObjectDetectorPreprocessor,
|
199
|
+
)
|
170
200
|
from keras_hub.src.models.image_segmenter import ImageSegmenter
|
171
201
|
from keras_hub.src.models.image_segmenter_preprocessor import (
|
172
202
|
ImageSegmenterPreprocessor,
|
173
203
|
)
|
174
|
-
from keras_hub.src.models.
|
175
|
-
from keras_hub.src.models.
|
176
|
-
from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import (
|
177
|
-
Llama3CausalLMPreprocessor,
|
178
|
-
)
|
179
|
-
from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
|
204
|
+
from keras_hub.src.models.image_to_image import ImageToImage
|
205
|
+
from keras_hub.src.models.inpaint import Inpaint
|
180
206
|
from keras_hub.src.models.llama.llama_backbone import LlamaBackbone
|
181
207
|
from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM
|
182
208
|
from keras_hub.src.models.llama.llama_causal_lm_preprocessor import (
|
183
209
|
LlamaCausalLMPreprocessor,
|
184
210
|
)
|
185
211
|
from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer
|
212
|
+
from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone
|
213
|
+
from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM
|
214
|
+
from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import (
|
215
|
+
Llama3CausalLMPreprocessor,
|
216
|
+
)
|
217
|
+
from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
|
186
218
|
from keras_hub.src.models.masked_lm import MaskedLM
|
187
219
|
from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor
|
188
220
|
from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone
|
@@ -191,11 +223,10 @@ from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import (
|
|
191
223
|
MistralCausalLMPreprocessor,
|
192
224
|
)
|
193
225
|
from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer
|
194
|
-
from keras_hub.src.models.
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
MiTImageClassifier,
|
226
|
+
from keras_hub.src.models.mit.mit_backbone import MiTBackbone
|
227
|
+
from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier
|
228
|
+
from keras_hub.src.models.mit.mit_image_classifier_preprocessor import (
|
229
|
+
MiTImageClassifierPreprocessor,
|
199
230
|
)
|
200
231
|
from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone
|
201
232
|
from keras_hub.src.models.mobilenet.mobilenet_image_classifier import (
|
@@ -233,6 +264,13 @@ from keras_hub.src.models.resnet.resnet_image_classifier import (
|
|
233
264
|
from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import (
|
234
265
|
ResNetImageClassifierPreprocessor,
|
235
266
|
)
|
267
|
+
from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
|
268
|
+
from keras_hub.src.models.retinanet.retinanet_object_detector import (
|
269
|
+
RetinaNetObjectDetector,
|
270
|
+
)
|
271
|
+
from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import (
|
272
|
+
RetinaNetObjectDetectorPreprocessor,
|
273
|
+
)
|
236
274
|
from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone
|
237
275
|
from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM
|
238
276
|
from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import (
|
@@ -254,13 +292,26 @@ from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
|
|
254
292
|
from keras_hub.src.models.sam.sam_backbone import SAMBackbone
|
255
293
|
from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter
|
256
294
|
from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (
|
257
|
-
SAMImageSegmenterPreprocessor
|
295
|
+
SAMImageSegmenterPreprocessor,
|
296
|
+
)
|
297
|
+
from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
|
298
|
+
from keras_hub.src.models.segformer.segformer_image_segmenter import (
|
299
|
+
SegFormerImageSegmenter,
|
300
|
+
)
|
301
|
+
from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import (
|
302
|
+
SegFormerImageSegmenterPreprocessor,
|
258
303
|
)
|
259
304
|
from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
|
260
305
|
from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
|
261
306
|
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
|
262
307
|
StableDiffusion3Backbone,
|
263
308
|
)
|
309
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import (
|
310
|
+
StableDiffusion3ImageToImage,
|
311
|
+
)
|
312
|
+
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import (
|
313
|
+
StableDiffusion3Inpaint,
|
314
|
+
)
|
264
315
|
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import (
|
265
316
|
StableDiffusion3TextToImage,
|
266
317
|
)
|
@@ -279,6 +330,14 @@ from keras_hub.src.models.text_classifier_preprocessor import (
|
|
279
330
|
from keras_hub.src.models.text_to_image import TextToImage
|
280
331
|
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
|
281
332
|
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier
|
333
|
+
from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import (
|
334
|
+
VGGImageClassifierPreprocessor,
|
335
|
+
)
|
336
|
+
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
|
337
|
+
from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier
|
338
|
+
from keras_hub.src.models.vit.vit_image_classifier_preprocessor import (
|
339
|
+
ViTImageClassifierPreprocessor,
|
340
|
+
)
|
282
341
|
from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone
|
283
342
|
from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone
|
284
343
|
from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
|
@@ -21,8 +21,8 @@ from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer
|
|
21
21
|
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
|
22
22
|
from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
|
23
23
|
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
|
24
|
-
from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
|
25
24
|
from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer
|
25
|
+
from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
|
26
26
|
from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer
|
27
27
|
from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer
|
28
28
|
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
|
@@ -20,29 +20,74 @@ class RequiresImagesException(Exception):
|
|
20
20
|
ALL_AXES = 4
|
21
21
|
|
22
22
|
|
23
|
-
def
|
23
|
+
def encode_box_to_deltas(
|
24
24
|
anchors,
|
25
25
|
boxes,
|
26
|
-
anchor_format
|
27
|
-
box_format
|
26
|
+
anchor_format,
|
27
|
+
box_format,
|
28
|
+
encoding_format="center_yxhw",
|
28
29
|
variance=None,
|
29
30
|
image_shape=None,
|
30
31
|
):
|
31
|
-
"""
|
32
|
+
"""Encodes bounding boxes relative to anchors as deltas.
|
33
|
+
|
34
|
+
This function calculates the deltas that represent the difference between
|
35
|
+
bounding boxes and provided anchors. Deltas encode the offsets and scaling
|
36
|
+
factors to apply to anchors to obtain the target boxes.
|
37
|
+
|
38
|
+
Boxes and anchors are first converted to the specified `encoding_format`
|
39
|
+
(defaulting to `center_yxhw`) for consistent delta representation.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
anchors: `Tensors`. Anchor boxes with shape of `(N, 4)` where N is the
|
43
|
+
number of anchors.
|
44
|
+
boxes: `Tensors` Bounding boxes to encode. Boxes can be of shape
|
45
|
+
`(B, N, 4)` or `(N, 4)`.
|
46
|
+
anchor_format: str. The format of the input `anchors`
|
47
|
+
(e.g., "xyxy", "xywh", etc.).
|
48
|
+
box_format: str. The format of the input `boxes`
|
49
|
+
(e.g., "xyxy", "xywh", etc.).
|
50
|
+
encoding_format: str. The intermediate format to which boxes and anchors
|
51
|
+
are converted before delta calculation. Defaults to "center_yxhw".
|
52
|
+
variance: `List[float]`. A 4-element array/tensor representing variance
|
53
|
+
factors to scale the box deltas. If provided, the calculated deltas
|
54
|
+
are divided by the variance. Defaults to None.
|
55
|
+
image_shape: `Tuple[int]`. The shape of the image (height, width, 3).
|
56
|
+
When using relative bounding box format for `box_format` the
|
57
|
+
`image_shape` is used for normalization.
|
58
|
+
Returns:
|
59
|
+
Encoded box deltas. The return type matches the `encode_format`.
|
60
|
+
|
61
|
+
Raises:
|
62
|
+
ValueError: If `variance` is not None and its length is not 4.
|
63
|
+
ValueError: If `encoding_format` is not `"center_xywh"` or
|
64
|
+
`"center_yxhw"`.
|
65
|
+
|
66
|
+
"""
|
32
67
|
if variance is not None:
|
33
68
|
variance = ops.convert_to_tensor(variance, "float32")
|
34
69
|
var_len = variance.shape[-1]
|
35
70
|
|
36
71
|
if var_len != 4:
|
37
72
|
raise ValueError(f"`variance` must be length 4, got {variance}")
|
73
|
+
|
74
|
+
if encoding_format not in ["center_xywh", "center_yxhw"]:
|
75
|
+
raise ValueError(
|
76
|
+
"`encoding_format` should be one of 'center_xywh' or "
|
77
|
+
f"'center_yxhw', got {encoding_format}"
|
78
|
+
)
|
79
|
+
|
38
80
|
encoded_anchors = convert_format(
|
39
81
|
anchors,
|
40
82
|
source=anchor_format,
|
41
|
-
target=
|
83
|
+
target=encoding_format,
|
42
84
|
image_shape=image_shape,
|
43
85
|
)
|
44
86
|
boxes = convert_format(
|
45
|
-
boxes,
|
87
|
+
boxes,
|
88
|
+
source=box_format,
|
89
|
+
target=encoding_format,
|
90
|
+
image_shape=image_shape,
|
46
91
|
)
|
47
92
|
anchor_dimensions = ops.maximum(
|
48
93
|
encoded_anchors[..., 2:], keras.backend.epsilon()
|
@@ -61,15 +106,54 @@ def _encode_box_to_deltas(
|
|
61
106
|
return boxes_delta
|
62
107
|
|
63
108
|
|
64
|
-
def
|
109
|
+
def decode_deltas_to_boxes(
|
65
110
|
anchors,
|
66
111
|
boxes_delta,
|
67
|
-
anchor_format
|
68
|
-
box_format
|
112
|
+
anchor_format,
|
113
|
+
box_format,
|
114
|
+
encoded_format="center_yxhw",
|
69
115
|
variance=None,
|
70
116
|
image_shape=None,
|
71
117
|
):
|
72
|
-
"""Converts
|
118
|
+
"""Converts bounding boxes from delta format to the specified `box_format`.
|
119
|
+
|
120
|
+
This function decodes bounding box deltas relative to anchors to obtain the
|
121
|
+
final bounding box coordinates. The boxes are encoded in a specific
|
122
|
+
`encoded_format` (center_yxhw by default) during the decoding process.
|
123
|
+
This allows flexibility in how the deltas are applied to the anchors.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
anchors: Can be `Tensors` or `Dict[Tensors]` where keys are level
|
127
|
+
indices and values are corresponding anchor boxes.
|
128
|
+
The shape of the array/tensor should be `(N, 4)` where N is the
|
129
|
+
number of anchors.
|
130
|
+
boxes_delta Can be `Tensors` or `Dict[Tensors]` Bounding box deltas
|
131
|
+
must have the same type and structure as `anchors`. The
|
132
|
+
shape of the array/tensor can be `(N, 4)` or `(B, N, 4)` where N is
|
133
|
+
the number of boxes.
|
134
|
+
anchor_format: str. The format of the input `anchors`.
|
135
|
+
(e.g., `"xyxy"`, `"xywh"`, etc.)
|
136
|
+
box_format: str. The desired format for the output boxes.
|
137
|
+
(e.g., `"xyxy"`, `"xywh"`, etc.)
|
138
|
+
encoded_format: str. Raw output format from regression head. Defaults
|
139
|
+
to `"center_yxhw"`.
|
140
|
+
variance: `List[floats]`. A 4-element array/tensor representing
|
141
|
+
variance factors to scale the box deltas. If provided, the deltas
|
142
|
+
are multiplied by the variance before being applied to the anchors.
|
143
|
+
Defaults to None.
|
144
|
+
image_shape: The shape of the image (height, width). This is needed
|
145
|
+
if normalization to image size is required when converting between
|
146
|
+
formats. Defaults to None.
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
Decoded box coordinates. The return type matches the `box_format`.
|
150
|
+
|
151
|
+
Raises:
|
152
|
+
ValueError: If `variance` is not None and its length is not 4.
|
153
|
+
ValueError: If `encoded_format` is not `"center_xywh"` or
|
154
|
+
`"center_yxhw"`.
|
155
|
+
|
156
|
+
"""
|
73
157
|
if variance is not None:
|
74
158
|
variance = ops.convert_to_tensor(variance, "float32")
|
75
159
|
var_len = variance.shape[-1]
|
@@ -77,11 +161,17 @@ def _decode_deltas_to_boxes(
|
|
77
161
|
if var_len != 4:
|
78
162
|
raise ValueError(f"`variance` must be length 4, got {variance}")
|
79
163
|
|
164
|
+
if encoded_format not in ["center_xywh", "center_yxhw"]:
|
165
|
+
raise ValueError(
|
166
|
+
f"`encoded_format` should be 'center_xywh' or 'center_yxhw', "
|
167
|
+
f"but got '{encoded_format}'."
|
168
|
+
)
|
169
|
+
|
80
170
|
def decode_single_level(anchor, box_delta):
|
81
171
|
encoded_anchor = convert_format(
|
82
172
|
anchor,
|
83
173
|
source=anchor_format,
|
84
|
-
target=
|
174
|
+
target=encoded_format,
|
85
175
|
image_shape=image_shape,
|
86
176
|
)
|
87
177
|
if variance is not None:
|
@@ -97,7 +187,7 @@ def _decode_deltas_to_boxes(
|
|
97
187
|
)
|
98
188
|
box = convert_format(
|
99
189
|
box,
|
100
|
-
source=
|
190
|
+
source=encoded_format,
|
101
191
|
target=box_format,
|
102
192
|
image_shape=image_shape,
|
103
193
|
)
|
@@ -34,7 +34,8 @@ class MaskedLMHead(keras.layers.Layer):
|
|
34
34
|
token_embedding: Optional. A `keras_hub.layers.ReversibleEmbedding`
|
35
35
|
instance. If passed, the layer will be used to project from the
|
36
36
|
`hidden_dim` of the model to the output `vocabulary_size`.
|
37
|
-
intermediate_activation: The activation function of intermediate dense
|
37
|
+
intermediate_activation: The activation function of intermediate dense
|
38
|
+
layer.
|
38
39
|
activation: The activation function for the outputs of the layer.
|
39
40
|
Usually either `None` (return logits), or `"softmax"`
|
40
41
|
(return probabilities).
|
@@ -1,9 +1,7 @@
|
|
1
1
|
import keras
|
2
2
|
from keras import ops
|
3
|
-
from packaging.version import parse
|
4
3
|
|
5
4
|
from keras_hub.src.api_export import keras_hub_export
|
6
|
-
from keras_hub.src.utils.keras_utils import assert_quantization_support
|
7
5
|
|
8
6
|
|
9
7
|
@keras_hub_export("keras_hub.layers.ReversibleEmbedding")
|
@@ -145,10 +143,6 @@ class ReversibleEmbedding(keras.layers.Embedding):
|
|
145
143
|
if not self.built:
|
146
144
|
return
|
147
145
|
super().save_own_variables(store)
|
148
|
-
# Before Keras 3.2, the reverse weight is saved in the super() call.
|
149
|
-
# After Keras 3.2, the reverse weight must be saved manually.
|
150
|
-
if parse(keras.version()) < parse("3.2.0"):
|
151
|
-
return
|
152
146
|
target_variables = []
|
153
147
|
if not self.tie_weights:
|
154
148
|
# Store the reverse embedding weights as the last weights.
|
@@ -239,9 +233,7 @@ class ReversibleEmbedding(keras.layers.Embedding):
|
|
239
233
|
|
240
234
|
def quantize(self, mode, type_check=True):
|
241
235
|
import gc
|
242
|
-
import inspect
|
243
236
|
|
244
|
-
assert_quantization_support()
|
245
237
|
if type_check and type(self) is not ReversibleEmbedding:
|
246
238
|
raise NotImplementedError(
|
247
239
|
f"Layer {self.__class__.__name__} does not have a `quantize()` "
|
@@ -250,14 +242,9 @@ class ReversibleEmbedding(keras.layers.Embedding):
|
|
250
242
|
self._check_quantize_args(mode, self.compute_dtype)
|
251
243
|
|
252
244
|
def abs_max_quantize(inputs, axis):
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
inputs, axis=axis, to_numpy=True
|
257
|
-
)
|
258
|
-
else:
|
259
|
-
# `keras<=3.4.1` doesn't support `to_numpy`
|
260
|
-
return keras.quantizers.abs_max_quantize(inputs, axis=axis)
|
245
|
+
return keras.quantizers.abs_max_quantize(
|
246
|
+
inputs, axis=axis, to_numpy=True
|
247
|
+
)
|
261
248
|
|
262
249
|
self._tracker.unlock()
|
263
250
|
if mode == "int8":
|
@@ -0,0 +1,36 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
5
|
+
|
6
|
+
|
7
|
+
@keras_hub_export("keras_hub.layers.RMSNormalization")
|
8
|
+
class RMSNormalization(keras.layers.Layer):
|
9
|
+
"""Root Mean Square (RMS) Normalization layer.
|
10
|
+
|
11
|
+
This layer normalizes the input tensor based on its RMS value and applies
|
12
|
+
a learned scaling factor.
|
13
|
+
|
14
|
+
Args:
|
15
|
+
input_dim: int. The dimensionality of the input tensor.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self, input_dim):
|
19
|
+
super().__init__()
|
20
|
+
self.scale = self.add_weight(
|
21
|
+
name="scale", shape=(input_dim,), initializer="ones"
|
22
|
+
)
|
23
|
+
|
24
|
+
def call(self, x):
|
25
|
+
"""Applies RMS normalization to the input tensor.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
x: Input tensor of shape (batch_size, input_dim).
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
The RMS-normalized tensor of the same shape (batch_size, input_dim),
|
32
|
+
scaled by the learned `scale` parameter.
|
33
|
+
"""
|
34
|
+
x = ops.cast(x, float)
|
35
|
+
rrms = ops.rsqrt(ops.mean(ops.square(x), axis=-1, keepdims=True) + 1e-6)
|
36
|
+
return (x * rrms) * self.scale
|
@@ -11,7 +11,8 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
11
11
|
This layer encodes absolute positional information with a rotation
|
12
12
|
matrix. It calculates the rotary encoding with a mix of sine and
|
13
13
|
cosine functions with geometrically increasing wavelengths.
|
14
|
-
Defined and formulated in
|
14
|
+
Defined and formulated in
|
15
|
+
[RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4).
|
15
16
|
The input must be a tensor with shape a sequence dimension and a feature
|
16
17
|
dimension. Typically, this will either an input with shape
|
17
18
|
`(batch_size, sequence_length, feature_length)` or
|
@@ -65,7 +66,7 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
65
66
|
scaling_factor=1.0,
|
66
67
|
sequence_axis=1,
|
67
68
|
feature_axis=-1,
|
68
|
-
**kwargs
|
69
|
+
**kwargs,
|
69
70
|
):
|
70
71
|
super().__init__(**kwargs)
|
71
72
|
self.max_wavelength = max_wavelength
|
@@ -5,12 +5,13 @@ from keras_hub.src.api_export import keras_hub_export
|
|
5
5
|
from keras_hub.src.layers.modeling.cached_multi_head_attention import (
|
6
6
|
CachedMultiHeadAttention,
|
7
7
|
)
|
8
|
-
from keras_hub.src.
|
9
|
-
|
10
|
-
from keras_hub.src.layers.modeling.transformer_layer_utils import ( # isort:skip
|
8
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
11
9
|
compute_causal_mask,
|
10
|
+
)
|
11
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
12
12
|
merge_padding_and_attention_mask,
|
13
13
|
)
|
14
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
14
15
|
|
15
16
|
|
16
17
|
@keras_hub_export("keras_hub.layers.TransformerDecoder")
|
@@ -265,13 +266,13 @@ class TransformerDecoder(keras.layers.Layer):
|
|
265
266
|
`[batch_size, decoder_sequence_length]`.
|
266
267
|
decoder_attention_mask: a boolean Tensor. Customized decoder
|
267
268
|
sequence mask, must be of shape
|
268
|
-
`[batch_size, decoder_sequence_length, decoder_sequence_length]
|
269
|
+
`[batch_size, decoder_sequence_length, decoder_sequence_length]`
|
269
270
|
encoder_padding_mask: a boolean Tensor, the padding mask of encoder
|
270
271
|
sequence, must be of shape
|
271
272
|
`[batch_size, encoder_sequence_length]`.
|
272
273
|
encoder_attention_mask: a boolean Tensor. Customized encoder
|
273
274
|
sequence mask, must be of shape
|
274
|
-
`[batch_size, encoder_sequence_length, encoder_sequence_length]
|
275
|
+
`[batch_size, encoder_sequence_length, encoder_sequence_length]`
|
275
276
|
self_attention_cache: a dense float Tensor. The cache of key/values
|
276
277
|
pairs in the self-attention layer. Has shape
|
277
278
|
`[batch_size, 2, max_seq_len, num_heads, key_dims]`.
|
@@ -435,7 +436,8 @@ class TransformerDecoder(keras.layers.Layer):
|
|
435
436
|
input_length = output_length = ops.shape(decoder_sequence)[1]
|
436
437
|
# We need to handle a rectangular causal mask when doing cached
|
437
438
|
# decoding. For generative inference, `decoder_sequence` will
|
438
|
-
# generally be length 1, and `cache` will be the full generation
|
439
|
+
# generally be length 1, and `cache` will be the full generation
|
440
|
+
# length.
|
439
441
|
if self_attention_cache is not None:
|
440
442
|
input_length = ops.shape(self_attention_cache)[2]
|
441
443
|
|
@@ -170,7 +170,12 @@ class TransformerEncoder(keras.layers.Layer):
|
|
170
170
|
self.built = True
|
171
171
|
|
172
172
|
def call(
|
173
|
-
self,
|
173
|
+
self,
|
174
|
+
inputs,
|
175
|
+
padding_mask=None,
|
176
|
+
attention_mask=None,
|
177
|
+
training=None,
|
178
|
+
return_attention_scores=False,
|
174
179
|
):
|
175
180
|
"""Forward pass of the TransformerEncoder.
|
176
181
|
|
@@ -185,6 +190,9 @@ class TransformerEncoder(keras.layers.Layer):
|
|
185
190
|
[batch_size, sequence_length, sequence_length].
|
186
191
|
training: a boolean indicating whether the layer should behave in
|
187
192
|
training mode or in inference mode.
|
193
|
+
return_attention_scores: a boolean indicating whether the output
|
194
|
+
should be `(attention_output, attention_scores)` if `True` or
|
195
|
+
`attention_output` if `False`. Defaults to `False`.
|
188
196
|
|
189
197
|
Returns:
|
190
198
|
A Tensor of the same shape as the `inputs`.
|
@@ -200,12 +208,23 @@ class TransformerEncoder(keras.layers.Layer):
|
|
200
208
|
residual = x
|
201
209
|
if self.normalize_first:
|
202
210
|
x = self._self_attention_layer_norm(x)
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
211
|
+
|
212
|
+
if return_attention_scores:
|
213
|
+
x, attention_scores = self._self_attention_layer(
|
214
|
+
query=x,
|
215
|
+
value=x,
|
216
|
+
attention_mask=self_attention_mask,
|
217
|
+
return_attention_scores=return_attention_scores,
|
218
|
+
training=training,
|
219
|
+
)
|
220
|
+
else:
|
221
|
+
x = self._self_attention_layer(
|
222
|
+
query=x,
|
223
|
+
value=x,
|
224
|
+
attention_mask=self_attention_mask,
|
225
|
+
training=training,
|
226
|
+
)
|
227
|
+
|
209
228
|
x = self._self_attention_dropout(x, training=training)
|
210
229
|
x = x + residual
|
211
230
|
if not self.normalize_first:
|
@@ -222,6 +241,9 @@ class TransformerEncoder(keras.layers.Layer):
|
|
222
241
|
if not self.normalize_first:
|
223
242
|
x = self._feedforward_layer_norm(x)
|
224
243
|
|
244
|
+
if return_attention_scores:
|
245
|
+
return x, attention_scores
|
246
|
+
|
225
247
|
return x
|
226
248
|
|
227
249
|
def get_config(self):
|
@@ -2,11 +2,10 @@ from keras_hub.src.api_export import keras_hub_export
|
|
2
2
|
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
|
3
3
|
PreprocessingLayer,
|
4
4
|
)
|
5
|
-
from keras_hub.src.utils.preset_utils import AUDIO_CONVERTER_CONFIG_FILE
|
6
5
|
from keras_hub.src.utils.preset_utils import builtin_presets
|
7
6
|
from keras_hub.src.utils.preset_utils import find_subclass
|
8
7
|
from keras_hub.src.utils.preset_utils import get_preset_loader
|
9
|
-
from keras_hub.src.utils.preset_utils import
|
8
|
+
from keras_hub.src.utils.preset_utils import get_preset_saver
|
10
9
|
from keras_hub.src.utils.python_utils import classproperty
|
11
10
|
|
12
11
|
|
@@ -101,8 +100,5 @@ class AudioConverter(PreprocessingLayer):
|
|
101
100
|
Args:
|
102
101
|
preset_dir: The path to the local model preset directory.
|
103
102
|
"""
|
104
|
-
|
105
|
-
|
106
|
-
preset_dir,
|
107
|
-
config_file=AUDIO_CONVERTER_CONFIG_FILE,
|
108
|
-
)
|
103
|
+
saver = get_preset_saver(preset_dir)
|
104
|
+
saver.save_audio_converter(self)
|