keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,181 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+ import numpy as np
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
20
+ from keras_hub.src.models.mix_transformer.mix_transformer_layers import (
21
+ HierarchicalTransformerEncoder,
22
+ )
23
+ from keras_hub.src.models.mix_transformer.mix_transformer_layers import (
24
+ OverlappingPatchingAndEmbedding,
25
+ )
26
+
27
+
28
+ @keras_hub_export("keras_hub.models.MiTBackbone")
29
+ class MiTBackbone(FeaturePyramidBackbone):
30
+ def __init__(
31
+ self,
32
+ depths,
33
+ num_layers,
34
+ blockwise_num_heads,
35
+ blockwise_sr_ratios,
36
+ end_value,
37
+ patch_sizes,
38
+ strides,
39
+ include_rescaling=True,
40
+ image_shape=(224, 224, 3),
41
+ hidden_dims=None,
42
+ **kwargs,
43
+ ):
44
+ """A Backbone implementing the MixTransformer.
45
+
46
+ This architecture to be used as a backbone for the SegFormer
47
+ architecture [SegFormer: Simple and Efficient Design for Semantic
48
+ Segmentation with Transformers](https://arxiv.org/abs/2105.15203)
49
+ [Based on the TensorFlow implementation from DeepVision](
50
+ https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer)
51
+
52
+ Args:
53
+ depths: The number of transformer encoders to be used per layer in the
54
+ network.
55
+ num_layers: int. The number of Transformer layers.
56
+ blockwise_num_heads: list of integers, the number of heads to use
57
+ in the attention computation for each layer.
58
+ blockwise_sr_ratios: list of integers, the sequence reduction
59
+ ratio to perform for each layer on the sequence before key and
60
+ value projections. If set to > 1, a `Conv2D` layer is used to
61
+ reduce the length of the sequence.
62
+ end_value: The end value of the sequence.
63
+ include_rescaling: bool, whether to rescale the inputs. If set
64
+ to `True`, inputs will be passed through a `Rescaling(1/255.0)`
65
+ layer. Defaults to `True`.
66
+ image_shape: optional shape tuple, defaults to (224, 224, 3).
67
+ hidden_dims: the embedding dims per hierarchical layer, used as
68
+ the levels of the feature pyramid.
69
+ patch_sizes: list of integers, the patch_size to apply for each layer.
70
+ strides: list of integers, stride to apply for each layer.
71
+
72
+ Examples:
73
+
74
+ Using the class with a `backbone`:
75
+
76
+ ```python
77
+ images = np.ones(shape=(1, 96, 96, 3))
78
+ labels = np.zeros(shape=(1, 96, 96, 1))
79
+ backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_imagenet")
80
+
81
+ # Evaluate model
82
+ model(images)
83
+
84
+ # Train model
85
+ model.compile(
86
+ optimizer="adam",
87
+ loss=keras.losses.BinaryCrossentropy(from_logits=False),
88
+ metrics=["accuracy"],
89
+ )
90
+ model.fit(images, labels, epochs=3)
91
+ ```
92
+ """
93
+ dpr = [x for x in np.linspace(0.0, end_value, sum(depths))]
94
+
95
+ # === Layers ===
96
+ cur = 0
97
+ patch_embedding_layers = []
98
+ transformer_blocks = []
99
+ layer_norms = []
100
+
101
+ for i in range(num_layers):
102
+ patch_embed_layer = OverlappingPatchingAndEmbedding(
103
+ project_dim=hidden_dims[i],
104
+ patch_size=patch_sizes[i],
105
+ stride=strides[i],
106
+ name=f"patch_and_embed_{i}",
107
+ )
108
+ patch_embedding_layers.append(patch_embed_layer)
109
+
110
+ transformer_block = [
111
+ HierarchicalTransformerEncoder(
112
+ project_dim=hidden_dims[i],
113
+ num_heads=blockwise_num_heads[i],
114
+ sr_ratio=blockwise_sr_ratios[i],
115
+ drop_prob=dpr[cur + k],
116
+ name=f"hierarchical_encoder_{i}_{k}",
117
+ )
118
+ for k in range(depths[i])
119
+ ]
120
+ transformer_blocks.append(transformer_block)
121
+ cur += depths[i]
122
+ layer_norms.append(keras.layers.LayerNormalization())
123
+
124
+ # === Functional Model ===
125
+ image_input = keras.layers.Input(shape=image_shape)
126
+ x = image_input
127
+
128
+ if include_rescaling:
129
+ x = keras.layers.Rescaling(scale=1 / 255)(x)
130
+
131
+ pyramid_outputs = {}
132
+ for i in range(num_layers):
133
+ # Compute new height/width after the `proj`
134
+ # call in `OverlappingPatchingAndEmbedding`
135
+ stride = strides[i]
136
+ new_height, new_width = (
137
+ int(ops.shape(x)[1] / stride),
138
+ int(ops.shape(x)[2] / stride),
139
+ )
140
+
141
+ x = patch_embedding_layers[i](x)
142
+ for blk in transformer_blocks[i]:
143
+ x = blk(x)
144
+ x = layer_norms[i](x)
145
+ x = keras.layers.Reshape(
146
+ (new_height, new_width, -1), name=f"output_level_{i}"
147
+ )(x)
148
+ pyramid_outputs[f"P{i + 1}"] = x
149
+
150
+ super().__init__(inputs=image_input, outputs=x, **kwargs)
151
+
152
+ # === Config ===
153
+ self.depths = depths
154
+ self.include_rescaling = include_rescaling
155
+ self.image_shape = image_shape
156
+ self.hidden_dims = hidden_dims
157
+ self.pyramid_outputs = pyramid_outputs
158
+ self.num_layers = num_layers
159
+ self.blockwise_num_heads = blockwise_num_heads
160
+ self.blockwise_sr_ratios = blockwise_sr_ratios
161
+ self.end_value = end_value
162
+ self.patch_sizes = patch_sizes
163
+ self.strides = strides
164
+
165
+ def get_config(self):
166
+ config = super().get_config()
167
+ config.update(
168
+ {
169
+ "depths": self.depths,
170
+ "include_rescaling": self.include_rescaling,
171
+ "hidden_dims": self.hidden_dims,
172
+ "image_shape": self.image_shape,
173
+ "num_layers": self.num_layers,
174
+ "blockwise_num_heads": self.blockwise_num_heads,
175
+ "blockwise_sr_ratios": self.blockwise_sr_ratios,
176
+ "end_value": self.end_value,
177
+ "patch_sizes": self.patch_sizes,
178
+ "strides": self.strides,
179
+ }
180
+ )
181
+ return config
@@ -0,0 +1,133 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.image_classifier import ImageClassifier
18
+ from keras_hub.src.models.mix_transformer.mix_transformer_backbone import (
19
+ MiTBackbone,
20
+ )
21
+
22
+
23
+ @keras_hub_export("keras_hub.models.MiTImageClassifier")
24
+ class MiTImageClassifier(ImageClassifier):
25
+ """MiTImageClassifier image classifier model.
26
+
27
+ Args:
28
+ backbone: A `keras_hub.models.MiTBackbone` instance.
29
+ num_classes: int. The number of classes to predict.
30
+ activation: `None`, str or callable. The activation function to use on
31
+ the `Dense` layer. Set `activation=None` to return the output
32
+ logits. Defaults to `"softmax"`.
33
+
34
+ To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
35
+ where `x` is a tensor and `y` is a integer from `[0, num_classes)`.
36
+ All `ImageClassifier` tasks include a `from_preset()` constructor which can
37
+ be used to load a pre-trained config and weights.
38
+
39
+ Examples:
40
+
41
+ Call `predict()` to run inference.
42
+ ```python
43
+ # Load preset and train
44
+ images = np.ones((2, 224, 224, 3), dtype="float32")
45
+ classifier = keras_hub.models.MiTImageClassifier.from_preset(
46
+ "mit_b0_imagenet")
47
+ classifier.predict(images)
48
+ ```
49
+
50
+ Call `fit()` on a single batch.
51
+ ```python
52
+ # Load preset and train
53
+ images = np.ones((2, 224, 224, 3), dtype="float32")
54
+ labels = [0, 3]
55
+ classifier = keras_hub.models.MixTransformerImageClassifier.from_preset(
56
+ "mit_b0_imagenet")
57
+ classifier.fit(x=images, y=labels, batch_size=2)
58
+ ```
59
+
60
+ Call `fit()` with custom loss, optimizer and backbone.
61
+ ```python
62
+ classifier = keras_hub.models.MiTImageClassifier.from_preset(
63
+ "mit_b0_imagenet")
64
+ classifier.compile(
65
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
66
+ optimizer=keras.optimizers.Adam(5e-5),
67
+ )
68
+ classifier.backbone.trainable = False
69
+ classifier.fit(x=images, y=labels, batch_size=2)
70
+ ```
71
+
72
+ Custom backbone.
73
+ ```python
74
+ images = np.ones((2, 224, 224, 3), dtype="float32")
75
+ labels = [0, 3]
76
+ backbone = keras_hub.models.MiTBackbone(
77
+ stackwise_num_filters=[128, 256, 512, 1024],
78
+ stackwise_depth=[3, 9, 9, 3],
79
+ include_rescaling=False,
80
+ block_type="basic_block",
81
+ image_shape = (224, 224, 3),
82
+ )
83
+ classifier = keras_hub.models.MiTImageClassifier(
84
+ backbone=backbone,
85
+ num_classes=4,
86
+ )
87
+ classifier.fit(x=images, y=labels, batch_size=2)
88
+ ```
89
+ """
90
+
91
+ backbone_cls = MiTBackbone
92
+
93
+ def __init__(
94
+ self,
95
+ backbone,
96
+ num_classes,
97
+ activation="softmax",
98
+ preprocessor=None, # adding this dummy arg for saved model test
99
+ # TODO: once preprocessor flow is figured out, this needs to be updated
100
+ **kwargs,
101
+ ):
102
+ # === Layers ===
103
+ self.backbone = backbone
104
+ self.output_dense = keras.layers.Dense(
105
+ num_classes,
106
+ activation=activation,
107
+ name="predictions",
108
+ )
109
+
110
+ # === Functional Model ===
111
+ inputs = self.backbone.input
112
+ x = self.backbone(inputs)
113
+ outputs = self.output_dense(x)
114
+ super().__init__(
115
+ inputs=inputs,
116
+ outputs=outputs,
117
+ **kwargs,
118
+ )
119
+
120
+ # === Config ===
121
+ self.num_classes = num_classes
122
+ self.activation = activation
123
+
124
+ def get_config(self):
125
+ # Backbone serialized in `super`
126
+ config = super().get_config()
127
+ config.update(
128
+ {
129
+ "num_classes": self.num_classes,
130
+ "activation": self.activation,
131
+ }
132
+ )
133
+ return config
@@ -0,0 +1,300 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ import keras
17
+ from keras import ops
18
+ from keras import random
19
+
20
+
21
+ class OverlappingPatchingAndEmbedding(keras.layers.Layer):
22
+ def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs):
23
+ """Overlapping Patching and Embedding layer.
24
+
25
+ Differs from `PatchingAndEmbedding` in that the patch size does not
26
+ affect the sequence length. It's fully derived from the `stride`
27
+ parameter. Additionally, no positional embedding is done
28
+ as part of the layer - only a projection using a `Conv2D` layer.
29
+
30
+ Args:
31
+ project_dim: integer, the dimensionality of the projection.
32
+ Defaults to `32`.
33
+ patch_size: integer, the size of the patches to encode.
34
+ Defaults to `7`.
35
+ stride: integer, the stride to use for the patching before
36
+ projection. Defaults to `5`.
37
+ """
38
+ super().__init__(**kwargs)
39
+
40
+ self.project_dim = project_dim
41
+ self.patch_size = patch_size
42
+ self.stride = stride
43
+
44
+ self.proj = keras.layers.Conv2D(
45
+ filters=project_dim,
46
+ kernel_size=patch_size,
47
+ strides=stride,
48
+ padding="same",
49
+ )
50
+ self.norm = keras.layers.LayerNormalization()
51
+
52
+ def call(self, x):
53
+ x = self.proj(x)
54
+ # B, H, W, C
55
+ shape = x.shape
56
+ x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3]))
57
+ x = self.norm(x)
58
+ return x
59
+
60
+ def get_config(self):
61
+ config = super().get_config()
62
+ config.update(
63
+ {
64
+ "project_dim": self.project_dim,
65
+ "patch_size": self.patch_size,
66
+ "stride": self.stride,
67
+ }
68
+ )
69
+ return config
70
+
71
+
72
+ class HierarchicalTransformerEncoder(keras.layers.Layer):
73
+ """Hierarchical transformer encoder block implementation as a Keras Layer.
74
+
75
+ The layer uses `SegFormerMultiheadAttention` as a `MultiHeadAttention`
76
+ alternative for computational efficiency, and is meant to be used
77
+ within the SegFormer architecture.
78
+
79
+ Args:
80
+ project_dim: integer, the dimensionality of the projection of the
81
+ encoder, and output of the `SegFormerMultiheadAttention` layer.
82
+ Due to the residual addition the input dimensionality has to be
83
+ equal to the output dimensionality.
84
+ num_heads: integer, the number of heads for the
85
+ `SegFormerMultiheadAttention` layer.
86
+ drop_prob: float, the probability of dropping a random
87
+ sample using the `DropPath` layer. Defaults to `0.0`.
88
+ layer_norm_epsilon: float, the epsilon for
89
+ `LayerNormalization` layers. Defaults to `1e-06`
90
+ sr_ratio: integer, the ratio to use within
91
+ `SegFormerMultiheadAttention`. If set to > 1, a `Conv2D`
92
+ layer is used to reduce the length of the sequence. Defaults to `1`.
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ project_dim,
98
+ num_heads,
99
+ sr_ratio=1,
100
+ drop_prob=0.0,
101
+ layer_norm_epsilon=1e-6,
102
+ **kwargs,
103
+ ):
104
+ super().__init__(**kwargs)
105
+ self.project_dim = project_dim
106
+ self.num_heads = num_heads
107
+ self.drop_prop = drop_prob
108
+
109
+ self.norm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon)
110
+ self.attn = SegFormerMultiheadAttention(
111
+ project_dim, num_heads, sr_ratio
112
+ )
113
+ self.drop_path = DropPath(drop_prob)
114
+ self.norm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon)
115
+ self.mlp = MixFFN(
116
+ channels=project_dim,
117
+ mid_channels=int(project_dim * 4),
118
+ )
119
+
120
+ def build(self, input_shape):
121
+ super().build(input_shape)
122
+ self.H = ops.sqrt(ops.cast(input_shape[1], "float32"))
123
+ self.W = ops.sqrt(ops.cast(input_shape[2], "float32"))
124
+
125
+ def call(self, x):
126
+ x = x + self.drop_path(self.attn(self.norm1(x)))
127
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
128
+ return x
129
+
130
+ def get_config(self):
131
+ config = super().get_config()
132
+ config.update(
133
+ {
134
+ "mlp": keras.saving.serialize_keras_object(self.mlp),
135
+ "project_dim": self.project_dim,
136
+ "num_heads": self.num_heads,
137
+ "drop_prop": self.drop_prop,
138
+ }
139
+ )
140
+ return config
141
+
142
+
143
+ class MixFFN(keras.layers.Layer):
144
+ def __init__(self, channels, mid_channels):
145
+ super().__init__()
146
+ self.fc1 = keras.layers.Dense(mid_channels)
147
+ self.dwconv = keras.layers.DepthwiseConv2D(
148
+ kernel_size=3,
149
+ strides=1,
150
+ padding="same",
151
+ )
152
+ self.fc2 = keras.layers.Dense(channels)
153
+
154
+ def call(self, x):
155
+ x = self.fc1(x)
156
+ shape = ops.shape(x)
157
+ H, W = int(math.sqrt(shape[1])), int(math.sqrt(shape[1]))
158
+ B, C = shape[0], shape[2]
159
+ x = ops.reshape(x, (B, H, W, C))
160
+ x = self.dwconv(x)
161
+ x = ops.reshape(x, (B, -1, C))
162
+ x = ops.nn.gelu(x)
163
+ x = self.fc2(x)
164
+ return x
165
+
166
+
167
+ class SegFormerMultiheadAttention(keras.layers.Layer):
168
+ def __init__(self, project_dim, num_heads, sr_ratio):
169
+ """Efficient MultiHeadAttention implementation as a Keras layer.
170
+
171
+ A huge bottleneck in scaling transformers is the self-attention layer
172
+ with an O(n^2) complexity.
173
+
174
+ SegFormerMultiheadAttention performs a sequence reduction (SR) operation
175
+ with a given ratio, to reduce the sequence length before performing key
176
+ and value projections, reducing the O(n^2) complexity to O(n^2/R) where
177
+ R is the sequence reduction ratio.
178
+
179
+ Args:
180
+ project_dim: integer, the dimensionality of the projection
181
+ of the `SegFormerMultiheadAttention` layer.
182
+ num_heads: integer, the number of heads to use in the
183
+ attention computation.
184
+ sr_ratio: integer, the sequence reduction ratio to perform
185
+ on the sequence before key and value projections.
186
+ """
187
+ super().__init__()
188
+ self.num_heads = num_heads
189
+ self.sr_ratio = sr_ratio
190
+ self.scale = (project_dim // num_heads) ** -0.5
191
+ self.q = keras.layers.Dense(project_dim)
192
+ self.k = keras.layers.Dense(project_dim)
193
+ self.v = keras.layers.Dense(project_dim)
194
+ self.proj = keras.layers.Dense(project_dim)
195
+
196
+ if sr_ratio > 1:
197
+ self.sr = keras.layers.Conv2D(
198
+ filters=project_dim,
199
+ kernel_size=sr_ratio,
200
+ strides=sr_ratio,
201
+ padding="same",
202
+ )
203
+ self.norm = keras.layers.LayerNormalization()
204
+
205
+ def call(self, x):
206
+ input_shape = ops.shape(x)
207
+ H, W = int(math.sqrt(input_shape[1])), int(math.sqrt(input_shape[1]))
208
+ B, C = input_shape[0], input_shape[2]
209
+
210
+ q = self.q(x)
211
+ q = ops.reshape(
212
+ q,
213
+ (
214
+ input_shape[0],
215
+ input_shape[1],
216
+ self.num_heads,
217
+ input_shape[2] // self.num_heads,
218
+ ),
219
+ )
220
+ q = ops.transpose(q, [0, 2, 1, 3])
221
+
222
+ if self.sr_ratio > 1:
223
+ x = ops.reshape(
224
+ ops.transpose(x, [0, 2, 1]),
225
+ (B, H, W, C),
226
+ )
227
+ x = self.sr(x)
228
+ x = ops.reshape(x, [input_shape[0], input_shape[2], -1])
229
+ x = ops.transpose(x, [0, 2, 1])
230
+ x = self.norm(x)
231
+
232
+ k = self.k(x)
233
+ v = self.v(x)
234
+
235
+ k = ops.transpose(
236
+ ops.reshape(
237
+ k,
238
+ [B, -1, self.num_heads, C // self.num_heads],
239
+ ),
240
+ [0, 2, 1, 3],
241
+ )
242
+
243
+ v = ops.transpose(
244
+ ops.reshape(
245
+ v,
246
+ [B, -1, self.num_heads, C // self.num_heads],
247
+ ),
248
+ [0, 2, 1, 3],
249
+ )
250
+
251
+ attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale
252
+ attn = ops.nn.softmax(attn, axis=-1)
253
+
254
+ attn = attn @ v
255
+ attn = ops.reshape(
256
+ ops.transpose(attn, [0, 2, 1, 3]),
257
+ [input_shape[0], input_shape[1], input_shape[2]],
258
+ )
259
+
260
+ x = self.proj(attn)
261
+ return x
262
+
263
+
264
+ class DropPath(keras.layers.Layer):
265
+ """Implements the DropPath layer.
266
+
267
+ DropPath randomly drops samples during
268
+ training with a probability of `rate`. Note that this layer drops individual
269
+ samples within a batch and not the entire batch, whereas StochasticDepth
270
+ randomly drops the entire batch.
271
+
272
+ Args:
273
+ rate: float, the probability of the residual branch being dropped.
274
+ seed: (Optional) integer. Used to create a random seed.
275
+ """
276
+
277
+ def __init__(self, rate=0.5, seed=None, **kwargs):
278
+ super().__init__(**kwargs)
279
+ self.rate = rate
280
+ self._seed_val = seed
281
+ self.seed = random.SeedGenerator(seed=seed)
282
+
283
+ def call(self, x, training=None):
284
+ if self.rate == 0.0 or not training:
285
+ return x
286
+ else:
287
+ batch_size = x.shape[0] or ops.shape(x)[0]
288
+ drop_map_shape = (batch_size,) + (1,) * (len(x.shape) - 1)
289
+ drop_map = ops.cast(
290
+ random.uniform(drop_map_shape, seed=self.seed) > self.rate,
291
+ x.dtype,
292
+ )
293
+ x = x / (1.0 - self.rate)
294
+ x = x * drop_map
295
+ return x
296
+
297
+ def get_config(self):
298
+ config = super().get_config()
299
+ config.update({"rate": self.rate, "seed": self._seed_val})
300
+ return config
@@ -0,0 +1,20 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.models.opt.opt_backbone import OPTBackbone
16
+ from keras_hub.src.models.opt.opt_presets import backbone_presets
17
+ from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer
18
+ from keras_hub.src.utils.preset_utils import register_presets
19
+
20
+ register_presets(backbone_presets, (OPTBackbone, OPTTokenizer))