keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,187 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.models.image_classifier import ImageClassifier
6
+ from keras_hub.src.models.task import Task
7
+ from keras_hub.src.models.vit.vit_backbone import ViTBackbone
8
+ from keras_hub.src.models.vit.vit_image_classifier_preprocessor import (
9
+ ViTImageClassifierPreprocessor,
10
+ )
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.ViTImageClassifier")
14
+ class ViTImageClassifier(ImageClassifier):
15
+ """ViT image classification task.
16
+
17
+ `ViTImageClassifier` tasks wrap a `keras_hub.models.ViTBackbone` and
18
+ a `keras_hub.models.Preprocessor` to create a model that can be used for
19
+ image classification. `ViTImageClassifier` tasks take an additional
20
+ `num_classes` argument, controlling the number of predicted output classes.
21
+
22
+ To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
23
+ labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
24
+
25
+ Not that unlike `keras_hub.model.ImageClassifier`, the `ViTImageClassifier`
26
+ we pluck out `cls_token` which is first seqence from the backbone.
27
+
28
+ Args:
29
+ backbone: A `keras_hub.models.ViTBackbone` instance or a `keras.Model`.
30
+ num_classes: int. The number of classes to predict.
31
+ preprocessor: `None`, a `keras_hub.models.Preprocessor` instance,
32
+ a `keras.Layer` instance, or a callable. If `None` no preprocessing
33
+ will be applied to the inputs.
34
+ pooling: String specifying the classification strategy. The choice
35
+ impacts the dimensionality and nature of the feature vector used for
36
+ classification.
37
+ `"token"`: A single vector (class token) representing the
38
+ overall image features.
39
+ `"gap"`: A single vector representing the average features
40
+ across the spatial dimensions.
41
+ intermediate_dim: Optional dimensionality of the intermediate
42
+ representation layer before the final classification layer.
43
+ If `None`, the output of the transformer is directly used.
44
+ Defaults to `None`.
45
+ activation: `None`, str, or callable. The activation function to use on
46
+ the `Dense` layer. Set `activation=None` to return the output
47
+ logits. Defaults to `"softmax"`.
48
+ head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The
49
+ dtype to use for the classification head's computations and weights.
50
+
51
+ Examples:
52
+
53
+ Call `predict()` to run inference.
54
+ ```python
55
+ # Load preset and train
56
+ images = np.random.randint(0, 256, size=(2, 224, 224, 3))
57
+ classifier = keras_hub.models.ViTImageClassifier.from_preset(
58
+ "vgg_16_imagenet"
59
+ )
60
+ classifier.predict(images)
61
+ ```
62
+
63
+ Call `fit()` on a single batch.
64
+ ```python
65
+ # Load preset and train
66
+ images = np.random.randint(0, 256, size=(2, 224, 224, 3))
67
+ labels = [0, 3]
68
+ classifier = keras_hub.models.VGGImageClassifier.from_preset(
69
+ "vit_base_patch16_224"
70
+ )
71
+ classifier.fit(x=images, y=labels, batch_size=2)
72
+ ```
73
+
74
+ Call `fit()` with custom loss, optimizer and backbone.
75
+ ```python
76
+ classifier = keras_hub.models.VGGImageClassifier.from_preset(
77
+ "vit_base_patch16_224"
78
+ )
79
+ classifier.compile(
80
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
81
+ optimizer=keras.optimizers.Adam(5e-5),
82
+ )
83
+ classifier.backbone.trainable = False
84
+ classifier.fit(x=images, y=labels, batch_size=2)
85
+ ```
86
+
87
+ Custom backbone.
88
+ ```python
89
+ images = np.random.randint(0, 256, size=(2, 224, 224, 3))
90
+ labels = [0, 3]
91
+ model = keras_hub.models.ViTBackbone(
92
+ image_shape = (224, 224, 3),
93
+ patch_size=16,
94
+ num_layers=6,
95
+ num_heads=3,
96
+ hidden_dim=768,
97
+ mlp_dim=2048
98
+ )
99
+ classifier = keras_hub.models.ViTImageClassifier(
100
+ backbone=backbone,
101
+ num_classes=4,
102
+ )
103
+ classifier.fit(x=images, y=labels, batch_size=2)
104
+ ```
105
+ """
106
+
107
+ backbone_cls = ViTBackbone
108
+ preprocessor_cls = ViTImageClassifierPreprocessor
109
+
110
+ def __init__(
111
+ self,
112
+ backbone,
113
+ num_classes,
114
+ preprocessor=None,
115
+ pooling="token",
116
+ intermediate_dim=None,
117
+ activation=None,
118
+ dropout=0.0,
119
+ head_dtype=None,
120
+ **kwargs,
121
+ ):
122
+ head_dtype = head_dtype or backbone.dtype_policy
123
+
124
+ # === Layers ===
125
+ self.backbone = backbone
126
+ self.preprocessor = preprocessor
127
+
128
+ if intermediate_dim is not None:
129
+ self.intermediate_layer = keras.layers.Dense(
130
+ intermediate_dim, activation="tanh", name="pre_logits"
131
+ )
132
+
133
+ self.dropout = keras.layers.Dropout(
134
+ rate=dropout,
135
+ dtype=head_dtype,
136
+ name="output_dropout",
137
+ )
138
+ self.output_dense = keras.layers.Dense(
139
+ num_classes,
140
+ activation=activation,
141
+ dtype=head_dtype,
142
+ name="predictions",
143
+ )
144
+
145
+ # === Functional Model ===
146
+ inputs = self.backbone.input
147
+ x = self.backbone(inputs)
148
+ if pooling == "token":
149
+ x = x[:, 0]
150
+ elif pooling == "gap":
151
+ ndim = len(ops.shape(x))
152
+ x = ops.mean(x, axis=list(range(1, ndim - 1))) # (1,) or (1,2)
153
+
154
+ if intermediate_dim is not None:
155
+ x = self.intermediate_layer(x)
156
+
157
+ x = self.dropout(x)
158
+ outputs = self.output_dense(x)
159
+
160
+ # Skip the parent class functional model.
161
+ Task.__init__(
162
+ self,
163
+ inputs=inputs,
164
+ outputs=outputs,
165
+ **kwargs,
166
+ )
167
+
168
+ # === config ===
169
+ self.num_classes = num_classes
170
+ self.pooling = pooling
171
+ self.intermediate_dim = intermediate_dim
172
+ self.activation = activation
173
+ self.dropout = dropout
174
+
175
+ def get_config(self):
176
+ # Backbone serialized in `super`
177
+ config = super().get_config()
178
+ config.update(
179
+ {
180
+ "num_classes": self.num_classes,
181
+ "pooling": self.pooling,
182
+ "intermediate_dim": self.intermediate_dim,
183
+ "activation": self.activation,
184
+ "dropout": self.dropout,
185
+ }
186
+ )
187
+ return config
@@ -0,0 +1,12 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.image_classifier_preprocessor import (
3
+ ImageClassifierPreprocessor,
4
+ )
5
+ from keras_hub.src.models.vit.vit_backbone import ViTBackbone
6
+ from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter
7
+
8
+
9
+ @keras_hub_export("keras_hub.models.ViTImageClassifierPreprocessor")
10
+ class ViTImageClassifierPreprocessor(ImageClassifierPreprocessor):
11
+ backbone_cls = ViTBackbone
12
+ image_converter_cls = ViTImageConverter
@@ -0,0 +1,73 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
+ from keras_hub.src.models.vit.vit_backbone import ViTBackbone
4
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
5
+
6
+
7
+ @keras_hub_export("keras_hub.layers.ViTImageConverter")
8
+ class ViTImageConverter(ImageConverter):
9
+ """Converts images to the format expected by a ViT model.
10
+
11
+ This layer performs image normalization using mean and standard deviation
12
+ values. By default, it uses the same normalization as the
13
+ "google/vit-large-patch16-224" model on Hugging Face:
14
+ `norm_mean=[0.5, 0.5, 0.5]` and `norm_std=[0.5, 0.5, 0.5]`
15
+ ([reference](https://huggingface.co/google/vit-large-patch16-224/blob/main/preprocessor_config.json)).
16
+ These defaults are suitable for models pretrained using this normalization.
17
+
18
+ Args:
19
+ norm_mean: list or tuple of floats. Mean values for image normalization.
20
+ Defaults to `[0.5, 0.5, 0.5]`.
21
+ norm_std: list or tuple of floats. Standard deviation values for
22
+ image normalization. Defaults to `[0.5, 0.5, 0.5]`.
23
+ **kwargs: Additional keyword arguments passed to
24
+ `keras_hub.layers.preprocessing.ImageConverter`.
25
+
26
+ Examples:
27
+ ```python
28
+ import keras
29
+ import numpy as np
30
+ from keras_hub.src.layers import ViTImageConverter
31
+
32
+ # Example image (replace with your actual image data)
33
+ image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C)
34
+
35
+ # Create a ViTImageConverter instance
36
+ converter = ViTImageConverter(
37
+ image_size=(28,28),
38
+ scale=1/255.
39
+ )
40
+ # Preprocess the image
41
+ preprocessed_image = converter(image)
42
+ ```
43
+ """
44
+
45
+ backbone_cls = ViTBackbone
46
+
47
+ def __init__(
48
+ self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs
49
+ ):
50
+ super().__init__(**kwargs)
51
+ self.norm_mean = norm_mean
52
+ self.norm_std = norm_std
53
+
54
+ @preprocessing_function
55
+ def call(self, inputs):
56
+ x = super().call(inputs)
57
+ # By default normalize using imagenet mean and std
58
+ if self.norm_mean:
59
+ x = x - self._expand_non_channel_dims(self.norm_mean, x)
60
+ if self.norm_std:
61
+ x = x / self._expand_non_channel_dims(self.norm_std, x)
62
+
63
+ return x
64
+
65
+ def get_config(self):
66
+ config = super().get_config()
67
+ config.update(
68
+ {
69
+ "norm_mean": self.norm_mean,
70
+ "norm_std": self.norm_std,
71
+ }
72
+ )
73
+ return config
@@ -0,0 +1,391 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.utils.keras_utils import standardize_data_format
5
+
6
+
7
+ class MLP(keras.layers.Layer):
8
+ """Multi-Layer Perceptron (MLP) block.
9
+
10
+ Args:
11
+ hidden_dim: int. Dimensionality of the hidden representations.
12
+ mlp_dim: int. Dimensionality of the intermediate MLP layer.
13
+ use_bias: bool. Whether to use bias in the dense layers. Defaults to
14
+ `True`.
15
+ dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`.
16
+ **kwargs: Additional keyword arguments passed to `keras.layers.Layer`
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ hidden_dim,
22
+ mlp_dim,
23
+ use_bias=True,
24
+ dropout_rate=0.0,
25
+ **kwargs,
26
+ ):
27
+ super().__init__(**kwargs)
28
+
29
+ # === Config ===
30
+ self.hidden_dim = hidden_dim
31
+ self.mlp_dim = mlp_dim
32
+ self.use_bias = use_bias
33
+ self.dropout_rate = dropout_rate
34
+
35
+ def build(self, input_shape):
36
+ self.dense_1 = keras.layers.Dense(
37
+ units=self.mlp_dim,
38
+ use_bias=self.use_bias,
39
+ activation="gelu",
40
+ bias_initializer=(
41
+ keras.initializers.RandomNormal(stddev=1e-6)
42
+ if self.use_bias
43
+ else None
44
+ ),
45
+ dtype=self.dtype_policy,
46
+ name="dense_1",
47
+ )
48
+ self.dense_1.build(input_shape)
49
+ self.dense_2 = keras.layers.Dense(
50
+ units=self.hidden_dim,
51
+ use_bias=self.use_bias,
52
+ bias_initializer=(
53
+ keras.initializers.RandomNormal(stddev=1e-6)
54
+ if self.use_bias
55
+ else None
56
+ ),
57
+ dtype=self.dtype_policy,
58
+ name="dense_2",
59
+ )
60
+ self.dense_2.build((None, None, self.mlp_dim))
61
+ self.dropout = keras.layers.Dropout(
62
+ self.dropout_rate, dtype=self.dtype_policy, name="dropout"
63
+ )
64
+ self.built = True
65
+
66
+ def call(self, inputs):
67
+ x = self.dense_1(inputs)
68
+ x = self.dense_2(x)
69
+ out = self.dropout(x)
70
+ return out
71
+
72
+
73
+ class ViTPatchingAndEmbedding(keras.layers.Layer):
74
+ """Patches the image and embeds the patches.
75
+
76
+ Args:
77
+ image_size: int. Size of the input image (height or width).
78
+ Assumed to be square.
79
+ patch_size: int. Size of each image patch.
80
+ hidden_dim: int. Dimensionality of the patch embeddings.
81
+ num_channels: int. Number of channels in the input image. Defaults to
82
+ `3`.
83
+ data_format: str. `"channels_last"` or `"channels_first"`. Defaults to
84
+ `None` (which uses `"channels_last"`).
85
+ **kwargs: Additional keyword arguments passed to `keras.layers.Layer`
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ image_size,
91
+ patch_size,
92
+ hidden_dim,
93
+ num_channels=3,
94
+ data_format=None,
95
+ **kwargs,
96
+ ):
97
+ super().__init__(**kwargs)
98
+ num_patches = (image_size // patch_size) ** 2
99
+ num_positions = num_patches + 1
100
+
101
+ # === Config ===
102
+ self.image_size = image_size
103
+ self.patch_size = patch_size
104
+ self.hidden_dim = hidden_dim
105
+ self.num_channels = num_channels
106
+ self.num_patches = num_patches
107
+ self.num_positions = num_positions
108
+ self.data_format = standardize_data_format(data_format)
109
+
110
+ def build(self, input_shape):
111
+ self.class_token = self.add_weight(
112
+ shape=(
113
+ 1,
114
+ 1,
115
+ self.hidden_dim,
116
+ ),
117
+ initializer="random_normal",
118
+ dtype=self.variable_dtype,
119
+ name="class_token",
120
+ )
121
+ self.patch_embedding = keras.layers.Conv2D(
122
+ filters=self.hidden_dim,
123
+ kernel_size=self.patch_size,
124
+ strides=self.patch_size,
125
+ padding="valid",
126
+ activation=None,
127
+ dtype=self.dtype_policy,
128
+ data_format=self.data_format,
129
+ name="patch_embedding",
130
+ )
131
+ self.patch_embedding.build(input_shape)
132
+ self.position_embedding = keras.layers.Embedding(
133
+ self.num_positions,
134
+ self.hidden_dim,
135
+ dtype=self.dtype_policy,
136
+ embeddings_initializer=keras.initializers.RandomNormal(stddev=0.02),
137
+ name="position_embedding",
138
+ )
139
+ self.position_embedding.build((1, self.num_positions))
140
+ self.position_ids = keras.ops.expand_dims(
141
+ keras.ops.arange(self.num_positions), axis=0
142
+ )
143
+ self.built = True
144
+
145
+ def call(self, inputs):
146
+ patch_embeddings = self.patch_embedding(inputs)
147
+ if self.data_format == "channels_first":
148
+ patch_embeddings = ops.transpose(
149
+ patch_embeddings, axes=(0, 2, 3, 1)
150
+ )
151
+ embeddings_shape = ops.shape(patch_embeddings)
152
+ patch_embeddings = ops.reshape(
153
+ patch_embeddings, [embeddings_shape[0], -1, embeddings_shape[-1]]
154
+ )
155
+ class_token = ops.tile(self.class_token, (embeddings_shape[0], 1, 1))
156
+ position_embeddings = self.position_embedding(self.position_ids)
157
+ embeddings = ops.concatenate([class_token, patch_embeddings], axis=1)
158
+ return ops.add(embeddings, position_embeddings)
159
+
160
+ def compute_output_shape(self, input_shape):
161
+ return (
162
+ input_shape[0],
163
+ self.num_positions,
164
+ self.hidden_dim,
165
+ )
166
+
167
+ def get_config(self):
168
+ config = super().get_config()
169
+ config.update(
170
+ {
171
+ "image_size": self.image_size,
172
+ "patch_size": self.patch_size,
173
+ "hidden_dim": self.hidden_dim,
174
+ "num_channels": self.num_channels,
175
+ "num_patches": self.num_patches,
176
+ "num_positions": self.num_positions,
177
+ }
178
+ )
179
+ return config
180
+
181
+
182
+ class ViTEncoderBlock(keras.layers.Layer):
183
+ """Transformer encoder block.
184
+
185
+ Args:
186
+ num_heads: int. Number of attention heads.
187
+ hidden_dim: int. Dimensionality of the hidden representations.
188
+ mlp_dim: int. Dimensionality of the intermediate MLP layer.
189
+ use_mha_bias: bool. Whether to use bias in the multi-head attention
190
+ layer. Defaults to `True`.
191
+ use_mlp_bias: bool. Whether to use bias in the MLP layer. Defaults to
192
+ `True`.
193
+ dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`.
194
+ attention_dropout: float. Dropout rate for the attention mechanism.
195
+ Between 0 and 1. Defaults to `0.0`.
196
+ layer_norm_epsilon: float. Small float value for layer normalization
197
+ stability. Defaults to `1e-6`.
198
+ **kwargs: Additional keyword arguments passed to `keras.layers.Layer`
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ num_heads,
204
+ hidden_dim,
205
+ mlp_dim,
206
+ use_mha_bias=True,
207
+ use_mlp_bias=True,
208
+ dropout_rate=0.0,
209
+ attention_dropout=0.0,
210
+ layer_norm_epsilon=1e-6,
211
+ **kwargs,
212
+ ):
213
+ super().__init__(**kwargs)
214
+
215
+ key_dim = hidden_dim // num_heads
216
+
217
+ # === Config ===
218
+ self.num_heads = num_heads
219
+ self.hidden_dim = hidden_dim
220
+ self.key_dim = key_dim
221
+ self.mlp_dim = mlp_dim
222
+ self.use_mha_bias = use_mha_bias
223
+ self.use_mlp_bias = use_mlp_bias
224
+ self.dropout_rate = dropout_rate
225
+ self.attention_dropout = attention_dropout
226
+ self.layer_norm_epsilon = layer_norm_epsilon
227
+
228
+ def build(self, input_shape):
229
+ # Attention block
230
+ self.layer_norm_1 = keras.layers.LayerNormalization(
231
+ epsilon=self.layer_norm_epsilon,
232
+ name="ln_1",
233
+ dtype=self.dtype_policy,
234
+ )
235
+ self.layer_norm_1.build(input_shape)
236
+ self.mha = keras.layers.MultiHeadAttention(
237
+ num_heads=self.num_heads,
238
+ key_dim=self.key_dim,
239
+ use_bias=self.use_mha_bias,
240
+ dropout=self.attention_dropout,
241
+ name="mha",
242
+ dtype=self.dtype_policy,
243
+ )
244
+ self.mha.build(input_shape, input_shape)
245
+ self.dropout = keras.layers.Dropout(
246
+ self.dropout_rate, dtype=self.dtype_policy, name="dropout"
247
+ )
248
+
249
+ # MLP block
250
+ self.layer_norm_2 = keras.layers.LayerNormalization(
251
+ epsilon=self.layer_norm_epsilon,
252
+ name="ln_2",
253
+ dtype=self.dtype_policy,
254
+ )
255
+ self.layer_norm_2.build((None, None, self.hidden_dim))
256
+ self.mlp = MLP(
257
+ hidden_dim=self.hidden_dim,
258
+ mlp_dim=self.mlp_dim,
259
+ use_bias=self.use_mlp_bias,
260
+ name="mlp",
261
+ dtype=self.dtype_policy,
262
+ )
263
+ self.mlp.build((None, None, self.hidden_dim))
264
+ self.built = True
265
+
266
+ def call(self, inputs):
267
+ x = self.layer_norm_1(inputs)
268
+ x = self.mha(x, x)
269
+ x = self.dropout(x)
270
+ x = x + inputs
271
+
272
+ y = self.layer_norm_2(x)
273
+ y = self.mlp(y)
274
+
275
+ return x + y
276
+
277
+ def get_config(self):
278
+ config = super().get_config()
279
+ config.update(
280
+ {
281
+ "num_heads": self.num_heads,
282
+ "hidden_dim": self.hidden_dim,
283
+ "key_dim": self.key_dim,
284
+ "mlp_dim": self.mlp_dim,
285
+ "use_mha_bias": self.use_mha_bias,
286
+ "use_mlp_bias": self.use_mlp_bias,
287
+ "dropout_rate": self.dropout_rate,
288
+ "attention_dropout": self.attention_dropout,
289
+ "layer_norm_epsilon": self.layer_norm_epsilon,
290
+ }
291
+ )
292
+ return config
293
+
294
+
295
+ class ViTEncoder(keras.layers.Layer):
296
+ """Vision Transformer (ViT) encoder.
297
+
298
+ Args:
299
+ num_layers: int. Number of Transformer encoder blocks.
300
+ num_heads: int. Number of attention heads.
301
+ hidden_dim: int. Dimensionality of the hidden representations.
302
+ mlp_dim: int. Dimensionality of the intermediate MLP layer.
303
+ use_mha_bias: bool. Whether to use bias in the multi-head attention
304
+ layers. Defaults to `True`.
305
+ use_mlp_bias: bool. Whether to use bias in the MLP layers. Defaults to
306
+ `True`.
307
+ dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`.
308
+ attention_dropout: float. Dropout rate for the attention mechanism.
309
+ Between 0 and 1. Defaults to `0.0`.
310
+ layer_norm_epsilon: float. Small float value for layer normalization
311
+ tability. Defaults to `1e-6`.
312
+ **kwargs: Additional keyword arguments passed to `keras.layers.Layer`
313
+ """
314
+
315
+ def __init__(
316
+ self,
317
+ num_layers,
318
+ num_heads,
319
+ hidden_dim,
320
+ mlp_dim,
321
+ use_mha_bias=True,
322
+ use_mlp_bias=True,
323
+ dropout_rate=0.0,
324
+ attention_dropout=0.0,
325
+ layer_norm_epsilon=1e-6,
326
+ **kwargs,
327
+ ):
328
+ super().__init__(**kwargs)
329
+
330
+ # === config ===
331
+ self.num_layers = num_layers
332
+ self.num_heads = num_heads
333
+ self.hidden_dim = hidden_dim
334
+ self.mlp_dim = mlp_dim
335
+ self.use_mha_bias = use_mha_bias
336
+ self.use_mlp_bias = use_mlp_bias
337
+ self.dropout_rate = dropout_rate
338
+ self.attention_dropout = attention_dropout
339
+ self.layer_norm_epsilon = layer_norm_epsilon
340
+
341
+ def build(self, input_shape):
342
+ self.encoder_layers = []
343
+ for i in range(self.num_layers):
344
+ encoder_block = ViTEncoderBlock(
345
+ num_heads=self.num_heads,
346
+ hidden_dim=self.hidden_dim,
347
+ mlp_dim=self.mlp_dim,
348
+ dropout_rate=self.dropout_rate,
349
+ use_mha_bias=self.use_mha_bias,
350
+ use_mlp_bias=self.use_mlp_bias,
351
+ attention_dropout=self.attention_dropout,
352
+ layer_norm_epsilon=self.layer_norm_epsilon,
353
+ dtype=self.dtype_policy,
354
+ name=f"tranformer_block_{i + 1}",
355
+ )
356
+ encoder_block.build((None, None, self.hidden_dim))
357
+ self.encoder_layers.append(encoder_block)
358
+ self.dropout = keras.layers.Dropout(
359
+ self.dropout_rate, dtype=self.dtype_policy, name="dropout"
360
+ )
361
+ self.layer_norm = keras.layers.LayerNormalization(
362
+ epsilon=self.layer_norm_epsilon,
363
+ dtype=self.dtype_policy,
364
+ name="ln",
365
+ )
366
+ self.layer_norm.build((None, None, self.hidden_dim))
367
+ self.built = True
368
+
369
+ def call(self, inputs):
370
+ x = self.dropout(inputs)
371
+ for i in range(self.num_layers):
372
+ x = self.encoder_layers[i](x)
373
+ x = self.layer_norm(x)
374
+ return x
375
+
376
+ def get_config(self):
377
+ config = super().get_config()
378
+ config.update(
379
+ {
380
+ "num_layers": self.num_layers,
381
+ "num_heads": self.num_heads,
382
+ "hidden_dim": self.hidden_dim,
383
+ "mlp_dim": self.mlp_dim,
384
+ "use_mha_bias": self.use_mha_bias,
385
+ "use_mlp_bias": self.use_mlp_bias,
386
+ "dropout_rate": self.dropout_rate,
387
+ "attention_dropout": self.attention_dropout,
388
+ "layer_norm_epsilon": self.layer_norm_epsilon,
389
+ }
390
+ )
391
+ return config