keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -42,7 +42,9 @@ class LlamaCausalLM(CausalLM):
42
42
  self.preprocessor = preprocessor
43
43
 
44
44
  # === Functional Model ===
45
- inputs = backbone.inputs
45
+ # This must be "backbone.input" i.e. the full input structure,
46
+ # rather than "backbone.inputs" which is the flattened list of inputs.
47
+ inputs = backbone.input
46
48
  hidden_states = backbone(inputs)
47
49
  outputs = backbone.token_embedding(hidden_states, reverse=True)
48
50
  super().__init__(
@@ -6,11 +6,9 @@ backbone_presets = {
6
6
  "metadata": {
7
7
  "description": "7 billion parameter, 32-layer, base LLaMA 2 model.",
8
8
  "params": 6738415616,
9
- "official_name": "LLaMA 2",
10
- "path": "llama2",
11
- "model_card": "https://github.com/meta-llama/llama",
9
+ "path": "llama",
12
10
  },
13
- "kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en/1",
11
+ "kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en/2",
14
12
  },
15
13
  "llama2_7b_en_int8": {
16
14
  "metadata": {
@@ -19,11 +17,9 @@ backbone_presets = {
19
17
  "activation and weights quantized to int8."
20
18
  ),
21
19
  "params": 6739839488,
22
- "official_name": "LLaMA 2",
23
- "path": "llama2",
24
- "model_card": "https://github.com/meta-llama/llama",
20
+ "path": "llama",
25
21
  },
26
- "kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en_int8/1",
22
+ "kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en_int8/2",
27
23
  },
28
24
  "llama2_instruct_7b_en": {
29
25
  "metadata": {
@@ -32,11 +28,9 @@ backbone_presets = {
32
28
  "model."
33
29
  ),
34
30
  "params": 6738415616,
35
- "official_name": "LLaMA 2",
36
- "path": "llama2",
37
- "model_card": "https://github.com/meta-llama/llama",
31
+ "path": "llama",
38
32
  },
39
- "kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en/1",
33
+ "kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en/2",
40
34
  },
41
35
  "llama2_instruct_7b_en_int8": {
42
36
  "metadata": {
@@ -45,11 +39,9 @@ backbone_presets = {
45
39
  "model with activation and weights quantized to int8."
46
40
  ),
47
41
  "params": 6739839488,
48
- "official_name": "LLaMA 2",
49
- "path": "llama2",
50
- "model_card": "https://github.com/meta-llama/llama",
42
+ "path": "llama",
51
43
  },
52
- "kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en_int8/1",
44
+ "kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en_int8/2",
53
45
  },
54
46
  "vicuna_1.5_7b_en": {
55
47
  "metadata": {
@@ -58,10 +50,8 @@ backbone_presets = {
58
50
  "model."
59
51
  ),
60
52
  "params": 6738415616,
61
- "official_name": "Vicuna",
62
- "path": "vicuna",
63
- "model_card": "https://github.com/lm-sys/FastChat",
53
+ "path": "llama",
64
54
  },
65
- "kaggle_handle": "kaggle://keras/vicuna/keras/vicuna_1.5_7b_en/1",
55
+ "kaggle_handle": "kaggle://keras/vicuna/keras/vicuna_1.5_7b_en/2",
66
56
  },
67
57
  }
@@ -24,17 +24,18 @@ class Llama3Backbone(LlamaBackbone):
24
24
  num_layers (int): The number of transformer layers.
25
25
  num_query_heads (int): The number of query attention heads for
26
26
  each transformer.
27
- hidden_dim (int): The size of the transformer encoding and pooling layers.
28
- intermediate_dim (int): The output dimension of the first Dense layer in a
29
- three-layer feedforward network for each transformer.
30
- num_key_value_heads (int): The number of key and value attention heads for
31
- each transformer.
32
- rope_max_wavelength (int, optional): The maximum angular wavelength of the
33
- sine/cosine curves, for rotary embeddings. Defaults to `10000`.
34
- rope_scaling_factor (float, optional): The scaling factor for calculation
35
- of roatary embedding. Defaults to `1.0`.
36
- layer_norm_epsilon (float, optional): Epsilon for the layer normalization
37
- layers in the transformer decoder. Defaults to `1e-6`.
27
+ hidden_dim (int): The size of the transformer encoding and pooling
28
+ layers.
29
+ intermediate_dim (int): The output dimension of the first Dense layer in
30
+ a three-layer feedforward network for each transformer.
31
+ num_key_value_heads (int): The number of key and value attention heads
32
+ fo each transformer.
33
+ rope_max_wavelength (int, optional): The maximum angular wavelength of
34
+ the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
35
+ rope_scaling_factor (float, optional): The scaling factor for
36
+ calculation of roatary embedding. Defaults to `1.0`.
37
+ layer_norm_epsilon (float, optional): Epsilon for the layer
38
+ normalization layers in the transformer decoder. Defaults to `1e-6`.
38
39
  dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
39
40
  for model computations and weights. Note that some computations,
40
41
  such as softmax and layer normalization, will always be done at
@@ -1,9 +1,9 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM
2
3
  from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone
3
4
  from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import (
4
5
  Llama3CausalLMPreprocessor,
5
6
  )
6
- from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM
7
7
 
8
8
 
9
9
  @keras_hub_export("keras_hub.models.Llama3CausalLM")
@@ -6,11 +6,9 @@ backbone_presets = {
6
6
  "metadata": {
7
7
  "description": "8 billion parameter, 32-layer, base LLaMA 3 model.",
8
8
  "params": 8030261248,
9
- "official_name": "LLaMA 3",
10
9
  "path": "llama3",
11
- "model_card": "https://github.com/meta-llama/llama3",
12
10
  },
13
- "kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en/3",
11
+ "kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en/4",
14
12
  },
15
13
  "llama3_8b_en_int8": {
16
14
  "metadata": {
@@ -19,11 +17,9 @@ backbone_presets = {
19
17
  "activation and weights quantized to int8."
20
18
  ),
21
19
  "params": 8031894016,
22
- "official_name": "LLaMA 3",
23
20
  "path": "llama3",
24
- "model_card": "https://github.com/meta-llama/llama3",
25
21
  },
26
- "kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en_int8/1",
22
+ "kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en_int8/2",
27
23
  },
28
24
  "llama3_instruct_8b_en": {
29
25
  "metadata": {
@@ -32,11 +28,9 @@ backbone_presets = {
32
28
  "model."
33
29
  ),
34
30
  "params": 8030261248,
35
- "official_name": "LLaMA 3",
36
31
  "path": "llama3",
37
- "model_card": "https://github.com/meta-llama/llama3",
38
32
  },
39
- "kaggle_handle": "kaggle://keras/llama3/keras/llama3_instruct_8b_en/3",
33
+ "kaggle_handle": "kaggle://keras/llama3/keras/llama3_instruct_8b_en/4",
40
34
  },
41
35
  "llama3_instruct_8b_en_int8": {
42
36
  "metadata": {
@@ -45,12 +39,10 @@ backbone_presets = {
45
39
  "model with activation and weights quantized to int8."
46
40
  ),
47
41
  "params": 8031894016,
48
- "official_name": "LLaMA 3",
49
42
  "path": "llama3",
50
- "model_card": "https://github.com/meta-llama/llama3",
51
43
  },
52
44
  "kaggle_handle": (
53
- "kaggle://keras/llama3/keras/llama3_instruct_8b_en_int8/1"
45
+ "kaggle://keras/llama3/keras/llama3_instruct_8b_en_int8/2"
54
46
  ),
55
47
  },
56
48
  }
@@ -16,10 +16,33 @@ class Llama3Tokenizer(BytePairTokenizer):
16
16
  self,
17
17
  vocabulary=None,
18
18
  merges=None,
19
+ bos_token="<|begin_of_text|>",
20
+ eos_token="<|end_of_text|>",
21
+ misc_special_tokens={"<|start_header_id|>", "<|end_header_id|>"},
19
22
  **kwargs,
20
23
  ):
21
- self._add_special_token("<|begin_of_text|>", "start_token")
22
- self._add_special_token("<|end_of_text|>", "end_token")
24
+ # Note: all special tokens must also appear in "vocabulary"
25
+
26
+ self._add_special_token(bos_token, "start_token")
27
+ misc_special_tokens -= {bos_token}
28
+ self._add_special_token(eos_token, "end_token")
29
+ misc_special_tokens -= {eos_token}
30
+ for i, token in enumerate(misc_special_tokens):
31
+ self._add_special_token(token, f"special_token_{i:03d}")
32
+
33
+ # Hack:
34
+ # Llama models use the <|end_of_text|> or the <|eot_id|> as the stop
35
+ # token. This info can be read from config when loading a Hugging Face
36
+ # checkpoint but no such config exists for Keras checkpoints.
37
+ # Setting both probable end tokens when no config is availble will
38
+ # make text generation work in all cases as it will stop
39
+ # on both end tokens. However, the packer will always use
40
+ # "<|end_of_text|>" , which will be the wrong eos_token for "instruct"
41
+ # variants of Llama3.
42
+ # TODO: load this correctly from a Keras tokenizer config.
43
+ if eos_token == "<|end_of_text|>":
44
+ self._add_special_token("<|eot_id|>", "end_token2")
45
+
23
46
  self.pad_token_id = 0
24
47
  super().__init__(
25
48
  vocabulary=vocabulary,
@@ -38,22 +38,23 @@ class MistralBackbone(Backbone):
38
38
  num_layers (int): The number of transformer layers.
39
39
  num_query_heads (int): The number of query attention heads for
40
40
  each transformer.
41
- hidden_dim (int): The size of the transformer encoding and pooling layers.
42
- intermediate_dim (int): The output dimension of the first Dense layer in a
43
- three-layer feedforward network for each transformer.
44
- num_key_value_heads (int): The number of key and value attention heads for
45
- each transformer.
46
- rope_max_wavelength (int, optional): The maximum angular wavelength of the
47
- sine/cosine curves, for rotary embeddings. Defaults to `10000`.
48
- rope_scaling_factor (float, optional): The scaling factor for calculation
49
- of roatary embedding. Defaults to `1.0`.
50
- layer_norm_epsilon (float, optional): Epsilon for the layer normalization
51
- layers in the transformer decoder. Defaults to `1e-6`.
41
+ hidden_dim (int): The size of the transformer encoding and pooling
42
+ layers.
43
+ intermediate_dim (int): The output dimension of the first Dense layer
44
+ in a three-layer feedforward network for each transformer.
45
+ num_key_value_heads (int): The number of key and value attention heads
46
+ for each transformer.
47
+ rope_max_wavelength (int, optional): The maximum angular wavelength of
48
+ the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
49
+ rope_scaling_factor (float, optional): The scaling factor for
50
+ calculation of roatary embedding. Defaults to `1.0`.
51
+ layer_norm_epsilon (float, optional): Epsilon for the layer
52
+ normalization layers in the transformer decoder. Defaults to `1e-6`.
52
53
  sliding_window (int, optional): The sliding window for the mistral
53
- attention layers. This controls the maximum cache size for the attention
54
- layers in each transformer decoder. Only `sliding_window` number of tokens
55
- are saved in the cache and used to generate the next token.
56
- Defaults to `512`.
54
+ attention layers. This controls the maximum cache size for the
55
+ attention layers in each transformer decoder. Only `sliding_window`
56
+ number of tokens are saved in the cache and used to generate the
57
+ next token. Defaults to `512`.
57
58
  dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
58
59
  for model computations and weights. Note that some computations,
59
60
  such as softmax and layer normalization, will always be done at
@@ -28,9 +28,9 @@ class MistralCausalLM(CausalLM):
28
28
 
29
29
  Args:
30
30
  backbone: A `keras_hub.models.MistralBackbone` instance.
31
- preprocessor: A `keras_hub.models.MistralCausalLMPreprocessor` or `None`.
32
- If `None`, this model will not apply preprocessing, and inputs
33
- should be preprocessed before calling the model.
31
+ preprocessor: A `keras_hub.models.MistralCausalLMPreprocessor` or
32
+ `None`. If `None`, this model will not apply preprocessing, and
33
+ inputs should be preprocessed before calling the model.
34
34
  """
35
35
 
36
36
  backbone_cls = MistralBackbone
@@ -42,7 +42,9 @@ class MistralCausalLM(CausalLM):
42
42
  self.preprocessor = preprocessor
43
43
 
44
44
  # === Functional Model ===
45
- inputs = backbone.inputs
45
+ # This must be "backbone.input" i.e. the full input structure,
46
+ # rather than "backbone.inputs" which is the flattened list of inputs.
47
+ inputs = backbone.input
46
48
  hidden_states = backbone(inputs)
47
49
  outputs = backbone.token_embedding(hidden_states, reverse=True)
48
50
  super().__init__(
@@ -6,30 +6,24 @@ backbone_presets = {
6
6
  "metadata": {
7
7
  "description": "Mistral 7B base model",
8
8
  "params": 7241732096,
9
- "official_name": "Mistral",
10
9
  "path": "mistral",
11
- "model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
12
10
  },
13
- "kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/6",
11
+ "kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/7",
14
12
  },
15
13
  "mistral_instruct_7b_en": {
16
14
  "metadata": {
17
15
  "description": "Mistral 7B instruct model",
18
16
  "params": 7241732096,
19
- "official_name": "Mistral",
20
17
  "path": "mistral",
21
- "model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
22
18
  },
23
- "kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/6",
19
+ "kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/7",
24
20
  },
25
21
  "mistral_0.2_instruct_7b_en": {
26
22
  "metadata": {
27
23
  "description": "Mistral 7B instruct Version 0.2 model",
28
24
  "params": 7241732096,
29
- "official_name": "Mistral",
30
25
  "path": "mistral",
31
- "model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
32
26
  },
33
- "kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.2_instruct_7b_en/1",
27
+ "kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.2_instruct_7b_en/2",
34
28
  },
35
29
  }
@@ -215,7 +215,8 @@ class MistralTransformerDecoder(keras.layers.Layer):
215
215
  # Mistral uses a banded attention mask if sliding window is not None
216
216
  if self.sliding_window is not None:
217
217
  # Below is a workaround for `ops.triu` for Keras 2.
218
- # TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is removed.
218
+ # TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is
219
+ # removed.
219
220
  # causal_mask = ops.triu(causal_mask, k=-self.sliding_window)
220
221
  i = ops.arange(output_length)[:, None] + cache_update_index
221
222
  j = ops.arange(input_length)[None, :]
@@ -0,0 +1,6 @@
1
+ from keras_hub.src.models.mit.mit_backbone import MiTBackbone
2
+ from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier
3
+ from keras_hub.src.models.mit.mit_presets import backbone_presets
4
+ from keras_hub.src.utils.preset_utils import register_presets
5
+
6
+ register_presets(backbone_presets, MiTBackbone)
@@ -1,28 +1,35 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # https://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
1
12
  import keras
2
13
  import numpy as np
3
14
  from keras import ops
4
15
 
5
16
  from keras_hub.src.api_export import keras_hub_export
6
17
  from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
7
- from keras_hub.src.models.mix_transformer.mix_transformer_layers import (
8
- HierarchicalTransformerEncoder,
9
- )
10
- from keras_hub.src.models.mix_transformer.mix_transformer_layers import (
11
- OverlappingPatchingAndEmbedding,
12
- )
18
+ from keras_hub.src.models.mit.mit_layers import HierarchicalTransformerEncoder
19
+ from keras_hub.src.models.mit.mit_layers import OverlappingPatchingAndEmbedding
13
20
 
14
21
 
15
22
  @keras_hub_export("keras_hub.models.MiTBackbone")
16
23
  class MiTBackbone(FeaturePyramidBackbone):
17
24
  def __init__(
18
25
  self,
19
- depths,
26
+ layerwise_depths,
20
27
  num_layers,
21
- blockwise_num_heads,
22
- blockwise_sr_ratios,
28
+ layerwise_num_heads,
29
+ layerwise_sr_ratios,
23
30
  max_drop_path_rate,
24
- patch_sizes,
25
- strides,
31
+ layerwise_patch_sizes,
32
+ layerwise_strides,
26
33
  image_shape=(None, None, 3),
27
34
  hidden_dims=None,
28
35
  **kwargs,
@@ -36,12 +43,12 @@ class MiTBackbone(FeaturePyramidBackbone):
36
43
  https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer)
37
44
 
38
45
  Args:
39
- depths: The number of transformer encoders to be used per layer in the
40
- network.
46
+ layerwise_depths: The number of transformer encoders to be used per
47
+ layer in the network.
41
48
  num_layers: int. The number of Transformer layers.
42
- blockwise_num_heads: list of integers, the number of heads to use
49
+ layerwise_num_heads: list of integers, the number of heads to use
43
50
  in the attention computation for each layer.
44
- blockwise_sr_ratios: list of integers, the sequence reduction
51
+ layerwise_sr_ratios: list of integers, the sequence reduction
45
52
  ratio to perform for each layer on the sequence before key and
46
53
  value projections. If set to > 1, a `Conv2D` layer is used to
47
54
  reduce the length of the sequence.
@@ -51,7 +58,8 @@ class MiTBackbone(FeaturePyramidBackbone):
51
58
  image_shape: optional shape tuple, defaults to (None, None, 3).
52
59
  hidden_dims: the embedding dims per hierarchical layer, used as
53
60
  the levels of the feature pyramid.
54
- patch_sizes: list of integers, the patch_size to apply for each layer.
61
+ patch_sizes: list of integers, the patch_size to apply for each
62
+ layer.
55
63
  strides: list of integers, stride to apply for each layer.
56
64
 
57
65
  Examples:
@@ -61,7 +69,7 @@ class MiTBackbone(FeaturePyramidBackbone):
61
69
  ```python
62
70
  images = np.ones(shape=(1, 96, 96, 3))
63
71
  labels = np.zeros(shape=(1, 96, 96, 1))
64
- backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_imagenet")
72
+ backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512")
65
73
 
66
74
  # Evaluate model
67
75
  model(images)
@@ -75,7 +83,10 @@ class MiTBackbone(FeaturePyramidBackbone):
75
83
  model.fit(images, labels, epochs=3)
76
84
  ```
77
85
  """
78
- dpr = [x for x in np.linspace(0.0, max_drop_path_rate, sum(depths))]
86
+ dpr = [
87
+ x
88
+ for x in np.linspace(0.0, max_drop_path_rate, sum(layerwise_depths))
89
+ ]
79
90
 
80
91
  # === Layers ===
81
92
  cur = 0
@@ -86,8 +97,8 @@ class MiTBackbone(FeaturePyramidBackbone):
86
97
  for i in range(num_layers):
87
98
  patch_embed_layer = OverlappingPatchingAndEmbedding(
88
99
  project_dim=hidden_dims[i],
89
- patch_size=patch_sizes[i],
90
- stride=strides[i],
100
+ patch_size=layerwise_patch_sizes[i],
101
+ stride=layerwise_strides[i],
91
102
  name=f"patch_and_embed_{i}",
92
103
  )
93
104
  patch_embedding_layers.append(patch_embed_layer)
@@ -95,16 +106,16 @@ class MiTBackbone(FeaturePyramidBackbone):
95
106
  transformer_block = [
96
107
  HierarchicalTransformerEncoder(
97
108
  project_dim=hidden_dims[i],
98
- num_heads=blockwise_num_heads[i],
99
- sr_ratio=blockwise_sr_ratios[i],
109
+ num_heads=layerwise_num_heads[i],
110
+ sr_ratio=layerwise_sr_ratios[i],
100
111
  drop_prob=dpr[cur + k],
101
112
  name=f"hierarchical_encoder_{i}_{k}",
102
113
  )
103
- for k in range(depths[i])
114
+ for k in range(layerwise_depths[i])
104
115
  ]
105
116
  transformer_blocks.append(transformer_block)
106
- cur += depths[i]
107
- layer_norms.append(keras.layers.LayerNormalization())
117
+ cur += layerwise_depths[i]
118
+ layer_norms.append(keras.layers.LayerNormalization(epsilon=1e-5))
108
119
 
109
120
  # === Functional Model ===
110
121
  image_input = keras.layers.Input(shape=image_shape)
@@ -113,7 +124,7 @@ class MiTBackbone(FeaturePyramidBackbone):
113
124
  for i in range(num_layers):
114
125
  # Compute new height/width after the `proj`
115
126
  # call in `OverlappingPatchingAndEmbedding`
116
- stride = strides[i]
127
+ stride = layerwise_strides[i]
117
128
  new_height, new_width = (
118
129
  int(ops.shape(x)[1] / stride),
119
130
  int(ops.shape(x)[2] / stride),
@@ -131,30 +142,30 @@ class MiTBackbone(FeaturePyramidBackbone):
131
142
  super().__init__(inputs=image_input, outputs=x, **kwargs)
132
143
 
133
144
  # === Config ===
134
- self.depths = depths
145
+ self.layerwise_depths = layerwise_depths
135
146
  self.image_shape = image_shape
136
147
  self.hidden_dims = hidden_dims
137
148
  self.pyramid_outputs = pyramid_outputs
138
149
  self.num_layers = num_layers
139
- self.blockwise_num_heads = blockwise_num_heads
140
- self.blockwise_sr_ratios = blockwise_sr_ratios
150
+ self.layerwise_num_heads = layerwise_num_heads
151
+ self.layerwise_sr_ratios = layerwise_sr_ratios
141
152
  self.max_drop_path_rate = max_drop_path_rate
142
- self.patch_sizes = patch_sizes
143
- self.strides = strides
153
+ self.layerwise_patch_sizes = layerwise_patch_sizes
154
+ self.layerwise_strides = layerwise_strides
144
155
 
145
156
  def get_config(self):
146
157
  config = super().get_config()
147
158
  config.update(
148
159
  {
149
- "depths": self.depths,
160
+ "layerwise_depths": self.layerwise_depths,
150
161
  "hidden_dims": self.hidden_dims,
151
162
  "image_shape": self.image_shape,
152
163
  "num_layers": self.num_layers,
153
- "blockwise_num_heads": self.blockwise_num_heads,
154
- "blockwise_sr_ratios": self.blockwise_sr_ratios,
164
+ "layerwise_num_heads": self.layerwise_num_heads,
165
+ "layerwise_sr_ratios": self.layerwise_sr_ratios,
155
166
  "max_drop_path_rate": self.max_drop_path_rate,
156
- "patch_sizes": self.patch_sizes,
157
- "strides": self.strides,
167
+ "layerwise_patch_sizes": self.layerwise_patch_sizes,
168
+ "layerwise_strides": self.layerwise_strides,
158
169
  }
159
170
  )
160
171
  return config
@@ -0,0 +1,12 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.image_classifier import ImageClassifier
3
+ from keras_hub.src.models.mit.mit_backbone import MiTBackbone
4
+ from keras_hub.src.models.mit.mit_image_classifier_preprocessor import (
5
+ MiTImageClassifierPreprocessor,
6
+ )
7
+
8
+
9
+ @keras_hub_export("keras_hub.models.MiTImageClassifier")
10
+ class MiTImageClassifier(ImageClassifier):
11
+ backbone_cls = MiTBackbone
12
+ preprocessor_cls = MiTImageClassifierPreprocessor
@@ -0,0 +1,12 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.image_classifier_preprocessor import (
3
+ ImageClassifierPreprocessor,
4
+ )
5
+ from keras_hub.src.models.mit.mit_backbone import MiTBackbone
6
+ from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter
7
+
8
+
9
+ @keras_hub_export("keras_hub.models.MiTImageClassifierPreprocessor")
10
+ class MiTImageClassifierPreprocessor(ImageClassifierPreprocessor):
11
+ backbone_cls = MiTBackbone
12
+ image_converter_cls = MiTImageConverter
@@ -0,0 +1,8 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
+ from keras_hub.src.models.mit import MiTBackbone
4
+
5
+
6
+ @keras_hub_export("keras_hub.layers.MiTImageConverter")
7
+ class MiTImageConverter(ImageConverter):
8
+ backbone_cls = MiTBackbone
@@ -28,19 +28,23 @@ class OverlappingPatchingAndEmbedding(keras.layers.Layer):
28
28
  self.patch_size = patch_size
29
29
  self.stride = stride
30
30
 
31
+ padding_size = self.patch_size // 2
32
+
33
+ self.padding = keras.layers.ZeroPadding2D(
34
+ padding=(padding_size, padding_size)
35
+ )
31
36
  self.proj = keras.layers.Conv2D(
32
37
  filters=project_dim,
33
38
  kernel_size=patch_size,
34
39
  strides=stride,
35
- padding="same",
40
+ padding="valid",
36
41
  )
37
- self.norm = keras.layers.LayerNormalization()
42
+ self.norm = keras.layers.LayerNormalization(epsilon=1e-5)
38
43
 
39
44
  def call(self, x):
45
+ x = self.padding(x)
40
46
  x = self.proj(x)
41
- # B, H, W, C
42
- shape = x.shape
43
- x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3]))
47
+ x = ops.reshape(x, (-1, x.shape[1] * x.shape[2], x.shape[3]))
44
48
  x = self.norm(x)
45
49
  return x
46
50
 
@@ -76,7 +80,8 @@ class HierarchicalTransformerEncoder(keras.layers.Layer):
76
80
  `LayerNormalization` layers. Defaults to `1e-06`
77
81
  sr_ratio: integer, the ratio to use within
78
82
  `SegFormerMultiheadAttention`. If set to > 1, a `Conv2D`
79
- layer is used to reduce the length of the sequence. Defaults to `1`.
83
+ layer is used to reduce the length of the sequence.
84
+ Defaults to `1`.
80
85
  """
81
86
 
82
87
  def __init__(
@@ -179,20 +184,21 @@ class SegFormerMultiheadAttention(keras.layers.Layer):
179
184
  self.k = keras.layers.Dense(project_dim)
180
185
  self.v = keras.layers.Dense(project_dim)
181
186
  self.proj = keras.layers.Dense(project_dim)
187
+ self.dropout = keras.layers.Dropout(0.1)
188
+ self.proj_drop = keras.layers.Dropout(0.1)
182
189
 
183
190
  if sr_ratio > 1:
184
191
  self.sr = keras.layers.Conv2D(
185
192
  filters=project_dim,
186
193
  kernel_size=sr_ratio,
187
194
  strides=sr_ratio,
188
- padding="same",
189
195
  )
190
- self.norm = keras.layers.LayerNormalization()
196
+ self.norm = keras.layers.LayerNormalization(epsilon=1e-5)
191
197
 
192
198
  def call(self, x):
193
199
  input_shape = ops.shape(x)
194
200
  H, W = int(math.sqrt(input_shape[1])), int(math.sqrt(input_shape[1]))
195
- B, C = input_shape[0], input_shape[2]
201
+ B, N, C = input_shape[0], input_shape[1], input_shape[2]
196
202
 
197
203
  q = self.q(x)
198
204
  q = ops.reshape(
@@ -208,12 +214,11 @@ class SegFormerMultiheadAttention(keras.layers.Layer):
208
214
 
209
215
  if self.sr_ratio > 1:
210
216
  x = ops.reshape(
211
- ops.transpose(x, [0, 2, 1]),
217
+ x,
212
218
  (B, H, W, C),
213
219
  )
214
220
  x = self.sr(x)
215
- x = ops.reshape(x, [input_shape[0], input_shape[2], -1])
216
- x = ops.transpose(x, [0, 2, 1])
221
+ x = ops.reshape(x, [B, -1, C])
217
222
  x = self.norm(x)
218
223
 
219
224
  k = self.k(x)
@@ -237,14 +242,16 @@ class SegFormerMultiheadAttention(keras.layers.Layer):
237
242
 
238
243
  attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale
239
244
  attn = ops.nn.softmax(attn, axis=-1)
245
+ attn = self.dropout(attn)
240
246
 
241
247
  attn = attn @ v
242
248
  attn = ops.reshape(
243
249
  ops.transpose(attn, [0, 2, 1, 3]),
244
- [input_shape[0], input_shape[1], input_shape[2]],
250
+ [B, N, C],
245
251
  )
246
252
 
247
253
  x = self.proj(attn)
254
+ x = self.proj_drop(x)
248
255
  return x
249
256
 
250
257