keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,383 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import itertools
16
+ from functools import partial
17
+
18
+ import keras
19
+ from keras import ops
20
+ from keras import tree
21
+
22
+ from keras_hub.src.api_export import keras_hub_export
23
+ from keras_hub.src.models.task import Task
24
+ from keras_hub.src.samplers.serialization import get as get_sampler
25
+ from keras_hub.src.utils.tensor_utils import tensor_to_list
26
+
27
+ try:
28
+ import tensorflow as tf
29
+ except ImportError:
30
+ tf = None
31
+
32
+
33
+ @keras_hub_export("keras_hub.models.CausalLM")
34
+ class CausalLM(Task):
35
+ """Base class for generative language modeling tasks.
36
+
37
+ `CausalLM` tasks wrap a `keras_hub.models.Backbone` and
38
+ a `keras_hub.models.Preprocessor` to create a model that can be used for
39
+ generation and generative fine-tuning.
40
+
41
+ `CausalLM` tasks provide an additional, high-level `generate()` function
42
+ which can be used to auto-regressively sample a model token by token with a
43
+ string in, string out signature. The `compile()` method of all `CausalLM`
44
+ classes contains an additional `sampler` argument, which can be used to pass
45
+ a `keras_hub.samplers.Sampler` to control how the predicted distribution
46
+ will be sampled.
47
+
48
+ When calling `fit()`, the tokenized input will be predicted token-by-token
49
+ with a causal mask applied, which gives both a pre-training and supervised
50
+ fine-tuning setup for controlling inference-time generation.
51
+
52
+ All `CausalLM` tasks include a `from_preset()` constructor which can be used
53
+ to load a pre-trained config and weights.
54
+
55
+ Example:
56
+ ```python
57
+ # Load a GPT2 backbone with pre-trained weights.
58
+ causal_lm = keras_hub.models.CausalLM.from_preset(
59
+ "gpt2_base_en",
60
+ )
61
+ causal_lm.compile(sampler="top_k")
62
+ causal_lm.generate("Keras is a", max_length=64)
63
+
64
+ # Load a Mistral instruction tuned checkpoint at bfloat16 precision.
65
+ causal_lm = keras_hub.models.CausalLM.from_preset(
66
+ "mistral_instruct_7b_en",
67
+ dtype="bfloat16",
68
+ )
69
+ causal_lm.compile(sampler="greedy")
70
+ causal_lm.generate("Keras is a", max_length=64)
71
+ ```
72
+ """
73
+
74
+ def __init__(self, *args, **kwargs):
75
+ super().__init__(*args, **kwargs)
76
+ # Default compilation.
77
+ self.compile()
78
+
79
+ def compile(
80
+ self,
81
+ optimizer="auto",
82
+ loss="auto",
83
+ *,
84
+ weighted_metrics="auto",
85
+ sampler="top_k",
86
+ **kwargs,
87
+ ):
88
+ """Configures the `CausalLM` task for training and generation.
89
+
90
+ The `CausalLM` task extends the default compilation signature of
91
+ `keras.Model.compile` with defaults for `optimizer`, `loss`, and
92
+ `weighted_metrics`. To override these defaults, pass any value
93
+ to these arguments during compilation.
94
+
95
+ The `CausalLM` task adds a new `sampler` to `compile`, which can be used
96
+ to control the sampling strategy used with the `generate` function.
97
+
98
+ Note that because training inputs include padded tokens which are
99
+ excluded from the loss, it is almost always a good idea to compile with
100
+ `weighted_metrics` and not `metrics`.
101
+
102
+ Args:
103
+ optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
104
+ instance. Defaults to `"auto"`, which uses the default optimizer
105
+ for the given model and task. See `keras.Model.compile` and
106
+ `keras.optimizers` for more info on possible `optimizer` values.
107
+ loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
108
+ Defaults to `"auto"`, where a
109
+ `keras.losses.SparseCategoricalCrossentropy` loss will be
110
+ applied for the token classification `CausalLM` task. See
111
+ `keras.Model.compile` and `keras.losses` for more info on
112
+ possible `loss` values.
113
+ weighted_metrics: `"auto"`, or a list of metrics to be evaluated by
114
+ the model during training and testing. Defaults to `"auto"`,
115
+ where a `keras.metrics.SparseCategoricalAccuracy` will be
116
+ applied to track the accuracy of the model at guessing masked
117
+ token values. See `keras.Model.compile` and `keras.metrics` for
118
+ more info on possible `weighted_metrics` values.
119
+ sampler: A sampler name, or a `keras_hub.samplers.Sampler` instance.
120
+ Configures the sampling method used during `generate()` calls.
121
+ See `keras_hub.samplers` for a full list of built-in sampling
122
+ strategies.
123
+ **kwargs: See `keras.Model.compile` for a full list of arguments
124
+ supported by the compile method.
125
+ """
126
+ if optimizer == "auto":
127
+ optimizer = keras.optimizers.Adam(2e-5)
128
+ if loss == "auto":
129
+ loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
130
+ if weighted_metrics == "auto":
131
+ weighted_metrics = [keras.metrics.SparseCategoricalAccuracy()]
132
+ super().compile(
133
+ optimizer=optimizer,
134
+ loss=loss,
135
+ weighted_metrics=weighted_metrics,
136
+ **kwargs,
137
+ )
138
+ self.sampler = get_sampler(sampler)
139
+ # Clear the compiled generate function.
140
+ self.generate_function = None
141
+
142
+ def generate_step(self):
143
+ """Run generation on a single batch of input."""
144
+ raise NotImplementedError
145
+
146
+ def make_generate_function(self):
147
+ """Create or return the compiled generation function."""
148
+ if self.generate_function is not None:
149
+ return self.generate_function
150
+
151
+ self.generate_function = self.generate_step
152
+ if keras.config.backend() == "torch":
153
+ import torch
154
+
155
+ def wrapped_generate_function(
156
+ inputs,
157
+ stop_token_ids=None,
158
+ ):
159
+ with torch.no_grad():
160
+ return self.generate_step(inputs, stop_token_ids)
161
+
162
+ self.generate_function = wrapped_generate_function
163
+ elif keras.config.backend() == "tensorflow" and not self.run_eagerly:
164
+ # `jit_compile` is a property of keras.Model after TF 2.12.
165
+ # Use `getattr()` for backwards compatibility.
166
+ jit_compile = getattr(self, "jit_compile", True)
167
+ self.generate_function = tf.function(
168
+ self.generate_step, jit_compile=jit_compile
169
+ )
170
+ elif keras.config.backend() == "jax" and not self.run_eagerly:
171
+ import jax
172
+
173
+ @partial(jax.jit, static_argnames=["stop_token_ids"])
174
+ def compiled_generate_function(inputs, stop_token_ids, state):
175
+ (
176
+ sampler_variables,
177
+ trainable_variables,
178
+ non_trainable_variables,
179
+ ) = state
180
+ mapping = itertools.chain(
181
+ zip(self.sampler.variables, sampler_variables),
182
+ zip(self.trainable_variables, trainable_variables),
183
+ zip(self.non_trainable_variables, non_trainable_variables),
184
+ )
185
+
186
+ with keras.StatelessScope(state_mapping=mapping) as scope:
187
+ outputs = self.generate_step(inputs, stop_token_ids)
188
+
189
+ # Get updated sampler variables from the stateless scope.
190
+ sampler_variables = []
191
+ for v in self.sampler.variables:
192
+ new_v = scope.get_current_value(v)
193
+ sampler_variables.append(new_v if new_v is not None else v)
194
+ return outputs, sampler_variables
195
+
196
+ def wrapped_generate_function(
197
+ inputs,
198
+ stop_token_ids=None,
199
+ ):
200
+ if isinstance(stop_token_ids, list):
201
+ stop_token_ids = tuple(stop_token_ids)
202
+
203
+ # Create an explicit tuple of all variable state.
204
+ state = (
205
+ self.sampler.variables,
206
+ # Use the explicit variable.value to preserve the
207
+ # sharding spec of distribution.
208
+ [v.value for v in self.trainable_variables],
209
+ [v.value for v in self.non_trainable_variables],
210
+ )
211
+ inputs = tree.map_structure(ops.convert_to_tensor, inputs)
212
+ outputs, sampler_variables = compiled_generate_function(
213
+ inputs,
214
+ stop_token_ids,
215
+ state,
216
+ )
217
+ # Only assign the sampler variables (random seeds), as other
218
+ # model variables should never be updated in generation.
219
+ for ref_v, v in zip(self.sampler.variables, sampler_variables):
220
+ ref_v.assign(v)
221
+ return outputs
222
+
223
+ self.generate_function = wrapped_generate_function
224
+
225
+ return self.generate_function
226
+
227
+ def _normalize_generate_inputs(
228
+ self,
229
+ inputs,
230
+ ):
231
+ """Normalize user input to the generate function.
232
+
233
+ This function converts all inputs to tensors, adds a batch dimension if
234
+ necessary, and returns a iterable "dataset like" object (either an
235
+ actual `tf.data.Dataset` or a list with a single batch element).
236
+ """
237
+ input_is_scalar = False
238
+
239
+ if isinstance(inputs, tf.data.Dataset):
240
+ return inputs, input_is_scalar
241
+
242
+ def normalize(x):
243
+ x_is_scalar = False
244
+ if isinstance(x, str) or isinstance(x, list):
245
+ x = tf.convert_to_tensor(x)
246
+
247
+ if isinstance(x, tf.Tensor) and x.shape.rank == 0:
248
+ x_is_scalar = True
249
+ x = x[tf.newaxis]
250
+
251
+ return x, x_is_scalar
252
+
253
+ if isinstance(inputs, dict):
254
+ for key in inputs:
255
+ inputs[key], input_is_scalar = normalize(inputs[key])
256
+ else:
257
+ inputs, input_is_scalar = normalize(inputs)
258
+
259
+ # We avoid converting to a dataset purely for speed, for a single batch
260
+ # of input, creating a dataset would add significant overhead.
261
+ return [inputs], input_is_scalar
262
+
263
+ def _normalize_generate_outputs(
264
+ self,
265
+ outputs,
266
+ input_is_scalar,
267
+ ):
268
+ """Normalize user output from the generate function.
269
+
270
+ This function converts all output to numpy (for integer output), or
271
+ python strings (for string output). If a batch dimension was added to
272
+ the input, it is removed from the output (so generate can be string in,
273
+ string out).
274
+ """
275
+
276
+ def normalize(x):
277
+ if isinstance(x[0], list):
278
+ outputs = []
279
+ for batch in x:
280
+ for e in batch:
281
+ outputs.append(e)
282
+ return outputs[0] if input_is_scalar else outputs
283
+ if isinstance(x[0], tf.Tensor) and x[0].dtype == tf.string:
284
+ outputs = tf.concat(x, axis=0)
285
+ outputs = tf.squeeze(outputs, 0) if input_is_scalar else outputs
286
+ return tensor_to_list(outputs)
287
+ outputs = ops.concatenate(x, axis=0)
288
+ outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs
289
+ return ops.convert_to_numpy(outputs)
290
+
291
+ if isinstance(outputs[0], dict):
292
+ normalized = {}
293
+ for key in outputs[0]:
294
+ normalized[key] = normalize([x[key] for x in outputs])
295
+ return normalized
296
+ return normalize([x for x in outputs])
297
+
298
+ def generate(
299
+ self,
300
+ inputs,
301
+ max_length=None,
302
+ stop_token_ids="auto",
303
+ ):
304
+ """Generate text given prompt `inputs`.
305
+
306
+ This method generates text based on given `inputs`. The sampling method
307
+ used for generation can be set via the `compile()` method.
308
+
309
+ If `inputs` are a `tf.data.Dataset`, outputs will be generated
310
+ "batch-by-batch" and concatenated. Otherwise, all inputs will be handled
311
+ as a single batch.
312
+
313
+ If a `preprocessor` is attached to the model, `inputs` will be
314
+ preprocessed inside the `generate()` function and should match the
315
+ structure expected by the `preprocessor` layer (usually raw strings).
316
+ If a `preprocessor` is not attached, inputs should match the structure
317
+ expected by the `backbone`. See the example usage above for a
318
+ demonstration of each.
319
+
320
+ Args:
321
+ inputs: python data, tensor data, or a `tf.data.Dataset`. If a
322
+ `preprocessor` is attached to the model, `inputs` should match
323
+ the structure expected by the `preprocessor` layer. If a
324
+ `preprocessor` is not attached, `inputs` should match the
325
+ structure expected the `backbone` model.
326
+ max_length: Optional. int. The max length of the generated sequence.
327
+ Will default to the max configured `sequence_length` of the
328
+ `preprocessor`. If `preprocessor` is `None`, `inputs` should be
329
+ should be padded to the desired maximum length and this argument
330
+ will be ignored.
331
+ stop_token_ids: Optional. `None`, "auto", or tuple of token ids. Defaults
332
+ to "auto" which uses the `preprocessor.tokenizer.end_token_id`.
333
+ Not specifying a processor will produce an error. None stops
334
+ generation after generating `max_length` tokens. You may also
335
+ specify a list of token id's the model should stop on. Note that
336
+ sequences of tokens will each be interpreted as a stop token,
337
+ multi-token stop sequences are not supported.
338
+ """
339
+ # Setup our three main passes.
340
+ # 1. Optionally preprocessing strings to dense integer tensors.
341
+ # 2. Generate new tokens via a compiled function on dense tensors.
342
+ # 3. Optionally postprocess dense integer tensors back to string.
343
+ generate_function = self.make_generate_function()
344
+
345
+ if self.preprocessor is None and stop_token_ids == "auto":
346
+ raise ValueError(
347
+ 'A `preprocessor` must be attached to the model if `stop_token_ids="auto"`. '
348
+ "Currently `preprocessor=None`. To call `generate()` with preprocessing "
349
+ "detached, either pass `stop_token_ids=None` to always generate until "
350
+ "`max_length` or pass a tuple of token ids that should terminate generation "
351
+ "as `stop_token_ids`."
352
+ )
353
+ elif stop_token_ids == "auto":
354
+ stop_token_ids = [self.preprocessor.tokenizer.end_token_id]
355
+
356
+ def preprocess(x):
357
+ return self.preprocessor.generate_preprocess(
358
+ x, sequence_length=max_length
359
+ )
360
+
361
+ def generate(x):
362
+ return generate_function(x, stop_token_ids=stop_token_ids)
363
+
364
+ def postprocess(x):
365
+ return self.preprocessor.generate_postprocess(x)
366
+
367
+ # Normalize inputs, apply our three passes, and normalize outputs.
368
+ inputs, input_is_scalar = self._normalize_generate_inputs(inputs)
369
+
370
+ if self.preprocessor is not None:
371
+ if isinstance(inputs, tf.data.Dataset):
372
+ inputs = inputs.map(preprocess, tf.data.AUTOTUNE)
373
+ inputs = inputs.prefetch(tf.data.AUTOTUNE)
374
+ else:
375
+ # Fast path for non-dataset, single-batch input.
376
+ inputs = [preprocess(x) for x in inputs]
377
+
378
+ outputs = [generate(x) for x in inputs]
379
+
380
+ if self.preprocessor is not None:
381
+ outputs = [postprocess(x) for x in outputs]
382
+
383
+ return self._normalize_generate_outputs(outputs, input_is_scalar)
@@ -0,0 +1,109 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.task import Task
18
+
19
+
20
+ @keras_hub_export("keras_hub.models.Classifier")
21
+ class Classifier(Task):
22
+ """Base class for all classification tasks.
23
+
24
+ `Classifier` tasks wrap a `keras_hub.models.Backbone` and
25
+ a `keras_hub.models.Preprocessor` to create a model that can be used for
26
+ sequence classification. `Classifier` tasks take an additional
27
+ `num_classes` argument, controlling the number of predicted output classes.
28
+
29
+ To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
30
+ labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
31
+
32
+ All `Classifier` tasks include a `from_preset()` constructor which can be
33
+ used to load a pre-trained config and weights.
34
+
35
+ Example:
36
+ ```python
37
+ # Load a BERT classifier with pre-trained weights.
38
+ classifier = keras_hub.models.Classifier.from_preset(
39
+ "bert_base_en",
40
+ num_classes=2,
41
+ )
42
+ # Fine-tune on IMDb movie reviews (or any dataset).
43
+ imdb_train, imdb_test = tfds.load(
44
+ "imdb_reviews",
45
+ split=["train", "test"],
46
+ as_supervised=True,
47
+ batch_size=16,
48
+ )
49
+ classifier.fit(imdb_train, validation_data=imdb_test)
50
+ # Predict two new examples.
51
+ classifier.predict(["What an amazing movie!", "A total waste of my time."])
52
+ ```
53
+ """
54
+
55
+ def __init__(self, *args, **kwargs):
56
+ super().__init__(*args, **kwargs)
57
+ # Default compilation.
58
+ self.compile()
59
+
60
+ def compile(
61
+ self,
62
+ optimizer="auto",
63
+ loss="auto",
64
+ *,
65
+ metrics="auto",
66
+ **kwargs,
67
+ ):
68
+ """Configures the `Classifier` task for training.
69
+
70
+ The `Classifier` task extends the default compilation signature of
71
+ `keras.Model.compile` with defaults for `optimizer`, `loss`, and
72
+ `metrics`. To override these defaults, pass any value
73
+ to these arguments during compilation.
74
+
75
+ Args:
76
+ optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
77
+ instance. Defaults to `"auto"`, which uses the default optimizer
78
+ for the given model and task. See `keras.Model.compile` and
79
+ `keras.optimizers` for more info on possible `optimizer` values.
80
+ loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
81
+ Defaults to `"auto"`, where a
82
+ `keras.losses.SparseCategoricalCrossentropy` loss will be
83
+ applied for the classification task. See
84
+ `keras.Model.compile` and `keras.losses` for more info on
85
+ possible `loss` values.
86
+ metrics: `"auto"`, or a list of metrics to be evaluated by
87
+ the model during training and testing. Defaults to `"auto"`,
88
+ where a `keras.metrics.SparseCategoricalAccuracy` will be
89
+ applied to track the accuracy of the model during training.
90
+ See `keras.Model.compile` and `keras.metrics` for
91
+ more info on possible `metrics` values.
92
+ **kwargs: See `keras.Model.compile` for a full list of arguments
93
+ supported by the compile method.
94
+ """
95
+ if optimizer == "auto":
96
+ optimizer = keras.optimizers.Adam(5e-5)
97
+ if loss == "auto":
98
+ activation = getattr(self, "activation", None)
99
+ activation = keras.activations.get(activation)
100
+ from_logits = activation != keras.activations.softmax
101
+ loss = keras.losses.SparseCategoricalCrossentropy(from_logits)
102
+ if metrics == "auto":
103
+ metrics = [keras.metrics.SparseCategoricalAccuracy()]
104
+ super().compile(
105
+ optimizer=optimizer,
106
+ loss=loss,
107
+ metrics=metrics,
108
+ **kwargs,
109
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.