keras-hub-nightly 0.19.0.dev202412120352__py3-none-any.whl → 0.19.0.dev202412140350__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. keras_hub/api/layers/__init__.py +1 -0
  2. keras_hub/api/models/__init__.py +11 -6
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/converters.py +2 -2
  5. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  6. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  7. keras_hub/src/layers/modeling/rms_normalization.py +8 -6
  8. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  9. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  10. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  11. keras_hub/src/layers/modeling/transformer_encoder.py +3 -1
  12. keras_hub/src/metrics/bleu.py +1 -1
  13. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  14. keras_hub/src/models/bart/bart_backbone.py +4 -4
  15. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  16. keras_hub/src/models/bert/bert_presets.py +4 -2
  17. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  18. keras_hub/src/models/causal_lm.py +19 -15
  19. keras_hub/src/models/clip/clip_vision_embedding.py +1 -1
  20. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  21. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  22. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  23. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  24. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  25. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  26. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +17 -13
  27. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -3
  28. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +1 -1
  29. keras_hub/src/models/densenet/densenet_backbone.py +3 -1
  30. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -1
  31. keras_hub/src/models/densenet/densenet_presets.py +6 -6
  32. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  33. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  34. keras_hub/src/models/distil_bert/distil_bert_presets.py +2 -1
  35. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  36. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  37. keras_hub/src/models/efficientnet/cba.py +1 -1
  38. keras_hub/src/models/efficientnet/efficientnet_backbone.py +20 -8
  39. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +1 -1
  40. keras_hub/src/models/efficientnet/efficientnet_presets.py +12 -11
  41. keras_hub/src/models/efficientnet/fusedmbconv.py +3 -5
  42. keras_hub/src/models/efficientnet/mbconv.py +1 -1
  43. keras_hub/src/models/electra/electra_backbone.py +2 -2
  44. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  45. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  46. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  47. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  48. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  49. keras_hub/src/models/flux/flux_layers.py +46 -44
  50. keras_hub/src/models/flux/flux_maths.py +24 -17
  51. keras_hub/src/models/flux/flux_model.py +24 -19
  52. keras_hub/src/models/flux/flux_presets.py +2 -1
  53. keras_hub/src/models/flux/flux_text_to_image.py +7 -3
  54. keras_hub/src/models/gemma/gemma_backbone.py +27 -20
  55. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  56. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  57. keras_hub/src/models/gemma/gemma_presets.py +9 -3
  58. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  59. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  60. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  61. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  62. keras_hub/src/models/image_classifier_preprocessor.py +4 -1
  63. keras_hub/src/models/image_object_detector.py +2 -2
  64. keras_hub/src/models/image_object_detector_preprocessor.py +4 -4
  65. keras_hub/src/models/image_segmenter_preprocessor.py +2 -2
  66. keras_hub/src/models/llama/llama_backbone.py +34 -26
  67. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  68. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  69. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  70. keras_hub/src/models/mistral/mistral_causal_lm.py +3 -3
  71. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  72. keras_hub/src/models/mit/mit_backbone.py +4 -3
  73. keras_hub/src/models/mit/mit_layers.py +2 -1
  74. keras_hub/src/models/mobilenet/mobilenet_backbone.py +7 -7
  75. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  76. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +5 -3
  77. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +2 -2
  78. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  79. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  80. keras_hub/src/models/preprocessor.py +2 -2
  81. keras_hub/src/models/retinanet/feature_pyramid.py +3 -2
  82. keras_hub/src/models/retinanet/prediction_head.py +2 -2
  83. keras_hub/src/models/retinanet/retinanet_backbone.py +2 -2
  84. keras_hub/src/models/retinanet/retinanet_image_converter.py +1 -1
  85. keras_hub/src/models/retinanet/retinanet_object_detector.py +5 -6
  86. keras_hub/src/models/retinanet/retinanet_presets.py +2 -1
  87. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  88. keras_hub/src/models/roberta/roberta_presets.py +4 -2
  89. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  90. keras_hub/src/models/sam/sam_backbone.py +2 -2
  91. keras_hub/src/models/sam/sam_image_segmenter.py +6 -5
  92. keras_hub/src/models/sam/sam_layers.py +5 -3
  93. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  94. keras_hub/src/models/sam/sam_transformer.py +5 -4
  95. keras_hub/src/models/segformer/segformer_backbone.py +18 -14
  96. keras_hub/src/models/segformer/segformer_image_segmenter.py +51 -38
  97. keras_hub/src/models/segformer/segformer_presets.py +24 -12
  98. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  99. keras_hub/src/models/stable_diffusion_3/mmdit.py +20 -1
  100. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +1 -1
  101. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +13 -6
  102. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +2 -2
  103. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +7 -3
  104. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  105. keras_hub/src/models/task.py +4 -2
  106. keras_hub/src/models/text_classifier.py +2 -2
  107. keras_hub/src/models/text_to_image.py +5 -1
  108. keras_hub/src/models/vae/vae_layers.py +0 -1
  109. keras_hub/src/models/vit/__init__.py +5 -0
  110. keras_hub/src/models/vit/vit_backbone.py +152 -0
  111. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  112. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  113. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  114. keras_hub/src/models/vit/vit_layers.py +391 -0
  115. keras_hub/src/models/vit/vit_presets.py +49 -0
  116. keras_hub/src/models/vit_det/vit_det_backbone.py +4 -2
  117. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  118. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -3
  119. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  120. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  121. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  122. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  123. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  124. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  125. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  126. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  127. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  128. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  129. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  130. keras_hub/src/samplers/sampler.py +2 -1
  131. keras_hub/src/tests/test_case.py +2 -2
  132. keras_hub/src/tokenizers/byte_pair_tokenizer.py +2 -2
  133. keras_hub/src/tokenizers/byte_tokenizer.py +2 -8
  134. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  135. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +7 -12
  136. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +8 -5
  137. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +7 -3
  138. keras_hub/src/utils/preset_utils.py +25 -18
  139. keras_hub/src/utils/tensor_utils.py +4 -4
  140. keras_hub/src/utils/timm/convert_efficientnet.py +2 -4
  141. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  142. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  143. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  144. keras_hub/src/version_utils.py +1 -1
  145. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/METADATA +1 -1
  146. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/RECORD +148 -140
  147. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/WHEEL +0 -0
  148. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,152 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.backbone import Backbone
5
+ from keras_hub.src.models.vit.vit_layers import ViTEncoder
6
+ from keras_hub.src.models.vit.vit_layers import ViTPatchingAndEmbedding
7
+ from keras_hub.src.utils.keras_utils import standardize_data_format
8
+
9
+
10
+ @keras_hub_export("keras_hub.models.ViTBackbone")
11
+ class ViTBackbone(Backbone):
12
+ """Vision Transformer (ViT) backbone.
13
+
14
+ This backbone implements the Vision Transformer architecture as described in
15
+ [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929).
16
+ It transforms the input image into a sequence of patches, embeds them, and
17
+ then processes them through a series of Transformer encoder layers.
18
+
19
+ Args:
20
+ image_shape: A tuple or list of 3 integers representing the shape of the
21
+ input image `(height, width, channels)`, `height` and `width` must
22
+ be equal.
23
+ patch_size: int. The size of each image patch, the input image will be
24
+ divided into patches of shape `(patch_size, patch_size)`.
25
+ num_layers: int. The number of transformer encoder layers.
26
+ num_heads: int. specifying the number of attention heads in each
27
+ Transformer encoder layer.
28
+ hidden_dim: int. The dimensionality of the hidden representations.
29
+ mlp_dim: int. The dimensionality of the intermediate MLP layer in
30
+ each Transformer encoder layer.
31
+ dropout_rate: float. The dropout rate for the Transformer encoder
32
+ layers.
33
+ attention_dropout: float. The dropout rate for the attention mechanism
34
+ in each Transformer encoder layer.
35
+ layer_norm_epsilon: float. Value used for numerical stability in
36
+ layer normalization.
37
+ use_mha_bias: bool. Whether to use bias in the multi-head
38
+ attention layers.
39
+ use_mlp_bias: bool. Whether to use bias in the MLP layers.
40
+ data_format: str. `"channels_last"` or `"channels_first"`, specifying
41
+ the data format for the input image. If `None`, defaults to
42
+ `"channels_last"`.
43
+ dtype: The dtype of the layer weights. Defaults to None.
44
+ **kwargs: Additional keyword arguments to be passed to the parent
45
+ `Backbone` class.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ image_shape,
51
+ patch_size,
52
+ num_layers,
53
+ num_heads,
54
+ hidden_dim,
55
+ mlp_dim,
56
+ dropout_rate=0.0,
57
+ attention_dropout=0.0,
58
+ layer_norm_epsilon=1e-6,
59
+ use_mha_bias=True,
60
+ use_mlp_bias=True,
61
+ data_format=None,
62
+ dtype=None,
63
+ **kwargs,
64
+ ):
65
+ # === Laters ===
66
+ data_format = standardize_data_format(data_format)
67
+ h_axis, w_axis, channels_axis = (
68
+ (-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3)
69
+ )
70
+ # Check that the input image is well specified.
71
+ if image_shape[h_axis] is None or image_shape[w_axis] is None:
72
+ raise ValueError(
73
+ f"Image shape must have defined height and width. Found `None` "
74
+ f"at index {h_axis} (height) or {w_axis} (width). "
75
+ f"Image shape: {image_shape}"
76
+ )
77
+ if image_shape[h_axis] != image_shape[w_axis]:
78
+ raise ValueError(
79
+ f"Image height and width must be equal. Found height: "
80
+ f"{image_shape[h_axis]}, width: {image_shape[w_axis]} at "
81
+ f"indices {h_axis} and {w_axis} respectively. Image shape: "
82
+ f"{image_shape}"
83
+ )
84
+
85
+ num_channels = image_shape[channels_axis]
86
+
87
+ # === Functional Model ===
88
+ inputs = keras.layers.Input(shape=image_shape)
89
+
90
+ x = ViTPatchingAndEmbedding(
91
+ image_size=image_shape[h_axis],
92
+ patch_size=patch_size,
93
+ hidden_dim=hidden_dim,
94
+ num_channels=num_channels,
95
+ data_format=data_format,
96
+ dtype=dtype,
97
+ name="vit_patching_and_embedding",
98
+ )(inputs)
99
+
100
+ output = ViTEncoder(
101
+ num_layers=num_layers,
102
+ num_heads=num_heads,
103
+ hidden_dim=hidden_dim,
104
+ mlp_dim=mlp_dim,
105
+ dropout_rate=dropout_rate,
106
+ attention_dropout=attention_dropout,
107
+ layer_norm_epsilon=layer_norm_epsilon,
108
+ use_mha_bias=use_mha_bias,
109
+ use_mlp_bias=use_mlp_bias,
110
+ dtype=dtype,
111
+ name="vit_encoder",
112
+ )(x)
113
+
114
+ super().__init__(
115
+ inputs=inputs,
116
+ outputs=output,
117
+ dtype=dtype,
118
+ **kwargs,
119
+ )
120
+
121
+ # === Config ===
122
+ self.image_shape = image_shape
123
+ self.patch_size = patch_size
124
+ self.num_layers = num_layers
125
+ self.num_heads = num_heads
126
+ self.hidden_dim = hidden_dim
127
+ self.mlp_dim = mlp_dim
128
+ self.dropout_rate = dropout_rate
129
+ self.attention_dropout = attention_dropout
130
+ self.layer_norm_epsilon = layer_norm_epsilon
131
+ self.use_mha_bias = use_mha_bias
132
+ self.use_mlp_bias = use_mlp_bias
133
+ self.data_format = data_format
134
+
135
+ def get_config(self):
136
+ config = super().get_config()
137
+ config.update(
138
+ {
139
+ "image_shape": self.image_shape,
140
+ "patch_size": self.patch_size,
141
+ "num_layers": self.num_layers,
142
+ "num_heads": self.num_heads,
143
+ "hidden_dim": self.hidden_dim,
144
+ "mlp_dim": self.mlp_dim,
145
+ "dropout_rate": self.dropout_rate,
146
+ "attention_dropout": self.attention_dropout,
147
+ "layer_norm_epsilon": self.layer_norm_epsilon,
148
+ "use_mha_bias": self.use_mha_bias,
149
+ "use_mlp_bias": self.use_mlp_bias,
150
+ }
151
+ )
152
+ return config
@@ -0,0 +1,187 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.models.image_classifier import ImageClassifier
6
+ from keras_hub.src.models.task import Task
7
+ from keras_hub.src.models.vit.vit_backbone import ViTBackbone
8
+ from keras_hub.src.models.vit.vit_image_classifier_preprocessor import (
9
+ ViTImageClassifierPreprocessor,
10
+ )
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.ViTImageClassifier")
14
+ class ViTImageClassifier(ImageClassifier):
15
+ """ViT image classification task.
16
+
17
+ `ViTImageClassifier` tasks wrap a `keras_hub.models.ViTBackbone` and
18
+ a `keras_hub.models.Preprocessor` to create a model that can be used for
19
+ image classification. `ViTImageClassifier` tasks take an additional
20
+ `num_classes` argument, controlling the number of predicted output classes.
21
+
22
+ To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
23
+ labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
24
+
25
+ Not that unlike `keras_hub.model.ImageClassifier`, the `ViTImageClassifier`
26
+ we pluck out `cls_token` which is first seqence from the backbone.
27
+
28
+ Args:
29
+ backbone: A `keras_hub.models.ViTBackbone` instance or a `keras.Model`.
30
+ num_classes: int. The number of classes to predict.
31
+ preprocessor: `None`, a `keras_hub.models.Preprocessor` instance,
32
+ a `keras.Layer` instance, or a callable. If `None` no preprocessing
33
+ will be applied to the inputs.
34
+ pooling: String specifying the classification strategy. The choice
35
+ impacts the dimensionality and nature of the feature vector used for
36
+ classification.
37
+ `"token"`: A single vector (class token) representing the
38
+ overall image features.
39
+ `"gap"`: A single vector representing the average features
40
+ across the spatial dimensions.
41
+ intermediate_dim: Optional dimensionality of the intermediate
42
+ representation layer before the final classification layer.
43
+ If `None`, the output of the transformer is directly used.
44
+ Defaults to `None`.
45
+ activation: `None`, str, or callable. The activation function to use on
46
+ the `Dense` layer. Set `activation=None` to return the output
47
+ logits. Defaults to `"softmax"`.
48
+ head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The
49
+ dtype to use for the classification head's computations and weights.
50
+
51
+ Examples:
52
+
53
+ Call `predict()` to run inference.
54
+ ```python
55
+ # Load preset and train
56
+ images = np.random.randint(0, 256, size=(2, 224, 224, 3))
57
+ classifier = keras_hub.models.ViTImageClassifier.from_preset(
58
+ "vgg_16_imagenet"
59
+ )
60
+ classifier.predict(images)
61
+ ```
62
+
63
+ Call `fit()` on a single batch.
64
+ ```python
65
+ # Load preset and train
66
+ images = np.random.randint(0, 256, size=(2, 224, 224, 3))
67
+ labels = [0, 3]
68
+ classifier = keras_hub.models.VGGImageClassifier.from_preset(
69
+ "vit_base_patch16_224"
70
+ )
71
+ classifier.fit(x=images, y=labels, batch_size=2)
72
+ ```
73
+
74
+ Call `fit()` with custom loss, optimizer and backbone.
75
+ ```python
76
+ classifier = keras_hub.models.VGGImageClassifier.from_preset(
77
+ "vit_base_patch16_224"
78
+ )
79
+ classifier.compile(
80
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
81
+ optimizer=keras.optimizers.Adam(5e-5),
82
+ )
83
+ classifier.backbone.trainable = False
84
+ classifier.fit(x=images, y=labels, batch_size=2)
85
+ ```
86
+
87
+ Custom backbone.
88
+ ```python
89
+ images = np.random.randint(0, 256, size=(2, 224, 224, 3))
90
+ labels = [0, 3]
91
+ model = keras_hub.models.ViTBackbone(
92
+ image_shape = (224, 224, 3),
93
+ patch_size=16,
94
+ num_layers=6,
95
+ num_heads=3,
96
+ hidden_dim=768,
97
+ mlp_dim=2048
98
+ )
99
+ classifier = keras_hub.models.ViTImageClassifier(
100
+ backbone=backbone,
101
+ num_classes=4,
102
+ )
103
+ classifier.fit(x=images, y=labels, batch_size=2)
104
+ ```
105
+ """
106
+
107
+ backbone_cls = ViTBackbone
108
+ preprocessor_cls = ViTImageClassifierPreprocessor
109
+
110
+ def __init__(
111
+ self,
112
+ backbone,
113
+ num_classes,
114
+ preprocessor=None,
115
+ pooling="token",
116
+ intermediate_dim=None,
117
+ activation=None,
118
+ dropout=0.0,
119
+ head_dtype=None,
120
+ **kwargs,
121
+ ):
122
+ head_dtype = head_dtype or backbone.dtype_policy
123
+
124
+ # === Layers ===
125
+ self.backbone = backbone
126
+ self.preprocessor = preprocessor
127
+
128
+ if intermediate_dim is not None:
129
+ self.intermediate_layer = keras.layers.Dense(
130
+ intermediate_dim, activation="tanh", name="pre_logits"
131
+ )
132
+
133
+ self.dropout = keras.layers.Dropout(
134
+ rate=dropout,
135
+ dtype=head_dtype,
136
+ name="output_dropout",
137
+ )
138
+ self.output_dense = keras.layers.Dense(
139
+ num_classes,
140
+ activation=activation,
141
+ dtype=head_dtype,
142
+ name="predictions",
143
+ )
144
+
145
+ # === Functional Model ===
146
+ inputs = self.backbone.input
147
+ x = self.backbone(inputs)
148
+ if pooling == "token":
149
+ x = x[:, 0]
150
+ elif pooling == "gap":
151
+ ndim = len(ops.shape(x))
152
+ x = ops.mean(x, axis=list(range(1, ndim - 1))) # (1,) or (1,2)
153
+
154
+ if intermediate_dim is not None:
155
+ x = self.intermediate_layer(x)
156
+
157
+ x = self.dropout(x)
158
+ outputs = self.output_dense(x)
159
+
160
+ # Skip the parent class functional model.
161
+ Task.__init__(
162
+ self,
163
+ inputs=inputs,
164
+ outputs=outputs,
165
+ **kwargs,
166
+ )
167
+
168
+ # === config ===
169
+ self.num_classes = num_classes
170
+ self.pooling = pooling
171
+ self.intermediate_dim = intermediate_dim
172
+ self.activation = activation
173
+ self.dropout = dropout
174
+
175
+ def get_config(self):
176
+ # Backbone serialized in `super`
177
+ config = super().get_config()
178
+ config.update(
179
+ {
180
+ "num_classes": self.num_classes,
181
+ "pooling": self.pooling,
182
+ "intermediate_dim": self.intermediate_dim,
183
+ "activation": self.activation,
184
+ "dropout": self.dropout,
185
+ }
186
+ )
187
+ return config
@@ -0,0 +1,12 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.image_classifier_preprocessor import (
3
+ ImageClassifierPreprocessor,
4
+ )
5
+ from keras_hub.src.models.vit.vit_backbone import ViTBackbone
6
+ from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter
7
+
8
+
9
+ @keras_hub_export("keras_hub.models.ViTImageClassifierPreprocessor")
10
+ class ViTImageClassifierPreprocessor(ImageClassifierPreprocessor):
11
+ backbone_cls = ViTBackbone
12
+ image_converter_cls = ViTImageConverter
@@ -0,0 +1,73 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
+ from keras_hub.src.models.vit.vit_backbone import ViTBackbone
4
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
5
+
6
+
7
+ @keras_hub_export("keras_hub.layers.ViTImageConverter")
8
+ class ViTImageConverter(ImageConverter):
9
+ """Converts images to the format expected by a ViT model.
10
+
11
+ This layer performs image normalization using mean and standard deviation
12
+ values. By default, it uses the same normalization as the
13
+ "google/vit-large-patch16-224" model on Hugging Face:
14
+ `norm_mean=[0.5, 0.5, 0.5]` and `norm_std=[0.5, 0.5, 0.5]`
15
+ ([reference](https://huggingface.co/google/vit-large-patch16-224/blob/main/preprocessor_config.json)).
16
+ These defaults are suitable for models pretrained using this normalization.
17
+
18
+ Args:
19
+ norm_mean: list or tuple of floats. Mean values for image normalization.
20
+ Defaults to `[0.5, 0.5, 0.5]`.
21
+ norm_std: list or tuple of floats. Standard deviation values for
22
+ image normalization. Defaults to `[0.5, 0.5, 0.5]`.
23
+ **kwargs: Additional keyword arguments passed to
24
+ `keras_hub.layers.preprocessing.ImageConverter`.
25
+
26
+ Examples:
27
+ ```python
28
+ import keras
29
+ import numpy as np
30
+ from keras_hub.src.layers import ViTImageConverter
31
+
32
+ # Example image (replace with your actual image data)
33
+ image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C)
34
+
35
+ # Create a ViTImageConverter instance
36
+ converter = ViTImageConverter(
37
+ image_size=(28,28),
38
+ scale=1/255.
39
+ )
40
+ # Preprocess the image
41
+ preprocessed_image = converter(image)
42
+ ```
43
+ """
44
+
45
+ backbone_cls = ViTBackbone
46
+
47
+ def __init__(
48
+ self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs
49
+ ):
50
+ super().__init__(**kwargs)
51
+ self.norm_mean = norm_mean
52
+ self.norm_std = norm_std
53
+
54
+ @preprocessing_function
55
+ def call(self, inputs):
56
+ x = super().call(inputs)
57
+ # By default normalize using imagenet mean and std
58
+ if self.norm_mean:
59
+ x = x - self._expand_non_channel_dims(self.norm_mean, x)
60
+ if self.norm_std:
61
+ x = x / self._expand_non_channel_dims(self.norm_std, x)
62
+
63
+ return x
64
+
65
+ def get_config(self):
66
+ config = super().get_config()
67
+ config.update(
68
+ {
69
+ "norm_mean": self.norm_mean,
70
+ "norm_std": self.norm_std,
71
+ }
72
+ )
73
+ return config