keras-hub-nightly 0.19.0.dev202412120352__py3-none-any.whl → 0.19.0.dev202412140350__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +1 -0
- keras_hub/api/models/__init__.py +11 -6
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/converters.py +2 -2
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/rms_normalization.py +8 -6
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +3 -1
- keras_hub/src/metrics/bleu.py +1 -1
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/bert/bert_presets.py +4 -2
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/causal_lm.py +19 -15
- keras_hub/src/models/clip/clip_vision_embedding.py +1 -1
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +17 -13
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -3
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +1 -1
- keras_hub/src/models/densenet/densenet_backbone.py +3 -1
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -1
- keras_hub/src/models/densenet/densenet_presets.py +6 -6
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +2 -1
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/cba.py +1 -1
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +20 -8
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +1 -1
- keras_hub/src/models/efficientnet/efficientnet_presets.py +12 -11
- keras_hub/src/models/efficientnet/fusedmbconv.py +3 -5
- keras_hub/src/models/efficientnet/mbconv.py +1 -1
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/flux/flux_layers.py +46 -44
- keras_hub/src/models/flux/flux_maths.py +24 -17
- keras_hub/src/models/flux/flux_model.py +24 -19
- keras_hub/src/models/flux/flux_presets.py +2 -1
- keras_hub/src/models/flux/flux_text_to_image.py +7 -3
- keras_hub/src/models/gemma/gemma_backbone.py +27 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +9 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier_preprocessor.py +4 -1
- keras_hub/src/models/image_object_detector.py +2 -2
- keras_hub/src/models/image_object_detector_preprocessor.py +4 -4
- keras_hub/src/models/image_segmenter_preprocessor.py +2 -2
- keras_hub/src/models/llama/llama_backbone.py +34 -26
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +3 -3
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/mit_backbone.py +4 -3
- keras_hub/src/models/mit/mit_layers.py +2 -1
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +7 -7
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +5 -3
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +2 -2
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +2 -2
- keras_hub/src/models/retinanet/feature_pyramid.py +3 -2
- keras_hub/src/models/retinanet/prediction_head.py +2 -2
- keras_hub/src/models/retinanet/retinanet_backbone.py +2 -2
- keras_hub/src/models/retinanet/retinanet_image_converter.py +1 -1
- keras_hub/src/models/retinanet/retinanet_object_detector.py +5 -6
- keras_hub/src/models/retinanet/retinanet_presets.py +2 -1
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +4 -2
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/sam_backbone.py +2 -2
- keras_hub/src/models/sam/sam_image_segmenter.py +6 -5
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/segformer_backbone.py +18 -14
- keras_hub/src/models/segformer/segformer_image_segmenter.py +51 -38
- keras_hub/src/models/segformer/segformer_presets.py +24 -12
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +20 -1
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +1 -1
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +13 -6
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +2 -2
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +7 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/task.py +4 -2
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +5 -1
- keras_hub/src/models/vae/vae_layers.py +0 -1
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +49 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +4 -2
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +1 -3
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +2 -2
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +2 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +2 -8
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +7 -12
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +8 -5
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +7 -3
- keras_hub/src/utils/preset_utils.py +25 -18
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +2 -4
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/RECORD +148 -140
- {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,152 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.backbone import Backbone
|
5
|
+
from keras_hub.src.models.vit.vit_layers import ViTEncoder
|
6
|
+
from keras_hub.src.models.vit.vit_layers import ViTPatchingAndEmbedding
|
7
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
8
|
+
|
9
|
+
|
10
|
+
@keras_hub_export("keras_hub.models.ViTBackbone")
|
11
|
+
class ViTBackbone(Backbone):
|
12
|
+
"""Vision Transformer (ViT) backbone.
|
13
|
+
|
14
|
+
This backbone implements the Vision Transformer architecture as described in
|
15
|
+
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929).
|
16
|
+
It transforms the input image into a sequence of patches, embeds them, and
|
17
|
+
then processes them through a series of Transformer encoder layers.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
image_shape: A tuple or list of 3 integers representing the shape of the
|
21
|
+
input image `(height, width, channels)`, `height` and `width` must
|
22
|
+
be equal.
|
23
|
+
patch_size: int. The size of each image patch, the input image will be
|
24
|
+
divided into patches of shape `(patch_size, patch_size)`.
|
25
|
+
num_layers: int. The number of transformer encoder layers.
|
26
|
+
num_heads: int. specifying the number of attention heads in each
|
27
|
+
Transformer encoder layer.
|
28
|
+
hidden_dim: int. The dimensionality of the hidden representations.
|
29
|
+
mlp_dim: int. The dimensionality of the intermediate MLP layer in
|
30
|
+
each Transformer encoder layer.
|
31
|
+
dropout_rate: float. The dropout rate for the Transformer encoder
|
32
|
+
layers.
|
33
|
+
attention_dropout: float. The dropout rate for the attention mechanism
|
34
|
+
in each Transformer encoder layer.
|
35
|
+
layer_norm_epsilon: float. Value used for numerical stability in
|
36
|
+
layer normalization.
|
37
|
+
use_mha_bias: bool. Whether to use bias in the multi-head
|
38
|
+
attention layers.
|
39
|
+
use_mlp_bias: bool. Whether to use bias in the MLP layers.
|
40
|
+
data_format: str. `"channels_last"` or `"channels_first"`, specifying
|
41
|
+
the data format for the input image. If `None`, defaults to
|
42
|
+
`"channels_last"`.
|
43
|
+
dtype: The dtype of the layer weights. Defaults to None.
|
44
|
+
**kwargs: Additional keyword arguments to be passed to the parent
|
45
|
+
`Backbone` class.
|
46
|
+
"""
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
image_shape,
|
51
|
+
patch_size,
|
52
|
+
num_layers,
|
53
|
+
num_heads,
|
54
|
+
hidden_dim,
|
55
|
+
mlp_dim,
|
56
|
+
dropout_rate=0.0,
|
57
|
+
attention_dropout=0.0,
|
58
|
+
layer_norm_epsilon=1e-6,
|
59
|
+
use_mha_bias=True,
|
60
|
+
use_mlp_bias=True,
|
61
|
+
data_format=None,
|
62
|
+
dtype=None,
|
63
|
+
**kwargs,
|
64
|
+
):
|
65
|
+
# === Laters ===
|
66
|
+
data_format = standardize_data_format(data_format)
|
67
|
+
h_axis, w_axis, channels_axis = (
|
68
|
+
(-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3)
|
69
|
+
)
|
70
|
+
# Check that the input image is well specified.
|
71
|
+
if image_shape[h_axis] is None or image_shape[w_axis] is None:
|
72
|
+
raise ValueError(
|
73
|
+
f"Image shape must have defined height and width. Found `None` "
|
74
|
+
f"at index {h_axis} (height) or {w_axis} (width). "
|
75
|
+
f"Image shape: {image_shape}"
|
76
|
+
)
|
77
|
+
if image_shape[h_axis] != image_shape[w_axis]:
|
78
|
+
raise ValueError(
|
79
|
+
f"Image height and width must be equal. Found height: "
|
80
|
+
f"{image_shape[h_axis]}, width: {image_shape[w_axis]} at "
|
81
|
+
f"indices {h_axis} and {w_axis} respectively. Image shape: "
|
82
|
+
f"{image_shape}"
|
83
|
+
)
|
84
|
+
|
85
|
+
num_channels = image_shape[channels_axis]
|
86
|
+
|
87
|
+
# === Functional Model ===
|
88
|
+
inputs = keras.layers.Input(shape=image_shape)
|
89
|
+
|
90
|
+
x = ViTPatchingAndEmbedding(
|
91
|
+
image_size=image_shape[h_axis],
|
92
|
+
patch_size=patch_size,
|
93
|
+
hidden_dim=hidden_dim,
|
94
|
+
num_channels=num_channels,
|
95
|
+
data_format=data_format,
|
96
|
+
dtype=dtype,
|
97
|
+
name="vit_patching_and_embedding",
|
98
|
+
)(inputs)
|
99
|
+
|
100
|
+
output = ViTEncoder(
|
101
|
+
num_layers=num_layers,
|
102
|
+
num_heads=num_heads,
|
103
|
+
hidden_dim=hidden_dim,
|
104
|
+
mlp_dim=mlp_dim,
|
105
|
+
dropout_rate=dropout_rate,
|
106
|
+
attention_dropout=attention_dropout,
|
107
|
+
layer_norm_epsilon=layer_norm_epsilon,
|
108
|
+
use_mha_bias=use_mha_bias,
|
109
|
+
use_mlp_bias=use_mlp_bias,
|
110
|
+
dtype=dtype,
|
111
|
+
name="vit_encoder",
|
112
|
+
)(x)
|
113
|
+
|
114
|
+
super().__init__(
|
115
|
+
inputs=inputs,
|
116
|
+
outputs=output,
|
117
|
+
dtype=dtype,
|
118
|
+
**kwargs,
|
119
|
+
)
|
120
|
+
|
121
|
+
# === Config ===
|
122
|
+
self.image_shape = image_shape
|
123
|
+
self.patch_size = patch_size
|
124
|
+
self.num_layers = num_layers
|
125
|
+
self.num_heads = num_heads
|
126
|
+
self.hidden_dim = hidden_dim
|
127
|
+
self.mlp_dim = mlp_dim
|
128
|
+
self.dropout_rate = dropout_rate
|
129
|
+
self.attention_dropout = attention_dropout
|
130
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
131
|
+
self.use_mha_bias = use_mha_bias
|
132
|
+
self.use_mlp_bias = use_mlp_bias
|
133
|
+
self.data_format = data_format
|
134
|
+
|
135
|
+
def get_config(self):
|
136
|
+
config = super().get_config()
|
137
|
+
config.update(
|
138
|
+
{
|
139
|
+
"image_shape": self.image_shape,
|
140
|
+
"patch_size": self.patch_size,
|
141
|
+
"num_layers": self.num_layers,
|
142
|
+
"num_heads": self.num_heads,
|
143
|
+
"hidden_dim": self.hidden_dim,
|
144
|
+
"mlp_dim": self.mlp_dim,
|
145
|
+
"dropout_rate": self.dropout_rate,
|
146
|
+
"attention_dropout": self.attention_dropout,
|
147
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
148
|
+
"use_mha_bias": self.use_mha_bias,
|
149
|
+
"use_mlp_bias": self.use_mlp_bias,
|
150
|
+
}
|
151
|
+
)
|
152
|
+
return config
|
@@ -0,0 +1,187 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
5
|
+
from keras_hub.src.models.image_classifier import ImageClassifier
|
6
|
+
from keras_hub.src.models.task import Task
|
7
|
+
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
|
8
|
+
from keras_hub.src.models.vit.vit_image_classifier_preprocessor import (
|
9
|
+
ViTImageClassifierPreprocessor,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
@keras_hub_export("keras_hub.models.ViTImageClassifier")
|
14
|
+
class ViTImageClassifier(ImageClassifier):
|
15
|
+
"""ViT image classification task.
|
16
|
+
|
17
|
+
`ViTImageClassifier` tasks wrap a `keras_hub.models.ViTBackbone` and
|
18
|
+
a `keras_hub.models.Preprocessor` to create a model that can be used for
|
19
|
+
image classification. `ViTImageClassifier` tasks take an additional
|
20
|
+
`num_classes` argument, controlling the number of predicted output classes.
|
21
|
+
|
22
|
+
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
|
23
|
+
labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
|
24
|
+
|
25
|
+
Not that unlike `keras_hub.model.ImageClassifier`, the `ViTImageClassifier`
|
26
|
+
we pluck out `cls_token` which is first seqence from the backbone.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
backbone: A `keras_hub.models.ViTBackbone` instance or a `keras.Model`.
|
30
|
+
num_classes: int. The number of classes to predict.
|
31
|
+
preprocessor: `None`, a `keras_hub.models.Preprocessor` instance,
|
32
|
+
a `keras.Layer` instance, or a callable. If `None` no preprocessing
|
33
|
+
will be applied to the inputs.
|
34
|
+
pooling: String specifying the classification strategy. The choice
|
35
|
+
impacts the dimensionality and nature of the feature vector used for
|
36
|
+
classification.
|
37
|
+
`"token"`: A single vector (class token) representing the
|
38
|
+
overall image features.
|
39
|
+
`"gap"`: A single vector representing the average features
|
40
|
+
across the spatial dimensions.
|
41
|
+
intermediate_dim: Optional dimensionality of the intermediate
|
42
|
+
representation layer before the final classification layer.
|
43
|
+
If `None`, the output of the transformer is directly used.
|
44
|
+
Defaults to `None`.
|
45
|
+
activation: `None`, str, or callable. The activation function to use on
|
46
|
+
the `Dense` layer. Set `activation=None` to return the output
|
47
|
+
logits. Defaults to `"softmax"`.
|
48
|
+
head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The
|
49
|
+
dtype to use for the classification head's computations and weights.
|
50
|
+
|
51
|
+
Examples:
|
52
|
+
|
53
|
+
Call `predict()` to run inference.
|
54
|
+
```python
|
55
|
+
# Load preset and train
|
56
|
+
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
|
57
|
+
classifier = keras_hub.models.ViTImageClassifier.from_preset(
|
58
|
+
"vgg_16_imagenet"
|
59
|
+
)
|
60
|
+
classifier.predict(images)
|
61
|
+
```
|
62
|
+
|
63
|
+
Call `fit()` on a single batch.
|
64
|
+
```python
|
65
|
+
# Load preset and train
|
66
|
+
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
|
67
|
+
labels = [0, 3]
|
68
|
+
classifier = keras_hub.models.VGGImageClassifier.from_preset(
|
69
|
+
"vit_base_patch16_224"
|
70
|
+
)
|
71
|
+
classifier.fit(x=images, y=labels, batch_size=2)
|
72
|
+
```
|
73
|
+
|
74
|
+
Call `fit()` with custom loss, optimizer and backbone.
|
75
|
+
```python
|
76
|
+
classifier = keras_hub.models.VGGImageClassifier.from_preset(
|
77
|
+
"vit_base_patch16_224"
|
78
|
+
)
|
79
|
+
classifier.compile(
|
80
|
+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
81
|
+
optimizer=keras.optimizers.Adam(5e-5),
|
82
|
+
)
|
83
|
+
classifier.backbone.trainable = False
|
84
|
+
classifier.fit(x=images, y=labels, batch_size=2)
|
85
|
+
```
|
86
|
+
|
87
|
+
Custom backbone.
|
88
|
+
```python
|
89
|
+
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
|
90
|
+
labels = [0, 3]
|
91
|
+
model = keras_hub.models.ViTBackbone(
|
92
|
+
image_shape = (224, 224, 3),
|
93
|
+
patch_size=16,
|
94
|
+
num_layers=6,
|
95
|
+
num_heads=3,
|
96
|
+
hidden_dim=768,
|
97
|
+
mlp_dim=2048
|
98
|
+
)
|
99
|
+
classifier = keras_hub.models.ViTImageClassifier(
|
100
|
+
backbone=backbone,
|
101
|
+
num_classes=4,
|
102
|
+
)
|
103
|
+
classifier.fit(x=images, y=labels, batch_size=2)
|
104
|
+
```
|
105
|
+
"""
|
106
|
+
|
107
|
+
backbone_cls = ViTBackbone
|
108
|
+
preprocessor_cls = ViTImageClassifierPreprocessor
|
109
|
+
|
110
|
+
def __init__(
|
111
|
+
self,
|
112
|
+
backbone,
|
113
|
+
num_classes,
|
114
|
+
preprocessor=None,
|
115
|
+
pooling="token",
|
116
|
+
intermediate_dim=None,
|
117
|
+
activation=None,
|
118
|
+
dropout=0.0,
|
119
|
+
head_dtype=None,
|
120
|
+
**kwargs,
|
121
|
+
):
|
122
|
+
head_dtype = head_dtype or backbone.dtype_policy
|
123
|
+
|
124
|
+
# === Layers ===
|
125
|
+
self.backbone = backbone
|
126
|
+
self.preprocessor = preprocessor
|
127
|
+
|
128
|
+
if intermediate_dim is not None:
|
129
|
+
self.intermediate_layer = keras.layers.Dense(
|
130
|
+
intermediate_dim, activation="tanh", name="pre_logits"
|
131
|
+
)
|
132
|
+
|
133
|
+
self.dropout = keras.layers.Dropout(
|
134
|
+
rate=dropout,
|
135
|
+
dtype=head_dtype,
|
136
|
+
name="output_dropout",
|
137
|
+
)
|
138
|
+
self.output_dense = keras.layers.Dense(
|
139
|
+
num_classes,
|
140
|
+
activation=activation,
|
141
|
+
dtype=head_dtype,
|
142
|
+
name="predictions",
|
143
|
+
)
|
144
|
+
|
145
|
+
# === Functional Model ===
|
146
|
+
inputs = self.backbone.input
|
147
|
+
x = self.backbone(inputs)
|
148
|
+
if pooling == "token":
|
149
|
+
x = x[:, 0]
|
150
|
+
elif pooling == "gap":
|
151
|
+
ndim = len(ops.shape(x))
|
152
|
+
x = ops.mean(x, axis=list(range(1, ndim - 1))) # (1,) or (1,2)
|
153
|
+
|
154
|
+
if intermediate_dim is not None:
|
155
|
+
x = self.intermediate_layer(x)
|
156
|
+
|
157
|
+
x = self.dropout(x)
|
158
|
+
outputs = self.output_dense(x)
|
159
|
+
|
160
|
+
# Skip the parent class functional model.
|
161
|
+
Task.__init__(
|
162
|
+
self,
|
163
|
+
inputs=inputs,
|
164
|
+
outputs=outputs,
|
165
|
+
**kwargs,
|
166
|
+
)
|
167
|
+
|
168
|
+
# === config ===
|
169
|
+
self.num_classes = num_classes
|
170
|
+
self.pooling = pooling
|
171
|
+
self.intermediate_dim = intermediate_dim
|
172
|
+
self.activation = activation
|
173
|
+
self.dropout = dropout
|
174
|
+
|
175
|
+
def get_config(self):
|
176
|
+
# Backbone serialized in `super`
|
177
|
+
config = super().get_config()
|
178
|
+
config.update(
|
179
|
+
{
|
180
|
+
"num_classes": self.num_classes,
|
181
|
+
"pooling": self.pooling,
|
182
|
+
"intermediate_dim": self.intermediate_dim,
|
183
|
+
"activation": self.activation,
|
184
|
+
"dropout": self.dropout,
|
185
|
+
}
|
186
|
+
)
|
187
|
+
return config
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.models.image_classifier_preprocessor import (
|
3
|
+
ImageClassifierPreprocessor,
|
4
|
+
)
|
5
|
+
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
|
6
|
+
from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter
|
7
|
+
|
8
|
+
|
9
|
+
@keras_hub_export("keras_hub.models.ViTImageClassifierPreprocessor")
|
10
|
+
class ViTImageClassifierPreprocessor(ImageClassifierPreprocessor):
|
11
|
+
backbone_cls = ViTBackbone
|
12
|
+
image_converter_cls = ViTImageConverter
|
@@ -0,0 +1,73 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
3
|
+
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
|
4
|
+
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
5
|
+
|
6
|
+
|
7
|
+
@keras_hub_export("keras_hub.layers.ViTImageConverter")
|
8
|
+
class ViTImageConverter(ImageConverter):
|
9
|
+
"""Converts images to the format expected by a ViT model.
|
10
|
+
|
11
|
+
This layer performs image normalization using mean and standard deviation
|
12
|
+
values. By default, it uses the same normalization as the
|
13
|
+
"google/vit-large-patch16-224" model on Hugging Face:
|
14
|
+
`norm_mean=[0.5, 0.5, 0.5]` and `norm_std=[0.5, 0.5, 0.5]`
|
15
|
+
([reference](https://huggingface.co/google/vit-large-patch16-224/blob/main/preprocessor_config.json)).
|
16
|
+
These defaults are suitable for models pretrained using this normalization.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
norm_mean: list or tuple of floats. Mean values for image normalization.
|
20
|
+
Defaults to `[0.5, 0.5, 0.5]`.
|
21
|
+
norm_std: list or tuple of floats. Standard deviation values for
|
22
|
+
image normalization. Defaults to `[0.5, 0.5, 0.5]`.
|
23
|
+
**kwargs: Additional keyword arguments passed to
|
24
|
+
`keras_hub.layers.preprocessing.ImageConverter`.
|
25
|
+
|
26
|
+
Examples:
|
27
|
+
```python
|
28
|
+
import keras
|
29
|
+
import numpy as np
|
30
|
+
from keras_hub.src.layers import ViTImageConverter
|
31
|
+
|
32
|
+
# Example image (replace with your actual image data)
|
33
|
+
image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C)
|
34
|
+
|
35
|
+
# Create a ViTImageConverter instance
|
36
|
+
converter = ViTImageConverter(
|
37
|
+
image_size=(28,28),
|
38
|
+
scale=1/255.
|
39
|
+
)
|
40
|
+
# Preprocess the image
|
41
|
+
preprocessed_image = converter(image)
|
42
|
+
```
|
43
|
+
"""
|
44
|
+
|
45
|
+
backbone_cls = ViTBackbone
|
46
|
+
|
47
|
+
def __init__(
|
48
|
+
self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs
|
49
|
+
):
|
50
|
+
super().__init__(**kwargs)
|
51
|
+
self.norm_mean = norm_mean
|
52
|
+
self.norm_std = norm_std
|
53
|
+
|
54
|
+
@preprocessing_function
|
55
|
+
def call(self, inputs):
|
56
|
+
x = super().call(inputs)
|
57
|
+
# By default normalize using imagenet mean and std
|
58
|
+
if self.norm_mean:
|
59
|
+
x = x - self._expand_non_channel_dims(self.norm_mean, x)
|
60
|
+
if self.norm_std:
|
61
|
+
x = x / self._expand_non_channel_dims(self.norm_std, x)
|
62
|
+
|
63
|
+
return x
|
64
|
+
|
65
|
+
def get_config(self):
|
66
|
+
config = super().get_config()
|
67
|
+
config.update(
|
68
|
+
{
|
69
|
+
"norm_mean": self.norm_mean,
|
70
|
+
"norm_std": self.norm_std,
|
71
|
+
}
|
72
|
+
)
|
73
|
+
return config
|