keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,139 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # https://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+ """MiT model preset configurations."""
13
+
14
+ backbone_presets_with_weights = {
15
+ "mit_b0_ade20k_512": {
16
+ "metadata": {
17
+ "description": (
18
+ "MiT (MixTransformer) model with 8 transformer blocks."
19
+ ),
20
+ "params": 3321962,
21
+ "path": "mit",
22
+ },
23
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b0_ade20k_512/4",
24
+ },
25
+ "mit_b1_ade20k_512": {
26
+ "metadata": {
27
+ "description": (
28
+ "MiT (MixTransformer) model with 8 transformer blocks."
29
+ ),
30
+ "params": 13156554,
31
+ "path": "mit",
32
+ },
33
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b1_ade20k_512/4",
34
+ },
35
+ "mit_b2_ade20k_512": {
36
+ "metadata": {
37
+ "description": (
38
+ "MiT (MixTransformer) model with 16 transformer blocks."
39
+ ),
40
+ "params": 24201418,
41
+ "path": "mit",
42
+ },
43
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b2_ade20k_512/4",
44
+ },
45
+ "mit_b3_ade20k_512": {
46
+ "metadata": {
47
+ "description": (
48
+ "MiT (MixTransformer) model with 28 transformer blocks."
49
+ ),
50
+ "params": 44077258,
51
+ "path": "mit",
52
+ },
53
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b3_ade20k_512/3",
54
+ },
55
+ "mit_b4_ade20k_512": {
56
+ "metadata": {
57
+ "description": (
58
+ "MiT (MixTransformer) model with 41 transformer blocks."
59
+ ),
60
+ "params": 60847818,
61
+ "path": "mit",
62
+ },
63
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b4_ade20k_512/3",
64
+ },
65
+ "mit_b5_ade20k_640": {
66
+ "metadata": {
67
+ "description": (
68
+ "MiT (MixTransformer) model with 52 transformer blocks."
69
+ ),
70
+ "params": 81448138,
71
+ "path": "mit",
72
+ },
73
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b5_ade20k_640/3",
74
+ },
75
+ "mit_b0_cityscapes_1024": {
76
+ "metadata": {
77
+ "description": (
78
+ "MiT (MixTransformer) model with 8 transformer blocks."
79
+ ),
80
+ "params": 3321962,
81
+ "path": "mit",
82
+ },
83
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b0_cityscapes_1024/3",
84
+ },
85
+ "mit_b1_cityscapes_1024": {
86
+ "metadata": {
87
+ "description": (
88
+ "MiT (MixTransformer) model with 8 transformer blocks."
89
+ ),
90
+ "params": 13156554,
91
+ "path": "mit",
92
+ },
93
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b1_cityscapes_1024/3",
94
+ },
95
+ "mit_b2_cityscapes_1024": {
96
+ "metadata": {
97
+ "description": (
98
+ "MiT (MixTransformer) model with 16 transformer blocks."
99
+ ),
100
+ "params": 24201418,
101
+ "path": "mit",
102
+ },
103
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b2_cityscapes_1024/3",
104
+ },
105
+ "mit_b3_cityscapes_1024": {
106
+ "metadata": {
107
+ "description": (
108
+ "MiT (MixTransformer) model with 28 transformer blocks."
109
+ ),
110
+ "params": 44077258,
111
+ "path": "mit",
112
+ },
113
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b3_cityscapes_1024/3",
114
+ },
115
+ "mit_b4_cityscapes_1024": {
116
+ "metadata": {
117
+ "description": (
118
+ "MiT (MixTransformer) model with 41 transformer blocks."
119
+ ),
120
+ "params": 60847818,
121
+ "path": "mit",
122
+ },
123
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b4_cityscapes_1024/3",
124
+ },
125
+ "mit_b5_cityscapes_1024": {
126
+ "metadata": {
127
+ "description": (
128
+ "MiT (MixTransformer) model with 52 transformer blocks."
129
+ ),
130
+ "params": 81448138,
131
+ "path": "mit",
132
+ },
133
+ "kaggle_handle": "kaggle://keras/mit/keras/mit_b5_cityscapes_1024/3",
134
+ },
135
+ }
136
+
137
+ backbone_presets = {
138
+ **backbone_presets_with_weights,
139
+ }
@@ -47,11 +47,11 @@ class MobileNetBackbone(Backbone):
47
47
  of filters in each layer.
48
48
  - If `depth_multiplier` > 1.0, proportionally increases the number
49
49
  of filters in each layer.
50
- - If `depth_multiplier` = 1, default number of filters from the paper
51
- are used at each layer.
50
+ - If `depth_multiplier` = 1, default number of filters from the
51
+ paper are used at each layer.
52
52
  input_num_filters: number of filters in first convolution layer
53
- output_num_filters: specifies whether to add conv and batch_norm in the end,
54
- if set to None, it will not add these layers in the end.
53
+ output_num_filters: specifies whether to add conv and batch_norm in the
54
+ end, if set to None, it will not add these layers in the end.
55
55
  'None' for MobileNetV1
56
56
  input_activation: activation function to be used in the input layer
57
57
  'hard_swish' for MobileNetV3,
@@ -96,7 +96,7 @@ class MobileNetBackbone(Backbone):
96
96
  stackwise_activation,
97
97
  output_num_filters,
98
98
  inverted_res_block,
99
- image_shape=(224, 224, 3),
99
+ image_shape=(None, None, 3),
100
100
  input_activation="hard_swish",
101
101
  output_activation="hard_swish",
102
102
  depth_multiplier=1.0,
@@ -365,7 +365,7 @@ def apply_depthwise_conv_block(
365
365
  batch normalization and relu6 activation.
366
366
 
367
367
  Args:
368
- x: Input tensor of shape `(rows, cols, channels)
368
+ x: Input tensor of shape `(rows, cols, channels)`
369
369
  filters: Integer, the dimensionality of the output space
370
370
  (i.e. the number of output filters in the pointwise convolution).
371
371
  depth_multiplier: controls the width of the network.
@@ -383,8 +383,8 @@ def apply_depthwise_conv_block(
383
383
  block_id: Integer, a unique identification designating the block number.
384
384
 
385
385
  Input shape:
386
- 4D tensor with shape: `(batch, rows, cols, channels)` in "channels_last"
387
- 4D tensor with shape: `(batch, channels, rows, cols)` in "channels_first"
386
+ 4D tensor with shape `(batch, rows, cols, channels)` in "channels_last"
387
+ 4D tensor with shape `(batch, channels, rows, cols)` in "channels_first"
388
388
  Returns:
389
389
  Output tensor of block.
390
390
  """
@@ -1,5 +1,3 @@
1
- import keras
2
-
3
1
  from keras_hub.src.api_export import keras_hub_export
4
2
  from keras_hub.src.models.image_classifier import ImageClassifier
5
3
  from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone
@@ -7,94 +5,4 @@ from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone
7
5
 
8
6
  @keras_hub_export("keras_hub.models.MobileNetImageClassifier")
9
7
  class MobileNetImageClassifier(ImageClassifier):
10
- """MobileNetV3 image classifier task model.
11
-
12
- To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
13
- where `x` is a tensor and `y` is a integer from `[0, num_classes)`.
14
- All `ImageClassifier` tasks include a `from_preset()` constructor which can
15
- be used to load a pre-trained config and weights.
16
-
17
- Args:
18
- backbone: A `keras_hub.models.MobileNetBackbone` instance.
19
- num_classes: int. The number of classes to predict.
20
- activation: `None`, str or callable. The activation function to use on
21
- the `Dense` layer. Set `activation=None` to return the output
22
- logits. Defaults to `"softmax"`.
23
-
24
- Examples:
25
-
26
- Call `predict()` to run inference.
27
- ```python
28
- # Load preset and train
29
- images = np.ones((2, 224, 224, 3), dtype="float32")
30
- classifier = keras_hub.models.MobileNetImageClassifier.from_preset(
31
- "mobilenet_v3_small_imagenet")
32
- classifier.predict(images)
33
- ```
34
-
35
- Custom backbone.
36
- ```python
37
- images = np.ones((2, 224, 224, 3), dtype="float32")
38
- labels = [0, 3]
39
- model = MobileNetBackbone(
40
- stackwise_expansion = [1, 4, 6],
41
- stackwise_filters = [4, 8, 16],
42
- stackwise_kernel_size = [3, 3, 5],
43
- stackwise_stride = [2, 2, 1],
44
- stackwise_se_ratio = [ 0.25, None, 0.25],
45
- stackwise_activation = ["relu", "relu", "hard_swish"],
46
- output_filter=1280,
47
- activation="hard_swish",
48
- inverted_res_block=True,
49
- )
50
- classifier = keras_hub.models.MobileNetImageClassifier(
51
- backbone=backbone,
52
- num_classes=4,
53
- )
54
- classifier.fit(x=images, y=labels, batch_size=2)
55
- ```
56
- """
57
-
58
8
  backbone_cls = MobileNetBackbone
59
-
60
- def __init__(
61
- self,
62
- backbone,
63
- num_classes,
64
- activation="softmax",
65
- preprocessor=None, # adding this dummy arg for saved model test
66
- # TODO: once preprocessor flow is figured out, this needs to be updated
67
- **kwargs,
68
- ):
69
- # === Layers ===
70
- self.backbone = backbone
71
- self.output_dense = keras.layers.Dense(
72
- num_classes,
73
- activation=activation,
74
- name="predictions",
75
- )
76
-
77
- # === Functional Model ===
78
- inputs = self.backbone.input
79
- x = self.backbone(inputs)
80
- outputs = self.output_dense(x)
81
- super().__init__(
82
- inputs=inputs,
83
- outputs=outputs,
84
- **kwargs,
85
- )
86
-
87
- # === Config ===
88
- self.num_classes = num_classes
89
- self.activation = activation
90
-
91
- def get_config(self):
92
- # Backbone serialized in `super`
93
- config = super().get_config()
94
- config.update(
95
- {
96
- "num_classes": self.num_classes,
97
- "activation": self.activation,
98
- }
99
- )
100
- return config
@@ -171,8 +171,8 @@ class OPTCausalLM(CausalLM):
171
171
  Args:
172
172
  token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
173
173
  cache: a dense float Tensor, the cache of key and value.
174
- cache_update_index: int, or int Tensor. The index of current inputs in the
175
- whole sequence.
174
+ cache_update_index: int, or int Tensor. The index of current inputs
175
+ in the whole sequence.
176
176
 
177
177
  Returns:
178
178
  A (logits, hidden_states, cache) tuple. Where `logits` is the
@@ -9,11 +9,9 @@ backbone_presets = {
9
9
  "BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
10
10
  ),
11
11
  "params": 125237760,
12
- "official_name": "OPT",
13
12
  "path": "opt",
14
- "model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
15
13
  },
16
- "kaggle_handle": "kaggle://keras/opt/keras/opt_125m_en/2",
14
+ "kaggle_handle": "kaggle://keras/opt/keras/opt_125m_en/3",
17
15
  },
18
16
  # We skip the 350m checkpoint because it does not match the structure of
19
17
  # other checkpoints.
@@ -24,11 +22,9 @@ backbone_presets = {
24
22
  "BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
25
23
  ),
26
24
  "params": 1315753984,
27
- "official_name": "OPT",
28
25
  "path": "opt",
29
- "model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
30
26
  },
31
- "kaggle_handle": "kaggle://keras/opt/keras/opt_1.3b_en/2",
27
+ "kaggle_handle": "kaggle://keras/opt/keras/opt_1.3b_en/3",
32
28
  },
33
29
  "opt_2.7b_en": {
34
30
  "metadata": {
@@ -37,11 +33,9 @@ backbone_presets = {
37
33
  "BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
38
34
  ),
39
35
  "params": 2700000000,
40
- "official_name": "OPT",
41
36
  "path": "opt",
42
- "model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
43
37
  },
44
- "kaggle_handle": "kaggle://keras/opt/keras/opt_2.7b_en/2",
38
+ "kaggle_handle": "kaggle://keras/opt/keras/opt_2.7b_en/3",
45
39
  },
46
40
  "opt_6.7b_en": {
47
41
  "metadata": {
@@ -50,10 +44,8 @@ backbone_presets = {
50
44
  "BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
51
45
  ),
52
46
  "params": 6700000000,
53
- "official_name": "OPT",
54
47
  "path": "opt",
55
- "model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
56
48
  },
57
- "kaggle_handle": "kaggle://keras/opt/keras/opt_6.7b_en/2",
49
+ "kaggle_handle": "kaggle://keras/opt/keras/opt_6.7b_en/3",
58
50
  },
59
51
  }
@@ -48,24 +48,40 @@ class PaliGemmaBackbone(Backbone):
48
48
  a two-layer feedforward network for each transformer decoder block.
49
49
  head_dim: int. The size of each attention head in the mixed decoder.
50
50
  vit_patch_size: int. The size of each square patch in the input image.
51
- vit_num_heads: int. The number of attention heads for the vision(image)
51
+ vit_num_heads: int. The number of attention heads for the vision (image)
52
52
  transformer encoder.
53
53
  vit_hidden_dim: int. The size of the transformer hidden state at the end
54
54
  of each vision transformer layer.
55
55
  vit_num_layers: int. The number of vision transformer layers.
56
56
  vit_intermediate_dim: int. The output dimension of the first Dense layer
57
- in a two-layer feedforward network for vision transformer.
58
- vit_pooling: string. The encoded vision embeddings are pooled using the
59
- specified polling setting. The accepted values are `"map"`, `"gap"`,
60
- `"0"` or `"none"`. Defaults to `"none"`.
57
+ in a two-layer feedforward network for vision transformer. Defaults
58
+ to `4304`.
59
+ vit_pooling: `None` or string. The encoded vision embeddings are pooled
60
+ using the specified polling setting. The accepted values are
61
+ `"map"`, `"gap"`, `"0"` or `None`. Defaults to `None`.
61
62
  vit_classifier_activation: activation function. The activation that
62
63
  is used for final output classification in the vision transformer.
64
+ Defaults to `None`.
63
65
  vit_name: string. The name used for vision transformer layers.
64
- include_rescaling: bool. If true, the image input will be rescaled from
65
- the range `[0, 255]`, to the range `[0, 1]`.
66
+ query_head_dim_normalize: boolean. If `True` normalize the query before
67
+ attention with `head_dim`. If `False`, normalize the query with
68
+ `hidden_dim / num_query_heads`. Defaults to `True`.
69
+ use_post_ffw_norm: boolean. Whether to normalize after the feedforward
70
+ block. Defaults to `False`.
71
+ use_post_attention_norm: boolean. Whether to normalize after the
72
+ attention block. Defaults to `False`.
73
+ attention_logit_soft_cap: `None` or int. Soft cap for the attention
74
+ logits. Defaults to `None`.
75
+ final_logit_soft_cap: `None` or int. Soft cap for the final logits.
76
+ Defaults to `None`.
77
+ use_sliding_window_attention: boolean. Whether to use sliding local
78
+ window attention. Defaults to `False`.
79
+ sliding_window_size: int. Size of the sliding local window. Defaults to
80
+ `4096`.
66
81
  layer_norm_epsilon: float. The epsilon value user for every layer norm
67
- in all transformer blocks.
82
+ in all transformer blocks. Defaults to `1e-6`.
68
83
  dropout: float. Dropout probability for the Transformer decoder blocks.
84
+ Defaults to `0`.
69
85
  dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
70
86
  for the models computations and weights. Note that some
71
87
  computations, such as softmax and layer normalization will always
@@ -121,7 +137,13 @@ class PaliGemmaBackbone(Backbone):
121
137
  vit_pooling=None,
122
138
  vit_classifier_activation=None,
123
139
  vit_name=None,
124
- include_rescaling=True,
140
+ query_head_dim_normalize=True,
141
+ use_post_ffw_norm=False,
142
+ use_post_attention_norm=False,
143
+ attention_logit_soft_cap=None,
144
+ final_logit_soft_cap=None,
145
+ use_sliding_window_attention=False,
146
+ sliding_window_size=4096,
125
147
  layer_norm_epsilon=1e-6,
126
148
  dropout=0,
127
149
  dtype=None,
@@ -139,13 +161,13 @@ class PaliGemmaBackbone(Backbone):
139
161
  seed=None,
140
162
  ),
141
163
  dtype=dtype,
164
+ logit_soft_cap=final_logit_soft_cap,
142
165
  name="token_embedding",
143
166
  )
144
167
  # TODO Remove this. Work around for previous serialization bug.
145
168
  vit_intermediate_dim = vit_intermediate_dim or 4304
146
169
  self.vit_encoder = PaliGemmaVit(
147
170
  image_size=image_size,
148
- include_rescaling=include_rescaling,
149
171
  patch_size=vit_patch_size,
150
172
  num_heads=vit_num_heads,
151
173
  hidden_dim=vit_hidden_dim,
@@ -159,12 +181,19 @@ class PaliGemmaBackbone(Backbone):
159
181
  )
160
182
  self.transformer_layers = []
161
183
  for i in range(num_layers):
184
+ sliding_window = use_sliding_window_attention and (i % 2 == 0)
162
185
  layer = PaliGemmaDecoderBlock(
163
186
  hidden_dim=hidden_dim,
164
187
  intermediate_dim=intermediate_dim,
165
- num_query_heads=num_query_heads,
166
188
  head_dim=head_dim,
189
+ num_query_heads=num_query_heads,
167
190
  num_key_value_heads=num_key_value_heads,
191
+ query_head_dim_normalize=query_head_dim_normalize,
192
+ use_post_ffw_norm=use_post_ffw_norm,
193
+ use_post_attention_norm=use_post_attention_norm,
194
+ logit_soft_cap=attention_logit_soft_cap,
195
+ use_sliding_window_attention=sliding_window,
196
+ sliding_window_size=sliding_window_size,
168
197
  dropout=dropout,
169
198
  dtype=dtype,
170
199
  name=f"decoder_block_{i}",
@@ -177,7 +206,9 @@ class PaliGemmaBackbone(Backbone):
177
206
  )
178
207
 
179
208
  # === Functional Model ===
180
- image_input = self.vit_encoder.inputs[0]
209
+ image_input = keras.Input(
210
+ shape=(image_size, image_size, 3), name="images"
211
+ )
181
212
  token_id_input = keras.Input(
182
213
  shape=(None,), dtype="int32", name="token_ids"
183
214
  )
@@ -215,7 +246,6 @@ class PaliGemmaBackbone(Backbone):
215
246
  # === Config ===
216
247
  self.vocabulary_size = vocabulary_size
217
248
  self.image_size = image_size
218
- self.include_rescaling = include_rescaling
219
249
  self.num_layers = num_layers
220
250
  self.num_query_heads = num_query_heads
221
251
  self.num_key_value_heads = num_key_value_heads
@@ -224,7 +254,15 @@ class PaliGemmaBackbone(Backbone):
224
254
  self.head_dim = head_dim
225
255
  self.layer_norm_epsilon = layer_norm_epsilon
226
256
  self.dropout = dropout
227
- # VIT Params
257
+ # Gemma2 params
258
+ self.query_head_dim_normalize = query_head_dim_normalize
259
+ self.use_post_ffw_norm = use_post_ffw_norm
260
+ self.use_post_attention_norm = use_post_attention_norm
261
+ self.attention_logit_soft_cap = attention_logit_soft_cap
262
+ self.final_logit_soft_cap = final_logit_soft_cap
263
+ self.sliding_window_size = sliding_window_size
264
+ self.use_sliding_window_attention = use_sliding_window_attention
265
+ # ViT params
228
266
  self.vit_patch_size = vit_patch_size
229
267
  self.vit_num_heads = vit_num_heads
230
268
  self.vit_hidden_dim = vit_hidden_dim
@@ -242,15 +280,12 @@ class PaliGemmaBackbone(Backbone):
242
280
  {
243
281
  "vocabulary_size": self.vocabulary_size,
244
282
  "image_size": self.image_size,
245
- "include_rescaling": self.include_rescaling,
246
283
  "num_layers": self.num_layers,
247
284
  "num_query_heads": self.num_query_heads,
248
285
  "num_key_value_heads": self.num_key_value_heads,
249
286
  "hidden_dim": self.hidden_dim,
250
287
  "intermediate_dim": self.intermediate_dim,
251
288
  "head_dim": self.head_dim,
252
- "layer_norm_epsilon": self.layer_norm_epsilon,
253
- "dropout": self.dropout,
254
289
  "vit_patch_size": self.vit_patch_size,
255
290
  "vit_num_heads": self.vit_num_heads,
256
291
  "vit_hidden_dim": self.vit_hidden_dim,
@@ -259,6 +294,17 @@ class PaliGemmaBackbone(Backbone):
259
294
  "vit_pooling": self.vit_pooling,
260
295
  "vit_classifier_activation": self.vit_classifier_activation,
261
296
  "vit_name": self.vit_name,
297
+ "query_head_dim_normalize": self.query_head_dim_normalize,
298
+ "use_post_ffw_norm": self.use_post_ffw_norm,
299
+ "use_post_attention_norm": self.use_post_attention_norm,
300
+ "final_logit_soft_cap": self.final_logit_soft_cap,
301
+ "attention_logit_soft_cap": self.attention_logit_soft_cap,
302
+ "sliding_window_size": self.sliding_window_size,
303
+ "use_sliding_window_attention": (
304
+ self.use_sliding_window_attention
305
+ ),
306
+ "layer_norm_epsilon": self.layer_norm_epsilon,
307
+ "dropout": self.dropout,
262
308
  }
263
309
  )
264
310
  return config
@@ -110,7 +110,9 @@ class PaliGemmaCausalLM(CausalLM):
110
110
  self.backbone = backbone
111
111
 
112
112
  # === Functional Model ===
113
- inputs = backbone.inputs
113
+ # This must be "backbone.input" i.e. the full input structure,
114
+ # rather than "backbone.inputs" which is the flattened list of inputs.
115
+ inputs = backbone.input
114
116
  hidden_state = backbone(inputs=inputs)
115
117
  outputs = backbone.token_embedding(hidden_state, reverse=True)
116
118
  outputs = outputs[:, backbone.image_sequence_length :, :]
@@ -31,33 +31,25 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
31
31
  the attention layer.
32
32
  num_key_value_heads: int. The number of heads for the key and value
33
33
  projections in the attention layer.
34
+ query_head_dim_normalize: boolean. If `True` normalize the query before
35
+ attention with `head_dim`. If `False`, normalize the query with
36
+ `hidden_dim / num_query_heads`. Defaults to `True`.
37
+ use_post_ffw_norm: boolean. Whether to normalize after the feedforward
38
+ block. Defaults to `False`.
39
+ use_post_attention_norm: boolean. Whether to normalize after the
40
+ attention block. Defaults to `False`.
41
+ logit_soft_cap: `None` or int. Soft cap for the attention logits.
42
+ Defaults to `None`.
43
+ use_sliding_window_attention: boolean. Whether to use sliding local
44
+ window attention. Defaults to `False`.
45
+ sliding_window_size: int. Size of the sliding local window. Defaults to
46
+ `4096`.
34
47
  layer_norm_epsilon: float. The epsilon hyperparameter used for layer
35
- normalization.
48
+ normalization. Defaults to `1e-6`.
36
49
  dropout: float. The dropout rate for the transformer attention layer.
50
+ Defaults to `0`.
37
51
  """
38
52
 
39
- def __init__(
40
- self,
41
- hidden_dim,
42
- intermediate_dim,
43
- head_dim,
44
- num_query_heads,
45
- num_key_value_heads,
46
- layer_norm_epsilon=1e-6,
47
- dropout=0,
48
- **kwargs,
49
- ):
50
- super().__init__(
51
- hidden_dim=hidden_dim,
52
- intermediate_dim=intermediate_dim,
53
- head_dim=head_dim,
54
- num_query_heads=num_query_heads,
55
- num_key_value_heads=num_key_value_heads,
56
- layer_norm_epsilon=layer_norm_epsilon,
57
- dropout=dropout,
58
- **kwargs,
59
- )
60
-
61
53
  def call(
62
54
  self,
63
55
  x,
@@ -83,6 +75,9 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
83
75
  attention_mask=attention_mask,
84
76
  )
85
77
 
78
+ if self.use_post_attention_norm:
79
+ attention = self.post_attention_norm(attention)
80
+
86
81
  if self.dropout:
87
82
  attention = self.attention_dropout(attention)
88
83
 
@@ -94,6 +89,9 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
94
89
  x = keras.activations.gelu(x1, approximate=True) * x2
95
90
  x = self.ffw_linear(x)
96
91
 
92
+ if self.use_post_ffw_norm:
93
+ x = self.post_ffw_norm(x)
94
+
97
95
  x = x + attention_x
98
96
 
99
97
  if cache is not None:
@@ -1,12 +1,10 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
- from keras_hub.src.layers.preprocessing.resizing_image_converter import (
3
- ResizingImageConverter,
4
- )
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
5
3
  from keras_hub.src.models.pali_gemma.pali_gemma_backbone import (
6
4
  PaliGemmaBackbone,
7
5
  )
8
6
 
9
7
 
10
8
  @keras_hub_export("keras_hub.layers.PaliGemmaImageConverter")
11
- class PaliGemmaImageConverter(ResizingImageConverter):
9
+ class PaliGemmaImageConverter(ImageConverter):
12
10
  backbone_cls = PaliGemmaBackbone