keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,260 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import numpy as np
17
+
18
+ try:
19
+ import tensorflow as tf
20
+ except ImportError:
21
+ raise ImportError(
22
+ "To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
23
+ "The TensorFlow package is required for data preprocessing with any backend."
24
+ )
25
+
26
+ from keras_hub.src.api_export import keras_hub_export
27
+ from keras_hub.src.layers.preprocessing.preprocessing_layer import (
28
+ PreprocessingLayer,
29
+ )
30
+
31
+
32
+ @keras_hub_export("keras_hub.models.WhisperAudioFeatureExtractor")
33
+ class WhisperAudioFeatureExtractor(PreprocessingLayer):
34
+ """
35
+ Whisper audio feature extractor layer.
36
+
37
+ This layer takes in a batch of audio tensors, and computes the log-mel
38
+ spectrogram features for each audio tensor.
39
+
40
+ The input audio tensor can either be of shape `(length_of_audio,)` or
41
+ `(batch_size, length_of_audio)`. The output is a tensor of shape
42
+ `(batch_size, num_frames, num_mels)`, where `num_frames` is
43
+ `(max_audio_length * sampling_rate) / stride`.
44
+
45
+ Args:
46
+ num_mels: int. The number of mel-frequency filters. Defaults to `80`.
47
+ num_fft_bins: int. The size of the Fourier Transform in STFT.
48
+ Defaults to `400`.
49
+ stride: int. The distance between neighboring
50
+ sliding window frames while computing STFT.
51
+ Defaults to `160`.
52
+ sampling_rate: int. The sample rate of the audio. Defaults to `16000`.
53
+ max_audio_length: int. The length of each audio chunk in
54
+ seconds. The input audio tensor will be padded/trimmed to
55
+ `max_audio_length * sampling_rate`. Defaults to `30`.
56
+
57
+ Examples:
58
+
59
+ ```python
60
+ audio_tensor = tf.ones((8000,), dtype="float32")
61
+
62
+ # Compute the log-mel spectrogram.
63
+ whisper_audio_feature_extractor = keras_hub.models.WhisperAudioFeatureExtractor()
64
+ whisper_audio_feature_extractor(audio_tensor)
65
+
66
+ # Compute the log-mel spectrogram for a batch of audio tensors.
67
+ audio_tensor_1 = tf.ones((8000,), dtype="float32")
68
+ audio_tensor_2 = tf.ones((10000,), dtype="float32"
69
+ audio_tensor = tf.ragged.stack([audio_tensor_1, audio_tensor_2], axis=0)
70
+ whisper_audio_feature_extractor(audio_tensor)
71
+ ```
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ num_mels=80,
77
+ num_fft_bins=400,
78
+ stride=160,
79
+ sampling_rate=16000,
80
+ max_audio_length=30,
81
+ **kwargs,
82
+ ):
83
+ super().__init__(**kwargs)
84
+
85
+ self._convert_input_args = False
86
+ self._allow_non_tensor_positional_args = True
87
+ self.built = True
88
+
89
+ self.num_mels = num_mels
90
+ self.num_fft_bins = num_fft_bins
91
+ self.stride = stride
92
+ self.sampling_rate = sampling_rate
93
+ self.max_audio_length = max_audio_length
94
+ self.num_samples = self.sampling_rate * self.max_audio_length
95
+
96
+ # After transposition, `self.mel_filters`'s shape is
97
+ # `(num_fft_bins // 2 + 1, num_mels).`
98
+ self.mel_filters = self._get_mel_filters()
99
+
100
+ def _get_mel_filters(self):
101
+ """
102
+ Adapted from Hugging Face
103
+ (https://github.com/huggingface/transformers/blob/v4.27.1/src/transformers/models/whisper/feature_extraction_whisper.py#L86)
104
+ """
105
+
106
+ # TODO: Convert to TensorFlow ops (if possible).
107
+
108
+ dtype = np.float32
109
+ # Initialize the weights
110
+ weights = np.zeros(
111
+ (self.num_mels, int(1 + self.num_fft_bins // 2)), dtype=dtype
112
+ )
113
+
114
+ # Center freqs of each FFT bin
115
+ fftfreqs = np.fft.rfftfreq(
116
+ n=self.num_fft_bins, d=1.0 / self.sampling_rate
117
+ )
118
+
119
+ # 'Center freqs' of mel bands - uniformly spaced between limits
120
+ min_mel = 0.0
121
+ max_mel = 45.245640471924965
122
+
123
+ mels = np.linspace(min_mel, max_mel, self.num_mels + 2)
124
+
125
+ mels = np.asanyarray(mels)
126
+
127
+ # Fill in the linear scale
128
+ f_min = 0.0
129
+ f_sp = 200.0 / 3
130
+ freqs = f_min + f_sp * mels
131
+
132
+ # And now the nonlinear scale
133
+ min_log_hz = 1000.0 # beginning of log region (Hz)
134
+ min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
135
+ logstep = np.log(6.4) / 27.0 # step size for log region
136
+
137
+ # If we have vector data, vectorize
138
+ log_t = mels >= min_log_mel
139
+ freqs[log_t] = min_log_hz * np.exp(
140
+ logstep * (mels[log_t] - min_log_mel)
141
+ )
142
+
143
+ mel_f = freqs
144
+
145
+ fdiff = np.diff(mel_f)
146
+ ramps = np.subtract.outer(mel_f, fftfreqs)
147
+
148
+ for i in range(self.num_mels):
149
+ # lower and upper slopes for all bins
150
+ lower = -ramps[i] / fdiff[i]
151
+ upper = ramps[i + 2] / fdiff[i + 1]
152
+
153
+ # .. then intersect them with each other and zero
154
+ weights[i] = np.maximum(0, np.minimum(lower, upper))
155
+
156
+ # Slaney-style mel is scaled to be approx constant energy per channel
157
+ enorm = 2.0 / (mel_f[2 : self.num_mels + 2] - mel_f[: self.num_mels])
158
+ weights *= enorm[:, np.newaxis]
159
+
160
+ weights = np.transpose(weights)
161
+ return tf.constant(weights, dtype=self.compute_dtype)
162
+
163
+ def _extract_audio_features(self, audio):
164
+ audio = tf.cast(audio, self.compute_dtype)
165
+ # Use "reflection" padding - `tf.signal.stft` uses symmetric padding
166
+ # internally.
167
+ audio = tf.pad(
168
+ audio,
169
+ paddings=[[0, 0], [self.num_fft_bins // 2, self.num_fft_bins // 2]],
170
+ mode="REFLECT",
171
+ )
172
+
173
+ # Compute the mel spectrogram.
174
+ stft = tf.signal.stft(
175
+ audio,
176
+ frame_length=self.num_fft_bins,
177
+ frame_step=self.stride,
178
+ fft_length=self.num_fft_bins,
179
+ )
180
+ magnitudes = tf.square(tf.abs(stft[:, :-1, :]))
181
+
182
+ mel_spec = tf.matmul(
183
+ magnitudes,
184
+ self.mel_filters,
185
+ )
186
+
187
+ def tf_log10(x):
188
+ """
189
+ Computes log base 10 of input tensor using TensorFlow's natural log operator.
190
+ """
191
+ numerator = tf.math.log(x)
192
+ denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
193
+ return numerator / denominator
194
+
195
+ # Clamp the values to a minimum value of 1e-10. This is done to avoid
196
+ # taking the log of 0, i.e., for numerical stability.
197
+ mel_spec = tf.maximum(mel_spec, 1e-10)
198
+
199
+ # Calculate the log mel spectrogram.
200
+ log_spec = tf_log10(mel_spec)
201
+ # Dynamic range compression.
202
+ log_spec_shape = tf.shape(log_spec)
203
+ max_value_minus_eight = tf.math.subtract(
204
+ tf.math.reduce_max(log_spec, axis=[1, 2]),
205
+ tf.cast(8, dtype=log_spec.dtype),
206
+ )
207
+ max_value_minus_eight = tf.expand_dims(max_value_minus_eight, axis=1)
208
+ max_value_minus_eight = tf.repeat(
209
+ max_value_minus_eight,
210
+ repeats=log_spec_shape[1] * log_spec_shape[2],
211
+ axis=1,
212
+ )
213
+ max_value_minus_eight = tf.reshape(
214
+ max_value_minus_eight, shape=log_spec_shape
215
+ )
216
+ log_spec = tf.maximum(log_spec, max_value_minus_eight)
217
+ # Normalization.
218
+ type_cast_four = tf.cast(4, dtype=log_spec.dtype)
219
+ log_spec = tf.math.divide(
220
+ tf.math.add(log_spec, type_cast_four),
221
+ type_cast_four,
222
+ )
223
+
224
+ return log_spec
225
+
226
+ def call(self, audio):
227
+ if not isinstance(audio, (tf.Tensor, tf.RaggedTensor)):
228
+ audio = tf.convert_to_tensor(audio)
229
+
230
+ rank_1_input = audio.shape.rank == 1
231
+ if rank_1_input:
232
+ audio = tf.expand_dims(audio, 0)
233
+
234
+ # Convert the tensor to a Ragged Tensor.
235
+ if isinstance(audio, tf.Tensor):
236
+ audio = tf.RaggedTensor.from_tensor(audio)
237
+
238
+ # Pad audio.
239
+ audio_shape = audio.shape.as_list()
240
+ audio_shape[-1] = self.num_samples
241
+ audio = audio.to_tensor(shape=audio_shape)
242
+
243
+ # Find the log mel spectrogram.
244
+ log_spec = self._extract_audio_features(audio)
245
+ if rank_1_input:
246
+ log_spec = tf.squeeze(log_spec, 0)
247
+ return log_spec
248
+
249
+ def get_config(self):
250
+ config = super().get_config()
251
+ config.update(
252
+ {
253
+ "num_mels": self.num_mels,
254
+ "num_fft_bins": self.num_fft_bins,
255
+ "stride": self.stride,
256
+ "sampling_rate": self.sampling_rate,
257
+ "max_audio_length": self.max_audio_length,
258
+ }
259
+ )
260
+ return config
@@ -0,0 +1,305 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import keras
17
+ from keras import ops
18
+
19
+ from keras_hub.src.api_export import keras_hub_export
20
+ from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
21
+ from keras_hub.src.layers.modeling.token_and_position_embedding import (
22
+ TokenAndPositionEmbedding,
23
+ )
24
+ from keras_hub.src.models.backbone import Backbone
25
+ from keras_hub.src.models.whisper.whisper_decoder import WhisperDecoder
26
+ from keras_hub.src.models.whisper.whisper_encoder import WhisperEncoder
27
+ from keras_hub.src.utils.tensor_utils import assert_tf_backend
28
+
29
+
30
+ def whisper_kernel_initializer(stddev=0.02):
31
+ return keras.initializers.TruncatedNormal(stddev=stddev)
32
+
33
+
34
+ class Padder(keras.layers.Layer):
35
+ def call(self, x):
36
+ return ops.pad(x, [[0, 0], [1, 1], [0, 0]])
37
+
38
+
39
+ @keras_hub_export("keras_hub.models.WhisperBackbone")
40
+ class WhisperBackbone(Backbone):
41
+ """A Whisper encoder-decoder network for speech.
42
+
43
+ This class implements a Transformer-based encoder-decoder model as
44
+ described in
45
+ ["Robust Speech Recognition via Large-Scale Weak Supervision"](https://arxiv.org/abs/2212.04356).
46
+ It includes the embedding lookups and transformer layers, but not the head
47
+ for predicting the next token.
48
+
49
+ The default constructor gives a fully customizable, randomly initialized Whisper
50
+ model with any number of layers, heads, and embedding dimensions. To load
51
+ preset architectures and weights, use the `from_preset()` constructor.
52
+
53
+ Disclaimer: Pre-trained models are provided on an "as is" basis, without
54
+ warranties or conditions of any kind. The underlying model is provided by a
55
+ third party and subject to a separate license, available
56
+ [here](https://github.com/openai/whisper).
57
+
58
+ Args:
59
+ vocabulary_size: int. The size of the token vocabulary.
60
+ num_layers: int. The number of transformer encoder layers and
61
+ transformer decoder layers.
62
+ num_heads: int. The number of attention heads for each transformer.
63
+ The hidden size must be divisible by the number of attention heads.
64
+ hidden_dim: int. The size of the transformer encoding and pooler layers.
65
+ intermediate_dim: int. The output dimension of the first Dense layer in
66
+ a two-layer feedforward network for each transformer.
67
+ num_mels: int. The number of mel-frequency filters. Defaults to `80`.
68
+ dropout: float. Dropout probability for the Transformer encoder.
69
+ max_encoder_sequence_length: int. The maximum sequence length that the
70
+ audio encoder can consume. Since the second convolutional layer in
71
+ the encoder reduces the sequence length by half (stride of 2), we
72
+ use `max_encoder_sequence_length // 2` as the sequence length for the
73
+ positional embedding layer.
74
+ max_decoder_sequence_length: int. The maximum sequence length that the
75
+ text decoder can consume.
76
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
77
+ for model computations and weights. Note that some computations,
78
+ such as softmax and layer normalization, will always be done at
79
+ float32 precision regardless of dtype.
80
+
81
+ Examples:
82
+
83
+ ```python
84
+ input_data = {
85
+ "encoder_features": np.ones(shape=(1, 12, 80), dtype="int32"),
86
+ "decoder_token_ids": np.ones(shape=(1, 12), dtype="int32"),
87
+ "decoder_padding_mask": np.array(
88
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]
89
+ ),
90
+ }
91
+
92
+ # Randomly initialized Whisper encoder-decoder model with a custom config.
93
+ model = keras_hub.models.WhisperBackbone(
94
+ vocabulary_size=51864,
95
+ num_layers=4,
96
+ num_heads=4,
97
+ hidden_dim=256,
98
+ intermediate_dim=512,
99
+ max_encoder_sequence_length=128,
100
+ max_decoder_sequence_length=128,
101
+ )
102
+ model(input_data)
103
+ ```
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ vocabulary_size,
109
+ num_layers,
110
+ num_heads,
111
+ hidden_dim,
112
+ intermediate_dim,
113
+ num_mels=80,
114
+ dropout=0.0,
115
+ max_encoder_sequence_length=3000,
116
+ max_decoder_sequence_length=448,
117
+ dtype=None,
118
+ **kwargs,
119
+ ):
120
+ assert_tf_backend(self.__class__.__name__)
121
+
122
+ # === Layers ===
123
+ self.encoder_conv_layer_1 = keras.layers.Conv1D(
124
+ filters=hidden_dim,
125
+ kernel_size=3,
126
+ strides=1,
127
+ padding="same",
128
+ dtype=dtype,
129
+ name="encoder_token_embedding_conv_layer_1",
130
+ )
131
+ self.encoder_conv_layer_2 = keras.layers.Conv1D(
132
+ filters=hidden_dim,
133
+ kernel_size=3,
134
+ strides=2,
135
+ padding="valid",
136
+ dtype=dtype,
137
+ name="encoder_token_embedding_conv_layer_2",
138
+ )
139
+ self.encoder_padder = Padder(
140
+ dtype=dtype,
141
+ name="encoder_padder",
142
+ )
143
+ self.encoder_position_embedding = PositionEmbedding(
144
+ initializer=whisper_kernel_initializer(),
145
+ sequence_length=max_encoder_sequence_length // 2,
146
+ dtype=dtype,
147
+ name="encoder_position_embedding",
148
+ trainable=False,
149
+ )
150
+ self.encoder_embeddings_add = keras.layers.Add(
151
+ dtype=dtype,
152
+ name="encoder_embeddings_add",
153
+ )
154
+ self.encoder_embeddings_dropout = keras.layers.Dropout(
155
+ dropout,
156
+ dtype=dtype,
157
+ name="encoder_embeddings_dropout",
158
+ )
159
+ self.encoder_transformer_layers = []
160
+ for i in range(num_layers):
161
+ layer = WhisperEncoder(
162
+ num_heads=num_heads,
163
+ intermediate_dim=intermediate_dim,
164
+ activation=keras.activations.gelu,
165
+ layer_norm_epsilon=1e-5,
166
+ dropout=dropout,
167
+ kernel_initializer=whisper_kernel_initializer(),
168
+ normalize_first=True,
169
+ dtype=dtype,
170
+ name=f"transformer_encoder_layer_{i}",
171
+ )
172
+ self.encoder_transformer_layers.append(layer)
173
+ self.encoder_layer_norm = keras.layers.LayerNormalization(
174
+ axis=-1,
175
+ epsilon=1e-5,
176
+ dtype=dtype,
177
+ name="encoder_layer_norm",
178
+ )
179
+ self.decoder_embeddings = TokenAndPositionEmbedding(
180
+ vocabulary_size=vocabulary_size,
181
+ sequence_length=max_decoder_sequence_length,
182
+ embedding_dim=hidden_dim,
183
+ embeddings_initializer=whisper_kernel_initializer(),
184
+ dtype=dtype,
185
+ name="decoder_token_and_position_embedding",
186
+ )
187
+ self.token_embedding = self.decoder_embeddings.token_embedding
188
+ self.decoder_embeddings_dropout = keras.layers.Dropout(
189
+ dropout,
190
+ dtype=dtype,
191
+ name="decoder_embeddings_dropout",
192
+ )
193
+ self.decoder_transformer_layers = []
194
+ for i in range(num_layers):
195
+ layer = WhisperDecoder(
196
+ intermediate_dim=intermediate_dim,
197
+ num_heads=num_heads,
198
+ dropout=dropout,
199
+ activation=keras.activations.gelu,
200
+ layer_norm_epsilon=1e-5,
201
+ kernel_initializer=whisper_kernel_initializer(),
202
+ normalize_first=True,
203
+ dtype=dtype,
204
+ name=f"transformer_decoder_layer_{i}",
205
+ )
206
+ self.decoder_transformer_layers.append(layer)
207
+ self.decoder_layer_norm = keras.layers.LayerNormalization(
208
+ axis=-1,
209
+ epsilon=1e-5,
210
+ dtype=dtype,
211
+ name="decoder_layer_norm",
212
+ )
213
+
214
+ # === Functional Model ===
215
+ # Note that the encoder does not have a padding mask:
216
+ # https://github.com/openai/whisper/blob/v20230124/whisper/model.py#L132.
217
+ encoder_feature_input = keras.Input(
218
+ shape=(None, num_mels), dtype="float32", name="encoder_features"
219
+ )
220
+ decoder_token_id_input = keras.Input(
221
+ shape=(None,), dtype="int32", name="decoder_token_ids"
222
+ )
223
+ decoder_padding_mask_input = keras.Input(
224
+ shape=(None,), dtype="int32", name="decoder_padding_mask"
225
+ )
226
+ # Encoder.
227
+ # Embed the input features. This consists of two 1D convolutional
228
+ # layers.
229
+ # For the first layer, we use `padding="same"` since that corresponds to
230
+ # a padding size of 1.
231
+ embedded_features = keras.activations.gelu(
232
+ self.encoder_conv_layer_1(encoder_feature_input),
233
+ approximate=False,
234
+ )
235
+ # For the second conv. layer, we cannot use `padding="same"` since
236
+ # that corresponds to a padding size of 1.5 (since stride is 2). Hence,
237
+ # we will manually pad the input.
238
+ embedded_features = self.encoder_padder(embedded_features)
239
+ embedded_features = keras.activations.gelu(
240
+ self.encoder_conv_layer_2(embedded_features),
241
+ approximate=False,
242
+ )
243
+ # The position embedding layer for the encoder is a sinusoidal embedding
244
+ # layer: https://github.com/openai/whisper/blob/v20230124/whisper/model.py#L137.
245
+ # Hence, we set it to be non-trainable.
246
+ # TODO: We can use `keras_hub.layers.SinePositionEncoding` layer.
247
+ positions = self.encoder_position_embedding(embedded_features)
248
+ x = self.encoder_embeddings_add((embedded_features, positions))
249
+ x = self.encoder_embeddings_dropout(x)
250
+ for transformer_layer in self.encoder_transformer_layers:
251
+ x = transformer_layer(x)
252
+ x = self.encoder_layer_norm(x)
253
+ encoder_output = x
254
+ # Decoder.
255
+ x = self.decoder_embeddings(decoder_token_id_input)
256
+ x = self.decoder_embeddings_dropout(x)
257
+ for transformer_layer in self.decoder_transformer_layers:
258
+ x = transformer_layer(
259
+ decoder_sequence=x,
260
+ encoder_sequence=encoder_output,
261
+ decoder_padding_mask=decoder_padding_mask_input,
262
+ )
263
+ x = self.decoder_layer_norm(x)
264
+ decoder_output = x
265
+ super().__init__(
266
+ inputs={
267
+ "encoder_features": encoder_feature_input,
268
+ "decoder_token_ids": decoder_token_id_input,
269
+ "decoder_padding_mask": decoder_padding_mask_input,
270
+ },
271
+ outputs={
272
+ "encoder_sequence_output": encoder_output,
273
+ "decoder_sequence_output": decoder_output,
274
+ },
275
+ dtype=dtype,
276
+ **kwargs,
277
+ )
278
+
279
+ # === Config ===
280
+ self.vocabulary_size = vocabulary_size
281
+ self.num_layers = num_layers
282
+ self.num_heads = num_heads
283
+ self.hidden_dim = hidden_dim
284
+ self.intermediate_dim = intermediate_dim
285
+ self.num_mels = num_mels
286
+ self.dropout = dropout
287
+ self.max_encoder_sequence_length = max_encoder_sequence_length
288
+ self.max_decoder_sequence_length = max_decoder_sequence_length
289
+
290
+ def get_config(self):
291
+ config = super().get_config()
292
+ config.update(
293
+ {
294
+ "vocabulary_size": self.vocabulary_size,
295
+ "num_layers": self.num_layers,
296
+ "num_heads": self.num_heads,
297
+ "hidden_dim": self.hidden_dim,
298
+ "intermediate_dim": self.intermediate_dim,
299
+ "num_mels": self.num_mels,
300
+ "dropout": self.dropout,
301
+ "max_encoder_sequence_length": self.max_encoder_sequence_length,
302
+ "max_decoder_sequence_length": self.max_decoder_sequence_length,
303
+ }
304
+ )
305
+ return config