keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,181 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.utils.tensor_utils import is_float_dtype
20
+
21
+
22
+ @keras_hub_export("keras_hub.metrics.Perplexity")
23
+ class Perplexity(keras.metrics.Metric):
24
+ """Perplexity metric.
25
+
26
+ This class implements the perplexity metric. In short, this class calculates
27
+ the cross entropy loss and takes its exponent.
28
+ Note: This implementation is not suitable for fixed-size windows.
29
+
30
+ Args:
31
+ from_logits: bool. If True, `y_pred` (input to `update_state()`) should
32
+ be the logits as returned by the model. Otherwise, `y_pred` is a
33
+ tensor of probabilities.
34
+ mask_token_id: int. ID of the token to be masked. If provided, the mask
35
+ is computed for this class. Note that if this field is provided, and
36
+ if the `sample_weight` field in `update_state()` is also provided,
37
+ we will compute the final `sample_weight` as the element-wise
38
+ product of the mask and the `sample_weight`.
39
+ dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
40
+ not specified, it defaults to `"float32"`.
41
+ name: string. Name of the metric instance.
42
+ **kwargs: Other keyword arguments.
43
+
44
+ Examples:
45
+
46
+ 1. Calculate perplexity by calling update_state() and result().
47
+ 1.1. `sample_weight`, and `mask_token_id` are not provided.
48
+ >>> np.random.seed(42)
49
+ >>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
50
+ >>> target = np.random.randint(10, size=[2, 5])
51
+ >>> logits = np.random.uniform(size=(2, 5, 10))
52
+ >>> perplexity.update_state(target, logits)
53
+ >>> perplexity.result()
54
+ <tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
55
+
56
+ 1.2. `sample_weight` specified (masking token with ID 0).
57
+ >>> np.random.seed(42)
58
+ >>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
59
+ >>> target = np.random.randint(10, size=[2, 5])
60
+ >>> logits = np.random.uniform(size=(2, 5, 10))
61
+ >>> sample_weight = (target != 0).astype("float32")
62
+ >>> perplexity.update_state(target, logits, sample_weight)
63
+ >>> perplexity.result()
64
+ <tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
65
+
66
+ 2. Call perplexity directly.
67
+ >>> np.random.seed(42)
68
+ >>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
69
+ >>> target = np.random.randint(10, size=[2, 5])
70
+ >>> logits = np.random.uniform(size=(2, 5, 10))
71
+ >>> perplexity(target, logits)
72
+ <tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
73
+
74
+ 3. Provide the padding token ID and let the class compute the mask on its
75
+ own.
76
+ >>> np.random.seed(42)
77
+ >>> perplexity = keras_hub.metrics.Perplexity(mask_token_id=0)
78
+ >>> target = np.random.randint(10, size=[2, 5])
79
+ >>> logits = np.random.uniform(size=(2, 5, 10))
80
+ >>> perplexity(target, logits)
81
+ <tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ from_logits=False,
87
+ mask_token_id=None,
88
+ dtype="float32",
89
+ name="perplexity",
90
+ **kwargs,
91
+ ):
92
+ if not is_float_dtype(dtype):
93
+ raise ValueError(
94
+ "`dtype` must be a floating point type. "
95
+ f"Received: dtype={dtype}"
96
+ )
97
+
98
+ super().__init__(name=name, dtype=dtype, **kwargs)
99
+
100
+ self._crossentropy = keras.losses.SparseCategoricalCrossentropy(
101
+ from_logits=from_logits, reduction="sum"
102
+ )
103
+
104
+ self.from_logits = from_logits
105
+ self.mask_token_id = mask_token_id
106
+
107
+ self._aggregate_crossentropy = self.add_weight(
108
+ shape=(),
109
+ initializer="zeros",
110
+ dtype=self.dtype,
111
+ name="aggregate_crossentropy",
112
+ )
113
+ self._number_of_samples = self.add_weight(
114
+ shape=(),
115
+ initializer="zeros",
116
+ dtype=self.dtype,
117
+ name="number_of_samples",
118
+ )
119
+
120
+ def update_state(self, y_true, y_pred, sample_weight=None):
121
+ # y_true shape: (batch_size, seq_len)
122
+ # y_pred shape: (batch_size, seq_len, vocab_size)
123
+ y_true = ops.cast(y_true, self.dtype)
124
+ y_pred = ops.cast(y_pred, self.dtype)
125
+
126
+ if sample_weight is not None:
127
+ sample_weight = ops.cast(sample_weight, self.dtype)
128
+
129
+ batch_size = ops.cast(ops.shape(y_true)[0], self.dtype)
130
+
131
+ if self.mask_token_id is not None:
132
+ mask = ops.cast(
133
+ ops.logical_not(ops.equal(y_true, self.mask_token_id)),
134
+ self.dtype,
135
+ )
136
+ if sample_weight is None:
137
+ sample_weight = mask
138
+ else:
139
+ sample_weight = ops.multiply(mask, sample_weight)
140
+
141
+ # Calculate the Cross Entropy Loss.
142
+ crossentropy_value = ops.cast(
143
+ self._crossentropy(y_true, y_pred, sample_weight=sample_weight),
144
+ self.dtype,
145
+ ) # scalar
146
+
147
+ # Divide the loss by the number of non-masked tokens
148
+ if sample_weight is not None:
149
+ crossentropy_value = crossentropy_value / ops.sum(
150
+ sample_weight
151
+ ) # scalar
152
+ else:
153
+ crossentropy_value = crossentropy_value / (
154
+ ops.cast(ops.shape(y_true)[0], self.dtype)
155
+ * ops.cast(ops.shape(y_true)[1], self.dtype)
156
+ ) # scalar
157
+
158
+ self._aggregate_crossentropy.assign_add(batch_size * crossentropy_value)
159
+ self._number_of_samples.assign_add(batch_size)
160
+
161
+ def result(self):
162
+ perplexity_score = ops.where(
163
+ ops.equal(ops.convert_to_tensor(self._number_of_samples), 0),
164
+ 0,
165
+ ops.exp(self._aggregate_crossentropy / self._number_of_samples),
166
+ )
167
+ return perplexity_score
168
+
169
+ def reset_state(self):
170
+ self._aggregate_crossentropy.assign(0.0)
171
+ self._number_of_samples.assign(0.0)
172
+
173
+ def get_config(self):
174
+ config = super().get_config()
175
+ config.update(
176
+ {
177
+ "from_logits": self.from_logits,
178
+ "mask_token_id": self.mask_token_id,
179
+ }
180
+ )
181
+ return config
@@ -0,0 +1,204 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.utils.tensor_utils import is_float_dtype
19
+ from keras_hub.src.utils.tensor_utils import tensor_to_list
20
+
21
+ try:
22
+ import tensorflow as tf
23
+ except ImportError:
24
+ tf = None
25
+
26
+ try:
27
+ from rouge_score import rouge_scorer
28
+ except ImportError:
29
+ rouge_scorer = None
30
+
31
+
32
+ class RougeBase(keras.metrics.Metric):
33
+ """ROUGE metric.
34
+
35
+ This class implements two variants of the ROUGE metric - ROUGE-N,
36
+ and ROUGE-L.
37
+
38
+ Note on input shapes:
39
+ For `y_true` and `y_pred`, this class supports scalar values and batch
40
+ inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`.
41
+
42
+ Args:
43
+ variant: string. One of "rougeN", "rougeL". For "rougeN", N lies in
44
+ the range [1, 9]. Defaults to `"rouge2"`.
45
+ use_stemmer: bool. Whether Porter Stemmer should be used to strip word
46
+ suffixes to improve matching. Defaults to `False`.
47
+ dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
48
+ not specified, it defaults to `"float32"`.
49
+ name: string. Name of the metric instance.
50
+ **kwargs: Other keyword arguments.
51
+
52
+ References:
53
+ - [Lin et al., 2004](https://aclanthology.org/W04-1013/)
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ variant="rouge2",
59
+ use_stemmer=False,
60
+ dtype="float32",
61
+ name="rouge",
62
+ **kwargs,
63
+ ):
64
+ super().__init__(name=name, dtype=dtype, **kwargs)
65
+
66
+ if rouge_scorer is None:
67
+ raise ImportError(
68
+ f"{self.__class__.__name__} requires the `rouge_score` "
69
+ "package. Please install it with `pip install rouge-score`."
70
+ )
71
+
72
+ if not is_float_dtype(dtype):
73
+ raise ValueError(
74
+ "`dtype` must be a floating point type. "
75
+ f"Received: dtype={dtype}"
76
+ )
77
+
78
+ if variant not in tuple(
79
+ ("rouge" + str(order) for order in range(1, 10))
80
+ ) + ("rougeL",):
81
+ raise ValueError(
82
+ "Invalid variant of ROUGE. Should be one of: rougeN, rougeL, "
83
+ "with N ranging from 1 to 9. Received: "
84
+ f"variant={variant}"
85
+ )
86
+
87
+ self.variant = variant
88
+ self.use_stemmer = use_stemmer
89
+
90
+ # To-do: Add split_summaries and tokenizer options after the maintainers
91
+ # of rouge_scorer have released a new version.
92
+ self._rouge_scorer = rouge_scorer.RougeScorer(
93
+ rouge_types=[self.variant],
94
+ use_stemmer=use_stemmer,
95
+ )
96
+
97
+ self._rouge_precision = self.add_weight(
98
+ shape=(),
99
+ initializer="zeros",
100
+ dtype=self.dtype,
101
+ name="rouge_precision",
102
+ )
103
+ self._rouge_recall = self.add_weight(
104
+ shape=(),
105
+ initializer="zeros",
106
+ dtype=self.dtype,
107
+ name="rouge_recall",
108
+ )
109
+ self._rouge_f1_score = self.add_weight(
110
+ shape=(),
111
+ initializer="zeros",
112
+ dtype=self.dtype,
113
+ name="rouge_f1_score",
114
+ )
115
+
116
+ self._number_of_samples = self.add_weight(
117
+ shape=(),
118
+ initializer="zeros",
119
+ dtype=self.dtype,
120
+ name="number_of_samples",
121
+ )
122
+
123
+ def update_state(self, y_true, y_pred, sample_weight=None):
124
+ # Three possible shapes for y_true and y_pred: Python string,
125
+ # [batch_size] and [batch_size, 1]. In the latter two cases, we have
126
+ # strings in the tensor/list.
127
+
128
+ def validate_and_fix_rank(inputs, tensor_name):
129
+ if not isinstance(inputs, tf.Tensor):
130
+ inputs = tf.convert_to_tensor(inputs)
131
+
132
+ if inputs.shape.rank == 0:
133
+ return inputs[tf.newaxis]
134
+ elif inputs.shape.rank == 1:
135
+ return inputs
136
+ elif inputs.shape.rank == 2:
137
+ if inputs.shape[1] != 1:
138
+ raise ValueError(
139
+ f"{tensor_name} must be of shape `[batch_size, 1]`. "
140
+ f"Found shape: {inputs.shape}"
141
+ )
142
+ else:
143
+ return tf.squeeze(inputs, axis=1)
144
+ else:
145
+ raise ValueError(
146
+ f"{tensor_name} must be of rank 0 (scalar input), 1 or 2. "
147
+ f"Found rank: {inputs.shape.rank}"
148
+ )
149
+
150
+ y_true = validate_and_fix_rank(y_true, "y_true")
151
+ y_pred = validate_and_fix_rank(y_pred, "y_pred")
152
+
153
+ batch_size = tf.shape(y_true)[0]
154
+
155
+ def calculate_rouge_score(reference, hypothesis):
156
+ reference = tensor_to_list(reference)
157
+ hypothesis = tensor_to_list(hypothesis)
158
+ score = self._rouge_scorer.score(reference, hypothesis)[
159
+ self.variant
160
+ ]
161
+ return score.precision, score.recall, score.fmeasure
162
+
163
+ for batch_idx in range(batch_size):
164
+ score = calculate_rouge_score(y_true[batch_idx], y_pred[batch_idx])
165
+ self._rouge_precision.assign_add(score[0])
166
+ self._rouge_recall.assign_add(score[1])
167
+ self._rouge_f1_score.assign_add(score[2])
168
+
169
+ self._number_of_samples.assign_add(
170
+ ops.cast(batch_size, dtype=self.dtype)
171
+ )
172
+
173
+ def result(self):
174
+ if self._number_of_samples == 0:
175
+ return {
176
+ "precision": 0.0,
177
+ "recall": 0.0,
178
+ "f1_score": 0.0,
179
+ }
180
+
181
+ rouge_precision = self._rouge_precision / self._number_of_samples
182
+ rouge_recall = self._rouge_recall / self._number_of_samples
183
+ rouge_f1_score = self._rouge_f1_score / self._number_of_samples
184
+ return {
185
+ "precision": rouge_precision,
186
+ "recall": rouge_recall,
187
+ "f1_score": rouge_f1_score,
188
+ }
189
+
190
+ def reset_state(self):
191
+ self._rouge_precision.assign(0.0)
192
+ self._rouge_recall.assign(0.0)
193
+ self._rouge_f1_score.assign(0.0)
194
+ self._number_of_samples.assign(0.0)
195
+
196
+ def get_config(self):
197
+ config = super().get_config()
198
+ config.update(
199
+ {
200
+ "variant": self.variant,
201
+ "use_stemmer": self.use_stemmer,
202
+ }
203
+ )
204
+ return config
@@ -0,0 +1,97 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.api_export import keras_hub_export
16
+ from keras_hub.src.metrics.rouge_base import RougeBase
17
+
18
+
19
+ @keras_hub_export("keras_hub.metrics.RougeL")
20
+ class RougeL(RougeBase):
21
+ """ROUGE-L metric.
22
+
23
+ This class implements the ROUGE-L variant of the ROUGE metric. The ROUGE-L
24
+ metric is traditionally used for evaluating summarisation systems.
25
+ Succinctly put, ROUGE-L is a score based on the length of the longest
26
+ common subsequence present in the reference text and the hypothesis text.
27
+
28
+ Note on input shapes:
29
+ For `y_true` and `y_pred`, this class supports scalar values and batch
30
+ inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`.
31
+
32
+ Args:
33
+ use_stemmer: bool. Whether Porter Stemmer should be used to strip word
34
+ suffixes to improve matching. Defaults to `False`.
35
+ dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
36
+ not specified, it defaults to `"float32"`.
37
+ name: string. Name of the metric instance.
38
+ **kwargs: Other keyword arguments.
39
+
40
+ References:
41
+ - [Lin et al., 2004](https://aclanthology.org/W04-1013/)
42
+
43
+ Examples:
44
+
45
+ 1. Python string.
46
+ >>> rouge_l = keras_hub.metrics.RougeL()
47
+ >>> y_true = "the tiny little cat was found under the big funny bed"
48
+ >>> y_pred = "the cat was under the bed"
49
+ >>> rouge_l(y_true, y_pred)["f1_score"]
50
+ <tf.Tensor: shape=(), dtype=float32, numpy=0.7058824>
51
+
52
+ 2. List inputs.
53
+ a. Python list.
54
+ >>> rouge_l = keras_hub.metrics.RougeL()
55
+ >>> y_true = [
56
+ ... "the tiny little cat was found under the big funny bed",
57
+ ... "i really love contributing to KerasHub",
58
+ ... ]
59
+ >>> y_pred = [
60
+ ... "the cat was under the bed",
61
+ ... "i love contributing to KerasHub",
62
+ ... ]
63
+ >>> rouge_l(y_true, y_pred)["f1_score"]
64
+ <tf.Tensor: shape=(), dtype=float32, numpy=0.80748665>
65
+
66
+
67
+ 3. 2D inputs.
68
+ >>> rouge_l = keras_hub.metrics.RougeL()
69
+ >>> y_true = [
70
+ ... ["the tiny little cat was found under the big funny bed"],
71
+ ... ["i really love contributing to KerasHub"],
72
+ ... ]
73
+ >>> y_pred = [
74
+ ... ["the cat was under the bed"],
75
+ ... ["i love contributing to KerasHub"],
76
+ ... ]
77
+ >>> rouge_l(y_true, y_pred)["f1_score"]
78
+ <tf.Tensor: shape=(), dtype=float32, numpy=0.80748665>
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ use_stemmer=False,
84
+ name="rouge-l",
85
+ **kwargs,
86
+ ):
87
+ super().__init__(
88
+ variant="rougeL",
89
+ use_stemmer=use_stemmer,
90
+ name=name,
91
+ **kwargs,
92
+ )
93
+
94
+ def get_config(self):
95
+ config = super().get_config()
96
+ del config["variant"]
97
+ return config
@@ -0,0 +1,125 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.api_export import keras_hub_export
16
+ from keras_hub.src.metrics.rouge_base import RougeBase
17
+
18
+
19
+ @keras_hub_export("keras_hub.metrics.RougeN")
20
+ class RougeN(RougeBase):
21
+ """ROUGE-N metric.
22
+
23
+ This class implements the ROUGE-N variant of the ROUGE metric. The ROUGE-N
24
+ metric is traditionally used for evaluating summarisation systems.
25
+ Succinctly put, ROUGE-N is a score based on the number of matching n-grams
26
+ between the reference text and the hypothesis text.
27
+
28
+ Note on input shapes:
29
+ For `y_true` and `y_pred`, this class supports scalar values and batch
30
+ inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`.
31
+
32
+ Args:
33
+ order: The order of n-grams which are to be matched. It should lie in
34
+ range [1, 9]. Defaults to `2`.
35
+ use_stemmer: bool. Whether Porter Stemmer should be used to strip word
36
+ suffixes to improve matching. Defaults to `False`.
37
+ dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
38
+ not specified, it defaults to `"float32"`.
39
+ name: string. Name of the metric instance.
40
+ **kwargs: Other keyword arguments.
41
+
42
+ References:
43
+ - [Lin et al., 2004](https://aclanthology.org/W04-1013/)
44
+
45
+ Examples:
46
+
47
+ 1. Python string.
48
+ >>> rouge_n = keras_hub.metrics.RougeN(order=2)
49
+ >>> y_true = "the tiny little cat was found under the big funny bed"
50
+ >>> y_pred = "the cat was under the bed"
51
+ >>> rouge_n(y_true, y_pred)["f1_score"]
52
+ <tf.Tensor: shape=(), dtype=float32, numpy=0.26666668>
53
+
54
+ 2. List inputs.
55
+ >>> rouge_n = keras_hub.metrics.RougeN(order=2)
56
+ >>> y_true = [
57
+ ... "the tiny little cat was found under the big funny bed",
58
+ ... "i really love contributing to KerasHub",
59
+ ... ]
60
+ >>> y_pred = [
61
+ ... "the cat was under the bed",
62
+ ... "i love contributing to KerasHub",
63
+ ... ]
64
+ >>> rouge_n(y_true, y_pred)["f1_score"]
65
+ <tf.Tensor: shape=(), dtype=float32, numpy=0.4666667>
66
+
67
+ 3. 2D inputs.
68
+ >>> rouge_n = keras_hub.metrics.RougeN(order=2)
69
+ >>> y_true =[
70
+ ... ["the tiny little cat was found under the big funny bed"],
71
+ ... ["i really love contributing to KerasHub"],
72
+ ... ]
73
+ >>> y_pred =[
74
+ ... ["the cat was under the bed"],
75
+ ... ["i love contributing to KerasHub"],
76
+ ... ]
77
+ >>> rouge_n(y_true, y_pred)["f1_score"]
78
+ <tf.Tensor: shape=(), dtype=float32, numpy=0.4666667>
79
+
80
+ 4. Trigrams.
81
+ >>> rouge_n = keras_hub.metrics.RougeN(order=3)
82
+ >>> y_true = [
83
+ ... "the tiny little cat was found under the big funny bed",
84
+ ... "i really love contributing to KerasHub",
85
+ ... ]
86
+ >>> y_pred = [
87
+ ... "the cat was under the bed",
88
+ ... "i love contributing to KerasHub",
89
+ ... ]
90
+ >>> rouge_n(y_true, y_pred)["f1_score"]
91
+ <tf.Tensor: shape=(), dtype=float32, numpy=0.2857143>
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ order=2,
97
+ use_stemmer=False,
98
+ name="rouge-n",
99
+ **kwargs,
100
+ ):
101
+ if order not in range(1, 10):
102
+ raise ValueError(
103
+ "Invalid `order` value. Should lie in the range [1, 9]."
104
+ f"Received order={order}"
105
+ )
106
+
107
+ super().__init__(
108
+ variant=f"rouge{order}",
109
+ use_stemmer=use_stemmer,
110
+ name=name,
111
+ **kwargs,
112
+ )
113
+
114
+ self.order = order
115
+
116
+ def get_config(self):
117
+ config = super().get_config()
118
+ del config["variant"]
119
+
120
+ config.update(
121
+ {
122
+ "order": self.order,
123
+ }
124
+ )
125
+ return config
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,20 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.models.albert.albert_backbone import AlbertBackbone
16
+ from keras_hub.src.models.albert.albert_presets import backbone_presets
17
+ from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer
18
+ from keras_hub.src.utils.preset_utils import register_presets
19
+
20
+ register_presets(backbone_presets, (AlbertBackbone, AlbertTokenizer))