keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,565 @@
1
+ # Copyright 2024 The KerasCV Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import keras
17
+ from keras import ops
18
+
19
+
20
+ class MLP(keras.layers.Layer):
21
+ """A MLP block with architecture.
22
+
23
+ The MLP block implements `input_dim -> [intermediate_dim] ->
24
+ hidden_dim`. The code has been adapted from [Segment Anything paper](
25
+ https://arxiv.org/abs/2304.02643), [Segment Anything GitHub](
26
+ https://github.com/facebookresearch/segment-anything) and [Detectron2](
27
+ https://github.com/facebookresearch/detectron2).
28
+
29
+ Args:
30
+ intermediate_dim (int): The number of units in the hidden layers.
31
+ hidden_dim (int): The number of units in the output layer.
32
+ activation (str): Activation to use in the hidden layers.
33
+ Default is `"relu"`.
34
+ """
35
+
36
+ def __init__(
37
+ self, intermediate_dim, hidden_dim, activation="relu", **kwargs
38
+ ):
39
+ super().__init__(**kwargs)
40
+ self.intermediate_dim = intermediate_dim
41
+ self.hidden_dim = hidden_dim
42
+ self.activation = activation
43
+ h = [intermediate_dim]
44
+ self.dense_net = []
45
+ for intermediate_dim in h:
46
+ self.dense_net.append(keras.layers.Dense(intermediate_dim))
47
+ self.dense_net.append(keras.layers.Activation(activation))
48
+ self.dense_net.append(keras.layers.Dense(hidden_dim))
49
+ self.dense_net = keras.models.Sequential(self.dense_net)
50
+
51
+ def build(self, input_shape):
52
+ self.dense_net.build(input_shape)
53
+ self.built = True
54
+
55
+ def call(self, x):
56
+ return self.dense_net(x)
57
+
58
+ def get_config(self):
59
+ config = super().get_config()
60
+ config.update(
61
+ {
62
+ "intermediate_dim": self.intermediate_dim,
63
+ "hidden_dim": self.hidden_dim,
64
+ "activation": self.activation,
65
+ }
66
+ )
67
+ return config
68
+
69
+
70
+ class AddRelativePositionalEmbedding(keras.layers.Layer):
71
+ def __init__(self, input_size, key_dim, **kwargs):
72
+ super().__init__(**kwargs)
73
+ self.input_size = input_size
74
+ self.key_dim = key_dim
75
+ self.rel_pos_h = self.add_weight(
76
+ name="rel_pos_h",
77
+ shape=(2 * self.input_size[0] - 1, self.key_dim),
78
+ initializer="zeros",
79
+ )
80
+ self.rel_pos_w = self.add_weight(
81
+ name="rel_pos_w",
82
+ shape=(2 * self.input_size[1] - 1, self.key_dim),
83
+ initializer="zeros",
84
+ )
85
+ self.built = True
86
+
87
+ def _get_rel_pos(self, query_size, key_size, rel_pos):
88
+ """Get relative positional embeddings.
89
+
90
+ Get relative positional embeddings according to the relative positions
91
+ of query and key sizes.
92
+
93
+ Args:
94
+ query_size (int): The number of features of the queries.
95
+ key_size (int): The number of features of the keys.
96
+ rel_pos (tensor): Relative positional embedding tensor.
97
+
98
+ Returns:
99
+ tensor: Extracted positional embeddings according to relative
100
+ positions.
101
+ """
102
+ max_rel_dist = 2 * max(query_size, key_size) - 1
103
+ if ops.shape(rel_pos)[0] != max_rel_dist:
104
+ rel_pos_resized = ops.image.resize(
105
+ image=ops.reshape(
106
+ rel_pos,
107
+ (1, ops.shape(rel_pos)[0], ops.shape(rel_pos)[1], 1),
108
+ ),
109
+ size=(max_rel_dist, ops.shape(rel_pos)[1]),
110
+ interpolation="bilinear",
111
+ )
112
+ rel_pos_resized = ops.squeeze(rel_pos_resized, axis=(0, -1))
113
+ return rel_pos_resized
114
+ else:
115
+ rel_pos_resized = rel_pos
116
+ # Query coordinates
117
+ query_coordinates = ops.cast(
118
+ ops.arange(query_size), dtype=self.compute_dtype
119
+ )[:, None] * (max(key_size / query_size, 1.0))
120
+ # Key coordinates
121
+ key_coordinates = ops.cast(
122
+ ops.arange(key_size), dtype=self.compute_dtype
123
+ )[None, :] * (max(query_size / key_size, 1.0))
124
+ # Relative coordinates
125
+ relative_coordinates = (query_coordinates - key_coordinates) + (
126
+ key_size - 1
127
+ ) * max(query_size / key_size, 1.0)
128
+ relative_coordinates = ops.cast(relative_coordinates, dtype="int32")
129
+ return ops.take(rel_pos_resized, relative_coordinates, 0)
130
+
131
+ def call(self, attention_map, queries, query_size, key_size):
132
+ """Calculate decomposed Relative Positional Embeddings
133
+
134
+ The code has been adapted based on
135
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa: E501
136
+
137
+ Args:
138
+ attention_map (tensor): Attention map.
139
+ queries (tensor): Queries in the attention layer with shape
140
+ `(batch, query_height * query_width, channels)`.
141
+ query_size (tuple[int, int]): Spatial sequence size of queries with
142
+ `(query_height, query_width)`.
143
+ key_size (tuple[int, int]): Spatial sequence size of keys with
144
+ `(key_height, key_width)`.
145
+
146
+ Returns:
147
+ tensor: attention map with added relative positional embeddings.
148
+ """
149
+ query_height, query_width = query_size[0], query_size[1]
150
+ key_height, key_width = key_size[0], key_size[1]
151
+ rel_heights = self._get_rel_pos(
152
+ query_height, key_height, self.rel_pos_h
153
+ )
154
+ rel_widths = self._get_rel_pos(query_width, key_width, self.rel_pos_w)
155
+ shape = ops.shape(queries)
156
+ batch, channels = shape[0], shape[2]
157
+ rel_queries = ops.reshape(
158
+ queries, (batch, query_height, query_width, channels)
159
+ )
160
+ rel_heights = ops.einsum("bhwc,hkc->bhwk", rel_queries, rel_heights)
161
+ rel_widths = ops.einsum("bhwc,wkc->bhwk", rel_queries, rel_widths)
162
+ attention_map = ops.reshape(
163
+ attention_map,
164
+ (batch, query_height, query_width, key_height, key_width),
165
+ )
166
+ attention_map = attention_map + rel_heights[..., :, None]
167
+ attention_map = attention_map + rel_widths[..., None, :]
168
+ attention_map = ops.reshape(
169
+ attention_map,
170
+ (batch, query_height * query_width, key_height * key_width),
171
+ )
172
+ return attention_map
173
+
174
+ def get_config(self):
175
+ config = super().get_config()
176
+ config.update({"input_size": self.input_size, "key_dim": self.key_dim})
177
+ return config
178
+
179
+
180
+ class MultiHeadAttentionWithRelativePE(keras.layers.Layer):
181
+ """Multi-head Attention block with relative position embeddings.
182
+
183
+ The code has been adapted from [Segment Anything paper](
184
+ https://arxiv.org/abs/2304.02643), [Segment Anything GitHub](
185
+ https://github.com/facebookresearch/segment-anything) and [Detectron2](
186
+ https://github.com/facebookresearch/detectron2).
187
+
188
+ Args:
189
+ num_heads (int): Number of attention heads.
190
+ key_dim (int): Size of each attention head for query, key, and
191
+ value.
192
+ use_bias (bool, optional): Whether to use bias when projecting
193
+ the queries, keys, and values. Defaults to `True`.
194
+ use_rel_pos (bool, optional): Whether to use relative positional
195
+ embeddings or not. Defaults to `False`.
196
+ input_size (tuple[int, int], optional): Size of the input image.
197
+ Must be provided when using relative positional embeddings.
198
+ Defaults to `None`.
199
+
200
+ Raises:
201
+ ValueError: When `input_size = None` with `use_rel_pos = True`.
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ num_heads,
207
+ key_dim,
208
+ use_bias=True,
209
+ use_rel_pos=False,
210
+ input_size=None,
211
+ **kwargs
212
+ ):
213
+ super().__init__(**kwargs)
214
+ self.num_heads = num_heads
215
+ self.key_dim = key_dim
216
+ self.scale = self.key_dim**-0.5
217
+ self.use_bias = use_bias
218
+ self.input_size = input_size
219
+ self.use_rel_pos = use_rel_pos
220
+ self.qkv = keras.layers.Dense(
221
+ key_dim * self.num_heads * 3, use_bias=self.use_bias
222
+ )
223
+ self.projection = keras.layers.Dense(key_dim * self.num_heads)
224
+ if self.use_rel_pos:
225
+ if input_size is None:
226
+ raise ValueError(
227
+ "Input size must be provided if using relative "
228
+ "positional encoding."
229
+ )
230
+ self.add_decomposed_reative_pe = AddRelativePositionalEmbedding(
231
+ self.input_size, self.key_dim
232
+ )
233
+
234
+ def build(self, input_shape=None):
235
+ self.qkv.build([self.key_dim * self.num_heads])
236
+ self.projection.build([self.key_dim * self.num_heads])
237
+ self.built = True
238
+
239
+ def compute_output_shape(self, input_shape):
240
+ return input_shape
241
+
242
+ def call(self, x):
243
+ batch, height, width, channels = ops.shape(x)
244
+ qkv = ops.transpose(
245
+ ops.reshape(
246
+ self.qkv(x),
247
+ (batch, height * width, 3, self.num_heads, self.key_dim),
248
+ ),
249
+ axes=(2, 0, 3, 1, 4),
250
+ )
251
+ qkv = ops.reshape(
252
+ qkv, (3, batch * self.num_heads, height * width, self.key_dim)
253
+ )
254
+ queries, keys, values = ops.unstack(qkv, axis=0)
255
+ attention_map = (queries * self.scale) @ ops.transpose(
256
+ keys, axes=(0, 2, 1)
257
+ )
258
+ if self.use_rel_pos:
259
+ attention_map = self.add_decomposed_reative_pe(
260
+ attention_map,
261
+ queries=queries,
262
+ query_size=(height, width),
263
+ key_size=(height, width),
264
+ )
265
+ attention_map = ops.softmax(attention_map, axis=-1)
266
+ x = ops.reshape(
267
+ attention_map @ values,
268
+ (batch, self.num_heads, height, width, self.key_dim),
269
+ )
270
+ x = ops.transpose(x, axes=(0, 2, 3, 1, 4))
271
+ x = ops.reshape(x, (batch, height, width, channels))
272
+ x = self.projection(x)
273
+
274
+ return x
275
+
276
+ def get_config(self):
277
+ config = super().get_config()
278
+ config.update(
279
+ {
280
+ "num_heads": self.num_heads,
281
+ "key_dim": self.key_dim,
282
+ "use_bias": self.use_bias,
283
+ "use_rel_pos": self.use_rel_pos,
284
+ "input_size": self.input_size,
285
+ }
286
+ )
287
+ return config
288
+
289
+
290
+ class WindowPartitioning(keras.layers.Layer):
291
+ def __init__(self, window_size, **kwargs):
292
+ super().__init__(**kwargs)
293
+ self.window_size = window_size
294
+ self.built = True
295
+
296
+ def partition(self, x):
297
+ batch, height, width, channels = ops.shape(x)
298
+ pad_height = (
299
+ self.window_size - height % self.window_size
300
+ ) % self.window_size
301
+ pad_width = (
302
+ self.window_size - width % self.window_size
303
+ ) % self.window_size
304
+ if pad_height > 0 or pad_width > 0:
305
+ x = ops.pad(x, ((0, 0), (0, pad_height), (0, pad_width), (0, 0)))
306
+ height_padded, width_padded = height + pad_height, width + pad_width
307
+ x = ops.reshape(
308
+ x,
309
+ (
310
+ batch,
311
+ height_padded // self.window_size,
312
+ self.window_size,
313
+ width_padded // self.window_size,
314
+ self.window_size,
315
+ channels,
316
+ ),
317
+ )
318
+ windows = ops.reshape(
319
+ ops.transpose(x, axes=(0, 1, 3, 2, 4, 5)),
320
+ (-1, self.window_size, self.window_size, channels),
321
+ )
322
+ return windows, (height_padded, width_padded)
323
+
324
+ def unpartition(self, windows, height_width_padded, height_width):
325
+ height_padded, width_padded = height_width_padded
326
+ height, width = height_width
327
+ batch = ops.shape(windows)[0] // (
328
+ (height_padded // self.window_size)
329
+ * (width_padded // self.window_size)
330
+ )
331
+ x = ops.reshape(
332
+ windows,
333
+ (
334
+ batch,
335
+ height_padded // self.window_size,
336
+ width_padded // self.window_size,
337
+ self.window_size,
338
+ self.window_size,
339
+ -1,
340
+ ),
341
+ )
342
+ x = ops.reshape(
343
+ ops.transpose(x, axes=(0, 1, 3, 2, 4, 5)),
344
+ (batch, height_padded, width_padded, -1),
345
+ )
346
+ return x[:, :height, :width, :]
347
+
348
+ def get_config(self):
349
+ config = super().get_config()
350
+ config.update({"window_size": self.window_size})
351
+ return config
352
+
353
+
354
+ class WindowedTransformerEncoder(keras.layers.Layer):
355
+ """Implements windowed transformer encoder.
356
+
357
+ Transformer blocks with support of window attention and residual
358
+ propagation blocks. The code has been adapted from [Segment Anything paper](
359
+ https://arxiv.org/abs/2304.02643), [Segment Anything GitHub](
360
+ https://github.com/facebookresearch/segment-anything) and [Detectron2](
361
+ https://github.com/facebookresearch/detectron2).
362
+
363
+ Args:
364
+ project_dim (int): the dimensionality of the projection of the
365
+ encoder, and output of the `MultiHeadAttention`.
366
+ intermediate_dim (int): the intermediate dimensionality of the MLP head
367
+ before projecting to `project_dim`.
368
+ num_heads (int): the number of heads for the `MultiHeadAttention`
369
+ layer.
370
+ use_bias (bool, optional): Whether to use bias to project the keys,
371
+ queries, and values in the attention layer. Defaults to `True`.
372
+ use_rel_pos (bool, optional): Whether to use relative positional
373
+ emcodings in the attention layer. Defaults to `False`.
374
+ window_size (int, optional): Window size for windowed attention.
375
+ Defaults to `0`.
376
+ input_size (tuple[int, int], optional): Height and width of the input
377
+ image as a tuple of integers. Must be provided when using relative
378
+ positional embeddings. Defaults to `None`.
379
+ activation (str, optional): the activation function to apply in the
380
+ MLP head - should be a function. Defaults to `"gelu"`.
381
+ layer_norm_epsilon (float, optional): The epsilon to use in the layer
382
+ normalization layers. Defaults to `1e-6`.
383
+ """
384
+
385
+ def __init__(
386
+ self,
387
+ project_dim,
388
+ intermediate_dim,
389
+ num_heads,
390
+ use_bias=True,
391
+ use_rel_pos=False,
392
+ window_size=0,
393
+ input_size=None,
394
+ activation="gelu",
395
+ layer_norm_epsilon=1e-6,
396
+ **kwargs
397
+ ):
398
+ super().__init__(**kwargs)
399
+ self.project_dim = project_dim
400
+ self.intermediate_dim = intermediate_dim
401
+ self.num_heads = num_heads
402
+ self.use_bias = use_bias
403
+ self.input_size = input_size
404
+ self.activation = activation
405
+ self.layer_norm_epsilon = layer_norm_epsilon
406
+ self.window_size = window_size
407
+ self.use_rel_pos = use_rel_pos
408
+
409
+ self.layer_norm1 = keras.layers.LayerNormalization(
410
+ epsilon=self.layer_norm_epsilon
411
+ )
412
+ self.layer_norm2 = keras.layers.LayerNormalization(
413
+ epsilon=self.layer_norm_epsilon
414
+ )
415
+ self.attention = MultiHeadAttentionWithRelativePE(
416
+ num_heads=self.num_heads,
417
+ key_dim=self.project_dim // self.num_heads,
418
+ use_bias=use_bias,
419
+ use_rel_pos=use_rel_pos,
420
+ input_size=(
421
+ input_size if window_size == 0 else (window_size, window_size)
422
+ ),
423
+ )
424
+ self.mlp_block = MLP(
425
+ intermediate_dim,
426
+ project_dim,
427
+ activation="gelu",
428
+ )
429
+ self.window_partitioning = WindowPartitioning(window_size)
430
+
431
+ def build(self, input_shape=None):
432
+ self.layer_norm1.build([None, None, None, self.project_dim])
433
+ self.layer_norm2.build([None, None, None, self.project_dim])
434
+ self.attention.build()
435
+ self.mlp_block.build([None, None, None, self.project_dim])
436
+ self.built = True
437
+
438
+ def compute_output_shape(self, input_shape):
439
+ return input_shape
440
+
441
+ def call(self, x):
442
+ shortcut = x
443
+ x = self.layer_norm1(x)
444
+ # Window Partition
445
+ if self.window_size > 0:
446
+ height, width = ops.shape(x)[1], ops.shape(x)[2]
447
+ x, height_width_padded = self.window_partitioning.partition(x)
448
+
449
+ x = self.attention(x)
450
+ # Reverse Window Partition
451
+ if self.window_size > 0:
452
+ x = self.window_partitioning.unpartition(
453
+ x,
454
+ height_width_padded=height_width_padded,
455
+ height_width=(height, width),
456
+ )
457
+ x = shortcut + x
458
+ x = x + self.mlp_block(self.layer_norm2(x))
459
+ return x
460
+
461
+ def get_config(self):
462
+ config = super().get_config()
463
+ config.update(
464
+ {
465
+ "project_dim": self.project_dim,
466
+ "intermediate_dim": self.intermediate_dim,
467
+ "num_heads": self.num_heads,
468
+ "use_bias": self.use_bias,
469
+ "use_rel_pos": self.use_rel_pos,
470
+ "window_size": self.window_size,
471
+ "input_size": self.input_size,
472
+ "activation": self.activation,
473
+ "layer_norm_epsilon": self.layer_norm_epsilon,
474
+ }
475
+ )
476
+ return config
477
+
478
+
479
+ class ViTDetPatchingAndEmbedding(keras.layers.Layer):
480
+ """
481
+ Implements a image patch and embedding layer.
482
+
483
+ Image to Patch Embedding using only a conv layer (without
484
+ layer normalization).The code has been adapted from [Segment Anything
485
+ paper](https://arxiv.org/abs/2304.02643), [Segment Anything GitHub](
486
+ https://github.com/facebookresearch/segment-anything) and [Detectron2](
487
+ https://github.com/facebookresearch/detectron2).
488
+
489
+ Args:
490
+ kernel_size (tuple[int, int], optional): Kernel size of the
491
+ projection layer. Defaults to `(16, 16)`.
492
+ strides (tuple, optional): Strides of the projection layer.
493
+ Defaults to `(16, 16)`.
494
+ embed_dim (int, optional): Number of filters to use in the
495
+ projection layer i.e. projection size. Defaults to `768`.
496
+ """
497
+
498
+ def __init__(
499
+ self, kernel_size=(16, 16), strides=(16, 16), embed_dim=768, **kwargs
500
+ ):
501
+ super().__init__(**kwargs)
502
+
503
+ self.projection = keras.layers.Conv2D(
504
+ embed_dim, kernel_size=kernel_size, strides=strides
505
+ )
506
+ self.kernel_size = kernel_size
507
+ self.strides = strides
508
+ self.embed_dim = embed_dim
509
+
510
+ def build(self, input_shape):
511
+ self.projection.build(input_shape)
512
+ self.built = True
513
+
514
+ def compute_output_shape(self, input_shape):
515
+ return self.projection.compute_output_shape(input_shape)
516
+
517
+ def call(self, x):
518
+ x = self.projection(x)
519
+ return x
520
+
521
+ def get_config(self):
522
+ config = super().get_config()
523
+ config.update(
524
+ {
525
+ "kernel_size": self.kernel_size,
526
+ "strides": self.strides,
527
+ "embed_dim": self.embed_dim,
528
+ }
529
+ )
530
+ return config
531
+
532
+
533
+ class AddPositionalEmbedding(keras.layers.Layer):
534
+ def __init__(self, img_size, patch_size, embed_dim, **kwargs):
535
+ super().__init__(**kwargs)
536
+ self.img_size = img_size
537
+ self.patch_size = patch_size
538
+ self.embed_dim = embed_dim
539
+ self.pos_embed = self.add_weight(
540
+ name="pos_embed",
541
+ shape=(
542
+ 1,
543
+ img_size // patch_size,
544
+ img_size // patch_size,
545
+ embed_dim,
546
+ ),
547
+ initializer="zeros",
548
+ )
549
+
550
+ def compute_output_shape(self, input_shape):
551
+ return input_shape
552
+
553
+ def call(self, x):
554
+ return x + self.pos_embed
555
+
556
+ def get_confg(self):
557
+ config = super().get_config()
558
+ config.update(
559
+ {
560
+ "img_size": self.img_size,
561
+ "patch_size": self.patch_size,
562
+ "embed_dim": self.embed_dim,
563
+ }
564
+ )
565
+ return config
@@ -0,0 +1,20 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone
16
+ from keras_hub.src.models.whisper.whisper_presets import backbone_presets
17
+ from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
18
+ from keras_hub.src.utils.preset_utils import register_presets
19
+
20
+ register_presets(backbone_presets, (WhisperBackbone, WhisperTokenizer))