keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,193 @@
1
+ """EfficientNet preset configurations."""
2
+
3
+ backbone_presets = {
4
+ "efficientnet_b0_ra_imagenet": {
5
+ "metadata": {
6
+ "description": (
7
+ "EfficientNet B0 model pre-trained on the ImageNet 1k dataset "
8
+ "with RandAugment recipe."
9
+ ),
10
+ "params": 5288548,
11
+ "path": "efficientnet",
12
+ },
13
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b0_ra_imagenet/2",
14
+ },
15
+ "efficientnet_b0_ra4_e3600_r224_imagenet": {
16
+ "metadata": {
17
+ "description": (
18
+ "EfficientNet B0 model pre-trained on the ImageNet 1k dataset "
19
+ "by Ross Wightman. Trained with timm scripts using "
20
+ "hyper-parameters inspired by the MobileNet-V4 small, mixed "
21
+ "with go-to hparams from timm and 'ResNet Strikes Back'."
22
+ ),
23
+ "params": 5288548,
24
+ "path": "efficientnet",
25
+ },
26
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b0_ra4_e3600_r224_imagenet/2",
27
+ },
28
+ "efficientnet_b1_ft_imagenet": {
29
+ "metadata": {
30
+ "description": (
31
+ "EfficientNet B1 model fine-tuned on the ImageNet 1k dataset."
32
+ ),
33
+ "params": 7794184,
34
+ "path": "efficientnet",
35
+ },
36
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet/5",
37
+ },
38
+ "efficientnet_b1_ra4_e3600_r240_imagenet": {
39
+ "metadata": {
40
+ "description": (
41
+ "EfficientNet B1 model pre-trained on the ImageNet 1k dataset "
42
+ "by Ross Wightman. Trained with timm scripts using "
43
+ "hyper-parameters inspired by the MobileNet-V4 small, mixed "
44
+ "with go-to hparams from timm and 'ResNet Strikes Back'."
45
+ ),
46
+ "params": 7794184,
47
+ "path": "efficientnet",
48
+ },
49
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ra4_e3600_r240_imagenet/2",
50
+ },
51
+ "efficientnet_b2_ra_imagenet": {
52
+ "metadata": {
53
+ "description": (
54
+ "EfficientNet B2 model pre-trained on the ImageNet 1k dataset "
55
+ "with RandAugment recipe."
56
+ ),
57
+ "params": 9109994,
58
+ "path": "efficientnet",
59
+ },
60
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b2_ra_imagenet/2",
61
+ },
62
+ "efficientnet_b3_ra2_imagenet": {
63
+ "metadata": {
64
+ "description": (
65
+ "EfficientNet B3 model pre-trained on the ImageNet 1k dataset "
66
+ "with RandAugment2 recipe."
67
+ ),
68
+ "params": 12233232,
69
+ "path": "efficientnet",
70
+ },
71
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b3_ra2_imagenet/2",
72
+ },
73
+ "efficientnet_b4_ra2_imagenet": {
74
+ "metadata": {
75
+ "description": (
76
+ "EfficientNet B4 model pre-trained on the ImageNet 1k dataset "
77
+ "with RandAugment2 recipe."
78
+ ),
79
+ "params": 19341616,
80
+ "path": "efficientnet",
81
+ },
82
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b4_ra2_imagenet/2",
83
+ },
84
+ "efficientnet_b5_sw_imagenet": {
85
+ "metadata": {
86
+ "description": (
87
+ "EfficientNet B5 model pre-trained on the ImageNet 12k dataset "
88
+ "by Ross Wightman. Based on Swin Transformer train / pretrain "
89
+ "recipe with modifications (related to both DeiT and ConvNeXt "
90
+ "recipes)."
91
+ ),
92
+ "params": 30389784,
93
+ "path": "efficientnet",
94
+ },
95
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b5_sw_imagenet/2",
96
+ },
97
+ "efficientnet_b5_sw_ft_imagenet": {
98
+ "metadata": {
99
+ "description": (
100
+ "EfficientNet B5 model pre-trained on the ImageNet 12k dataset "
101
+ "and fine-tuned on ImageNet-1k by Ross Wightman. Based on Swin "
102
+ "Transformer train / pretrain recipe with modifications "
103
+ "(related to both DeiT and ConvNeXt recipes)."
104
+ ),
105
+ "params": 30389784,
106
+ "path": "efficientnet",
107
+ },
108
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b5_sw_ft_imagenet/2",
109
+ },
110
+ "efficientnet_el_ra_imagenet": {
111
+ "metadata": {
112
+ "description": (
113
+ "EfficientNet-EdgeTPU Large model trained on the ImageNet 1k "
114
+ "dataset with RandAugment recipe."
115
+ ),
116
+ "params": 10589712,
117
+ "path": "efficientnet",
118
+ },
119
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet/5",
120
+ },
121
+ "efficientnet_em_ra2_imagenet": {
122
+ "metadata": {
123
+ "description": (
124
+ "EfficientNet-EdgeTPU Medium model trained on the ImageNet 1k "
125
+ "dataset with RandAugment2 recipe."
126
+ ),
127
+ "params": 6899496,
128
+ "path": "efficientnet",
129
+ },
130
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet/5",
131
+ },
132
+ "efficientnet_es_ra_imagenet": {
133
+ "metadata": {
134
+ "description": (
135
+ "EfficientNet-EdgeTPU Small model trained on the ImageNet 1k "
136
+ "dataset with RandAugment recipe."
137
+ ),
138
+ "params": 5438392,
139
+ "path": "efficientnet",
140
+ },
141
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet/5",
142
+ },
143
+ "efficientnet2_rw_m_agc_imagenet": {
144
+ "metadata": {
145
+ "description": (
146
+ "EfficientNet-v2 Medium model trained on the ImageNet 1k "
147
+ "dataset with adaptive gradient clipping."
148
+ ),
149
+ "params": 53236442,
150
+ "official_name": "EfficientNet",
151
+ "path": "efficientnet",
152
+ "model_card": "https://arxiv.org/abs/2104.00298",
153
+ },
154
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet2_rw_m_agc_imagenet/2",
155
+ },
156
+ "efficientnet2_rw_s_ra2_imagenet": {
157
+ "metadata": {
158
+ "description": (
159
+ "EfficientNet-v2 Small model trained on the ImageNet 1k "
160
+ "dataset with RandAugment2 recipe."
161
+ ),
162
+ "params": 23941296,
163
+ "official_name": "EfficientNet",
164
+ "path": "efficientnet",
165
+ "model_card": "https://arxiv.org/abs/2104.00298",
166
+ },
167
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet2_rw_s_ra2_imagenet/2",
168
+ },
169
+ "efficientnet2_rw_t_ra2_imagenet": {
170
+ "metadata": {
171
+ "description": (
172
+ "EfficientNet-v2 Tiny model trained on the ImageNet 1k "
173
+ "dataset with RandAugment2 recipe."
174
+ ),
175
+ "params": 13649388,
176
+ "official_name": "EfficientNet",
177
+ "path": "efficientnet",
178
+ "model_card": "https://arxiv.org/abs/2104.00298",
179
+ },
180
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet2_rw_t_ra2_imagenet/2",
181
+ },
182
+ "efficientnet_lite0_ra_imagenet": {
183
+ "metadata": {
184
+ "description": (
185
+ "EfficientNet-Lite model fine-trained on the ImageNet 1k "
186
+ "dataset with RandAugment recipe."
187
+ ),
188
+ "params": 4652008,
189
+ "path": "efficientnet",
190
+ },
191
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_lite0_ra_imagenet/2",
192
+ },
193
+ }
@@ -2,24 +2,13 @@ import keras
2
2
 
3
3
  BN_AXIS = 3
4
4
 
5
- CONV_KERNEL_INITIALIZER = {
6
- "class_name": "VarianceScaling",
7
- "config": {
8
- "scale": 2.0,
9
- "mode": "fan_out",
10
- "distribution": "truncated_normal",
11
- },
12
- }
13
-
14
5
 
15
6
  class FusedMBConvBlock(keras.layers.Layer):
16
7
  """Implementation of the FusedMBConv block
17
8
 
18
9
  Also known as a Fused Mobile Inverted Residual Bottleneck block from:
19
- [EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML]
20
- (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html)
21
- [EfficientNetV2: Smaller Models and Faster Training]
22
- (https://arxiv.org/abs/2104.00298v3).
10
+ [EfficientNet-EdgeTPU](https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html)
11
+ [EfficientNetV2: Smaller Models and Faster Training](https://arxiv.org/abs/2104.00298v3).
23
12
 
24
13
  FusedMBConv blocks are based on MBConv blocks, and replace the depthwise and
25
14
  1x1 output convolution blocks with a single 3x3 convolution block, fusing
@@ -44,13 +33,24 @@ class FusedMBConvBlock(keras.layers.Layer):
44
33
  convolutions
45
34
  strides: default 1, the strides to apply to the expansion phase
46
35
  convolutions
36
+ data_format: str, channels_last (default) or channels_first, expects
37
+ tensors to be of shape (N, H, W, C) or (N, C, H, W) respectively
47
38
  se_ratio: default 0.0, The filters used in the Squeeze-Excitation phase,
48
39
  and are chosen as the maximum between 1 and input_filters*se_ratio
49
40
  batch_norm_momentum: default 0.9, the BatchNormalization momentum
41
+ batch_norm_epsilon: default 1e-3, float, epsilon for batch norm
42
+ calcualtions. Used in denominator for calculations to prevent divide
43
+ by 0 errors.
50
44
  activation: default "swish", the activation function used between
51
45
  convolution operations
46
+ projection_activation: default None, the activation function to use
47
+ after the output projection convoultion
52
48
  dropout: float, the optional dropout rate to apply before the output
53
49
  convolution, defaults to 0.2
50
+ nores: bool, default False, forces no residual connection if True,
51
+ otherwise allows it if False.
52
+ projection_kernel_size: default 1, the kernel_size to apply to the
53
+ output projection phase convolution
54
54
 
55
55
  Returns:
56
56
  A tensor representing a feature map, passed through the FusedMBConv
@@ -67,11 +67,16 @@ class FusedMBConvBlock(keras.layers.Layer):
67
67
  expand_ratio=1,
68
68
  kernel_size=3,
69
69
  strides=1,
70
+ data_format="channels_last",
70
71
  se_ratio=0.0,
71
72
  batch_norm_momentum=0.9,
73
+ batch_norm_epsilon=1e-3,
72
74
  activation="swish",
75
+ projection_activation=None,
73
76
  dropout=0.2,
74
- **kwargs
77
+ nores=False,
78
+ projection_kernel_size=1,
79
+ **kwargs,
75
80
  ):
76
81
  super().__init__(**kwargs)
77
82
  self.input_filters = input_filters
@@ -79,44 +84,50 @@ class FusedMBConvBlock(keras.layers.Layer):
79
84
  self.expand_ratio = expand_ratio
80
85
  self.kernel_size = kernel_size
81
86
  self.strides = strides
87
+ self.data_format = data_format
82
88
  self.se_ratio = se_ratio
83
89
  self.batch_norm_momentum = batch_norm_momentum
90
+ self.batch_norm_epsilon = batch_norm_epsilon
84
91
  self.activation = activation
92
+ self.projection_activation = projection_activation
85
93
  self.dropout = dropout
94
+ self.nores = nores
95
+ self.projection_kernel_size = projection_kernel_size
86
96
  self.filters = self.input_filters * self.expand_ratio
87
97
  self.filters_se = max(1, int(input_filters * se_ratio))
88
98
 
99
+ padding_pixels = kernel_size // 2
100
+ self.conv1_pad = keras.layers.ZeroPadding2D(
101
+ padding=(padding_pixels, padding_pixels),
102
+ name=self.name + "expand_conv_pad",
103
+ )
89
104
  self.conv1 = keras.layers.Conv2D(
90
105
  filters=self.filters,
91
106
  kernel_size=kernel_size,
92
107
  strides=strides,
93
- kernel_initializer=CONV_KERNEL_INITIALIZER,
94
- padding="same",
95
- data_format="channels_last",
108
+ kernel_initializer=self._conv_kernel_initializer(),
109
+ padding="valid",
110
+ data_format=data_format,
96
111
  use_bias=False,
97
112
  name=self.name + "expand_conv",
98
113
  )
99
114
  self.bn1 = keras.layers.BatchNormalization(
100
115
  axis=BN_AXIS,
101
116
  momentum=self.batch_norm_momentum,
117
+ epsilon=self.batch_norm_epsilon,
102
118
  name=self.name + "expand_bn",
103
119
  )
104
120
  self.act = keras.layers.Activation(
105
121
  self.activation, name=self.name + "expand_activation"
106
122
  )
107
123
 
108
- self.bn2 = keras.layers.BatchNormalization(
109
- axis=BN_AXIS,
110
- momentum=self.batch_norm_momentum,
111
- name=self.name + "bn",
112
- )
113
-
114
124
  self.se_conv1 = keras.layers.Conv2D(
115
125
  self.filters_se,
116
126
  1,
117
127
  padding="same",
128
+ data_format=data_format,
118
129
  activation=self.activation,
119
- kernel_initializer=CONV_KERNEL_INITIALIZER,
130
+ kernel_initializer=self._conv_kernel_initializer(),
120
131
  name=self.name + "se_reduce",
121
132
  )
122
133
 
@@ -124,28 +135,40 @@ class FusedMBConvBlock(keras.layers.Layer):
124
135
  self.filters,
125
136
  1,
126
137
  padding="same",
138
+ data_format=data_format,
127
139
  activation="sigmoid",
128
- kernel_initializer=CONV_KERNEL_INITIALIZER,
140
+ kernel_initializer=self._conv_kernel_initializer(),
129
141
  name=self.name + "se_expand",
130
142
  )
131
143
 
144
+ padding_pixels = projection_kernel_size // 2
145
+ self.output_conv_pad = keras.layers.ZeroPadding2D(
146
+ padding=(padding_pixels, padding_pixels),
147
+ name=self.name + "project_conv_pad",
148
+ )
132
149
  self.output_conv = keras.layers.Conv2D(
133
150
  filters=self.output_filters,
134
- kernel_size=1 if expand_ratio != 1 else kernel_size,
151
+ kernel_size=projection_kernel_size,
135
152
  strides=1,
136
- kernel_initializer=CONV_KERNEL_INITIALIZER,
137
- padding="same",
138
- data_format="channels_last",
153
+ kernel_initializer=self._conv_kernel_initializer(),
154
+ padding="valid",
155
+ data_format=data_format,
139
156
  use_bias=False,
140
157
  name=self.name + "project_conv",
141
158
  )
142
159
 
143
- self.bn3 = keras.layers.BatchNormalization(
160
+ self.bn2 = keras.layers.BatchNormalization(
144
161
  axis=BN_AXIS,
145
162
  momentum=self.batch_norm_momentum,
163
+ epsilon=self.batch_norm_epsilon,
146
164
  name=self.name + "project_bn",
147
165
  )
148
166
 
167
+ if self.projection_activation:
168
+ self.projection_act = keras.layers.Activation(
169
+ self.projection_activation, name=self.name + "projection_act"
170
+ )
171
+
149
172
  if self.dropout:
150
173
  self.dropout_layer = keras.layers.Dropout(
151
174
  self.dropout,
@@ -153,23 +176,33 @@ class FusedMBConvBlock(keras.layers.Layer):
153
176
  name=self.name + "drop",
154
177
  )
155
178
 
179
+ def _conv_kernel_initializer(
180
+ self,
181
+ scale=2.0,
182
+ mode="fan_out",
183
+ distribution="truncated_normal",
184
+ seed=None,
185
+ ):
186
+ return keras.initializers.VarianceScaling(
187
+ scale=scale, mode=mode, distribution=distribution, seed=seed
188
+ )
189
+
156
190
  def build(self, input_shape):
157
191
  if self.name is None:
158
192
  self.name = keras.backend.get_uid("block0")
159
193
 
160
194
  def call(self, inputs):
161
195
  # Expansion phase
162
- if self.expand_ratio != 1:
163
- x = self.conv1(inputs)
164
- x = self.bn1(x)
165
- x = self.act(x)
166
- else:
167
- x = inputs
196
+ x = self.conv1_pad(inputs)
197
+ x = self.conv1(x)
198
+ x = self.bn1(x)
199
+ x = self.act(x)
168
200
 
169
201
  # Squeeze and excite
170
202
  if 0 < self.se_ratio <= 1:
171
203
  se = keras.layers.GlobalAveragePooling2D(
172
- name=self.name + "se_squeeze"
204
+ name=self.name + "se_squeeze",
205
+ data_format=self.data_format,
173
206
  )(x)
174
207
  if BN_AXIS == 1:
175
208
  se_shape = (self.filters, 1, 1)
@@ -186,13 +219,18 @@ class FusedMBConvBlock(keras.layers.Layer):
186
219
  x = keras.layers.multiply([x, se], name=self.name + "se_excite")
187
220
 
188
221
  # Output phase:
222
+ x = self.output_conv_pad(x)
189
223
  x = self.output_conv(x)
190
- x = self.bn3(x)
191
- if self.expand_ratio == 1:
192
- x = self.act(x)
224
+ x = self.bn2(x)
225
+ if self.expand_ratio == 1 and self.projection_activation:
226
+ x = self.projection_act(x)
193
227
 
194
228
  # Residual:
195
- if self.strides == 1 and self.input_filters == self.output_filters:
229
+ if (
230
+ self.strides == 1
231
+ and self.input_filters == self.output_filters
232
+ and not self.nores
233
+ ):
196
234
  if self.dropout:
197
235
  x = self.dropout_layer(x)
198
236
  x = keras.layers.Add(name=self.name + "add")([x, inputs])
@@ -205,10 +243,15 @@ class FusedMBConvBlock(keras.layers.Layer):
205
243
  "expand_ratio": self.expand_ratio,
206
244
  "kernel_size": self.kernel_size,
207
245
  "strides": self.strides,
246
+ "data_format": self.data_format,
208
247
  "se_ratio": self.se_ratio,
209
248
  "batch_norm_momentum": self.batch_norm_momentum,
249
+ "batch_norm_epsilon": self.batch_norm_epsilon,
210
250
  "activation": self.activation,
251
+ "projection_activation": self.projection_activation,
211
252
  "dropout": self.dropout,
253
+ "nores": self.nores,
254
+ "projection_kernel_size": self.projection_kernel_size,
212
255
  }
213
256
 
214
257
  base_config = super().get_config()
@@ -2,15 +2,6 @@ import keras
2
2
 
3
3
  BN_AXIS = 3
4
4
 
5
- CONV_KERNEL_INITIALIZER = {
6
- "class_name": "VarianceScaling",
7
- "config": {
8
- "scale": 2.0,
9
- "mode": "fan_out",
10
- "distribution": "truncated_normal",
11
- },
12
- }
13
-
14
5
 
15
6
  class MBConvBlock(keras.layers.Layer):
16
7
  def __init__(
@@ -20,11 +11,14 @@ class MBConvBlock(keras.layers.Layer):
20
11
  expand_ratio=1,
21
12
  kernel_size=3,
22
13
  strides=1,
14
+ data_format="channels_last",
23
15
  se_ratio=0.0,
24
16
  batch_norm_momentum=0.9,
17
+ batch_norm_epsilon=1e-3,
25
18
  activation="swish",
26
19
  dropout=0.2,
27
- **kwargs
20
+ nores=False,
21
+ **kwargs,
28
22
  ):
29
23
  """Implementation of the MBConv block
30
24
 
@@ -59,6 +53,9 @@ class MBConvBlock(keras.layers.Layer):
59
53
  is above 0. The filters used in this phase are chosen as the
60
54
  maximum between 1 and input_filters*se_ratio
61
55
  batch_norm_momentum: default 0.9, the BatchNormalization momentum
56
+ batch_norm_epsilon: default 1e-3, float, epsilon for batch norm
57
+ calcualtions. Used in denominator for calculations to prevent
58
+ divide by 0 errors.
62
59
  activation: default "swish", the activation function used between
63
60
  convolution operations
64
61
  dropout: float, the optional dropout rate to apply before the output
@@ -79,10 +76,13 @@ class MBConvBlock(keras.layers.Layer):
79
76
  self.expand_ratio = expand_ratio
80
77
  self.kernel_size = kernel_size
81
78
  self.strides = strides
79
+ self.data_format = data_format
82
80
  self.se_ratio = se_ratio
83
81
  self.batch_norm_momentum = batch_norm_momentum
82
+ self.batch_norm_epsilon = batch_norm_epsilon
84
83
  self.activation = activation
85
84
  self.dropout = dropout
85
+ self.nores = nores
86
86
  self.filters = self.input_filters * self.expand_ratio
87
87
  self.filters_se = max(1, int(input_filters * se_ratio))
88
88
 
@@ -90,15 +90,16 @@ class MBConvBlock(keras.layers.Layer):
90
90
  filters=self.filters,
91
91
  kernel_size=1,
92
92
  strides=1,
93
- kernel_initializer=CONV_KERNEL_INITIALIZER,
93
+ kernel_initializer=self._conv_kernel_initializer(),
94
94
  padding="same",
95
- data_format="channels_last",
95
+ data_format=data_format,
96
96
  use_bias=False,
97
97
  name=self.name + "expand_conv",
98
98
  )
99
99
  self.bn1 = keras.layers.BatchNormalization(
100
100
  axis=BN_AXIS,
101
101
  momentum=self.batch_norm_momentum,
102
+ epsilon=self.batch_norm_epsilon,
102
103
  name=self.name + "expand_bn",
103
104
  )
104
105
  self.act = keras.layers.Activation(
@@ -107,9 +108,9 @@ class MBConvBlock(keras.layers.Layer):
107
108
  self.depthwise = keras.layers.DepthwiseConv2D(
108
109
  kernel_size=self.kernel_size,
109
110
  strides=self.strides,
110
- depthwise_initializer=CONV_KERNEL_INITIALIZER,
111
+ depthwise_initializer=self._conv_kernel_initializer(),
111
112
  padding="same",
112
- data_format="channels_last",
113
+ data_format=data_format,
113
114
  use_bias=False,
114
115
  name=self.name + "dwconv2",
115
116
  )
@@ -117,6 +118,7 @@ class MBConvBlock(keras.layers.Layer):
117
118
  self.bn2 = keras.layers.BatchNormalization(
118
119
  axis=BN_AXIS,
119
120
  momentum=self.batch_norm_momentum,
121
+ epsilon=self.batch_norm_epsilon,
120
122
  name=self.name + "bn",
121
123
  )
122
124
 
@@ -124,8 +126,9 @@ class MBConvBlock(keras.layers.Layer):
124
126
  self.filters_se,
125
127
  1,
126
128
  padding="same",
129
+ data_format=data_format,
127
130
  activation=self.activation,
128
- kernel_initializer=CONV_KERNEL_INITIALIZER,
131
+ kernel_initializer=self._conv_kernel_initializer(),
129
132
  name=self.name + "se_reduce",
130
133
  )
131
134
 
@@ -133,18 +136,25 @@ class MBConvBlock(keras.layers.Layer):
133
136
  self.filters,
134
137
  1,
135
138
  padding="same",
139
+ data_format=data_format,
136
140
  activation="sigmoid",
137
- kernel_initializer=CONV_KERNEL_INITIALIZER,
141
+ kernel_initializer=self._conv_kernel_initializer(),
138
142
  name=self.name + "se_expand",
139
143
  )
140
144
 
145
+ projection_kernel_size = 1 if expand_ratio != 1 else kernel_size
146
+ padding_pixels = projection_kernel_size // 2
147
+ self.output_conv_pad = keras.layers.ZeroPadding2D(
148
+ padding=(padding_pixels, padding_pixels),
149
+ name=self.name + "project_conv_pad",
150
+ )
141
151
  self.output_conv = keras.layers.Conv2D(
142
152
  filters=self.output_filters,
143
- kernel_size=1 if expand_ratio != 1 else kernel_size,
153
+ kernel_size=projection_kernel_size,
144
154
  strides=1,
145
- kernel_initializer=CONV_KERNEL_INITIALIZER,
146
- padding="same",
147
- data_format="channels_last",
155
+ kernel_initializer=self._conv_kernel_initializer(),
156
+ padding="valid",
157
+ data_format=data_format,
148
158
  use_bias=False,
149
159
  name=self.name + "project_conv",
150
160
  )
@@ -152,6 +162,7 @@ class MBConvBlock(keras.layers.Layer):
152
162
  self.bn3 = keras.layers.BatchNormalization(
153
163
  axis=BN_AXIS,
154
164
  momentum=self.batch_norm_momentum,
165
+ epsilon=self.batch_norm_epsilon,
155
166
  name=self.name + "project_bn",
156
167
  )
157
168
 
@@ -162,6 +173,17 @@ class MBConvBlock(keras.layers.Layer):
162
173
  name=self.name + "drop",
163
174
  )
164
175
 
176
+ def _conv_kernel_initializer(
177
+ self,
178
+ scale=2.0,
179
+ mode="fan_out",
180
+ distribution="truncated_normal",
181
+ seed=None,
182
+ ):
183
+ return keras.initializers.VarianceScaling(
184
+ scale=scale, mode=mode, distribution=distribution, seed=seed
185
+ )
186
+
165
187
  def build(self, input_shape):
166
188
  if self.name is None:
167
189
  self.name = keras.backend.get_uid("block0")
@@ -183,7 +205,8 @@ class MBConvBlock(keras.layers.Layer):
183
205
  # Squeeze and excite
184
206
  if 0 < self.se_ratio <= 1:
185
207
  se = keras.layers.GlobalAveragePooling2D(
186
- name=self.name + "se_squeeze"
208
+ name=self.name + "se_squeeze",
209
+ data_format=self.data_format,
187
210
  )(x)
188
211
  if BN_AXIS == 1:
189
212
  se_shape = (self.filters, 1, 1)
@@ -199,10 +222,15 @@ class MBConvBlock(keras.layers.Layer):
199
222
  x = keras.layers.multiply([x, se], name=self.name + "se_excite")
200
223
 
201
224
  # Output phase
225
+ x = self.output_conv_pad(x)
202
226
  x = self.output_conv(x)
203
227
  x = self.bn3(x)
204
228
 
205
- if self.strides == 1 and self.input_filters == self.output_filters:
229
+ if (
230
+ self.strides == 1
231
+ and self.input_filters == self.output_filters
232
+ and not self.nores
233
+ ):
206
234
  if self.dropout:
207
235
  x = self.dropout_layer(x)
208
236
  x = keras.layers.Add(name=self.name + "add")([x, inputs])
@@ -215,10 +243,13 @@ class MBConvBlock(keras.layers.Layer):
215
243
  "expand_ratio": self.expand_ratio,
216
244
  "kernel_size": self.kernel_size,
217
245
  "strides": self.strides,
246
+ "data_format": self.data_format,
218
247
  "se_ratio": self.se_ratio,
219
248
  "batch_norm_momentum": self.batch_norm_momentum,
249
+ "batch_norm_epsilon": self.batch_norm_epsilon,
220
250
  "activation": self.activation,
221
251
  "dropout": self.dropout,
252
+ "nores": self.nores,
222
253
  }
223
254
  base_config = super().get_config()
224
255
  return dict(list(base_config.items()) + list(config.items()))
@@ -186,8 +186,8 @@ class ElectraBackbone(Backbone):
186
186
  # Index of classification token in the vocabulary
187
187
  cls_token_index = 0
188
188
  sequence_output = x
189
- # Construct the two ELECTRA outputs. The pooled output is a dense layer on
190
- # top of the [CLS] token.
189
+ # Construct the two ELECTRA outputs. The pooled output is a dense layer
190
+ # on top of the [CLS] token.
191
191
  pooled_output = self.pooled_dense(x[:, cls_token_index, :])
192
192
  super().__init__(
193
193
  inputs={