keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,326 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import keras
17
+ from absl import logging
18
+
19
+ from keras_hub.src.api_export import keras_hub_export
20
+ from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
21
+ from keras_hub.src.models.preprocessor import Preprocessor
22
+ from keras_hub.src.models.whisper.whisper_audio_feature_extractor import (
23
+ WhisperAudioFeatureExtractor,
24
+ )
25
+ from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
26
+ from keras_hub.src.utils.keras_utils import (
27
+ convert_inputs_to_list_of_tensor_segments,
28
+ )
29
+
30
+
31
+ @keras_hub_export("keras_hub.models.WhisperPreprocessor")
32
+ class WhisperPreprocessor(Preprocessor):
33
+ """A Whisper preprocessing layer which handles audio and text input.
34
+
35
+ This preprocessing layer will do three things:
36
+
37
+ 1. Compute the log-mel spectrogram of the audio tensor inputs using
38
+ `audio_feature_extractor`.
39
+ 2. Tokenize decoder inputs using the `tokenizer`.
40
+ 2. Add the appropriate special tokens - `"<|startoftranscript|>", task
41
+ token, language token, `"<|endoftext|>"`, etc.
42
+ 3. Construct a dictionary with keys `"encoder_features"`,
43
+ `"decoder_token_ids"`, `"decoder_padding_mask"` that can be passed
44
+ directly to a Whisper model.
45
+
46
+ Args:
47
+ tokenizer: A `keras_hub.models.WhisperTokenizer` instance.
48
+ audio_feature_extractor: A
49
+ `keras_hub.models.WhisperAudioFeatureExtractor` instance or `None`.
50
+ If `None` a feature extractor with default parameters will be
51
+ created.
52
+ decoder_sequence_length: The length of the packed decoder inputs.
53
+ language: string, language token. Should only be passed if your
54
+ tokenizer is multilingual.
55
+ task: string, task name. One of `"transcribe"`, `"translate"`. Should
56
+ only be passed if your tokenizer is multilingual.
57
+ no_timestamps: bool. If True, `"<|no_timestamps|>"` will be added as a
58
+ special token to your input.
59
+
60
+ Call arguments:
61
+ x: A dictionary with `"encoder_audio"` and `"decoder_text"` as its keys.
62
+ `"encoder_audio"` should correspond to the input audio tensor.
63
+ `"decoder_text"` should be a tensor of single string sequences.
64
+ Inputs may be batched or unbatched. Raw python inputs will be
65
+ converted to tensors.
66
+ y: Any label data. Will be passed through unaltered.
67
+ sample_weight: Any label weight data. Will be passed through unaltered.
68
+
69
+ Examples:
70
+
71
+ Directly calling the layer on data.
72
+ ```python
73
+ preprocessor = keras_hub.models.WhisperPreprocessor.from_preset(
74
+ "whisper_tiny_en",
75
+ )
76
+
77
+ # Preprocess unbatched inputs.
78
+ input_data = {
79
+ "encoder_audio": tf.ones((200,)),
80
+ "decoder_text": "The quick brown fox jumped.",
81
+ }
82
+ preprocessor(input_data)
83
+
84
+ # Preprocess batched inputs.
85
+ input_data = {
86
+ "encoder_audio": tf.ones((2, 200)),
87
+ "decoder_text": ["The quick brown fox jumped.", "Call me Ishmael."],
88
+ }
89
+ preprocessor(input_data)
90
+
91
+ # Custom audio feature extractor and vocabulary.
92
+ audio_feature_extractor = keras_hub.models.WhisperAudioFeatureExtractor(
93
+ num_mels=80,
94
+ num_fft_bins=400,
95
+ stride=100,
96
+ sampling_rate=100,
97
+ max_audio_length=5,
98
+ )
99
+
100
+ features = ["a quick fox.", "a fox quick."]
101
+ vocab = {"<|endoftext|>": 0, "a": 4, "Ġquick": 5, "Ġfox": 6}
102
+ merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"]
103
+ merges += ["Ġ f", "o x", "Ġf ox"]
104
+ special_tokens = {
105
+ "<|startoftranscript|>": 9,
106
+ "<|endoftext|>": 10,
107
+ "<|notimestamps|>": 11,
108
+ "<|transcribe|>": 12,
109
+ "<|translate|>": 13,
110
+ }
111
+
112
+ tokenizer = keras_hub.models.WhisperTokenizer(
113
+ vocabulary=vocab,
114
+ merges=merges,
115
+ special_tokens=special_tokens,
116
+ )
117
+ preprocessor = keras_hub.models.WhisperPreprocessor(
118
+ audio_feature_extractor=audio_feature_extractor,
119
+ tokenizer=tokenizer,
120
+ )
121
+
122
+ input_data = {
123
+ "encoder_audio": tf.ones((200,)),
124
+ "decoder_text": "The quick brown fox jumped.",
125
+ }
126
+ preprocessor(input_data)
127
+ ```
128
+
129
+ Mapping with `tf.data.Dataset`.
130
+ ```python
131
+ preprocessor = keras_hub.models.WhisperPreprocessor.from_preset(
132
+ "whisper_tiny_en")
133
+
134
+ # Map labeled single sentences.
135
+ features = {
136
+ "encoder_audio": tf.ones((2, 200)),
137
+ "decoder_text": ["The quick brown fox jumped.", "Call me Ishmael."],
138
+ }
139
+ labels = tf.constant(["True", "False"])
140
+ ds = tf.data.Dataset.from_tensor_slices((features, labels))
141
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
142
+
143
+ # Map unlabeled single sentences.
144
+ features = {
145
+ "encoder_audio": tf.ones((2, 200)),
146
+ "decoder_text": ["The quick brown fox jumped.", "Call me Ishmael."],
147
+ }
148
+ ds = tf.data.Dataset.from_tensor_slices(features)
149
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
150
+ ```
151
+ """
152
+
153
+ tokenizer_cls = WhisperTokenizer
154
+
155
+ def __init__(
156
+ self,
157
+ tokenizer,
158
+ audio_feature_extractor=None,
159
+ decoder_sequence_length=448,
160
+ language=None,
161
+ task=None,
162
+ no_timestamps=True,
163
+ **kwargs,
164
+ ):
165
+ super().__init__(**kwargs)
166
+ if audio_feature_extractor is None:
167
+ audio_feature_extractor = WhisperAudioFeatureExtractor()
168
+ self.audio_feature_extractor = audio_feature_extractor
169
+ self.tokenizer = tokenizer
170
+ self.decoder_packer = None
171
+ self.decoder_sequence_length = decoder_sequence_length
172
+ self.language = language
173
+ self.task = task
174
+ self.no_timestamps = no_timestamps
175
+
176
+ def build(self, input_shape):
177
+ # Defer packer creation to `build()` so that we can be sure tokenizer
178
+ # assets have loaded when restoring a saved model.
179
+
180
+ # Create list of tokens to be prepended to decoder inputs.
181
+ bos_tokens = [self.tokenizer.bos_token_id]
182
+ if self.tokenizer.language_tokens is not None:
183
+ if (
184
+ self.language is None
185
+ or self.language not in self.tokenizer.language_tokens
186
+ ):
187
+ raise ValueError(
188
+ "You must pass a non-None value for `language` when using "
189
+ "a multilingual tokenizer. The value must be one of "
190
+ f'{",".join(self.tokenizer.language_tokens.keys())}. '
191
+ f"Received: language={self.language}."
192
+ )
193
+ if self.task is None or self.task not in [
194
+ "transcribe",
195
+ "translate",
196
+ ]:
197
+ raise ValueError(
198
+ "You must pass a non-None value for `task` when using "
199
+ "a multilingual tokenizer. The value must be one of "
200
+ '`"transcribe"`, `"translate"`. '
201
+ f"Received: task={self.task}."
202
+ )
203
+
204
+ bos_tokens += [self.tokenizer.language_tokens[self.language]]
205
+
206
+ if self.task == "transcribe":
207
+ bos_tokens += [self.tokenizer.special_tokens["<|transcribe|>"]]
208
+ elif self.task == "translate":
209
+ bos_tokens += [self.tokenizer.special_tokens["<|translate|>"]]
210
+ else:
211
+ if self.language is not None:
212
+ logging.info(
213
+ "`tokenizer` is monolingual, and `language` has a "
214
+ "non-`None` value. Setting `language` to `None`."
215
+ )
216
+ self.language = None
217
+ if self.task is not None:
218
+ logging.info(
219
+ "`tokenizer` is monolingual, and `task` has a "
220
+ "non-`None` value. Setting `task` to `None`."
221
+ )
222
+ self.task = None
223
+
224
+ if self.no_timestamps:
225
+ bos_tokens += [self.tokenizer.no_timestamps_token_id]
226
+
227
+ # TODO: Use `MultiSegmentPacker` instead of `StartEndPacker` once we
228
+ # want to move to multi-segment packing and have improved
229
+ # `MultiSegmentPacker`'s performance.
230
+ self.decoder_packer = StartEndPacker(
231
+ start_value=bos_tokens,
232
+ end_value=self.tokenizer.eos_token_id,
233
+ pad_value=self.tokenizer.pad_token_id,
234
+ sequence_length=self.decoder_sequence_length,
235
+ return_padding_mask=True,
236
+ )
237
+
238
+ def call(self, x, y=None, sample_weight=None, decoder_sequence_length=None):
239
+ if not (
240
+ isinstance(x, dict)
241
+ and ["encoder_audio", "decoder_text"] == list(x.keys())
242
+ ):
243
+ raise ValueError(
244
+ '`x` must be a dictionary, containing the keys `"encoder_audio"`'
245
+ f' and `"decoder_text"`. Received x={x}.'
246
+ )
247
+
248
+ encoder_audio = x["encoder_audio"]
249
+ decoder_text = x["decoder_text"]
250
+
251
+ encoder_audio = convert_inputs_to_list_of_tensor_segments(encoder_audio)
252
+ decoder_text = convert_inputs_to_list_of_tensor_segments(decoder_text)
253
+
254
+ if len(encoder_audio) > 1 or len(decoder_text) > 1:
255
+ raise ValueError(
256
+ '`WhisperPreprocessor` requires both `"encoder_audio"` and '
257
+ f'`"decoder_text"` to contain only one segment, but received '
258
+ f"{len(encoder_audio)} and {len(decoder_text)}, respectively."
259
+ )
260
+
261
+ encoder_features = self.audio_feature_extractor(encoder_audio[0])
262
+ decoder_sequence_length = (
263
+ decoder_sequence_length or self.decoder_sequence_length
264
+ )
265
+ decoder_inputs = self.tokenizer(decoder_text[0])
266
+ decoder_token_ids, decoder_padding_mask = self.decoder_packer(
267
+ decoder_inputs,
268
+ sequence_length=decoder_sequence_length,
269
+ )
270
+
271
+ x = {
272
+ "encoder_features": encoder_features,
273
+ "decoder_token_ids": decoder_token_ids,
274
+ "decoder_padding_mask": decoder_padding_mask,
275
+ }
276
+
277
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
278
+
279
+ def get_config(self):
280
+ config = super().get_config()
281
+ config.update(
282
+ {
283
+ "audio_feature_extractor": keras.layers.serialize(
284
+ self.audio_feature_extractor
285
+ ),
286
+ "decoder_sequence_length": self.decoder_sequence_length,
287
+ "language": self.language,
288
+ "task": self.task,
289
+ "no_timestamps": self.no_timestamps,
290
+ }
291
+ )
292
+ return config
293
+
294
+ @classmethod
295
+ def from_config(cls, config):
296
+ if "tokenizer" in config and isinstance(config["tokenizer"], dict):
297
+ config["tokenizer"] = keras.layers.deserialize(config["tokenizer"])
298
+
299
+ if "audio_feature_extractor" in config and isinstance(
300
+ config["audio_feature_extractor"], dict
301
+ ):
302
+ config["audio_feature_extractor"] = keras.layers.deserialize(
303
+ config["audio_feature_extractor"]
304
+ )
305
+
306
+ return cls(**config)
307
+
308
+ @property
309
+ def decoder_sequence_length(self):
310
+ """The padded length of decoder input sequences."""
311
+ return self._decoder_sequence_length
312
+
313
+ @decoder_sequence_length.setter
314
+ def decoder_sequence_length(self, value):
315
+ self._decoder_sequence_length = value
316
+ if self.decoder_packer is not None:
317
+ self.decoder_packer.sequence_length = value
318
+
319
+ @property
320
+ def sequence_length(self):
321
+ """Alias for `decoder_sequence_length`."""
322
+ return self.decoder_sequence_length
323
+
324
+ @sequence_length.setter
325
+ def sequence_length(self, value):
326
+ self.decoder_sequence_length = value
@@ -0,0 +1,148 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Metadata for loading pretrained model weights.
16
+ backbone_presets = {
17
+ "whisper_tiny_en": {
18
+ "metadata": {
19
+ "description": (
20
+ "4-layer Whisper model. Trained on 438,000 hours of labelled "
21
+ "English speech data."
22
+ ),
23
+ "params": 37184256,
24
+ "official_name": "Whisper",
25
+ "path": "whisper",
26
+ "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
27
+ },
28
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_en/2",
29
+ },
30
+ "whisper_base_en": {
31
+ "metadata": {
32
+ "description": (
33
+ "6-layer Whisper model. Trained on 438,000 hours of labelled "
34
+ "English speech data."
35
+ ),
36
+ "params": 124439808,
37
+ "official_name": "Whisper",
38
+ "path": "whisper",
39
+ "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
40
+ },
41
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_en/2",
42
+ },
43
+ "whisper_small_en": {
44
+ "metadata": {
45
+ "description": (
46
+ "12-layer Whisper model. Trained on 438,000 hours of labelled "
47
+ "English speech data."
48
+ ),
49
+ "params": 241734144,
50
+ "official_name": "Whisper",
51
+ "path": "whisper",
52
+ "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
53
+ },
54
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_en/2",
55
+ },
56
+ "whisper_medium_en": {
57
+ "metadata": {
58
+ "description": (
59
+ "24-layer Whisper model. Trained on 438,000 hours of labelled "
60
+ "English speech data."
61
+ ),
62
+ "params": 763856896,
63
+ "official_name": "Whisper",
64
+ "path": "whisper",
65
+ "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
66
+ },
67
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_en/2",
68
+ },
69
+ "whisper_tiny_multi": {
70
+ "metadata": {
71
+ "description": (
72
+ "4-layer Whisper model. Trained on 680,000 hours of labelled "
73
+ "multilingual speech data."
74
+ ),
75
+ "params": 37760640,
76
+ "official_name": "Whisper",
77
+ "path": "whisper",
78
+ "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
79
+ },
80
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_multi/2",
81
+ },
82
+ "whisper_base_multi": {
83
+ "metadata": {
84
+ "description": (
85
+ "6-layer Whisper model. Trained on 680,000 hours of labelled "
86
+ "multilingual speech data."
87
+ ),
88
+ "params": 72593920,
89
+ "official_name": "Whisper",
90
+ "path": "whisper",
91
+ "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
92
+ },
93
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_multi/2",
94
+ },
95
+ "whisper_small_multi": {
96
+ "metadata": {
97
+ "description": (
98
+ "12-layer Whisper model. Trained on 680,000 hours of labelled "
99
+ "multilingual speech data."
100
+ ),
101
+ "params": 241734912,
102
+ "official_name": "Whisper",
103
+ "path": "whisper",
104
+ "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
105
+ },
106
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_multi/2",
107
+ },
108
+ "whisper_medium_multi": {
109
+ "metadata": {
110
+ "description": (
111
+ "24-layer Whisper model. Trained on 680,000 hours of labelled "
112
+ "multilingual speech data."
113
+ ),
114
+ "params": 763857920,
115
+ "official_name": "Whisper",
116
+ "path": "whisper",
117
+ "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
118
+ },
119
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_multi/2",
120
+ },
121
+ "whisper_large_multi": {
122
+ "metadata": {
123
+ "description": (
124
+ "32-layer Whisper model. Trained on 680,000 hours of labelled "
125
+ "multilingual speech data."
126
+ ),
127
+ "params": 1543304960,
128
+ "official_name": "Whisper",
129
+ "path": "whisper",
130
+ "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
131
+ },
132
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi/2",
133
+ },
134
+ "whisper_large_multi_v2": {
135
+ "metadata": {
136
+ "description": (
137
+ "32-layer Whisper model. Trained for 2.5 epochs on 680,000 "
138
+ "hours of labelled multilingual speech data. An improved "
139
+ "of `whisper_large_multi`."
140
+ ),
141
+ "params": 1543304960,
142
+ "official_name": "Whisper",
143
+ "path": "whisper",
144
+ "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
145
+ },
146
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi_v2/2",
147
+ },
148
+ }
@@ -0,0 +1,163 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
19
+
20
+
21
+ def _load_dict(dict_or_path):
22
+ if isinstance(dict_or_path, str):
23
+ with open(dict_or_path, "r", encoding="utf-8") as f:
24
+ dict_or_path = json.load(f)
25
+ return dict_or_path
26
+
27
+
28
+ @keras_hub_export("keras_hub.models.WhisperTokenizer")
29
+ class WhisperTokenizer(BytePairTokenizer):
30
+ """Whisper text tokenizer using Byte-Pair Encoding subword segmentation.
31
+
32
+ This tokenizer class will tokenize raw strings into integer sequences and
33
+ is based on `keras_hub.tokenizers.BytePairTokenizer`.
34
+ This tokenizer does not provide truncation or padding of inputs.
35
+
36
+ Args:
37
+ vocabulary: string or dict, maps token to integer ids. If it is a
38
+ string, it should be the file path to a json file.
39
+ merges: string or list, contains the merge rule. If it is a string,
40
+ it should be the file path to merge rules. The merge rule file
41
+ should have one merge rule per line. Every merge rule contains
42
+ merge entities separated by a space.
43
+ special_tokens: string or dict, maps special tokens to integer IDs. If
44
+ it is a string, it should be the path to a JSON file.
45
+ language_tokens: string or dict, maps language tokens to integer IDs. If
46
+ not None, the tokenizer will be assumed to be a multilingual
47
+ tokenizer.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ vocabulary=None,
53
+ merges=None,
54
+ special_tokens=None,
55
+ language_tokens=None,
56
+ **kwargs,
57
+ ):
58
+ special_tokens = _load_dict(special_tokens)
59
+ if language_tokens is not None:
60
+ language_tokens = _load_dict(language_tokens)
61
+
62
+ # Necessary special tokens.
63
+ self.bos_token = "<|startoftranscript|>"
64
+ self.eos_token = "<|endoftext|>"
65
+ # TODO: The pad token for the multilingual tokenizer is actually
66
+ # "", but it errors out (OOM). After BPE is fixed, we can update
67
+ # this to "". For now, we will use `"<|endoftext|>"`.
68
+ self.pad_token = "<|endoftext|>"
69
+
70
+ self.no_timestamps_token = "<|notimestamps|>"
71
+ # Task special tokens.
72
+ self.translate_token = "<|translate|>"
73
+ self.transcribe_token = "<|transcribe|>"
74
+
75
+ for token in [
76
+ self.bos_token,
77
+ self.eos_token,
78
+ self.pad_token,
79
+ self.no_timestamps_token,
80
+ self.translate_token,
81
+ self.transcribe_token,
82
+ ]:
83
+ if token not in special_tokens:
84
+ raise ValueError(
85
+ f"Cannot find token `'{token}'` in the provided "
86
+ f"`special_tokens`. Please provide `'{token}'` in your "
87
+ "`special_tokens`."
88
+ )
89
+
90
+ self.bos_token_id = special_tokens[self.bos_token]
91
+ self.eos_token_id = special_tokens[self.eos_token]
92
+ self.pad_token_id = special_tokens[self.pad_token]
93
+ self.no_timestamps_token_id = special_tokens[self.no_timestamps_token]
94
+ self.translate_token_id = special_tokens[self.translate_token]
95
+ self.transcribe_token_id = special_tokens[self.transcribe_token]
96
+
97
+ self.special_tokens = special_tokens
98
+ self.language_tokens = language_tokens
99
+
100
+ # TODO: Add language tokens to `unsplittable_tokens` once we figure
101
+ # out the performance issue with a large list.
102
+ unsplittable_tokens = list(special_tokens.keys())
103
+
104
+ super().__init__(
105
+ vocabulary=vocabulary,
106
+ merges=merges,
107
+ unsplittable_tokens=unsplittable_tokens,
108
+ **kwargs,
109
+ )
110
+
111
+ def save_assets(self, dir_path):
112
+ # TODO: whisper is currently mutating it's vocabulary before passing
113
+ # it to the super class, so we need to restore the unmutated vocabulary
114
+ # before saving our assets. We should find a more robust (and memory
115
+ # efficient) way to do this.
116
+ vocabulary = self.vocabulary
117
+ self.vocabulary = self._initial_vocabulary
118
+ super().save_assets(dir_path)
119
+ self.vocabulary = vocabulary
120
+
121
+ def set_vocabulary_and_merges(self, vocabulary, merges):
122
+ if vocabulary is not None:
123
+ vocabulary = _load_dict(vocabulary)
124
+ self._initial_vocabulary = dict(vocabulary)
125
+
126
+ if self.language_tokens is not None:
127
+ # Multilingual tokenizer.
128
+ # Add language tokens to the vocabulary. This makes
129
+ # detokenization easier for us.
130
+ vocabulary = {
131
+ **vocabulary,
132
+ **self.language_tokens,
133
+ }
134
+
135
+ for token in [
136
+ self.bos_token,
137
+ self.eos_token,
138
+ self.pad_token,
139
+ self.no_timestamps_token,
140
+ self.translate_token,
141
+ self.transcribe_token,
142
+ ]:
143
+ vocabulary[token] = self.special_tokens[token]
144
+ else:
145
+ self._initial_vocabulary = None
146
+
147
+ super().set_vocabulary_and_merges(vocabulary, merges)
148
+
149
+ def get_config(self):
150
+ config = super().get_config()
151
+
152
+ # In the constructor, we pass the list of special tokens to the
153
+ # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
154
+ # delete it from the config here.
155
+ del config["unsplittable_tokens"]
156
+
157
+ config.update(
158
+ {
159
+ "special_tokens": self.special_tokens,
160
+ "language_tokens": self.language_tokens,
161
+ }
162
+ )
163
+ return config
@@ -0,0 +1,26 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import (
16
+ XLMRobertaBackbone,
17
+ )
18
+ from keras_hub.src.models.xlm_roberta.xlm_roberta_presets import (
19
+ backbone_presets,
20
+ )
21
+ from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import (
22
+ XLMRobertaTokenizer,
23
+ )
24
+ from keras_hub.src.utils.preset_utils import register_presets
25
+
26
+ register_presets(backbone_presets, (XLMRobertaBackbone, XLMRobertaTokenizer))