keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,419 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import keras
18
+ from rich import console as rich_console
19
+ from rich import markup
20
+ from rich import table as rich_table
21
+
22
+ from keras_hub.src.api_export import keras_hub_export
23
+ from keras_hub.src.utils.keras_utils import print_msg
24
+ from keras_hub.src.utils.pipeline_model import PipelineModel
25
+ from keras_hub.src.utils.preset_utils import CONFIG_FILE
26
+ from keras_hub.src.utils.preset_utils import MODEL_WEIGHTS_FILE
27
+ from keras_hub.src.utils.preset_utils import TASK_CONFIG_FILE
28
+ from keras_hub.src.utils.preset_utils import TASK_WEIGHTS_FILE
29
+ from keras_hub.src.utils.preset_utils import check_config_class
30
+ from keras_hub.src.utils.preset_utils import check_file_exists
31
+ from keras_hub.src.utils.preset_utils import check_format
32
+ from keras_hub.src.utils.preset_utils import get_file
33
+ from keras_hub.src.utils.preset_utils import jax_memory_cleanup
34
+ from keras_hub.src.utils.preset_utils import list_presets
35
+ from keras_hub.src.utils.preset_utils import list_subclasses
36
+ from keras_hub.src.utils.preset_utils import load_serialized_object
37
+ from keras_hub.src.utils.preset_utils import save_serialized_object
38
+ from keras_hub.src.utils.python_utils import classproperty
39
+
40
+
41
+ @keras_hub_export("keras_hub.models.Task")
42
+ class Task(PipelineModel):
43
+ """Base class for all Task models.
44
+
45
+ A `Task` wraps a `keras_hub.models.Backbone` and
46
+ a `keras_hub.models.Preprocessor` to create a model that can be directly
47
+ used for training, fine-tuning, and prediction for a given text problem.
48
+
49
+ All `Task` models have `backbone` and `preprocessor` properties. By
50
+ default `fit()`, `predict()` and `evaluate()` will preprocess all inputs
51
+ automatically. To preprocess inputs separately or with a custom function,
52
+ you can set `task.preprocessor = None`, which disable any automatic
53
+ preprocessing on inputs.
54
+
55
+ All `Task` classes include a `from_preset()` constructor which can be used
56
+ to load a pre-trained config and weights. Calling `from_preset()` on a task
57
+ will automatically instantiate a `keras_hub.models.Backbone` and
58
+ `keras_hub.models.Preprocessor`.
59
+ """
60
+
61
+ backbone_cls = None
62
+ preprocessor_cls = None
63
+
64
+ def __init__(self, *args, **kwargs):
65
+ super().__init__(*args, **kwargs)
66
+ self._functional_layer_ids = set(
67
+ id(layer) for layer in self._flatten_layers()
68
+ )
69
+ self._initialized = True
70
+ if self.backbone is not None:
71
+ self.dtype_policy = self._backbone.dtype_policy
72
+
73
+ def preprocess_samples(self, x, y=None, sample_weight=None):
74
+ if self.preprocessor is not None:
75
+ return self.preprocessor(x, y=y, sample_weight=sample_weight)
76
+ else:
77
+ return super().preprocess_samples(x, y, sample_weight)
78
+
79
+ def __setattr__(self, name, value):
80
+ # Work around setattr issues for Keras 2 and Keras 3 torch backend.
81
+ # Since all our state is covered by functional model we can route
82
+ # around custom setattr calls.
83
+ is_property = isinstance(getattr(type(self), name, None), property)
84
+ is_unitialized = not hasattr(self, "_initialized")
85
+ is_torch = keras.config.backend() == "torch"
86
+ if is_torch and (is_property or is_unitialized):
87
+ return object.__setattr__(self, name, value)
88
+ return super().__setattr__(name, value)
89
+
90
+ @property
91
+ def backbone(self):
92
+ """A `keras_hub.models.Backbone` model with the core architecture."""
93
+ return getattr(self, "_backbone", None)
94
+
95
+ @backbone.setter
96
+ def backbone(self, value):
97
+ self._backbone = value
98
+
99
+ @property
100
+ def preprocessor(self):
101
+ """A `keras_hub.models.Preprocessor` layer used to preprocess input."""
102
+ return getattr(self, "_preprocessor", None)
103
+
104
+ @preprocessor.setter
105
+ def preprocessor(self, value):
106
+ self._preprocessor = value
107
+
108
+ def get_config(self):
109
+ # Don't chain to super here. The default `get_config()` for functional
110
+ # models is nested and cannot be passed to our Task constructors.
111
+ return {
112
+ "backbone": keras.layers.serialize(self.backbone),
113
+ "preprocessor": keras.layers.serialize(self.preprocessor),
114
+ "name": self.name,
115
+ }
116
+
117
+ @classmethod
118
+ def from_config(cls, config):
119
+ # The default `from_config()` for functional models will return a
120
+ # vanilla `keras.Model`. We override it to get a subclass instance back.
121
+ if "backbone" in config and isinstance(config["backbone"], dict):
122
+ config["backbone"] = keras.layers.deserialize(config["backbone"])
123
+ if "preprocessor" in config and isinstance(
124
+ config["preprocessor"], dict
125
+ ):
126
+ config["preprocessor"] = keras.layers.deserialize(
127
+ config["preprocessor"]
128
+ )
129
+ return cls(**config)
130
+
131
+ @classproperty
132
+ def presets(cls):
133
+ """List built-in presets for a `Task` subclass."""
134
+ presets = list_presets(cls)
135
+ # We can also load backbone presets.
136
+ if cls.backbone_cls is not None:
137
+ presets.update(cls.backbone_cls.presets)
138
+ for subclass in list_subclasses(cls):
139
+ presets.update(subclass.presets)
140
+ return presets
141
+
142
+ @classmethod
143
+ def from_preset(
144
+ cls,
145
+ preset,
146
+ load_weights=True,
147
+ **kwargs,
148
+ ):
149
+ """Instantiate a `keras_hub.models.Task` from a model preset.
150
+
151
+ A preset is a directory of configs, weights and other file assets used
152
+ to save and load a pre-trained model. The `preset` can be passed as a
153
+ one of:
154
+
155
+ 1. a built in preset identifier like `'bert_base_en'`
156
+ 2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'`
157
+ 3. a Hugging Face handle like `'hf://user/bert_base_en'`
158
+ 4. a path to a local preset directory like `'./bert_base_en'`
159
+
160
+ For any `Task` subclass, you can run `cls.presets.keys()` to list all
161
+ built-in presets available on the class.
162
+
163
+ This constructor can be called in one of two ways. Either from a task
164
+ specific base class like `keras_hub.models.CausalLM.from_preset()`, or
165
+ from a model class like `keras_hub.models.BertClassifier.from_preset()`.
166
+ If calling from the a base class, the subclass of the returning object
167
+ will be inferred from the config in the preset directory.
168
+
169
+ Args:
170
+ preset: string. A built in preset identifier, a Kaggle Models
171
+ handle, a Hugging Face handle, or a path to a local directory.
172
+ load_weights: bool. If `True`, the weights will be loaded into the
173
+ model architecture. If `False`, the weights will be randomly
174
+ initialized.
175
+
176
+ Examples:
177
+ ```python
178
+ # Load a Gemma generative task.
179
+ causal_lm = keras_hub.models.CausalLM.from_preset(
180
+ "gemma_2b_en",
181
+ )
182
+
183
+ # Load a Bert classification task.
184
+ model = keras_hub.models.Classifier.from_preset(
185
+ "bert_base_en",
186
+ num_classes=2,
187
+ )
188
+ ```
189
+ """
190
+ format = check_format(preset)
191
+
192
+ if format == "transformers":
193
+ if cls.backbone_cls is None:
194
+ raise ValueError("Backbone class is None")
195
+ if cls.preprocessor_cls is None:
196
+ raise ValueError("Preprocessor class is None")
197
+
198
+ backbone = cls.backbone_cls.from_preset(preset)
199
+ preprocessor = cls.preprocessor_cls.from_preset(preset)
200
+ return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)
201
+
202
+ if cls == Task:
203
+ raise ValueError(
204
+ "Do not call `Task.from_preset()` directly. Instead call a "
205
+ "particular task class, e.g. "
206
+ "`keras_hub.models.Classifier.from_preset()` or "
207
+ "`keras_hub.models.BertClassifier.from_preset()`."
208
+ )
209
+ if "backbone" in kwargs:
210
+ raise ValueError(
211
+ "You cannot pass a `backbone` argument to the `from_preset` "
212
+ f"method. Instead, call the {cls.__name__} default "
213
+ "constructor with a `backbone` argument. "
214
+ f"Received: backbone={kwargs['backbone']}."
215
+ )
216
+
217
+ # Check if we should load a `task.json` directly.
218
+ load_task_config = False
219
+ if check_file_exists(preset, TASK_CONFIG_FILE):
220
+ task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE)
221
+ if issubclass(task_preset_cls, cls):
222
+ load_task_config = True
223
+ if load_task_config:
224
+ # Task case.
225
+ task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE)
226
+ task = load_serialized_object(preset, TASK_CONFIG_FILE)
227
+ if load_weights:
228
+ jax_memory_cleanup(task)
229
+ if check_file_exists(preset, TASK_WEIGHTS_FILE):
230
+ task.load_task_weights(get_file(preset, TASK_WEIGHTS_FILE))
231
+ task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
232
+ task.preprocessor.tokenizer.load_preset_assets(preset)
233
+ return task
234
+
235
+ # Backbone case.
236
+ # If `task.json` doesn't exist or the task preset class is different
237
+ # from the calling class, create the task based on `config.json`.
238
+ backbone_preset_cls = check_config_class(preset, CONFIG_FILE)
239
+ if backbone_preset_cls is not cls.backbone_cls:
240
+ subclasses = list_subclasses(cls)
241
+ subclasses = tuple(
242
+ filter(
243
+ lambda x: x.backbone_cls == backbone_preset_cls,
244
+ subclasses,
245
+ )
246
+ )
247
+ if len(subclasses) == 0:
248
+ raise ValueError(
249
+ f"No registered subclass of `{cls.__name__}` can load "
250
+ f"a `{backbone_preset_cls.__name__}`."
251
+ )
252
+ if len(subclasses) > 1:
253
+ names = ", ".join(f"`{x.__name__}`" for x in subclasses)
254
+ raise ValueError(
255
+ f"Ambiguous call to `{cls.__name__}.from_preset()`. "
256
+ f"Found multiple possible subclasses {names}. "
257
+ "Please call `from_preset` on a subclass directly."
258
+ )
259
+ cls = subclasses[0]
260
+ # Forward dtype to the backbone.
261
+ backbone_kwargs = {}
262
+ if "dtype" in kwargs:
263
+ backbone_kwargs = {"dtype": kwargs.pop("dtype")}
264
+ backbone = backbone_preset_cls.from_preset(
265
+ preset, load_weights=load_weights, **backbone_kwargs
266
+ )
267
+ if "preprocessor" in kwargs:
268
+ preprocessor = kwargs.pop("preprocessor")
269
+ else:
270
+ preprocessor = cls.preprocessor_cls.from_preset(preset)
271
+ return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)
272
+
273
+ def load_task_weights(self, filepath):
274
+ """Load only the tasks specific weights not in the backbone."""
275
+ if not str(filepath).endswith(".weights.h5"):
276
+ raise ValueError(
277
+ "The filename must end in `.weights.h5`. Received: filepath={filepath}"
278
+ )
279
+ backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers())
280
+ keras.saving.load_weights(
281
+ self,
282
+ filepath,
283
+ objects_to_skip=backbone_layer_ids,
284
+ )
285
+
286
+ def has_task_weights(self):
287
+ task_weight_ids = set(id(w) for w in self.weights)
288
+ backbone_weight_ids = set(id(w) for w in self.backbone.weights)
289
+ return not task_weight_ids.issubset(backbone_weight_ids)
290
+
291
+ def save_task_weights(self, filepath):
292
+ """Save only the tasks specific weights not in the backbone."""
293
+ if not str(filepath).endswith(".weights.h5"):
294
+ raise ValueError(
295
+ "The filename must end in `.weights.h5`. "
296
+ f"Received: filepath={filepath}"
297
+ )
298
+
299
+ backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers())
300
+ if not self.has_task_weights():
301
+ raise ValueError(
302
+ f"Task {self} has no weights not in the `backbone`. "
303
+ "`save_task_weights()` has nothing to save."
304
+ )
305
+ keras.saving.save_weights(
306
+ self,
307
+ filepath=filepath,
308
+ objects_to_skip=backbone_layer_ids,
309
+ )
310
+
311
+ def save_to_preset(self, preset_dir):
312
+ """Save task to a preset directory.
313
+
314
+ Args:
315
+ preset_dir: The path to the local model preset directory.
316
+ """
317
+ if self.preprocessor is None:
318
+ raise ValueError(
319
+ "Cannot save `task` to preset: `Preprocessor` is not initialized."
320
+ )
321
+
322
+ save_serialized_object(self, preset_dir, config_file=TASK_CONFIG_FILE)
323
+ if self.has_task_weights():
324
+ self.save_task_weights(os.path.join(preset_dir, TASK_WEIGHTS_FILE))
325
+
326
+ self.preprocessor.save_to_preset(preset_dir)
327
+ self.backbone.save_to_preset(preset_dir)
328
+
329
+ @property
330
+ def layers(self):
331
+ # Remove preprocessor from layers so it does not show up in the summary.
332
+ layers = super().layers
333
+ if self.preprocessor and self.preprocessor in layers:
334
+ layers.remove(self.preprocessor)
335
+ return layers
336
+
337
+ def summary(
338
+ self,
339
+ line_length=None,
340
+ positions=None,
341
+ print_fn=None,
342
+ **kwargs,
343
+ ):
344
+ """Override `model.summary()` to show a preprocessor if set."""
345
+
346
+ # Compat fixes for tf.keras.
347
+ if not hasattr(self, "compiled"):
348
+ self.compiled = getattr(self.optimizer, "_is_compiled", False)
349
+ if (
350
+ self.compiled
351
+ and self.optimizer
352
+ and not hasattr(self.optimizer, "built")
353
+ ):
354
+ self.optimizer.built = getattr(self.optimizer, "_built", False)
355
+
356
+ # Below is copied from keras-core for now.
357
+ # We should consider an API contract.
358
+ line_length = line_length or 108
359
+
360
+ if not print_fn and not keras.utils.is_interactive_logging_enabled():
361
+ print_fn = print_msg
362
+
363
+ def highlight_number(x):
364
+ return f"[color(45)]{x}[/]" if x is None else f"[color(34)]{x}[/]"
365
+
366
+ def highlight_symbol(x):
367
+ return f"[color(33)]{x}[/]"
368
+
369
+ def bold_text(x):
370
+ return f"[bold]{x}[/]"
371
+
372
+ if self.preprocessor:
373
+ # Create a rich console for printing. Capture for non-interactive logging.
374
+ if print_fn:
375
+ console = rich_console.Console(
376
+ highlight=False, force_terminal=False, color_system=None
377
+ )
378
+ console.begin_capture()
379
+ else:
380
+ console = rich_console.Console(highlight=False)
381
+
382
+ column_1 = rich_table.Column(
383
+ "Tokenizer (type)",
384
+ justify="left",
385
+ width=int(0.5 * line_length),
386
+ )
387
+ column_2 = rich_table.Column(
388
+ "Vocab #",
389
+ justify="right",
390
+ width=int(0.5 * line_length),
391
+ )
392
+ table = rich_table.Table(
393
+ column_1, column_2, width=line_length, show_lines=True
394
+ )
395
+ tokenizer = self.preprocessor.tokenizer
396
+ tokenizer_name = markup.escape(tokenizer.name)
397
+ tokenizer_class = highlight_symbol(
398
+ markup.escape(tokenizer.__class__.__name__)
399
+ )
400
+ table.add_row(
401
+ f"{tokenizer_name} ({tokenizer_class})",
402
+ highlight_number(f"{tokenizer.vocabulary_size():,}"),
403
+ )
404
+
405
+ # Print the to the console.
406
+ preprocessor_name = markup.escape(self.preprocessor.name)
407
+ console.print(bold_text(f'Preprocessor: "{preprocessor_name}"'))
408
+ console.print(table)
409
+
410
+ # Output captured summary for non-interactive logging.
411
+ if print_fn:
412
+ print_fn(console.end_capture(), line_break=False)
413
+
414
+ super().summary(
415
+ line_length=line_length,
416
+ positions=positions,
417
+ print_fn=print_fn,
418
+ **kwargs,
419
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,158 @@
1
+ # Copyright 2023 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+ from keras import layers
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.models.backbone import Backbone
19
+
20
+
21
+ @keras_hub_export("keras_hub.models.VGGBackbone")
22
+ class VGGBackbone(Backbone):
23
+ """This class represents Keras Backbone of VGG model.
24
+
25
+ This class implements a VGG backbone as described in [Very Deep
26
+ Convolutional Networks for Large-Scale Image Recognition](
27
+ https://arxiv.org/abs/1409.1556)(ICLR 2015).
28
+
29
+ Args:
30
+ stackwise_num_repeats: list of ints, number of repeated convolutional
31
+ blocks per VGG block. For VGG16 this is [2, 2, 3, 3, 3] and for
32
+ VGG19 this is [2, 2, 4, 4, 4].
33
+ stackwise_num_filters: list of ints, filter size for convolutional
34
+ blocks per VGG block. For both VGG16 and VGG19 this is [
35
+ 64, 128, 256, 512, 512].
36
+ include_rescaling: bool, whether to rescale the inputs. If set to
37
+ True, inputs will be passed through a `Rescaling(1/255.0)` layer.
38
+ image_shape: tuple, optional shape tuple, defaults to (224, 224, 3).
39
+ pooling: bool, Optional pooling mode for feature extraction
40
+ when `include_top` is `False`.
41
+ - `None` means that the output of the model will be
42
+ the 4D tensor output of the
43
+ last convolutional block.
44
+ - `avg` means that global average pooling
45
+ will be applied to the output of the
46
+ last convolutional block, and thus
47
+ the output of the model will be a 2D tensor.
48
+ - `max` means that global max pooling will
49
+ be applied.
50
+
51
+ Examples:
52
+ ```python
53
+ input_data = np.ones((2, 224, 224, 3), dtype="float32")
54
+
55
+ # Pretrained VGG backbone.
56
+ model = keras_hub.models.VGGBackbone.from_preset("vgg16")
57
+ model(input_data)
58
+
59
+ # Randomly initialized VGG backbone with a custom config.
60
+ model = keras_hub.models.VGGBackbone(
61
+ stackwise_num_repeats = [2, 2, 3, 3, 3],
62
+ stackwise_num_filters = [64, 128, 256, 512, 512],
63
+ image_shape = (224, 224, 3),
64
+ include_rescaling = False,
65
+ pooling = "avg",
66
+ )
67
+ model(input_data)
68
+ ```
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ stackwise_num_repeats,
74
+ stackwise_num_filters,
75
+ include_rescaling,
76
+ image_shape=(224, 224, 3),
77
+ pooling="avg",
78
+ **kwargs,
79
+ ):
80
+
81
+ # === Functional Model ===
82
+ img_input = keras.layers.Input(shape=image_shape)
83
+ x = img_input
84
+
85
+ if include_rescaling:
86
+ x = layers.Rescaling(scale=1 / 255.0)(x)
87
+ for stack_index in range(len(stackwise_num_repeats) - 1):
88
+ x = apply_vgg_block(
89
+ x=x,
90
+ num_layers=stackwise_num_repeats[stack_index],
91
+ filters=stackwise_num_filters[stack_index],
92
+ kernel_size=(3, 3),
93
+ activation="relu",
94
+ padding="same",
95
+ max_pool=True,
96
+ name=f"block{stack_index + 1}",
97
+ )
98
+ if pooling == "avg":
99
+ x = layers.GlobalAveragePooling2D()(x)
100
+ elif pooling == "max":
101
+ x = layers.GlobalMaxPooling2D()(x)
102
+
103
+ super().__init__(inputs=img_input, outputs=x, **kwargs)
104
+
105
+ # === Config ===
106
+ self.stackwise_num_repeats = stackwise_num_repeats
107
+ self.stackwise_num_filters = stackwise_num_filters
108
+ self.include_rescaling = include_rescaling
109
+ self.image_shape = image_shape
110
+ self.pooling = pooling
111
+
112
+ def get_config(self):
113
+ return {
114
+ "stackwise_num_repeats": self.stackwise_num_repeats,
115
+ "stackwise_num_filters": self.stackwise_num_filters,
116
+ "include_rescaling": self.include_rescaling,
117
+ "image_shape": self.image_shape,
118
+ "pooling": self.pooling,
119
+ }
120
+
121
+
122
+ def apply_vgg_block(
123
+ x,
124
+ num_layers,
125
+ filters,
126
+ kernel_size,
127
+ activation,
128
+ padding,
129
+ max_pool,
130
+ name,
131
+ ):
132
+ """
133
+ Applies VGG block
134
+ Args:
135
+ x: Tensor, input tensor to pass through network
136
+ num_layers: int, number of CNN layers in the block
137
+ filters: int, filter size of each CNN layer in block
138
+ kernel_size: int (or) tuple, kernel size for CNN layer in block
139
+ activation: str (or) callable, activation function for each CNN layer in
140
+ block
141
+ padding: str (or) callable, padding function for each CNN layer in block
142
+ max_pool: bool, whether to add MaxPooling2D layer at end of block
143
+ name: str, name of the block
144
+
145
+ Returns:
146
+ keras.KerasTensor
147
+ """
148
+ for num in range(1, num_layers + 1):
149
+ x = layers.Conv2D(
150
+ filters,
151
+ kernel_size,
152
+ activation=activation,
153
+ padding=padding,
154
+ name=f"{name}_conv{num}",
155
+ )(x)
156
+ if max_pool:
157
+ x = layers.MaxPooling2D((2, 2), (2, 2), name=f"{name}_pool")(x)
158
+ return x