keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,612 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+ from keras import layers
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
20
+ from keras_hub.src.utils.keras_utils import standardize_data_format
21
+
22
+
23
+ @keras_hub_export("keras_hub.models.ResNetBackbone")
24
+ class ResNetBackbone(FeaturePyramidBackbone):
25
+ """ResNet and ResNetV2 core network with hyperparameters.
26
+
27
+ This class implements a ResNet backbone as described in [Deep Residual
28
+ Learning for Image Recognition](https://arxiv.org/abs/1512.03385)(
29
+ CVPR 2016), [Identity Mappings in Deep Residual Networks](
30
+ https://arxiv.org/abs/1603.05027)(ECCV 2016) and [ResNet strikes back: An
31
+ improved training procedure in timm](https://arxiv.org/abs/2110.00476)(
32
+ NeurIPS 2021 Workshop).
33
+
34
+ The difference in ResNet and ResNetV2 rests in the structure of their
35
+ individual building blocks. In ResNetV2, the batch normalization and
36
+ ReLU activation precede the convolution layers, as opposed to ResNet where
37
+ the batch normalization and ReLU activation are applied after the
38
+ convolution layers.
39
+
40
+ Note that `ResNetBackbone` expects the inputs to be images with a value
41
+ range of `[0, 255]` when `include_rescaling=True`.
42
+
43
+ Args:
44
+ stackwise_num_filters: list of ints. The number of filters for each
45
+ stack.
46
+ stackwise_num_blocks: list of ints. The number of blocks for each stack.
47
+ stackwise_num_strides: list of ints. The number of strides for each
48
+ stack.
49
+ block_type: str. The block type to stack. One of `"basic_block"` or
50
+ `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34.
51
+ Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152.
52
+ use_pre_activation: boolean. Whether to use pre-activation or not.
53
+ `True` for ResNetV2, `False` for ResNet.
54
+ include_rescaling: boolean. If `True`, rescale the input using
55
+ `Rescaling` and `Normalization` layers. If `False`, do nothing.
56
+ Defaults to `True`.
57
+ image_shape: tuple. The input shape without the batch size.
58
+ Defaults to `(None, None, 3)`.
59
+ pooling: `None` or str. Pooling mode for feature extraction. Defaults
60
+ to `"avg"`.
61
+ - `None` means that the output of the model will be the 4D tensor
62
+ from the last convolutional block.
63
+ - `avg` means that global average pooling will be applied to the
64
+ output of the last convolutional block, resulting in a 2D
65
+ tensor.
66
+ - `max` means that global max pooling will be applied to the
67
+ output of the last convolutional block, resulting in a 2D
68
+ tensor.
69
+ data_format: `None` or str. If specified, either `"channels_last"` or
70
+ `"channels_first"`. The ordering of the dimensions in the
71
+ inputs. `"channels_last"` corresponds to inputs with shape
72
+ `(batch_size, height, width, channels)`
73
+ while `"channels_first"` corresponds to inputs with shape
74
+ `(batch_size, channels, height, width)`. It defaults to the
75
+ `image_data_format` value found in your Keras config file at
76
+ `~/.keras/keras.json`. If you never set it, then it will be
77
+ `"channels_last"`.
78
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
79
+ to use for the model's computations and weights.
80
+
81
+ Examples:
82
+ ```python
83
+ input_data = np.random.uniform(0, 255, size=(2, 224, 224, 3))
84
+
85
+ # Pretrained ResNet backbone.
86
+ model = keras_hub.models.ResNetBackbone.from_preset("resnet50")
87
+ model(input_data)
88
+
89
+ # Randomly initialized ResNetV2 backbone with a custom config.
90
+ model = keras_hub.models.ResNetBackbone(
91
+ stackwise_num_filters=[64, 64, 64],
92
+ stackwise_num_blocks=[2, 2, 2],
93
+ stackwise_num_strides=[1, 2, 2],
94
+ block_type="basic_block",
95
+ use_pre_activation=True,
96
+ pooling="avg",
97
+ )
98
+ model(input_data)
99
+ ```
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ stackwise_num_filters,
105
+ stackwise_num_blocks,
106
+ stackwise_num_strides,
107
+ block_type,
108
+ use_pre_activation=False,
109
+ include_rescaling=True,
110
+ image_shape=(None, None, 3),
111
+ pooling="avg",
112
+ data_format=None,
113
+ dtype=None,
114
+ **kwargs,
115
+ ):
116
+ if len(stackwise_num_filters) != len(stackwise_num_blocks) or len(
117
+ stackwise_num_filters
118
+ ) != len(stackwise_num_strides):
119
+ raise ValueError(
120
+ "The length of `stackwise_num_filters`, `stackwise_num_blocks` "
121
+ "and `stackwise_num_strides` must be the same. Received: "
122
+ f"stackwise_num_filters={stackwise_num_filters}, "
123
+ f"stackwise_num_blocks={stackwise_num_blocks}, "
124
+ f"stackwise_num_strides={stackwise_num_strides}"
125
+ )
126
+ if stackwise_num_filters[0] != 64:
127
+ raise ValueError(
128
+ "The first element of `stackwise_num_filters` must be 64. "
129
+ f"Received: stackwise_num_filters={stackwise_num_filters}"
130
+ )
131
+ if block_type not in ("basic_block", "bottleneck_block"):
132
+ raise ValueError(
133
+ '`block_type` must be either `"basic_block"` or '
134
+ f'`"bottleneck_block"`. Received block_type={block_type}.'
135
+ )
136
+ version = "v1" if not use_pre_activation else "v2"
137
+ data_format = standardize_data_format(data_format)
138
+ bn_axis = -1 if data_format == "channels_last" else 1
139
+ num_stacks = len(stackwise_num_filters)
140
+
141
+ # === Functional Model ===
142
+ image_input = layers.Input(shape=image_shape)
143
+ if include_rescaling:
144
+ x = layers.Rescaling(scale=1 / 255.0, dtype=dtype)(image_input)
145
+ x = layers.Normalization(
146
+ axis=bn_axis,
147
+ mean=(0.485, 0.456, 0.406),
148
+ variance=(0.229**2, 0.224**2, 0.225**2),
149
+ dtype=dtype,
150
+ name="normalization",
151
+ )(x)
152
+ else:
153
+ x = image_input
154
+
155
+ # The padding between torch and tensorflow/jax differs when `strides>1`.
156
+ # Therefore, we need to manually pad the tensor.
157
+ x = layers.ZeroPadding2D(
158
+ 3,
159
+ data_format=data_format,
160
+ dtype=dtype,
161
+ name="conv1_pad",
162
+ )(x)
163
+ x = layers.Conv2D(
164
+ 64,
165
+ 7,
166
+ strides=2,
167
+ data_format=data_format,
168
+ use_bias=False,
169
+ dtype=dtype,
170
+ name="conv1_conv",
171
+ )(x)
172
+ if not use_pre_activation:
173
+ x = layers.BatchNormalization(
174
+ axis=bn_axis,
175
+ epsilon=1e-5,
176
+ momentum=0.9,
177
+ dtype=dtype,
178
+ name="conv1_bn",
179
+ )(x)
180
+ x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x)
181
+
182
+ if use_pre_activation:
183
+ # A workaround for ResNetV2: we need -inf padding to prevent zeros
184
+ # from being the max values in the following `MaxPooling2D`.
185
+ pad_width = [[1, 1], [1, 1]]
186
+ if data_format == "channels_last":
187
+ pad_width += [[0, 0]]
188
+ else:
189
+ pad_width = [[0, 0]] + pad_width
190
+ pad_width = [[0, 0]] + pad_width
191
+ x = ops.pad(x, pad_width=pad_width, constant_values=float("-inf"))
192
+ else:
193
+ x = layers.ZeroPadding2D(
194
+ 1, data_format=data_format, dtype=dtype, name="pool1_pad"
195
+ )(x)
196
+ x = layers.MaxPooling2D(
197
+ 3,
198
+ strides=2,
199
+ data_format=data_format,
200
+ dtype=dtype,
201
+ name="pool1_pool",
202
+ )(x)
203
+
204
+ pyramid_outputs = {}
205
+ for stack_index in range(num_stacks):
206
+ x = apply_stack(
207
+ x,
208
+ filters=stackwise_num_filters[stack_index],
209
+ blocks=stackwise_num_blocks[stack_index],
210
+ stride=stackwise_num_strides[stack_index],
211
+ block_type=block_type,
212
+ use_pre_activation=use_pre_activation,
213
+ first_shortcut=(
214
+ block_type == "bottleneck_block" or stack_index > 0
215
+ ),
216
+ data_format=data_format,
217
+ dtype=dtype,
218
+ name=f"{version}_stack{stack_index}",
219
+ )
220
+ pyramid_outputs[f"P{stack_index + 2}"] = x
221
+
222
+ if use_pre_activation:
223
+ x = layers.BatchNormalization(
224
+ axis=bn_axis,
225
+ epsilon=1e-5,
226
+ momentum=0.9,
227
+ dtype=dtype,
228
+ name="post_bn",
229
+ )(x)
230
+ x = layers.Activation("relu", dtype=dtype, name="post_relu")(x)
231
+
232
+ if pooling == "avg":
233
+ feature_map_output = layers.GlobalAveragePooling2D(
234
+ data_format=data_format, dtype=dtype
235
+ )(x)
236
+ elif pooling == "max":
237
+ feature_map_output = layers.GlobalMaxPooling2D(
238
+ data_format=data_format, dtype=dtype
239
+ )(x)
240
+ else:
241
+ feature_map_output = x
242
+
243
+ super().__init__(
244
+ inputs=image_input,
245
+ outputs=feature_map_output,
246
+ dtype=dtype,
247
+ **kwargs,
248
+ )
249
+
250
+ # === Config ===
251
+ self.stackwise_num_filters = stackwise_num_filters
252
+ self.stackwise_num_blocks = stackwise_num_blocks
253
+ self.stackwise_num_strides = stackwise_num_strides
254
+ self.block_type = block_type
255
+ self.use_pre_activation = use_pre_activation
256
+ self.include_rescaling = include_rescaling
257
+ self.image_shape = image_shape
258
+ self.pooling = pooling
259
+ self.pyramid_outputs = pyramid_outputs
260
+
261
+ def get_config(self):
262
+ config = super().get_config()
263
+ config.update(
264
+ {
265
+ "stackwise_num_filters": self.stackwise_num_filters,
266
+ "stackwise_num_blocks": self.stackwise_num_blocks,
267
+ "stackwise_num_strides": self.stackwise_num_strides,
268
+ "block_type": self.block_type,
269
+ "use_pre_activation": self.use_pre_activation,
270
+ "include_rescaling": self.include_rescaling,
271
+ "image_shape": self.image_shape,
272
+ "pooling": self.pooling,
273
+ }
274
+ )
275
+ return config
276
+
277
+
278
+ def apply_basic_block(
279
+ x,
280
+ filters,
281
+ kernel_size=3,
282
+ stride=1,
283
+ conv_shortcut=False,
284
+ use_pre_activation=False,
285
+ data_format=None,
286
+ dtype=None,
287
+ name=None,
288
+ ):
289
+ """Applies a basic residual block.
290
+
291
+ Args:
292
+ x: Tensor. The input tensor to pass through the block.
293
+ filters: int. The number of filters in the block.
294
+ kernel_size: int. The kernel size of the bottleneck layer. Defaults to
295
+ `3`.
296
+ stride: int. The stride length of the first layer. Defaults to `1`.
297
+ conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
298
+ use an identity or pooling shortcut based on the stride. Defaults to
299
+ `False`.
300
+ use_pre_activation: boolean. Whether to use pre-activation or not.
301
+ `True` for ResNetV2, `False` for ResNet. Defaults to `False`.
302
+ data_format: `None` or str. the ordering of the dimensions in the
303
+ inputs. Can be `"channels_last"`
304
+ (`(batch_size, height, width, channels)`) or`"channels_first"`
305
+ (`(batch_size, channels, height, width)`).
306
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
307
+ to use for the models computations and weights.
308
+ name: str. A prefix for the layer names used in the block.
309
+
310
+ Returns:
311
+ The output tensor for the basic residual block.
312
+ """
313
+ data_format = data_format or keras.config.image_data_format()
314
+ bn_axis = -1 if data_format == "channels_last" else 1
315
+
316
+ x_preact = None
317
+ if use_pre_activation:
318
+ x_preact = layers.BatchNormalization(
319
+ axis=bn_axis,
320
+ epsilon=1e-5,
321
+ momentum=0.9,
322
+ dtype=dtype,
323
+ name=f"{name}_pre_activation_bn",
324
+ )(x)
325
+ x_preact = layers.Activation(
326
+ "relu", dtype=dtype, name=f"{name}_pre_activation_relu"
327
+ )(x_preact)
328
+
329
+ if conv_shortcut:
330
+ x = x_preact if x_preact is not None else x
331
+ shortcut = layers.Conv2D(
332
+ filters,
333
+ 1,
334
+ strides=stride,
335
+ data_format=data_format,
336
+ use_bias=False,
337
+ dtype=dtype,
338
+ name=f"{name}_0_conv",
339
+ )(x)
340
+ if not use_pre_activation:
341
+ shortcut = layers.BatchNormalization(
342
+ axis=bn_axis,
343
+ epsilon=1e-5,
344
+ momentum=0.9,
345
+ dtype=dtype,
346
+ name=f"{name}_0_bn",
347
+ )(shortcut)
348
+ else:
349
+ shortcut = x
350
+
351
+ x = x_preact if x_preact is not None else x
352
+ if stride > 1:
353
+ x = layers.ZeroPadding2D(
354
+ (kernel_size - 1) // 2,
355
+ data_format=data_format,
356
+ dtype=dtype,
357
+ name=f"{name}_1_pad",
358
+ )(x)
359
+ x = layers.Conv2D(
360
+ filters,
361
+ kernel_size,
362
+ strides=stride,
363
+ padding="valid" if stride > 1 else "same",
364
+ data_format=data_format,
365
+ use_bias=False,
366
+ dtype=dtype,
367
+ name=f"{name}_1_conv",
368
+ )(x)
369
+ x = layers.BatchNormalization(
370
+ axis=bn_axis,
371
+ epsilon=1e-5,
372
+ momentum=0.9,
373
+ dtype=dtype,
374
+ name=f"{name}_1_bn",
375
+ )(x)
376
+ x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x)
377
+
378
+ x = layers.Conv2D(
379
+ filters,
380
+ kernel_size,
381
+ strides=1,
382
+ padding="same",
383
+ data_format=data_format,
384
+ use_bias=False,
385
+ dtype=dtype,
386
+ name=f"{name}_2_conv",
387
+ )(x)
388
+ if not use_pre_activation:
389
+ x = layers.BatchNormalization(
390
+ axis=bn_axis,
391
+ epsilon=1e-5,
392
+ momentum=0.9,
393
+ dtype=dtype,
394
+ name=f"{name}_2_bn",
395
+ )(x)
396
+ x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x])
397
+ x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x)
398
+ else:
399
+ x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x])
400
+ return x
401
+
402
+
403
+ def apply_bottleneck_block(
404
+ x,
405
+ filters,
406
+ kernel_size=3,
407
+ stride=1,
408
+ conv_shortcut=False,
409
+ use_pre_activation=False,
410
+ data_format=None,
411
+ dtype=None,
412
+ name=None,
413
+ ):
414
+ """Applies a bottleneck residual block.
415
+
416
+ Args:
417
+ x: Tensor. The input tensor to pass through the block.
418
+ filters: int. The number of filters in the block.
419
+ kernel_size: int. The kernel size of the bottleneck layer. Defaults to
420
+ `3`.
421
+ stride: int. The stride length of the first layer. Defaults to `1`.
422
+ conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
423
+ use an identity or pooling shortcut based on the stride. Defaults to
424
+ `False`.
425
+ use_pre_activation: boolean. Whether to use pre-activation or not.
426
+ `True` for ResNetV2, `False` for ResNet. Defaults to `False`.
427
+ data_format: `None` or str. the ordering of the dimensions in the
428
+ inputs. Can be `"channels_last"`
429
+ (`(batch_size, height, width, channels)`) or`"channels_first"`
430
+ (`(batch_size, channels, height, width)`).
431
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
432
+ to use for the models computations and weights.
433
+ name: str. A prefix for the layer names used in the block.
434
+
435
+ Returns:
436
+ The output tensor for the residual block.
437
+ """
438
+ data_format = data_format or keras.config.image_data_format()
439
+ bn_axis = -1 if data_format == "channels_last" else 1
440
+
441
+ x_preact = None
442
+ if use_pre_activation:
443
+ x_preact = layers.BatchNormalization(
444
+ axis=bn_axis,
445
+ epsilon=1e-5,
446
+ momentum=0.9,
447
+ dtype=dtype,
448
+ name=f"{name}_pre_activation_bn",
449
+ )(x)
450
+ x_preact = layers.Activation(
451
+ "relu", dtype=dtype, name=f"{name}_pre_activation_relu"
452
+ )(x_preact)
453
+
454
+ if conv_shortcut:
455
+ x = x_preact if x_preact is not None else x
456
+ shortcut = layers.Conv2D(
457
+ 4 * filters,
458
+ 1,
459
+ strides=stride,
460
+ data_format=data_format,
461
+ use_bias=False,
462
+ dtype=dtype,
463
+ name=f"{name}_0_conv",
464
+ )(x)
465
+ if not use_pre_activation:
466
+ shortcut = layers.BatchNormalization(
467
+ axis=bn_axis,
468
+ epsilon=1e-5,
469
+ momentum=0.9,
470
+ dtype=dtype,
471
+ name=f"{name}_0_bn",
472
+ )(shortcut)
473
+ else:
474
+ shortcut = x
475
+
476
+ x = x_preact if x_preact is not None else x
477
+ x = layers.Conv2D(
478
+ filters,
479
+ 1,
480
+ strides=1,
481
+ data_format=data_format,
482
+ use_bias=False,
483
+ dtype=dtype,
484
+ name=f"{name}_1_conv",
485
+ )(x)
486
+ x = layers.BatchNormalization(
487
+ axis=bn_axis,
488
+ epsilon=1e-5,
489
+ momentum=0.9,
490
+ dtype=dtype,
491
+ name=f"{name}_1_bn",
492
+ )(x)
493
+ x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x)
494
+
495
+ if stride > 1:
496
+ x = layers.ZeroPadding2D(
497
+ (kernel_size - 1) // 2,
498
+ data_format=data_format,
499
+ dtype=dtype,
500
+ name=f"{name}_2_pad",
501
+ )(x)
502
+ x = layers.Conv2D(
503
+ filters,
504
+ kernel_size,
505
+ strides=stride,
506
+ padding="valid" if stride > 1 else "same",
507
+ data_format=data_format,
508
+ use_bias=False,
509
+ dtype=dtype,
510
+ name=f"{name}_2_conv",
511
+ )(x)
512
+ x = layers.BatchNormalization(
513
+ axis=bn_axis,
514
+ epsilon=1e-5,
515
+ momentum=0.9,
516
+ dtype=dtype,
517
+ name=f"{name}_2_bn",
518
+ )(x)
519
+ x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x)
520
+
521
+ x = layers.Conv2D(
522
+ 4 * filters,
523
+ 1,
524
+ data_format=data_format,
525
+ use_bias=False,
526
+ dtype=dtype,
527
+ name=f"{name}_3_conv",
528
+ )(x)
529
+ if not use_pre_activation:
530
+ x = layers.BatchNormalization(
531
+ axis=bn_axis,
532
+ epsilon=1e-5,
533
+ momentum=0.9,
534
+ dtype=dtype,
535
+ name=f"{name}_3_bn",
536
+ )(x)
537
+ x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x])
538
+ x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x)
539
+ else:
540
+ x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x])
541
+ return x
542
+
543
+
544
+ def apply_stack(
545
+ x,
546
+ filters,
547
+ blocks,
548
+ stride,
549
+ block_type,
550
+ use_pre_activation,
551
+ first_shortcut=True,
552
+ data_format=None,
553
+ dtype=None,
554
+ name=None,
555
+ ):
556
+ """Applies a set of stacked residual blocks.
557
+
558
+ Args:
559
+ x: Tensor. The input tensor to pass through the stack.
560
+ filters: int. The number of filters in a block.
561
+ blocks: int. The number of blocks in the stack.
562
+ stride: int. The stride length of the first layer in the first block.
563
+ block_type: str. The block type to stack. One of `"basic_block"` or
564
+ `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34.
565
+ Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152.
566
+ use_pre_activation: boolean. Whether to use pre-activation or not.
567
+ `True` for ResNetV2, `False` for ResNet and ResNeXt.
568
+ first_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
569
+ use an identity or pooling shortcut based on the stride. Defaults to
570
+ `True`.
571
+ data_format: `None` or str. the ordering of the dimensions in the
572
+ inputs. Can be `"channels_last"`
573
+ (`(batch_size, height, width, channels)`) or`"channels_first"`
574
+ (`(batch_size, channels, height, width)`).
575
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
576
+ to use for the models computations and weights.
577
+ name: str. A prefix for the layer names used in the stack.
578
+
579
+ Returns:
580
+ Output tensor for the stacked blocks.
581
+ """
582
+ if name is None:
583
+ version = "v1" if not use_pre_activation else "v2"
584
+ name = f"{version}_stack"
585
+
586
+ if block_type == "basic_block":
587
+ block_fn = apply_basic_block
588
+ elif block_type == "bottleneck_block":
589
+ block_fn = apply_bottleneck_block
590
+ else:
591
+ raise ValueError(
592
+ '`block_type` must be either `"basic_block"` or '
593
+ f'`"bottleneck_block"`. Received block_type={block_type}.'
594
+ )
595
+ for i in range(blocks):
596
+ if i == 0:
597
+ stride = stride
598
+ conv_shortcut = first_shortcut
599
+ else:
600
+ stride = 1
601
+ conv_shortcut = False
602
+ x = block_fn(
603
+ x,
604
+ filters,
605
+ stride=stride,
606
+ conv_shortcut=conv_shortcut,
607
+ use_pre_activation=use_pre_activation,
608
+ data_format=data_format,
609
+ dtype=dtype,
610
+ name=f"{name}_block{str(i)}",
611
+ )
612
+ return x