keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,187 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
5
|
+
from keras_hub.src.models.image_classifier import ImageClassifier
|
6
|
+
from keras_hub.src.models.task import Task
|
7
|
+
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
|
8
|
+
from keras_hub.src.models.vit.vit_image_classifier_preprocessor import (
|
9
|
+
ViTImageClassifierPreprocessor,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
@keras_hub_export("keras_hub.models.ViTImageClassifier")
|
14
|
+
class ViTImageClassifier(ImageClassifier):
|
15
|
+
"""ViT image classification task.
|
16
|
+
|
17
|
+
`ViTImageClassifier` tasks wrap a `keras_hub.models.ViTBackbone` and
|
18
|
+
a `keras_hub.models.Preprocessor` to create a model that can be used for
|
19
|
+
image classification. `ViTImageClassifier` tasks take an additional
|
20
|
+
`num_classes` argument, controlling the number of predicted output classes.
|
21
|
+
|
22
|
+
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
|
23
|
+
labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
|
24
|
+
|
25
|
+
Not that unlike `keras_hub.model.ImageClassifier`, the `ViTImageClassifier`
|
26
|
+
we pluck out `cls_token` which is first seqence from the backbone.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
backbone: A `keras_hub.models.ViTBackbone` instance or a `keras.Model`.
|
30
|
+
num_classes: int. The number of classes to predict.
|
31
|
+
preprocessor: `None`, a `keras_hub.models.Preprocessor` instance,
|
32
|
+
a `keras.Layer` instance, or a callable. If `None` no preprocessing
|
33
|
+
will be applied to the inputs.
|
34
|
+
pooling: String specifying the classification strategy. The choice
|
35
|
+
impacts the dimensionality and nature of the feature vector used for
|
36
|
+
classification.
|
37
|
+
`"token"`: A single vector (class token) representing the
|
38
|
+
overall image features.
|
39
|
+
`"gap"`: A single vector representing the average features
|
40
|
+
across the spatial dimensions.
|
41
|
+
intermediate_dim: Optional dimensionality of the intermediate
|
42
|
+
representation layer before the final classification layer.
|
43
|
+
If `None`, the output of the transformer is directly used.
|
44
|
+
Defaults to `None`.
|
45
|
+
activation: `None`, str, or callable. The activation function to use on
|
46
|
+
the `Dense` layer. Set `activation=None` to return the output
|
47
|
+
logits. Defaults to `"softmax"`.
|
48
|
+
head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The
|
49
|
+
dtype to use for the classification head's computations and weights.
|
50
|
+
|
51
|
+
Examples:
|
52
|
+
|
53
|
+
Call `predict()` to run inference.
|
54
|
+
```python
|
55
|
+
# Load preset and train
|
56
|
+
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
|
57
|
+
classifier = keras_hub.models.ViTImageClassifier.from_preset(
|
58
|
+
"vgg_16_imagenet"
|
59
|
+
)
|
60
|
+
classifier.predict(images)
|
61
|
+
```
|
62
|
+
|
63
|
+
Call `fit()` on a single batch.
|
64
|
+
```python
|
65
|
+
# Load preset and train
|
66
|
+
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
|
67
|
+
labels = [0, 3]
|
68
|
+
classifier = keras_hub.models.VGGImageClassifier.from_preset(
|
69
|
+
"vit_base_patch16_224"
|
70
|
+
)
|
71
|
+
classifier.fit(x=images, y=labels, batch_size=2)
|
72
|
+
```
|
73
|
+
|
74
|
+
Call `fit()` with custom loss, optimizer and backbone.
|
75
|
+
```python
|
76
|
+
classifier = keras_hub.models.VGGImageClassifier.from_preset(
|
77
|
+
"vit_base_patch16_224"
|
78
|
+
)
|
79
|
+
classifier.compile(
|
80
|
+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
81
|
+
optimizer=keras.optimizers.Adam(5e-5),
|
82
|
+
)
|
83
|
+
classifier.backbone.trainable = False
|
84
|
+
classifier.fit(x=images, y=labels, batch_size=2)
|
85
|
+
```
|
86
|
+
|
87
|
+
Custom backbone.
|
88
|
+
```python
|
89
|
+
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
|
90
|
+
labels = [0, 3]
|
91
|
+
model = keras_hub.models.ViTBackbone(
|
92
|
+
image_shape = (224, 224, 3),
|
93
|
+
patch_size=16,
|
94
|
+
num_layers=6,
|
95
|
+
num_heads=3,
|
96
|
+
hidden_dim=768,
|
97
|
+
mlp_dim=2048
|
98
|
+
)
|
99
|
+
classifier = keras_hub.models.ViTImageClassifier(
|
100
|
+
backbone=backbone,
|
101
|
+
num_classes=4,
|
102
|
+
)
|
103
|
+
classifier.fit(x=images, y=labels, batch_size=2)
|
104
|
+
```
|
105
|
+
"""
|
106
|
+
|
107
|
+
backbone_cls = ViTBackbone
|
108
|
+
preprocessor_cls = ViTImageClassifierPreprocessor
|
109
|
+
|
110
|
+
def __init__(
|
111
|
+
self,
|
112
|
+
backbone,
|
113
|
+
num_classes,
|
114
|
+
preprocessor=None,
|
115
|
+
pooling="token",
|
116
|
+
intermediate_dim=None,
|
117
|
+
activation=None,
|
118
|
+
dropout=0.0,
|
119
|
+
head_dtype=None,
|
120
|
+
**kwargs,
|
121
|
+
):
|
122
|
+
head_dtype = head_dtype or backbone.dtype_policy
|
123
|
+
|
124
|
+
# === Layers ===
|
125
|
+
self.backbone = backbone
|
126
|
+
self.preprocessor = preprocessor
|
127
|
+
|
128
|
+
if intermediate_dim is not None:
|
129
|
+
self.intermediate_layer = keras.layers.Dense(
|
130
|
+
intermediate_dim, activation="tanh", name="pre_logits"
|
131
|
+
)
|
132
|
+
|
133
|
+
self.dropout = keras.layers.Dropout(
|
134
|
+
rate=dropout,
|
135
|
+
dtype=head_dtype,
|
136
|
+
name="output_dropout",
|
137
|
+
)
|
138
|
+
self.output_dense = keras.layers.Dense(
|
139
|
+
num_classes,
|
140
|
+
activation=activation,
|
141
|
+
dtype=head_dtype,
|
142
|
+
name="predictions",
|
143
|
+
)
|
144
|
+
|
145
|
+
# === Functional Model ===
|
146
|
+
inputs = self.backbone.input
|
147
|
+
x = self.backbone(inputs)
|
148
|
+
if pooling == "token":
|
149
|
+
x = x[:, 0]
|
150
|
+
elif pooling == "gap":
|
151
|
+
ndim = len(ops.shape(x))
|
152
|
+
x = ops.mean(x, axis=list(range(1, ndim - 1))) # (1,) or (1,2)
|
153
|
+
|
154
|
+
if intermediate_dim is not None:
|
155
|
+
x = self.intermediate_layer(x)
|
156
|
+
|
157
|
+
x = self.dropout(x)
|
158
|
+
outputs = self.output_dense(x)
|
159
|
+
|
160
|
+
# Skip the parent class functional model.
|
161
|
+
Task.__init__(
|
162
|
+
self,
|
163
|
+
inputs=inputs,
|
164
|
+
outputs=outputs,
|
165
|
+
**kwargs,
|
166
|
+
)
|
167
|
+
|
168
|
+
# === config ===
|
169
|
+
self.num_classes = num_classes
|
170
|
+
self.pooling = pooling
|
171
|
+
self.intermediate_dim = intermediate_dim
|
172
|
+
self.activation = activation
|
173
|
+
self.dropout = dropout
|
174
|
+
|
175
|
+
def get_config(self):
|
176
|
+
# Backbone serialized in `super`
|
177
|
+
config = super().get_config()
|
178
|
+
config.update(
|
179
|
+
{
|
180
|
+
"num_classes": self.num_classes,
|
181
|
+
"pooling": self.pooling,
|
182
|
+
"intermediate_dim": self.intermediate_dim,
|
183
|
+
"activation": self.activation,
|
184
|
+
"dropout": self.dropout,
|
185
|
+
}
|
186
|
+
)
|
187
|
+
return config
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.models.image_classifier_preprocessor import (
|
3
|
+
ImageClassifierPreprocessor,
|
4
|
+
)
|
5
|
+
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
|
6
|
+
from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter
|
7
|
+
|
8
|
+
|
9
|
+
@keras_hub_export("keras_hub.models.ViTImageClassifierPreprocessor")
|
10
|
+
class ViTImageClassifierPreprocessor(ImageClassifierPreprocessor):
|
11
|
+
backbone_cls = ViTBackbone
|
12
|
+
image_converter_cls = ViTImageConverter
|
@@ -0,0 +1,73 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
3
|
+
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
|
4
|
+
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
5
|
+
|
6
|
+
|
7
|
+
@keras_hub_export("keras_hub.layers.ViTImageConverter")
|
8
|
+
class ViTImageConverter(ImageConverter):
|
9
|
+
"""Converts images to the format expected by a ViT model.
|
10
|
+
|
11
|
+
This layer performs image normalization using mean and standard deviation
|
12
|
+
values. By default, it uses the same normalization as the
|
13
|
+
"google/vit-large-patch16-224" model on Hugging Face:
|
14
|
+
`norm_mean=[0.5, 0.5, 0.5]` and `norm_std=[0.5, 0.5, 0.5]`
|
15
|
+
([reference](https://huggingface.co/google/vit-large-patch16-224/blob/main/preprocessor_config.json)).
|
16
|
+
These defaults are suitable for models pretrained using this normalization.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
norm_mean: list or tuple of floats. Mean values for image normalization.
|
20
|
+
Defaults to `[0.5, 0.5, 0.5]`.
|
21
|
+
norm_std: list or tuple of floats. Standard deviation values for
|
22
|
+
image normalization. Defaults to `[0.5, 0.5, 0.5]`.
|
23
|
+
**kwargs: Additional keyword arguments passed to
|
24
|
+
`keras_hub.layers.preprocessing.ImageConverter`.
|
25
|
+
|
26
|
+
Examples:
|
27
|
+
```python
|
28
|
+
import keras
|
29
|
+
import numpy as np
|
30
|
+
from keras_hub.src.layers import ViTImageConverter
|
31
|
+
|
32
|
+
# Example image (replace with your actual image data)
|
33
|
+
image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C)
|
34
|
+
|
35
|
+
# Create a ViTImageConverter instance
|
36
|
+
converter = ViTImageConverter(
|
37
|
+
image_size=(28,28),
|
38
|
+
scale=1/255.
|
39
|
+
)
|
40
|
+
# Preprocess the image
|
41
|
+
preprocessed_image = converter(image)
|
42
|
+
```
|
43
|
+
"""
|
44
|
+
|
45
|
+
backbone_cls = ViTBackbone
|
46
|
+
|
47
|
+
def __init__(
|
48
|
+
self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs
|
49
|
+
):
|
50
|
+
super().__init__(**kwargs)
|
51
|
+
self.norm_mean = norm_mean
|
52
|
+
self.norm_std = norm_std
|
53
|
+
|
54
|
+
@preprocessing_function
|
55
|
+
def call(self, inputs):
|
56
|
+
x = super().call(inputs)
|
57
|
+
# By default normalize using imagenet mean and std
|
58
|
+
if self.norm_mean:
|
59
|
+
x = x - self._expand_non_channel_dims(self.norm_mean, x)
|
60
|
+
if self.norm_std:
|
61
|
+
x = x / self._expand_non_channel_dims(self.norm_std, x)
|
62
|
+
|
63
|
+
return x
|
64
|
+
|
65
|
+
def get_config(self):
|
66
|
+
config = super().get_config()
|
67
|
+
config.update(
|
68
|
+
{
|
69
|
+
"norm_mean": self.norm_mean,
|
70
|
+
"norm_std": self.norm_std,
|
71
|
+
}
|
72
|
+
)
|
73
|
+
return config
|
@@ -0,0 +1,391 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
5
|
+
|
6
|
+
|
7
|
+
class MLP(keras.layers.Layer):
|
8
|
+
"""Multi-Layer Perceptron (MLP) block.
|
9
|
+
|
10
|
+
Args:
|
11
|
+
hidden_dim: int. Dimensionality of the hidden representations.
|
12
|
+
mlp_dim: int. Dimensionality of the intermediate MLP layer.
|
13
|
+
use_bias: bool. Whether to use bias in the dense layers. Defaults to
|
14
|
+
`True`.
|
15
|
+
dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`.
|
16
|
+
**kwargs: Additional keyword arguments passed to `keras.layers.Layer`
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
hidden_dim,
|
22
|
+
mlp_dim,
|
23
|
+
use_bias=True,
|
24
|
+
dropout_rate=0.0,
|
25
|
+
**kwargs,
|
26
|
+
):
|
27
|
+
super().__init__(**kwargs)
|
28
|
+
|
29
|
+
# === Config ===
|
30
|
+
self.hidden_dim = hidden_dim
|
31
|
+
self.mlp_dim = mlp_dim
|
32
|
+
self.use_bias = use_bias
|
33
|
+
self.dropout_rate = dropout_rate
|
34
|
+
|
35
|
+
def build(self, input_shape):
|
36
|
+
self.dense_1 = keras.layers.Dense(
|
37
|
+
units=self.mlp_dim,
|
38
|
+
use_bias=self.use_bias,
|
39
|
+
activation="gelu",
|
40
|
+
bias_initializer=(
|
41
|
+
keras.initializers.RandomNormal(stddev=1e-6)
|
42
|
+
if self.use_bias
|
43
|
+
else None
|
44
|
+
),
|
45
|
+
dtype=self.dtype_policy,
|
46
|
+
name="dense_1",
|
47
|
+
)
|
48
|
+
self.dense_1.build(input_shape)
|
49
|
+
self.dense_2 = keras.layers.Dense(
|
50
|
+
units=self.hidden_dim,
|
51
|
+
use_bias=self.use_bias,
|
52
|
+
bias_initializer=(
|
53
|
+
keras.initializers.RandomNormal(stddev=1e-6)
|
54
|
+
if self.use_bias
|
55
|
+
else None
|
56
|
+
),
|
57
|
+
dtype=self.dtype_policy,
|
58
|
+
name="dense_2",
|
59
|
+
)
|
60
|
+
self.dense_2.build((None, None, self.mlp_dim))
|
61
|
+
self.dropout = keras.layers.Dropout(
|
62
|
+
self.dropout_rate, dtype=self.dtype_policy, name="dropout"
|
63
|
+
)
|
64
|
+
self.built = True
|
65
|
+
|
66
|
+
def call(self, inputs):
|
67
|
+
x = self.dense_1(inputs)
|
68
|
+
x = self.dense_2(x)
|
69
|
+
out = self.dropout(x)
|
70
|
+
return out
|
71
|
+
|
72
|
+
|
73
|
+
class ViTPatchingAndEmbedding(keras.layers.Layer):
|
74
|
+
"""Patches the image and embeds the patches.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
image_size: int. Size of the input image (height or width).
|
78
|
+
Assumed to be square.
|
79
|
+
patch_size: int. Size of each image patch.
|
80
|
+
hidden_dim: int. Dimensionality of the patch embeddings.
|
81
|
+
num_channels: int. Number of channels in the input image. Defaults to
|
82
|
+
`3`.
|
83
|
+
data_format: str. `"channels_last"` or `"channels_first"`. Defaults to
|
84
|
+
`None` (which uses `"channels_last"`).
|
85
|
+
**kwargs: Additional keyword arguments passed to `keras.layers.Layer`
|
86
|
+
"""
|
87
|
+
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
image_size,
|
91
|
+
patch_size,
|
92
|
+
hidden_dim,
|
93
|
+
num_channels=3,
|
94
|
+
data_format=None,
|
95
|
+
**kwargs,
|
96
|
+
):
|
97
|
+
super().__init__(**kwargs)
|
98
|
+
num_patches = (image_size // patch_size) ** 2
|
99
|
+
num_positions = num_patches + 1
|
100
|
+
|
101
|
+
# === Config ===
|
102
|
+
self.image_size = image_size
|
103
|
+
self.patch_size = patch_size
|
104
|
+
self.hidden_dim = hidden_dim
|
105
|
+
self.num_channels = num_channels
|
106
|
+
self.num_patches = num_patches
|
107
|
+
self.num_positions = num_positions
|
108
|
+
self.data_format = standardize_data_format(data_format)
|
109
|
+
|
110
|
+
def build(self, input_shape):
|
111
|
+
self.class_token = self.add_weight(
|
112
|
+
shape=(
|
113
|
+
1,
|
114
|
+
1,
|
115
|
+
self.hidden_dim,
|
116
|
+
),
|
117
|
+
initializer="random_normal",
|
118
|
+
dtype=self.variable_dtype,
|
119
|
+
name="class_token",
|
120
|
+
)
|
121
|
+
self.patch_embedding = keras.layers.Conv2D(
|
122
|
+
filters=self.hidden_dim,
|
123
|
+
kernel_size=self.patch_size,
|
124
|
+
strides=self.patch_size,
|
125
|
+
padding="valid",
|
126
|
+
activation=None,
|
127
|
+
dtype=self.dtype_policy,
|
128
|
+
data_format=self.data_format,
|
129
|
+
name="patch_embedding",
|
130
|
+
)
|
131
|
+
self.patch_embedding.build(input_shape)
|
132
|
+
self.position_embedding = keras.layers.Embedding(
|
133
|
+
self.num_positions,
|
134
|
+
self.hidden_dim,
|
135
|
+
dtype=self.dtype_policy,
|
136
|
+
embeddings_initializer=keras.initializers.RandomNormal(stddev=0.02),
|
137
|
+
name="position_embedding",
|
138
|
+
)
|
139
|
+
self.position_embedding.build((1, self.num_positions))
|
140
|
+
self.position_ids = keras.ops.expand_dims(
|
141
|
+
keras.ops.arange(self.num_positions), axis=0
|
142
|
+
)
|
143
|
+
self.built = True
|
144
|
+
|
145
|
+
def call(self, inputs):
|
146
|
+
patch_embeddings = self.patch_embedding(inputs)
|
147
|
+
if self.data_format == "channels_first":
|
148
|
+
patch_embeddings = ops.transpose(
|
149
|
+
patch_embeddings, axes=(0, 2, 3, 1)
|
150
|
+
)
|
151
|
+
embeddings_shape = ops.shape(patch_embeddings)
|
152
|
+
patch_embeddings = ops.reshape(
|
153
|
+
patch_embeddings, [embeddings_shape[0], -1, embeddings_shape[-1]]
|
154
|
+
)
|
155
|
+
class_token = ops.tile(self.class_token, (embeddings_shape[0], 1, 1))
|
156
|
+
position_embeddings = self.position_embedding(self.position_ids)
|
157
|
+
embeddings = ops.concatenate([class_token, patch_embeddings], axis=1)
|
158
|
+
return ops.add(embeddings, position_embeddings)
|
159
|
+
|
160
|
+
def compute_output_shape(self, input_shape):
|
161
|
+
return (
|
162
|
+
input_shape[0],
|
163
|
+
self.num_positions,
|
164
|
+
self.hidden_dim,
|
165
|
+
)
|
166
|
+
|
167
|
+
def get_config(self):
|
168
|
+
config = super().get_config()
|
169
|
+
config.update(
|
170
|
+
{
|
171
|
+
"image_size": self.image_size,
|
172
|
+
"patch_size": self.patch_size,
|
173
|
+
"hidden_dim": self.hidden_dim,
|
174
|
+
"num_channels": self.num_channels,
|
175
|
+
"num_patches": self.num_patches,
|
176
|
+
"num_positions": self.num_positions,
|
177
|
+
}
|
178
|
+
)
|
179
|
+
return config
|
180
|
+
|
181
|
+
|
182
|
+
class ViTEncoderBlock(keras.layers.Layer):
|
183
|
+
"""Transformer encoder block.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
num_heads: int. Number of attention heads.
|
187
|
+
hidden_dim: int. Dimensionality of the hidden representations.
|
188
|
+
mlp_dim: int. Dimensionality of the intermediate MLP layer.
|
189
|
+
use_mha_bias: bool. Whether to use bias in the multi-head attention
|
190
|
+
layer. Defaults to `True`.
|
191
|
+
use_mlp_bias: bool. Whether to use bias in the MLP layer. Defaults to
|
192
|
+
`True`.
|
193
|
+
dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`.
|
194
|
+
attention_dropout: float. Dropout rate for the attention mechanism.
|
195
|
+
Between 0 and 1. Defaults to `0.0`.
|
196
|
+
layer_norm_epsilon: float. Small float value for layer normalization
|
197
|
+
stability. Defaults to `1e-6`.
|
198
|
+
**kwargs: Additional keyword arguments passed to `keras.layers.Layer`
|
199
|
+
"""
|
200
|
+
|
201
|
+
def __init__(
|
202
|
+
self,
|
203
|
+
num_heads,
|
204
|
+
hidden_dim,
|
205
|
+
mlp_dim,
|
206
|
+
use_mha_bias=True,
|
207
|
+
use_mlp_bias=True,
|
208
|
+
dropout_rate=0.0,
|
209
|
+
attention_dropout=0.0,
|
210
|
+
layer_norm_epsilon=1e-6,
|
211
|
+
**kwargs,
|
212
|
+
):
|
213
|
+
super().__init__(**kwargs)
|
214
|
+
|
215
|
+
key_dim = hidden_dim // num_heads
|
216
|
+
|
217
|
+
# === Config ===
|
218
|
+
self.num_heads = num_heads
|
219
|
+
self.hidden_dim = hidden_dim
|
220
|
+
self.key_dim = key_dim
|
221
|
+
self.mlp_dim = mlp_dim
|
222
|
+
self.use_mha_bias = use_mha_bias
|
223
|
+
self.use_mlp_bias = use_mlp_bias
|
224
|
+
self.dropout_rate = dropout_rate
|
225
|
+
self.attention_dropout = attention_dropout
|
226
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
227
|
+
|
228
|
+
def build(self, input_shape):
|
229
|
+
# Attention block
|
230
|
+
self.layer_norm_1 = keras.layers.LayerNormalization(
|
231
|
+
epsilon=self.layer_norm_epsilon,
|
232
|
+
name="ln_1",
|
233
|
+
dtype=self.dtype_policy,
|
234
|
+
)
|
235
|
+
self.layer_norm_1.build(input_shape)
|
236
|
+
self.mha = keras.layers.MultiHeadAttention(
|
237
|
+
num_heads=self.num_heads,
|
238
|
+
key_dim=self.key_dim,
|
239
|
+
use_bias=self.use_mha_bias,
|
240
|
+
dropout=self.attention_dropout,
|
241
|
+
name="mha",
|
242
|
+
dtype=self.dtype_policy,
|
243
|
+
)
|
244
|
+
self.mha.build(input_shape, input_shape)
|
245
|
+
self.dropout = keras.layers.Dropout(
|
246
|
+
self.dropout_rate, dtype=self.dtype_policy, name="dropout"
|
247
|
+
)
|
248
|
+
|
249
|
+
# MLP block
|
250
|
+
self.layer_norm_2 = keras.layers.LayerNormalization(
|
251
|
+
epsilon=self.layer_norm_epsilon,
|
252
|
+
name="ln_2",
|
253
|
+
dtype=self.dtype_policy,
|
254
|
+
)
|
255
|
+
self.layer_norm_2.build((None, None, self.hidden_dim))
|
256
|
+
self.mlp = MLP(
|
257
|
+
hidden_dim=self.hidden_dim,
|
258
|
+
mlp_dim=self.mlp_dim,
|
259
|
+
use_bias=self.use_mlp_bias,
|
260
|
+
name="mlp",
|
261
|
+
dtype=self.dtype_policy,
|
262
|
+
)
|
263
|
+
self.mlp.build((None, None, self.hidden_dim))
|
264
|
+
self.built = True
|
265
|
+
|
266
|
+
def call(self, inputs):
|
267
|
+
x = self.layer_norm_1(inputs)
|
268
|
+
x = self.mha(x, x)
|
269
|
+
x = self.dropout(x)
|
270
|
+
x = x + inputs
|
271
|
+
|
272
|
+
y = self.layer_norm_2(x)
|
273
|
+
y = self.mlp(y)
|
274
|
+
|
275
|
+
return x + y
|
276
|
+
|
277
|
+
def get_config(self):
|
278
|
+
config = super().get_config()
|
279
|
+
config.update(
|
280
|
+
{
|
281
|
+
"num_heads": self.num_heads,
|
282
|
+
"hidden_dim": self.hidden_dim,
|
283
|
+
"key_dim": self.key_dim,
|
284
|
+
"mlp_dim": self.mlp_dim,
|
285
|
+
"use_mha_bias": self.use_mha_bias,
|
286
|
+
"use_mlp_bias": self.use_mlp_bias,
|
287
|
+
"dropout_rate": self.dropout_rate,
|
288
|
+
"attention_dropout": self.attention_dropout,
|
289
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
290
|
+
}
|
291
|
+
)
|
292
|
+
return config
|
293
|
+
|
294
|
+
|
295
|
+
class ViTEncoder(keras.layers.Layer):
|
296
|
+
"""Vision Transformer (ViT) encoder.
|
297
|
+
|
298
|
+
Args:
|
299
|
+
num_layers: int. Number of Transformer encoder blocks.
|
300
|
+
num_heads: int. Number of attention heads.
|
301
|
+
hidden_dim: int. Dimensionality of the hidden representations.
|
302
|
+
mlp_dim: int. Dimensionality of the intermediate MLP layer.
|
303
|
+
use_mha_bias: bool. Whether to use bias in the multi-head attention
|
304
|
+
layers. Defaults to `True`.
|
305
|
+
use_mlp_bias: bool. Whether to use bias in the MLP layers. Defaults to
|
306
|
+
`True`.
|
307
|
+
dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`.
|
308
|
+
attention_dropout: float. Dropout rate for the attention mechanism.
|
309
|
+
Between 0 and 1. Defaults to `0.0`.
|
310
|
+
layer_norm_epsilon: float. Small float value for layer normalization
|
311
|
+
tability. Defaults to `1e-6`.
|
312
|
+
**kwargs: Additional keyword arguments passed to `keras.layers.Layer`
|
313
|
+
"""
|
314
|
+
|
315
|
+
def __init__(
|
316
|
+
self,
|
317
|
+
num_layers,
|
318
|
+
num_heads,
|
319
|
+
hidden_dim,
|
320
|
+
mlp_dim,
|
321
|
+
use_mha_bias=True,
|
322
|
+
use_mlp_bias=True,
|
323
|
+
dropout_rate=0.0,
|
324
|
+
attention_dropout=0.0,
|
325
|
+
layer_norm_epsilon=1e-6,
|
326
|
+
**kwargs,
|
327
|
+
):
|
328
|
+
super().__init__(**kwargs)
|
329
|
+
|
330
|
+
# === config ===
|
331
|
+
self.num_layers = num_layers
|
332
|
+
self.num_heads = num_heads
|
333
|
+
self.hidden_dim = hidden_dim
|
334
|
+
self.mlp_dim = mlp_dim
|
335
|
+
self.use_mha_bias = use_mha_bias
|
336
|
+
self.use_mlp_bias = use_mlp_bias
|
337
|
+
self.dropout_rate = dropout_rate
|
338
|
+
self.attention_dropout = attention_dropout
|
339
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
340
|
+
|
341
|
+
def build(self, input_shape):
|
342
|
+
self.encoder_layers = []
|
343
|
+
for i in range(self.num_layers):
|
344
|
+
encoder_block = ViTEncoderBlock(
|
345
|
+
num_heads=self.num_heads,
|
346
|
+
hidden_dim=self.hidden_dim,
|
347
|
+
mlp_dim=self.mlp_dim,
|
348
|
+
dropout_rate=self.dropout_rate,
|
349
|
+
use_mha_bias=self.use_mha_bias,
|
350
|
+
use_mlp_bias=self.use_mlp_bias,
|
351
|
+
attention_dropout=self.attention_dropout,
|
352
|
+
layer_norm_epsilon=self.layer_norm_epsilon,
|
353
|
+
dtype=self.dtype_policy,
|
354
|
+
name=f"tranformer_block_{i + 1}",
|
355
|
+
)
|
356
|
+
encoder_block.build((None, None, self.hidden_dim))
|
357
|
+
self.encoder_layers.append(encoder_block)
|
358
|
+
self.dropout = keras.layers.Dropout(
|
359
|
+
self.dropout_rate, dtype=self.dtype_policy, name="dropout"
|
360
|
+
)
|
361
|
+
self.layer_norm = keras.layers.LayerNormalization(
|
362
|
+
epsilon=self.layer_norm_epsilon,
|
363
|
+
dtype=self.dtype_policy,
|
364
|
+
name="ln",
|
365
|
+
)
|
366
|
+
self.layer_norm.build((None, None, self.hidden_dim))
|
367
|
+
self.built = True
|
368
|
+
|
369
|
+
def call(self, inputs):
|
370
|
+
x = self.dropout(inputs)
|
371
|
+
for i in range(self.num_layers):
|
372
|
+
x = self.encoder_layers[i](x)
|
373
|
+
x = self.layer_norm(x)
|
374
|
+
return x
|
375
|
+
|
376
|
+
def get_config(self):
|
377
|
+
config = super().get_config()
|
378
|
+
config.update(
|
379
|
+
{
|
380
|
+
"num_layers": self.num_layers,
|
381
|
+
"num_heads": self.num_heads,
|
382
|
+
"hidden_dim": self.hidden_dim,
|
383
|
+
"mlp_dim": self.mlp_dim,
|
384
|
+
"use_mha_bias": self.use_mha_bias,
|
385
|
+
"use_mlp_bias": self.use_mlp_bias,
|
386
|
+
"dropout_rate": self.dropout_rate,
|
387
|
+
"attention_dropout": self.attention_dropout,
|
388
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
389
|
+
}
|
390
|
+
)
|
391
|
+
return config
|