keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
2
+ from keras_hub.src.models.vgg.vgg_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, VGGBackbone)
@@ -20,25 +20,14 @@ class VGGBackbone(Backbone):
20
20
  stackwise_num_filters: list of ints, filter size for convolutional
21
21
  blocks per VGG block. For both VGG16 and VGG19 this is [
22
22
  64, 128, 256, 512, 512].
23
- image_shape: tuple, optional shape tuple, defaults to (224, 224, 3).
24
- pooling: bool, Optional pooling mode for feature extraction
25
- when `include_top` is `False`.
26
- - `None` means that the output of the model will be
27
- the 4D tensor output of the
28
- last convolutional block.
29
- - `avg` means that global average pooling
30
- will be applied to the output of the
31
- last convolutional block, and thus
32
- the output of the model will be a 2D tensor.
33
- - `max` means that global max pooling will
34
- be applied.
23
+ image_shape: tuple, optional shape tuple, defaults to (None, None, 3).
35
24
 
36
25
  Examples:
37
26
  ```python
38
27
  input_data = np.ones((2, 224, 224, 3), dtype="float32")
39
28
 
40
29
  # Pretrained VGG backbone.
41
- model = keras_hub.models.VGGBackbone.from_preset("vgg16")
30
+ model = keras_hub.models.VGGBackbone.from_preset("vgg_16_imagenet")
42
31
  model(input_data)
43
32
 
44
33
  # Randomly initialized VGG backbone with a custom config.
@@ -46,7 +35,6 @@ class VGGBackbone(Backbone):
46
35
  stackwise_num_repeats = [2, 2, 3, 3, 3],
47
36
  stackwise_num_filters = [64, 128, 256, 512, 512],
48
37
  image_shape = (224, 224, 3),
49
- pooling = "avg",
50
38
  )
51
39
  model(input_data)
52
40
  ```
@@ -56,16 +44,14 @@ class VGGBackbone(Backbone):
56
44
  self,
57
45
  stackwise_num_repeats,
58
46
  stackwise_num_filters,
59
- image_shape=(224, 224, 3),
60
- pooling="avg",
47
+ image_shape=(None, None, 3),
61
48
  **kwargs,
62
49
  ):
63
-
64
50
  # === Functional Model ===
65
51
  img_input = keras.layers.Input(shape=image_shape)
66
52
  x = img_input
67
53
 
68
- for stack_index in range(len(stackwise_num_repeats) - 1):
54
+ for stack_index in range(len(stackwise_num_repeats)):
69
55
  x = apply_vgg_block(
70
56
  x=x,
71
57
  num_layers=stackwise_num_repeats[stack_index],
@@ -76,10 +62,6 @@ class VGGBackbone(Backbone):
76
62
  max_pool=True,
77
63
  name=f"block{stack_index + 1}",
78
64
  )
79
- if pooling == "avg":
80
- x = layers.GlobalAveragePooling2D()(x)
81
- elif pooling == "max":
82
- x = layers.GlobalMaxPooling2D()(x)
83
65
 
84
66
  super().__init__(inputs=img_input, outputs=x, **kwargs)
85
67
 
@@ -87,14 +69,12 @@ class VGGBackbone(Backbone):
87
69
  self.stackwise_num_repeats = stackwise_num_repeats
88
70
  self.stackwise_num_filters = stackwise_num_filters
89
71
  self.image_shape = image_shape
90
- self.pooling = pooling
91
72
 
92
73
  def get_config(self):
93
74
  return {
94
75
  "stackwise_num_repeats": self.stackwise_num_repeats,
95
76
  "stackwise_num_filters": self.stackwise_num_filters,
96
77
  "image_shape": self.image_shape,
97
- "pooling": self.pooling,
98
78
  }
99
79
 
100
80
 
@@ -2,58 +2,93 @@ import keras
2
2
 
3
3
  from keras_hub.src.api_export import keras_hub_export
4
4
  from keras_hub.src.models.image_classifier import ImageClassifier
5
+ from keras_hub.src.models.task import Task
5
6
  from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
7
+ from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import (
8
+ VGGImageClassifierPreprocessor,
9
+ )
6
10
 
7
11
 
8
12
  @keras_hub_export("keras_hub.models.VGGImageClassifier")
9
13
  class VGGImageClassifier(ImageClassifier):
10
- """VGG16 image classifier task model.
14
+ """VGG image classification task.
11
15
 
12
- Args:
13
- backbone: A `keras_hub.models.VGGBackbone` instance.
14
- num_classes: int, number of classes to predict.
15
- pooling: str, type of pooling layer. Must be one of "avg", "max".
16
- activation: Optional `str` or callable, defaults to "softmax". The
17
- activation function to use on the Dense layer. Set `activation=None`
18
- to return the output logits.
16
+ `VGGImageClassifier` tasks wrap a `keras_hub.models.VGGBackbone` and
17
+ a `keras_hub.models.Preprocessor` to create a model that can be used for
18
+ image classification. `VGGImageClassifier` tasks take an additional
19
+ `num_classes` argument, controlling the number of predicted output classes.
19
20
 
20
21
  To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
21
22
  labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
22
- All `ImageClassifier` tasks include a `from_preset()` constructor which can be
23
- used to load a pre-trained config and weights.
23
+
24
+ Not that unlike `keras_hub.model.ImageClassifier`, the `VGGImageClassifier`
25
+ allows and defaults to `pooling="flatten"`, when inputs are flatten and
26
+ passed through two intermediate dense layers before the final output
27
+ projection.
28
+
29
+ Args:
30
+ backbone: A `keras_hub.models.VGGBackbone` instance or a `keras.Model`.
31
+ num_classes: int. The number of classes to predict.
32
+ preprocessor: `None`, a `keras_hub.models.Preprocessor` instance,
33
+ a `keras.Layer` instance, or a callable. If `None` no preprocessing
34
+ will be applied to the inputs.
35
+ pooling: `"flatten"`, `"avg"`, or `"max"`. The type of pooling to apply
36
+ on backbone output. The default is flatten to match the original
37
+ VGG implementation, where backbone inputs will be flattened and
38
+ passed through two dense layers with a `"relu"` activation.
39
+ pooling_hidden_dim: the output feature size of the pooling dense layers.
40
+ This only applies when `pooling="flatten"`.
41
+ activation: `None`, str, or callable. The activation function to use on
42
+ the `Dense` layer. Set `activation=None` to return the output
43
+ logits. Defaults to `"softmax"`.
44
+ head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The
45
+ dtype to use for the classification head's computations and weights.
46
+
24
47
 
25
48
  Examples:
26
- Train from preset
49
+
50
+ Call `predict()` to run inference.
51
+ ```python
52
+ # Load preset and train
53
+ images = np.random.randint(0, 256, size=(2, 224, 224, 3))
54
+ classifier = keras_hub.models.VGGImageClassifier.from_preset(
55
+ "vgg_16_imagenet"
56
+ )
57
+ classifier.predict(images)
58
+ ```
59
+
60
+ Call `fit()` on a single batch.
27
61
  ```python
28
62
  # Load preset and train
29
- images = np.ones((2, 224, 224, 3), dtype="float32")
63
+ images = np.random.randint(0, 256, size=(2, 224, 224, 3))
30
64
  labels = [0, 3]
31
65
  classifier = keras_hub.models.VGGImageClassifier.from_preset(
32
- 'vgg_16_image_classifier')
66
+ "vgg_16_imagenet"
67
+ )
33
68
  classifier.fit(x=images, y=labels, batch_size=2)
69
+ ```
34
70
 
35
- # Re-compile (e.g., with a new learning rate).
71
+ Call `fit()` with custom loss, optimizer and backbone.
72
+ ```python
73
+ classifier = keras_hub.models.VGGImageClassifier.from_preset(
74
+ "vgg_16_imagenet"
75
+ )
36
76
  classifier.compile(
37
77
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
38
78
  optimizer=keras.optimizers.Adam(5e-5),
39
- jit_compile=True,
40
79
  )
41
-
42
- # Access backbone programmatically (e.g., to change `trainable`).
43
80
  classifier.backbone.trainable = False
44
- # Fit again.
45
81
  classifier.fit(x=images, y=labels, batch_size=2)
46
82
  ```
47
- Custom backbone
83
+
84
+ Custom backbone.
48
85
  ```python
49
- images = np.ones((2, 224, 224, 3), dtype="float32")
86
+ images = np.random.randint(0, 256, size=(2, 224, 224, 3))
50
87
  labels = [0, 3]
51
-
52
- backbone = keras_hub.models.VGGBackbone(
88
+ model = keras_hub.models.VGGBackbone(
53
89
  stackwise_num_repeats = [2, 2, 3, 3, 3],
54
90
  stackwise_num_filters = [64, 128, 256, 512, 512],
55
91
  image_shape = (224, 224, 3),
56
- pooling = "avg",
57
92
  )
58
93
  classifier = keras_hub.models.VGGImageClassifier(
59
94
  backbone=backbone,
@@ -64,31 +99,95 @@ class VGGImageClassifier(ImageClassifier):
64
99
  """
65
100
 
66
101
  backbone_cls = VGGBackbone
102
+ preprocessor_cls = VGGImageClassifierPreprocessor
67
103
 
68
104
  def __init__(
69
105
  self,
70
106
  backbone,
71
107
  num_classes,
72
- activation="softmax",
73
- preprocessor=None, # adding this dummy arg for saved model test
74
- # TODO: once preprocessor flow is figured out, this needs to be updated
108
+ preprocessor=None,
109
+ pooling="avg",
110
+ pooling_hidden_dim=4096,
111
+ activation=None,
112
+ dropout=0.0,
113
+ head_dtype=None,
75
114
  **kwargs,
76
115
  ):
116
+ head_dtype = head_dtype or backbone.dtype_policy
117
+ data_format = getattr(backbone, "data_format", None)
118
+
77
119
  # === Layers ===
78
120
  self.backbone = backbone
79
- self.output_dense = keras.layers.Dense(
80
- num_classes,
81
- activation=activation,
82
- name="predictions",
121
+ self.preprocessor = preprocessor
122
+ if pooling == "avg":
123
+ self.pooler = keras.layers.GlobalAveragePooling2D(
124
+ data_format,
125
+ dtype=head_dtype,
126
+ name="pooler",
127
+ )
128
+ elif pooling == "max":
129
+ self.pooler = keras.layers.GlobalMaxPooling2D(
130
+ data_format,
131
+ dtype=head_dtype,
132
+ name="pooler",
133
+ )
134
+ elif pooling == "flatten":
135
+ self.pooler = keras.Sequential(
136
+ [
137
+ keras.layers.Flatten(name="flatten"),
138
+ keras.layers.Dense(pooling_hidden_dim, activation="relu"),
139
+ keras.layers.Dense(pooling_hidden_dim, activation="relu"),
140
+ ],
141
+ name="pooler",
142
+ )
143
+ else:
144
+ raise ValueError(
145
+ "Unknown `pooling` type. Polling should be either `'avg'` or "
146
+ f"`'max'`. Received: pooling={pooling}."
147
+ )
148
+
149
+ self.head = keras.Sequential(
150
+ [
151
+ keras.layers.Conv2D(
152
+ filters=4096,
153
+ kernel_size=7,
154
+ name="fc1",
155
+ activation=activation,
156
+ use_bias=True,
157
+ padding="same",
158
+ ),
159
+ keras.layers.Dropout(
160
+ rate=dropout,
161
+ dtype=head_dtype,
162
+ name="output_dropout",
163
+ ),
164
+ keras.layers.Conv2D(
165
+ filters=4096,
166
+ kernel_size=1,
167
+ name="fc2",
168
+ activation=activation,
169
+ use_bias=True,
170
+ padding="same",
171
+ ),
172
+ self.pooler,
173
+ keras.layers.Dense(
174
+ num_classes,
175
+ activation=activation,
176
+ dtype=head_dtype,
177
+ name="predictions",
178
+ ),
179
+ ],
180
+ name="head",
83
181
  )
84
182
 
85
183
  # === Functional Model ===
86
184
  inputs = self.backbone.input
87
185
  x = self.backbone(inputs)
88
- outputs = self.output_dense(x)
186
+ outputs = self.head(x)
89
187
 
90
- # Instantiate using Functional API Model constructor
91
- super().__init__(
188
+ # Skip the parent class functional model.
189
+ Task.__init__(
190
+ self,
92
191
  inputs=inputs,
93
192
  outputs=outputs,
94
193
  **kwargs,
@@ -97,6 +196,10 @@ class VGGImageClassifier(ImageClassifier):
97
196
  # === Config ===
98
197
  self.num_classes = num_classes
99
198
  self.activation = activation
199
+ self.pooling = pooling
200
+ self.pooling_hidden_dim = pooling_hidden_dim
201
+ self.dropout = dropout
202
+ self.preprocessor = preprocessor
100
203
 
101
204
  def get_config(self):
102
205
  # Backbone serialized in `super`
@@ -104,7 +207,10 @@ class VGGImageClassifier(ImageClassifier):
104
207
  config.update(
105
208
  {
106
209
  "num_classes": self.num_classes,
210
+ "pooling": self.pooling,
107
211
  "activation": self.activation,
212
+ "pooling_hidden_dim": self.pooling_hidden_dim,
213
+ "dropout": self.dropout,
108
214
  }
109
215
  )
110
216
  return config
@@ -0,0 +1,12 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.image_classifier_preprocessor import (
3
+ ImageClassifierPreprocessor,
4
+ )
5
+ from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
6
+ from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter
7
+
8
+
9
+ @keras_hub_export("keras_hub.models.VGGImageClassifierPreprocessor")
10
+ class VGGImageClassifierPreprocessor(ImageClassifierPreprocessor):
11
+ backbone_cls = VGGBackbone
12
+ image_converter_cls = VGGImageConverter
@@ -0,0 +1,8 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
+ from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
4
+
5
+
6
+ @keras_hub_export("keras_hub.layers.VGGImageConverter")
7
+ class VGGImageConverter(ImageConverter):
8
+ backbone_cls = VGGBackbone
@@ -0,0 +1,48 @@
1
+ """vgg preset configurations."""
2
+
3
+ backbone_presets = {
4
+ "vgg_11_imagenet": {
5
+ "metadata": {
6
+ "description": (
7
+ "11-layer vgg model pre-trained on the ImageNet 1k dataset "
8
+ "at a 224x224 resolution."
9
+ ),
10
+ "params": 9220480,
11
+ "path": "vgg",
12
+ },
13
+ "kaggle_handle": "kaggle://keras/vgg/keras/vgg_11_imagenet/2",
14
+ },
15
+ "vgg_13_imagenet": {
16
+ "metadata": {
17
+ "description": (
18
+ "13-layer vgg model pre-trained on the ImageNet 1k dataset "
19
+ "at a 224x224 resolution."
20
+ ),
21
+ "params": 9404992,
22
+ "path": "vgg",
23
+ },
24
+ "kaggle_handle": "kaggle://keras/vgg/keras/vgg_13_imagenet/2",
25
+ },
26
+ "vgg_16_imagenet": {
27
+ "metadata": {
28
+ "description": (
29
+ "16-layer vgg model pre-trained on the ImageNet 1k dataset "
30
+ "at a 224x224 resolution."
31
+ ),
32
+ "params": 14714688,
33
+ "path": "vgg",
34
+ },
35
+ "kaggle_handle": "kaggle://keras/vgg/keras/vgg_16_imagenet/2",
36
+ },
37
+ "vgg_19_imagenet": {
38
+ "metadata": {
39
+ "description": (
40
+ "19-layer vgg model pre-trained on the ImageNet 1k dataset "
41
+ "at a 224x224 resolution."
42
+ ),
43
+ "params": 20024384,
44
+ "path": "vgg",
45
+ },
46
+ "kaggle_handle": "kaggle://keras/vgg/keras/vgg_19_imagenet/2",
47
+ },
48
+ }
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.vit.vit_backbone import ViTBackbone
2
+ from keras_hub.src.models.vit.vit_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, ViTBackbone)
@@ -0,0 +1,152 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.backbone import Backbone
5
+ from keras_hub.src.models.vit.vit_layers import ViTEncoder
6
+ from keras_hub.src.models.vit.vit_layers import ViTPatchingAndEmbedding
7
+ from keras_hub.src.utils.keras_utils import standardize_data_format
8
+
9
+
10
+ @keras_hub_export("keras_hub.models.ViTBackbone")
11
+ class ViTBackbone(Backbone):
12
+ """Vision Transformer (ViT) backbone.
13
+
14
+ This backbone implements the Vision Transformer architecture as described in
15
+ [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929).
16
+ It transforms the input image into a sequence of patches, embeds them, and
17
+ then processes them through a series of Transformer encoder layers.
18
+
19
+ Args:
20
+ image_shape: A tuple or list of 3 integers representing the shape of the
21
+ input image `(height, width, channels)`, `height` and `width` must
22
+ be equal.
23
+ patch_size: int. The size of each image patch, the input image will be
24
+ divided into patches of shape `(patch_size, patch_size)`.
25
+ num_layers: int. The number of transformer encoder layers.
26
+ num_heads: int. specifying the number of attention heads in each
27
+ Transformer encoder layer.
28
+ hidden_dim: int. The dimensionality of the hidden representations.
29
+ mlp_dim: int. The dimensionality of the intermediate MLP layer in
30
+ each Transformer encoder layer.
31
+ dropout_rate: float. The dropout rate for the Transformer encoder
32
+ layers.
33
+ attention_dropout: float. The dropout rate for the attention mechanism
34
+ in each Transformer encoder layer.
35
+ layer_norm_epsilon: float. Value used for numerical stability in
36
+ layer normalization.
37
+ use_mha_bias: bool. Whether to use bias in the multi-head
38
+ attention layers.
39
+ use_mlp_bias: bool. Whether to use bias in the MLP layers.
40
+ data_format: str. `"channels_last"` or `"channels_first"`, specifying
41
+ the data format for the input image. If `None`, defaults to
42
+ `"channels_last"`.
43
+ dtype: The dtype of the layer weights. Defaults to None.
44
+ **kwargs: Additional keyword arguments to be passed to the parent
45
+ `Backbone` class.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ image_shape,
51
+ patch_size,
52
+ num_layers,
53
+ num_heads,
54
+ hidden_dim,
55
+ mlp_dim,
56
+ dropout_rate=0.0,
57
+ attention_dropout=0.0,
58
+ layer_norm_epsilon=1e-6,
59
+ use_mha_bias=True,
60
+ use_mlp_bias=True,
61
+ data_format=None,
62
+ dtype=None,
63
+ **kwargs,
64
+ ):
65
+ # === Laters ===
66
+ data_format = standardize_data_format(data_format)
67
+ h_axis, w_axis, channels_axis = (
68
+ (-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3)
69
+ )
70
+ # Check that the input image is well specified.
71
+ if image_shape[h_axis] is None or image_shape[w_axis] is None:
72
+ raise ValueError(
73
+ f"Image shape must have defined height and width. Found `None` "
74
+ f"at index {h_axis} (height) or {w_axis} (width). "
75
+ f"Image shape: {image_shape}"
76
+ )
77
+ if image_shape[h_axis] != image_shape[w_axis]:
78
+ raise ValueError(
79
+ f"Image height and width must be equal. Found height: "
80
+ f"{image_shape[h_axis]}, width: {image_shape[w_axis]} at "
81
+ f"indices {h_axis} and {w_axis} respectively. Image shape: "
82
+ f"{image_shape}"
83
+ )
84
+
85
+ num_channels = image_shape[channels_axis]
86
+
87
+ # === Functional Model ===
88
+ inputs = keras.layers.Input(shape=image_shape)
89
+
90
+ x = ViTPatchingAndEmbedding(
91
+ image_size=image_shape[h_axis],
92
+ patch_size=patch_size,
93
+ hidden_dim=hidden_dim,
94
+ num_channels=num_channels,
95
+ data_format=data_format,
96
+ dtype=dtype,
97
+ name="vit_patching_and_embedding",
98
+ )(inputs)
99
+
100
+ output = ViTEncoder(
101
+ num_layers=num_layers,
102
+ num_heads=num_heads,
103
+ hidden_dim=hidden_dim,
104
+ mlp_dim=mlp_dim,
105
+ dropout_rate=dropout_rate,
106
+ attention_dropout=attention_dropout,
107
+ layer_norm_epsilon=layer_norm_epsilon,
108
+ use_mha_bias=use_mha_bias,
109
+ use_mlp_bias=use_mlp_bias,
110
+ dtype=dtype,
111
+ name="vit_encoder",
112
+ )(x)
113
+
114
+ super().__init__(
115
+ inputs=inputs,
116
+ outputs=output,
117
+ dtype=dtype,
118
+ **kwargs,
119
+ )
120
+
121
+ # === Config ===
122
+ self.image_shape = image_shape
123
+ self.patch_size = patch_size
124
+ self.num_layers = num_layers
125
+ self.num_heads = num_heads
126
+ self.hidden_dim = hidden_dim
127
+ self.mlp_dim = mlp_dim
128
+ self.dropout_rate = dropout_rate
129
+ self.attention_dropout = attention_dropout
130
+ self.layer_norm_epsilon = layer_norm_epsilon
131
+ self.use_mha_bias = use_mha_bias
132
+ self.use_mlp_bias = use_mlp_bias
133
+ self.data_format = data_format
134
+
135
+ def get_config(self):
136
+ config = super().get_config()
137
+ config.update(
138
+ {
139
+ "image_shape": self.image_shape,
140
+ "patch_size": self.patch_size,
141
+ "num_layers": self.num_layers,
142
+ "num_heads": self.num_heads,
143
+ "hidden_dim": self.hidden_dim,
144
+ "mlp_dim": self.mlp_dim,
145
+ "dropout_rate": self.dropout_rate,
146
+ "attention_dropout": self.attention_dropout,
147
+ "layer_norm_epsilon": self.layer_norm_epsilon,
148
+ "use_mha_bias": self.use_mha_bias,
149
+ "use_mlp_bias": self.use_mlp_bias,
150
+ }
151
+ )
152
+ return config