keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,237 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+ from keras import random
18
+
19
+ from keras_hub.src.api_export import keras_hub_export
20
+ from keras_hub.src.utils.tensor_utils import any_equal
21
+
22
+
23
+ @keras_hub_export("keras_hub.samplers.Sampler")
24
+ class Sampler:
25
+ """Base sampler class.
26
+
27
+ Args:
28
+ temperature: float. optional. Used to control the
29
+ randomness of the sampling. The higher the temperature, the
30
+ more diverse the samples. Defaults to `1.0`.
31
+
32
+ Call arguments:
33
+ {{call_args}}
34
+
35
+ This base class can be extended to implement different auto-regressive
36
+ sampling methods. To do so, override the `get_next_token()` method, which
37
+ computes the next token based on a probability distribution over all
38
+ possible vocab entries.
39
+
40
+ Example:
41
+
42
+ ```python
43
+ causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
44
+
45
+ # Greedy search with some tokens forbidden.
46
+ class CustomSampler(keras_hub.samplers.Sampler):
47
+ def __init__(self, forbidden_tokens, **kwargs):
48
+ super().__init__(**kwargs)
49
+ self.forbidden_tokens = forbidden_tokens
50
+
51
+ def get_next_token(self, probs):
52
+ batch_size, vocab_size = keras.ops.shape(probs)
53
+ for id in self.forbidden_tokens:
54
+ update = keras.ops.zeros((batch_size, 1))
55
+ probs = keras.ops.slice_update(probs, (0, id), update)
56
+ return keras.ops.argmax(probs, axis=-1)
57
+
58
+ # 257 = "a" with a leading space, 262 = "the" with a leading space.
59
+ causal_lm.compile(sampler=CustomSampler(forbidden_tokens=[257, 262]))
60
+ causal_lm.summary()
61
+ causal_lm.generate(["That's strange"])
62
+ ```
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ temperature=1.0,
68
+ ):
69
+ self.temperature = temperature
70
+ self._seed_generators = []
71
+
72
+ def __setattr__(self, name, value):
73
+ # We could update to the `Tracker` class from keras-core if our needs
74
+ # become more advanced (e.g. list assignment, nested trackables). For
75
+ # now, we only track `SeedGenerator` instances directly on the sampler.
76
+ if isinstance(value, random.SeedGenerator):
77
+ self._seed_generators.append(value)
78
+ return super().__setattr__(name, value)
79
+
80
+ @property
81
+ def variables(self):
82
+ variables = []
83
+ for sg in self._seed_generators:
84
+ variables.append(sg.state)
85
+ return variables
86
+
87
+ def __call__(
88
+ self,
89
+ next,
90
+ prompt,
91
+ cache=None,
92
+ index=0,
93
+ mask=None,
94
+ stop_token_ids=None,
95
+ hidden_states=None,
96
+ model=None,
97
+ ):
98
+ max_length = ops.shape(prompt)[-1]
99
+ # Make sure `max_length` and `index` are the same dtype.
100
+ index = ops.cast(index, "int32")
101
+ max_length = ops.cast(max_length, "int32")
102
+ if mask is None:
103
+ mask = ops.zeros_like(prompt, dtype="bool")
104
+ else:
105
+ mask = ops.cast(mask, dtype="bool")
106
+ # `ops.while_loop` will not accept `None` as a value for `loop_vars`.
107
+ cache = () if cache is None else cache
108
+
109
+ def cond(prompt, cache, index):
110
+ if stop_token_ids is None:
111
+ return True
112
+ # Stop if all sequences have produced a *new* id from stop_token_ids.
113
+ end_tokens = any_equal(prompt, stop_token_ids, ~mask)
114
+ prompt_done = ops.any(end_tokens, axis=-1)
115
+ return ops.logical_not(ops.all(prompt_done))
116
+
117
+ def body(prompt, cache, index):
118
+ # Compute the softmax distribution for the next token.
119
+ logits, _, cache = next(prompt, cache, index)
120
+ probabilities = self.compute_probabilities(logits)
121
+ # Compute the next token.
122
+ next_token = self.get_next_token(probabilities)
123
+ # Don't overwrite anywhere mask is True.
124
+ next_token = ops.cast(next_token, prompt.dtype)
125
+ next_token = ops.where(mask[:, index], prompt[:, index], next_token)
126
+ # Update the prompt with the next token.
127
+ next_token = next_token[:, None]
128
+ prompt = ops.slice_update(prompt, [0, index], next_token)
129
+
130
+ # Return the next prompt, cache and incremented index.
131
+ return (prompt, cache, index + 1)
132
+
133
+ prompt, _, _ = self.run_loop(
134
+ cond,
135
+ body,
136
+ loop_vars=(prompt, cache, index),
137
+ maximum_iterations=(max_length - index),
138
+ model=model,
139
+ )
140
+ return prompt
141
+
142
+ def compute_probabilities(self, logits):
143
+ """Compute token probabilities from logits.
144
+
145
+ This will always be done in full precision, regardless of dtype, and
146
+ scale by `temperature`.
147
+ """
148
+ logits = ops.cast(logits, "float32")
149
+ return keras.activations.softmax(logits / self.temperature)
150
+
151
+ def run_loop(
152
+ self, cond, body, model=None, loop_vars=None, maximum_iterations=None
153
+ ):
154
+ """Run ops.while_loops with a `StatelessScope` if necessary."""
155
+ if keras.config.backend() == "jax":
156
+ import itertools
157
+
158
+ if model:
159
+ model_trainable_variables = model.trainable_variables
160
+ model_non_trainable_variables = model.non_trainable_variables
161
+ else:
162
+ model_trainable_variables = []
163
+ model_non_trainable_variables = []
164
+
165
+ def stateless_cond(state, *loop_vars):
166
+ return cond(*loop_vars)
167
+
168
+ def stateless_body(state, *loop_vars):
169
+ (
170
+ sampler_variables,
171
+ trainable_variables,
172
+ non_trainable_variables,
173
+ ) = state
174
+ mapping = itertools.chain(
175
+ zip(self.variables, sampler_variables),
176
+ zip(model_trainable_variables, trainable_variables),
177
+ zip(model_non_trainable_variables, non_trainable_variables),
178
+ )
179
+ with keras.StatelessScope(state_mapping=mapping) as scope:
180
+ loop_vars = body(*loop_vars)
181
+
182
+ sampler_variables = []
183
+ for v in self.variables:
184
+ new_v = scope.get_current_value(v)
185
+ sampler_variables.append(new_v if new_v is not None else v)
186
+ state = (
187
+ sampler_variables,
188
+ trainable_variables,
189
+ non_trainable_variables,
190
+ )
191
+ return state, *loop_vars
192
+
193
+ variables = [ops.convert_to_tensor(v) for v in self.variables]
194
+ trainable_variables = [
195
+ ops.convert_to_tensor(v) for v in model_trainable_variables
196
+ ]
197
+ non_trainable_variables = [
198
+ ops.convert_to_tensor(v) for v in model_non_trainable_variables
199
+ ]
200
+ state = (
201
+ variables,
202
+ trainable_variables,
203
+ non_trainable_variables,
204
+ )
205
+ state, *loop_vars = ops.while_loop(
206
+ cond=stateless_cond,
207
+ body=stateless_body,
208
+ loop_vars=(state, *loop_vars),
209
+ maximum_iterations=maximum_iterations,
210
+ )
211
+ for ref_v, v in zip(self.variables, state[0]):
212
+ ref_v.assign(v)
213
+ else:
214
+ loop_vars = ops.while_loop(
215
+ cond=cond,
216
+ body=body,
217
+ loop_vars=(loop_vars),
218
+ maximum_iterations=maximum_iterations,
219
+ )
220
+ return loop_vars
221
+
222
+ def get_next_token(self, probabilities):
223
+ """Get the next token.
224
+ Args:
225
+ probabilities: a Tensor, the probability distribution for next
226
+ token over all vocab tokens.
227
+ Get the next token based on given probability distribution over tokens.
228
+ Subclasses must implement this method.
229
+ """
230
+ raise NotImplementedError
231
+
232
+ @classmethod
233
+ def from_config(cls, config):
234
+ return cls(**config)
235
+
236
+ def get_config(self):
237
+ return {"temperature": self.temperature}
@@ -0,0 +1,97 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.samplers.beam_sampler import BeamSampler
19
+ from keras_hub.src.samplers.contrastive_sampler import ContrastiveSampler
20
+ from keras_hub.src.samplers.greedy_sampler import GreedySampler
21
+ from keras_hub.src.samplers.random_sampler import RandomSampler
22
+ from keras_hub.src.samplers.top_k_sampler import TopKSampler
23
+ from keras_hub.src.samplers.top_p_sampler import TopPSampler
24
+
25
+
26
+ @keras_hub_export("keras_hub.samplers.serialize")
27
+ def serialize(sampler):
28
+ return keras.saving.serialize_keras_object(sampler)
29
+
30
+
31
+ @keras_hub_export("keras_hub.samplers.deserialize")
32
+ def deserialize(config, custom_objects=None):
33
+ """Return a `Sampler` object from its config."""
34
+ all_classes = {
35
+ "beam": BeamSampler,
36
+ "contrastive": ContrastiveSampler,
37
+ "greedy": GreedySampler,
38
+ "random": RandomSampler,
39
+ "top_k": TopKSampler,
40
+ "top_p": TopPSampler,
41
+ }
42
+ return keras.saving.deserialize_keras_object(
43
+ config,
44
+ module_objects=all_classes,
45
+ custom_objects=custom_objects,
46
+ printable_module_name="samplers",
47
+ )
48
+
49
+
50
+ @keras_hub_export("keras_hub.samplers.get")
51
+ def get(identifier):
52
+ """Retrieve a KerasHub sampler by the identifier.
53
+
54
+ The `identifier` may be the string name of a sampler class or class.
55
+
56
+ >>> identifier = 'greedy'
57
+ >>> sampler = keras_hub.samplers.get(identifier)
58
+
59
+ You can also specify `config` of the sampler to this function by passing
60
+ dict containing `class_name` and `config` as an identifier. Also note that
61
+ the `class_name` must map to a `Sampler` class.
62
+
63
+ >>> cfg = {'class_name': 'keras_hub>GreedySampler', 'config': {}}
64
+ >>> sampler = keras_hub.samplers.get(cfg)
65
+
66
+ In the case that the `identifier` is a class, this method will return a new
67
+ instance of the class by its constructor.
68
+
69
+ Args:
70
+ identifier: String or dict that contains the sampler name or
71
+ configurations.
72
+
73
+ Returns:
74
+ Sampler instance base on the input identifier.
75
+
76
+ Raises:
77
+ ValueError: If the input identifier is not a supported type or in a bad
78
+ format.
79
+ """
80
+
81
+ if identifier is None:
82
+ return None
83
+ if isinstance(identifier, dict):
84
+ return deserialize(identifier)
85
+ elif isinstance(identifier, str):
86
+ if not identifier.islower():
87
+ raise KeyError(
88
+ "`keras_hub.samplers.get()` must take a lowercase string "
89
+ f"identifier, but received: {identifier}."
90
+ )
91
+ return deserialize(identifier)
92
+ elif callable(identifier):
93
+ return identifier
94
+ else:
95
+ raise ValueError(
96
+ "Could not interpret sampler identifier: " + str(identifier)
97
+ )
@@ -0,0 +1,92 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras import ops
16
+ from keras import random
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.samplers.sampler import Sampler
20
+
21
+
22
+ @keras_hub_export("keras_hub.samplers.TopKSampler")
23
+ class TopKSampler(Sampler):
24
+ """Top-K Sampler class.
25
+
26
+ This sampler implements top-k search algorithm. Briefly, top-k algorithm
27
+ randomly selects a token from the tokens of top K probability, with
28
+ selection chance determined by the probability.
29
+
30
+ Args:
31
+ k: int, the `k` value of top-k.
32
+ seed: int. The random seed. Defaults to `None`.
33
+
34
+ Call arguments:
35
+ {{call_args}}
36
+
37
+ Examples:
38
+ ```python
39
+ causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
40
+
41
+ # Pass by name to compile.
42
+ causal_lm.compile(sampler="top_k")
43
+ causal_lm.generate(["Keras is a"])
44
+
45
+ # Pass by object to compile.
46
+ sampler = keras_hub.samplers.TopKSampler(k=5, temperature=0.7)
47
+ causal_lm.compile(sampler=sampler)
48
+ causal_lm.generate(["Keras is a"])
49
+ ```
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ k=5,
55
+ seed=None,
56
+ **kwargs,
57
+ ):
58
+ super().__init__(**kwargs)
59
+ self.k = k
60
+ self.seed = seed
61
+ self.seed_generator = random.SeedGenerator(seed)
62
+
63
+ def get_next_token(self, probabilities):
64
+ # Filter out top-k tokens.
65
+ top_k_pred, top_k_indices = ops.top_k(
66
+ probabilities,
67
+ k=self.k,
68
+ sorted=False,
69
+ )
70
+ # Sample the next token from the probability distribution.
71
+ sample_indices = random.categorical(
72
+ # tf does not support half precision multinomial sampling, so make
73
+ # sure we have full precision here.
74
+ ops.cast(ops.log(top_k_pred), "float32"),
75
+ 1,
76
+ seed=self.seed_generator,
77
+ dtype="int32",
78
+ )
79
+
80
+ # Rearrange to get the next token idx from the original order.
81
+ output = ops.take_along_axis(top_k_indices, sample_indices, axis=-1)
82
+ return ops.squeeze(output, axis=-1)
83
+
84
+ def get_config(self):
85
+ config = super().get_config()
86
+ config.update(
87
+ {
88
+ "k": self.k,
89
+ "seed": self.seed,
90
+ }
91
+ )
92
+ return config
@@ -0,0 +1,113 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras import ops
16
+ from keras import random
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.samplers.sampler import Sampler
20
+
21
+
22
+ @keras_hub_export("keras_hub.samplers.TopPSampler")
23
+ class TopPSampler(Sampler):
24
+ """Top-P Sampler class.
25
+
26
+ This sampler implements top-p search algorithm. Top-p search selects tokens
27
+ from the smallest subset of output probabilities that sum to greater than
28
+ `p`. Put in another way, top-p will first order token predictions by
29
+ likelihood, and ignore all tokens after the cumulative probability of
30
+ selected tokens exceeds `p`, then select a token from the remaining tokens.
31
+
32
+ Args:
33
+ p: float, the `p` value of top-p.
34
+ k: int. If set, this argument defines a
35
+ heuristic "top-k" cutoff applied before the "top-p" sampling. All
36
+ logits not in the top `k` will be discarded, and the remaining
37
+ logits will be sorted to find a cutoff point for `p`. Setting this
38
+ arg can significantly speed sampling up by reducing the number
39
+ of tokens to sort. Defaults to `None`.
40
+ seed: int. The random seed. Defaults to `None`.
41
+
42
+ Call arguments:
43
+ {{call_args}}
44
+
45
+ Examples:
46
+ ```python
47
+ causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
48
+
49
+ # Pass by name to compile.
50
+ causal_lm.compile(sampler="top_p")
51
+ causal_lm.generate(["Keras is a"])
52
+
53
+ # Pass by object to compile.
54
+ sampler = keras_hub.samplers.TopPSampler(p=0.1, k=1_000)
55
+ causal_lm.compile(sampler=sampler)
56
+ causal_lm.generate(["Keras is a"])
57
+ ```
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ p=0.1,
63
+ k=None,
64
+ seed=None,
65
+ **kwargs,
66
+ ):
67
+ super().__init__(**kwargs)
68
+ self.p = p
69
+ self.k = k
70
+ self.seed = seed
71
+ self.seed_generator = random.SeedGenerator(seed)
72
+
73
+ def get_next_token(self, probabilities):
74
+ cutoff = ops.shape(probabilities)[1]
75
+ if self.k is not None:
76
+ # If `k` is set, only sample from top `k` tokens.
77
+ cutoff = self.k
78
+ sorted_preds, sorted_indices = ops.top_k(
79
+ probabilities, k=cutoff, sorted=True
80
+ )
81
+ # Calculate cumulative probability distribution.
82
+ cumulative_probabilities = ops.cumsum(sorted_preds, axis=-1)
83
+ # Create a mask for the tokens to keep.
84
+ keep_mask = cumulative_probabilities <= self.p
85
+ # Shift to include the last token that exceed p.
86
+ shifted_keep_mask = ops.concatenate(
87
+ [ops.ones_like(keep_mask[:, :1]), keep_mask[:, :-1]], axis=-1
88
+ )
89
+ # Filter out unmasked tokens and sample from filtered distribution.
90
+ probabilities = ops.where(
91
+ shifted_keep_mask,
92
+ sorted_preds,
93
+ ops.zeros(ops.shape(sorted_preds), dtype=sorted_preds.dtype),
94
+ )
95
+ sorted_next_token = random.categorical(
96
+ ops.log(probabilities),
97
+ 1,
98
+ seed=self.seed_generator,
99
+ dtype="int32",
100
+ )
101
+ output = ops.take_along_axis(sorted_indices, sorted_next_token, axis=-1)
102
+ return ops.squeeze(output, axis=-1)
103
+
104
+ def get_config(self):
105
+ config = super().get_config()
106
+ config.update(
107
+ {
108
+ "p": self.p,
109
+ "k": self.k,
110
+ "seed": self.seed,
111
+ }
112
+ )
113
+ return config
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.