keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,394 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import collections
16
+ import math
17
+
18
+ import keras
19
+ from keras import ops
20
+
21
+ from keras_hub.src.api_export import keras_hub_export
22
+ from keras_hub.src.utils.tensor_utils import is_float_dtype
23
+ from keras_hub.src.utils.tensor_utils import tensor_to_list
24
+
25
+ try:
26
+ import tensorflow as tf
27
+ except ImportError:
28
+ tf = None
29
+
30
+
31
+ REPLACE_SUBSTRINGS = [
32
+ ("<skipped>", ""),
33
+ ("-\n", ""),
34
+ ("\n", " "),
35
+ ("&quot;", '"'),
36
+ ("&amp;", "&"),
37
+ ("&lt;", "<"),
38
+ ("&gt;", ">"),
39
+ ]
40
+
41
+
42
+ REGEX_PATTERNS = [
43
+ # language-dependent part (assuming Western languages)
44
+ (r"([\{-\~\[-\` -\&\(-\+\:-\@\/])", r" \1 "),
45
+ # tokenize period and comma unless preceded by a digit
46
+ (r"([^0-9])([\.,])", r"\1 \2 "),
47
+ # tokenize period and comma unless followed by a digit
48
+ (r"([\.,])([^0-9])", r" \1 \2"),
49
+ # tokenize dash when preceded by a digit
50
+ (r"([0-9])(-)", r"\1 \2 "),
51
+ # If last character is "." or ",", add space.
52
+ (r"[\.,]$", r" \0 \1"),
53
+ # one space only between words
54
+ (r"\s+", r" "),
55
+ ]
56
+
57
+
58
+ @keras_hub_export("keras_hub.metrics.Bleu")
59
+ class Bleu(keras.metrics.Metric):
60
+ """BLEU metric.
61
+
62
+ This class implements the BLEU metric. BLEU is generally used to evaluate
63
+ machine translation systems. By default, this implementation replicates
64
+ SacreBLEU, but user-defined tokenizers can be passed to deal with other
65
+ languages.
66
+
67
+ For BLEU score, we count the number of matching n-grams in the candidate
68
+ translation and the reference text. We find the "clipped count" of matching
69
+ n-grams so as to not give a high score to a (reference, prediction) pair
70
+ with redundant, repeated tokens. Secondly, BLEU score tends to reward
71
+ shorter predictions more, which is why a brevity penalty is applied to
72
+ penalise short predictions. For more details, see the following article:
73
+ https://cloud.google.com/translate/automl/docs/evaluate#bleu.
74
+
75
+ Note on input shapes:
76
+ For unbatched inputs, `y_pred` should be a tensor of shape `()`, and
77
+ `y_true` should be a tensor of shape `(num_references,)`. For batched
78
+ inputs, `y_pred` should be a tensor of shape `(batch_size,)`,
79
+ and `y_true` should be a tensor of shape `(batch_size, num_references)`. In
80
+ case of batched inputs, `y_true` can also be a ragged tensor of shape
81
+ `(batch_size, None)` if different samples have different number of
82
+ references.
83
+
84
+ Args:
85
+ tokenizer: callable. A function that takes a string `tf.RaggedTensor`
86
+ (of any shape), and tokenizes the strings in the tensor. If the
87
+ tokenizer is not specified, the default tokenizer is used. The
88
+ default tokenizer replicates the behaviour of SacreBLEU's
89
+ `"tokenizer_13a"` tokenizer
90
+ (https://github.com/mjpost/sacrebleu/blob/v2.1.0/sacrebleu/tokenizers/tokenizer_13a.py).
91
+ max_order: int. The maximum n-gram order to use. For example, if
92
+ `max_order` is set to 3, unigrams, bigrams, and trigrams will be
93
+ considered. Defaults to `4`.
94
+ smooth: bool. Whether to apply Lin et al. 2004 smoothing to the BLEU
95
+ score. Adds 1 to the matched n-gram count (i.e., numerator) and 1
96
+ to the total n-gram count (i.e., denominator) for every order while
97
+ calculating precision. Defaults to `False`.
98
+ dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
99
+ not specified, it defaults to `"float32"`.
100
+ name: string. Name of the metric instance.
101
+ **kwargs: Other keyword arguments.
102
+
103
+ References:
104
+ - [Papineni et al., 2002](https://aclanthology.org/P02-1040/)
105
+ - [SacreBLEU](https://github.com/mjpost/sacrebleu)
106
+ - [Lin et al., 2004](https://aclanthology.org/P04-1077/)
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ tokenizer=None,
112
+ max_order=4,
113
+ smooth=False,
114
+ dtype="float32",
115
+ name="bleu",
116
+ **kwargs,
117
+ ):
118
+ super().__init__(name=name, dtype=dtype, **kwargs)
119
+
120
+ if not is_float_dtype(dtype):
121
+ raise ValueError(
122
+ "`dtype` must be a floating point type. "
123
+ f"Received: dtype={dtype}"
124
+ )
125
+
126
+ self.tokenizer = tokenizer
127
+ self.max_order = max_order
128
+ self.smooth = smooth
129
+
130
+ self._matches = self.add_weight(
131
+ shape=(self.max_order,),
132
+ initializer="zeros",
133
+ dtype=self.dtype,
134
+ name="bleu_matches",
135
+ )
136
+ self._possible_matches = self.add_weight(
137
+ shape=(self.max_order,),
138
+ initializer="zeros",
139
+ dtype=self.dtype,
140
+ name="bleu_possible_matches",
141
+ )
142
+ self._translation_length = self.add_weight(
143
+ shape=(),
144
+ initializer="zeros",
145
+ dtype=self.dtype,
146
+ name="bleu_translation_length",
147
+ )
148
+ self._reference_length = self.add_weight(
149
+ shape=(),
150
+ initializer="zeros",
151
+ dtype=self.dtype,
152
+ name="bleu_reference_length",
153
+ )
154
+ self._bleu = self.add_weight(
155
+ shape=(),
156
+ initializer="zeros",
157
+ dtype=self.dtype,
158
+ name="bleu",
159
+ )
160
+
161
+ def _tokenizer(self, inputs):
162
+ """
163
+ Tokenizes the input strings. By default, replicates the behaviour of
164
+ SacreBLEU's default tokenizer, namely, `tokenizer_13a`.
165
+ """
166
+ if self.tokenizer:
167
+ return self.tokenizer(inputs)
168
+
169
+ for pattern, replacement in REPLACE_SUBSTRINGS + REGEX_PATTERNS:
170
+ inputs = tf.strings.regex_replace(
171
+ input=inputs,
172
+ pattern=pattern,
173
+ rewrite=replacement,
174
+ replace_global=True,
175
+ name=None,
176
+ )
177
+ inputs = tf.strings.split(inputs)
178
+ return inputs
179
+
180
+ def _get_ngrams(self, segment, max_order):
181
+ """Extracts all n-grams up to a given maximum order from an input segment.
182
+
183
+ Uses Python ops. Inspired from
184
+ https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py.
185
+
186
+ Args:
187
+ segment: list. Text segment from which n-grams will be
188
+ extracted.
189
+ max_order: int. Maximum length in tokens of the n-grams returned
190
+ by this method.
191
+ """
192
+ ngram_counts = collections.Counter()
193
+ for order in range(1, max_order + 1):
194
+ for i in range(0, len(segment) - order + 1):
195
+ ngram = tuple(segment[i : i + order])
196
+ ngram_counts[ngram] += 1
197
+ return ngram_counts
198
+
199
+ def _corpus_bleu(
200
+ self,
201
+ reference_corpus,
202
+ translation_corpus,
203
+ matches_by_order,
204
+ possible_matches_by_order,
205
+ translation_length,
206
+ reference_length,
207
+ max_order=4,
208
+ smooth=False,
209
+ ):
210
+ """Corpus BLEU implementation using Python ops.
211
+
212
+ Computes BLEU score of translated segments against one or more
213
+ references. Inspired from
214
+ https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py.
215
+
216
+ Args:
217
+ reference_corpus: list of lists of references for each
218
+ translation. Each reference should be tokenized into a list
219
+ of tokens.
220
+ translation_corpus: list of translations to score. Each
221
+ translation should be tokenized into a list of tokens.
222
+ matches_by_order: list of floats containing the initial number
223
+ of matches for each order.
224
+ possible_matches_by_order: list of floats containing the initial
225
+ number of possible matches for each order.
226
+ translation_length: float. Initial number of tokens in all the
227
+ translations.
228
+ reference_length: float. Initial number of tokens in all the
229
+ references.
230
+ max_order: int. Maximum n-gram order to use when computing
231
+ BLEU score.
232
+ smooth: boolean. Whether or not to apply Lin et al. 2004
233
+ smoothing.
234
+ """
235
+ for references, translation in zip(
236
+ reference_corpus, translation_corpus
237
+ ):
238
+ reference_length += min(len(r) for r in references)
239
+ translation_length += len(translation)
240
+
241
+ merged_ref_ngram_counts = collections.Counter()
242
+ for reference in references:
243
+ merged_ref_ngram_counts |= self._get_ngrams(
244
+ reference, max_order
245
+ )
246
+ translation_ngram_counts = self._get_ngrams(translation, max_order)
247
+ overlap = translation_ngram_counts & merged_ref_ngram_counts
248
+ for ngram in overlap:
249
+ matches_by_order[len(ngram) - 1] += overlap[ngram]
250
+ for order in range(1, max_order + 1):
251
+ possible_matches = len(translation) - order + 1
252
+ if possible_matches > 0:
253
+ possible_matches_by_order[order - 1] += possible_matches
254
+
255
+ precisions = [0] * max_order
256
+ for i in range(0, max_order):
257
+ if smooth:
258
+ precisions[i] = (matches_by_order[i] + 1.0) / (
259
+ possible_matches_by_order[i] + 1.0
260
+ )
261
+ else:
262
+ if possible_matches_by_order[i] > 0:
263
+ precisions[i] = (
264
+ float(matches_by_order[i])
265
+ / possible_matches_by_order[i]
266
+ )
267
+ else:
268
+ precisions[i] = 0.0
269
+
270
+ if min(precisions) > 0:
271
+ p_log_sum = sum((1.0 / max_order) * math.log(p) for p in precisions)
272
+ geo_mean = math.exp(p_log_sum)
273
+ else:
274
+ geo_mean = 0
275
+
276
+ ratio = float(translation_length) / reference_length
277
+
278
+ if ratio > 1.0:
279
+ bp = 1.0
280
+ else:
281
+ bp = math.exp(1 - 1.0 / ratio)
282
+
283
+ bleu = geo_mean * bp
284
+
285
+ return (
286
+ bleu,
287
+ matches_by_order,
288
+ possible_matches_by_order,
289
+ translation_length,
290
+ reference_length,
291
+ )
292
+
293
+ def _calculate_bleu_score(self, references, translation):
294
+ if isinstance(references, (tf.Tensor, tf.RaggedTensor)):
295
+ references = tensor_to_list(references)
296
+ if isinstance(translation, (tf.Tensor, tf.RaggedTensor)):
297
+ translation = tensor_to_list(translation)
298
+
299
+ matches = self._matches.numpy()
300
+ possible_matches = self._possible_matches.numpy()
301
+ translation_length = self._translation_length.numpy()
302
+ reference_length = self._reference_length.numpy()
303
+
304
+ (
305
+ bleu_score,
306
+ matches,
307
+ possible_matches,
308
+ translation_length,
309
+ reference_length,
310
+ ) = self._corpus_bleu(
311
+ reference_corpus=references,
312
+ translation_corpus=translation,
313
+ matches_by_order=matches,
314
+ possible_matches_by_order=possible_matches,
315
+ translation_length=translation_length,
316
+ reference_length=reference_length,
317
+ max_order=self.max_order,
318
+ smooth=self.smooth,
319
+ )
320
+ return (
321
+ bleu_score,
322
+ matches,
323
+ possible_matches,
324
+ translation_length,
325
+ reference_length,
326
+ )
327
+
328
+ def update_state(self, y_true, y_pred, sample_weight=None):
329
+ def validate_and_fix_rank(inputs, tensor_name, base_rank=0):
330
+ if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
331
+ inputs = tf.convert_to_tensor(inputs)
332
+
333
+ if inputs.shape.rank == base_rank:
334
+ return inputs[tf.newaxis]
335
+ elif inputs.shape.rank == base_rank + 1:
336
+ return inputs
337
+ elif inputs.shape.rank == base_rank + 2:
338
+ if tf.shape(inputs)[-1] != 1:
339
+ raise ValueError(
340
+ f"{tensor_name} is of rank {input.shape.rank}. The "
341
+ f"last dimension must be of size 1."
342
+ )
343
+ return tf.squeeze(inputs, axis=-1)
344
+ else:
345
+ raise ValueError(
346
+ f"{tensor_name} must be of rank {base_rank}, {base_rank+1} "
347
+ f"or {base_rank+2}. Found rank: {inputs.shape.rank}"
348
+ )
349
+
350
+ y_true = validate_and_fix_rank(y_true, "y_true", 1)
351
+ y_pred = validate_and_fix_rank(y_pred, "y_pred", 0)
352
+
353
+ # Tokenize the inputs.
354
+ y_true = self._tokenizer(y_true)
355
+ y_pred = self._tokenizer(y_pred)
356
+
357
+ (
358
+ bleu_score,
359
+ matches,
360
+ possible_matches,
361
+ translation_length,
362
+ reference_length,
363
+ ) = self._calculate_bleu_score(y_true, y_pred)
364
+
365
+ self._matches.assign(matches)
366
+ self._possible_matches.assign(possible_matches)
367
+ self._translation_length.assign(translation_length)
368
+ self._reference_length.assign(reference_length)
369
+ self._bleu.assign(bleu_score)
370
+
371
+ def result(self):
372
+ return self._bleu
373
+
374
+ def reset_state(self):
375
+ self._matches.assign(
376
+ ops.zeros(shape=(self.max_order,), dtype=self.dtype)
377
+ )
378
+ self._possible_matches.assign(
379
+ ops.zeros(shape=(self.max_order,), dtype=self.dtype)
380
+ )
381
+ self._translation_length.assign(0.0)
382
+ self._reference_length.assign(0.0)
383
+ self._bleu.assign(0.0)
384
+
385
+ def get_config(self):
386
+ config = super().get_config()
387
+ config.update(
388
+ {
389
+ "tokenizer": self.tokenizer,
390
+ "max_order": self.max_order,
391
+ "smooth": self.smooth,
392
+ }
393
+ )
394
+ return config
@@ -0,0 +1,197 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.utils.tensor_utils import is_float_dtype
19
+
20
+ try:
21
+ import tensorflow as tf
22
+ except ImportError:
23
+ tf = None
24
+
25
+
26
+ @keras_hub_export("keras_hub.metrics.EditDistance")
27
+ class EditDistance(keras.metrics.Metric):
28
+ """Edit Distance metric.
29
+
30
+ This class implements the edit distance metric, sometimes called
31
+ Levenshtein Distance, as a `keras.metrics.Metric`. Essentially, edit
32
+ distance is the least number of operations required to convert one string to
33
+ another, where an operation can be one of substitution, deletion or
34
+ insertion. By default, this metric will compute the normalized score, where
35
+ the unnormalized edit distance score is divided by the number of tokens in
36
+ the reference text.
37
+
38
+ This class can be used to compute character error rate (CER) and word error
39
+ rate (WER). You simply have to pass the appropriate tokenized text, and set
40
+ `normalize` to True.
41
+
42
+ Note on input shapes:
43
+ `y_true` and `y_pred` can either be tensors of rank 1 or ragged tensors of
44
+ rank 2. These tensors contain tokenized text.
45
+
46
+ Args:
47
+ normalize: bool. If True, the computed number of operations
48
+ (substitutions + deletions + insertions) across all samples is
49
+ divided by the aggregate number of tokens in all reference texts. If
50
+ False, number of operations are calculated for every sample, and
51
+ averaged over all the samples.
52
+ dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
53
+ not specified, it defaults to `"float32"`.
54
+ name: string. Name of the metric instance.
55
+ **kwargs: Other keyword arguments.
56
+
57
+ References:
58
+ - [Morris et al.](https://www.researchgate.net/publication/221478089)
59
+
60
+ Examples:
61
+
62
+ Various Input Types.
63
+
64
+ Single-level Python list.
65
+ >>> edit_distance = keras_hub.metrics.EditDistance()
66
+ >>> y_true = "the tiny little cat was found under the big funny bed".split()
67
+ >>> y_pred = "the cat was found under the bed".split()
68
+ >>> edit_distance(y_true, y_pred)
69
+ <tf.Tensor: shape=(), dtype=float32, numpy=0.36363637>
70
+
71
+ Nested Python list.
72
+ >>> edit_distance = keras_hub.metrics.EditDistance()
73
+ >>> y_true = [
74
+ ... "the tiny little cat was found under the big funny bed".split(),
75
+ ... "it is sunny today".split(),
76
+ ... ]
77
+ >>> y_pred = [
78
+ ... "the cat was found under the bed".split(),
79
+ ... "it is sunny but with a hint of cloud cover".split(),
80
+ ... ]
81
+ >>> edit_distance(y_true, y_pred)
82
+ <tf.Tensor: shape=(), dtype=float32, numpy=0.73333335>
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ normalize=True,
88
+ dtype="float32",
89
+ name="edit_distance",
90
+ **kwargs,
91
+ ):
92
+ super().__init__(name=name, dtype=dtype, **kwargs)
93
+
94
+ if not is_float_dtype(dtype):
95
+ raise ValueError(
96
+ "`dtype` must be a floating point type. "
97
+ f"Received: dtype={dtype}"
98
+ )
99
+
100
+ self.normalize = normalize
101
+
102
+ self._aggregate_unnormalized_edit_distance = self.add_weight(
103
+ shape=(),
104
+ initializer="zeros",
105
+ dtype=self.dtype,
106
+ name="aggregate_unnormalized_edit_distance",
107
+ )
108
+ if normalize:
109
+ self._aggregate_reference_length = self.add_weight(
110
+ shape=(),
111
+ initializer="zeros",
112
+ dtype=self.dtype,
113
+ name="aggregate_reference_length",
114
+ )
115
+ else:
116
+ self._number_of_samples = self.add_weight(
117
+ shape=(),
118
+ initializer="zeros",
119
+ dtype=self.dtype,
120
+ name="number_of_samples",
121
+ )
122
+
123
+ def update_state(self, y_true, y_pred, sample_weight=None):
124
+ def validate_and_fix_rank(inputs, tensor_name):
125
+ if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
126
+ inputs = tf.ragged.constant(inputs)
127
+
128
+ if inputs.shape.rank == 1:
129
+ return tf.RaggedTensor.from_tensor(inputs[tf.newaxis])
130
+ elif inputs.shape.rank == 2:
131
+ return inputs
132
+ else:
133
+ raise ValueError(
134
+ f"{tensor_name} must be of rank 1 or 2. "
135
+ f"Found rank: {inputs.shape.rank}"
136
+ )
137
+
138
+ y_true = validate_and_fix_rank(y_true, "y_true")
139
+ y_pred = validate_and_fix_rank(y_pred, "y_pred")
140
+
141
+ if self.normalize:
142
+ self._aggregate_reference_length.assign_add(
143
+ tf.cast(tf.size(y_true.flat_values), dtype=self.dtype)
144
+ )
145
+
146
+ def calculate_edit_distance(args):
147
+ reference, hypothesis = args
148
+
149
+ reference = tf.sparse.from_dense([reference])
150
+ hypothesis = tf.sparse.from_dense([hypothesis])
151
+
152
+ edit_distance = tf.squeeze(
153
+ tf.edit_distance(
154
+ hypothesis=hypothesis,
155
+ truth=reference,
156
+ normalize=False,
157
+ )
158
+ )
159
+
160
+ self._aggregate_unnormalized_edit_distance.assign_add(
161
+ tf.cast(edit_distance, dtype=self.dtype)
162
+ )
163
+ if not self.normalize:
164
+ self._number_of_samples.assign_add(tf.cast(1, dtype=self.dtype))
165
+ return 0
166
+
167
+ _ = tf.map_fn(
168
+ fn=calculate_edit_distance,
169
+ elems=(y_true, y_pred),
170
+ fn_output_signature="int8",
171
+ )
172
+
173
+ def result(self):
174
+ if self.normalize:
175
+ if self._aggregate_reference_length == 0:
176
+ return 0.0
177
+ return (
178
+ self._aggregate_unnormalized_edit_distance
179
+ / self._aggregate_reference_length
180
+ )
181
+ if self._number_of_samples == 0:
182
+ return 0.0
183
+ return (
184
+ self._aggregate_unnormalized_edit_distance / self._number_of_samples
185
+ )
186
+
187
+ def reset_state(self):
188
+ self._aggregate_unnormalized_edit_distance.assign(0.0)
189
+ if self.normalize:
190
+ self._aggregate_reference_length.assign(0.0)
191
+ else:
192
+ self._number_of_samples.assign(0.0)
193
+
194
+ def get_config(self):
195
+ config = super().get_config()
196
+ config.update({"normalize": self.normalize})
197
+ return config