keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,206 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ try:
19
+ import tensorflow as tf
20
+ import tensorflow_text as tf_text
21
+ except ImportError:
22
+ tf = None
23
+ tf_text = None
24
+
25
+
26
+ def _decode_strings_to_utf8(inputs):
27
+ """Recursively decodes to list of strings with 'utf-8' encoding."""
28
+ if isinstance(inputs, bytes):
29
+ # Handles the case when the input is a scalar string.
30
+ return inputs.decode("utf-8", errors="ignore")
31
+ else:
32
+ # Recursively iterate when input is a list.
33
+ return [_decode_strings_to_utf8(x) for x in inputs]
34
+
35
+
36
+ def tensor_to_list(inputs):
37
+ """Converts a tensor to nested lists.
38
+
39
+ Args:
40
+ inputs: Input tensor, or dict/list/tuple of input tensors.
41
+ """
42
+ if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)):
43
+ inputs = tf.convert_to_tensor(inputs)
44
+ if isinstance(inputs, tf.RaggedTensor):
45
+ list_outputs = inputs.to_list()
46
+ elif isinstance(inputs, tf.Tensor):
47
+ list_outputs = inputs.numpy()
48
+ if inputs.shape.rank != 0:
49
+ list_outputs = list_outputs.tolist()
50
+ if inputs.dtype == tf.string:
51
+ list_outputs = _decode_strings_to_utf8(list_outputs)
52
+ return list_outputs
53
+
54
+
55
+ def convert_to_backend_tensor_or_python_list(x):
56
+ """
57
+ Convert a tensor to the backend friendly representation of the data.
58
+
59
+ This wraps `ops.convert_to_tensor` to account for the fact that torch and
60
+ jax both lack native types for ragged and string data.
61
+
62
+ If we encounter one of these types in torch or jax, we will instead covert
63
+ the tensor to simple pythonic types (lists of strings).
64
+ """
65
+ if isinstance(x, tf.RaggedTensor) or getattr(x, "dtype", None) == tf.string:
66
+ return tensor_to_list(x)
67
+ dtype = getattr(x, "dtype", "float32")
68
+ dtype = keras.backend.standardize_dtype(dtype)
69
+ return ops.convert_to_tensor(x, dtype=dtype)
70
+
71
+
72
+ def convert_to_ragged_batch(inputs):
73
+ """Convert pythonic or numpy-like input to a 2-D `tf.RaggedTensor`.
74
+
75
+ This is useful for text preprocessing layers which deal with already
76
+ tokenized or split text.
77
+
78
+ Args:
79
+ inputs: A pythonic or numpy-like input to covert. This input should
80
+ represent a possibly batched list of token sequences.
81
+
82
+ Returns:
83
+ An `(inputs, unbatched, rectangular)` tuple, where `inputs` is a
84
+ 2-D `tf.RaggedTensor`, `unbatched` is `True` if the inputs were
85
+ origianlly rank 1, and `rectangular` is `True` if the inputs rows are
86
+ all of equal lengths.
87
+ """
88
+ # `tf.keras.layers.Layer` does a weird conversion in __call__, where a list
89
+ # of lists of ints will become a list of list of scalar tensors. We could
90
+ # clean this up if we no longer need to care about that case.
91
+ if isinstance(inputs, (list, tuple)):
92
+ if isinstance(inputs[0], (list, tuple)):
93
+ rectangular = len(set([len(row) for row in inputs])) == 1
94
+ rows = [
95
+ tf.convert_to_tensor(row, dtype_hint="int32") for row in inputs
96
+ ]
97
+ inputs = tf.ragged.stack(rows).with_row_splits_dtype("int64")
98
+ else:
99
+ inputs = tf.convert_to_tensor(inputs)
100
+ rectangular = True
101
+ elif isinstance(inputs, tf.Tensor):
102
+ rectangular = True
103
+ elif isinstance(inputs, tf.RaggedTensor):
104
+ rectangular = False
105
+ elif hasattr(inputs, "__array__"):
106
+ inputs = tf.convert_to_tensor(ops.convert_to_numpy(inputs))
107
+ rectangular = True
108
+ else:
109
+ raise ValueError(
110
+ f"Unknown tensor type. Tensor input can be passed as "
111
+ "tensors, numpy arrays, or python lists. Received: "
112
+ f"`type(inputs)={type(inputs)}`"
113
+ )
114
+ if inputs.shape.rank < 1 or inputs.shape.rank > 2:
115
+ raise ValueError(
116
+ f"Tokenized tensor input should be rank 1 (unbatched) or "
117
+ f"rank 2 (batched). Received: `inputs.shape={input.shape}`"
118
+ )
119
+ unbatched = inputs.shape.rank == 1
120
+ rectangular = rectangular or unbatched
121
+ if unbatched:
122
+ inputs = tf.expand_dims(inputs, 0)
123
+ if isinstance(inputs, tf.Tensor):
124
+ inputs = tf.RaggedTensor.from_tensor(inputs)
125
+ return inputs, unbatched, rectangular
126
+
127
+
128
+ def truncate_at_token(inputs, token, mask):
129
+ """Truncate at first instance of `token`, ignoring `mask`."""
130
+ matches = (inputs == token) & (~mask)
131
+ end_indices = tf.cast(tf.math.argmax(matches, -1), "int32")
132
+ end_indices = tf.where(end_indices == 0, tf.shape(inputs)[-1], end_indices)
133
+ return tf.RaggedTensor.from_tensor(inputs, end_indices)
134
+
135
+
136
+ def strip_to_ragged(token_ids, mask, ids_to_strip):
137
+ """Remove masked and special tokens from a sequence before detokenizing."""
138
+ token_ids = ops.convert_to_numpy(token_ids)
139
+ token_ids = token_ids.astype("int32")
140
+ mask = ops.convert_to_numpy(mask)
141
+ mask = mask.astype("bool")
142
+ for id in ids_to_strip:
143
+ mask = mask & (token_ids != id)
144
+ return tf.ragged.boolean_mask(token_ids, mask)
145
+
146
+
147
+ def assert_tf_libs_installed(symbol_name):
148
+ if tf_text is None or tf is None:
149
+ raise ImportError(
150
+ f"{symbol_name} requires `tensorflow` and `tensorflow-text` for "
151
+ "text processing. Run `pip install tensorflow-text` to install "
152
+ "both packages or visit https://www.tensorflow.org/install\n\n"
153
+ "If `tensorflow-text` is already installed, try importing it "
154
+ "in a clean python session. Your installation may have errors.\n\n"
155
+ "KerasHub uses `tf.data` and `tensorflow-text` to preprocess text "
156
+ "on all Keras backends. If you are running on Jax or Torch, this "
157
+ "installation does not need GPU support."
158
+ )
159
+
160
+
161
+ def assert_tf_backend(symbol_name):
162
+ if keras.config.backend() != "tensorflow":
163
+ raise RuntimeError(
164
+ f"{symbol_name} requires the `tensorflow` backend. "
165
+ "Please set `KERAS_BACKEND=tensorflow` when running your program."
166
+ )
167
+
168
+
169
+ def is_tensor_type(x):
170
+ return hasattr(x, "__array__")
171
+
172
+
173
+ def is_float_dtype(dtype):
174
+ return "float" in keras.backend.standardize_dtype(dtype)
175
+
176
+
177
+ def is_int_dtype(dtype):
178
+ return "int" in keras.backend.standardize_dtype(dtype)
179
+
180
+
181
+ def is_string_dtype(dtype):
182
+ return "string" in keras.backend.standardize_dtype(dtype)
183
+
184
+
185
+ def any_equal(inputs, values, padding_mask):
186
+ """Return a mask that is True anywhere `inputs` has a value in `values`.
187
+
188
+ Final mask has `padding_mask` applied.
189
+
190
+ Args:
191
+ inputs: Input tensor.
192
+ values: List or iterable of tensors shaped like `inputs` or broadcastable
193
+ by bit operators.
194
+ padding_mask: Tensor with shape compatible with inputs that will condition
195
+ output.
196
+
197
+ Returns:
198
+ A tensor with `inputs` shape where each position is True if it contains
199
+ a value from any `values`. Padding mask will be applied before
200
+ returning."""
201
+ output = ops.equal(inputs, values[0])
202
+ for value in values[1:]:
203
+ value_equality = ops.equal(inputs, value)
204
+ output = ops.logical_or(output, value_equality)
205
+
206
+ return ops.logical_and(output, padding_mask)
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,37 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Convert timm models to KerasHub."""
15
+
16
+ from keras_hub.src.utils.timm.convert_resnet import load_resnet_backbone
17
+
18
+
19
+ def load_timm_backbone(cls, preset, load_weights, **kwargs):
20
+ """Load a timm model config and weights as a KerasHub backbone.
21
+
22
+ Args:
23
+ cls (class): Keras model class.
24
+ preset (str): Preset configuration name.
25
+ load_weights (bool): Whether to load the weights.
26
+
27
+ Returns:
28
+ backbone: Initialized Keras model backbone.
29
+ """
30
+ if cls is None:
31
+ raise ValueError("Backbone class is None")
32
+ if cls.__name__ == "ResNetBackbone":
33
+ return load_resnet_backbone(cls, preset, load_weights, **kwargs)
34
+ raise ValueError(
35
+ f"{cls} has not been ported from the Hugging Face format yet. "
36
+ "Please check Hugging Face Hub for the Keras model. "
37
+ )
@@ -0,0 +1,171 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+
16
+ from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
17
+ from keras_hub.src.utils.preset_utils import jax_memory_cleanup
18
+ from keras_hub.src.utils.preset_utils import load_config
19
+ from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
20
+
21
+
22
+ def convert_backbone_config(timm_config):
23
+ timm_architecture = timm_config["architecture"]
24
+
25
+ if "resnetv2_" in timm_architecture:
26
+ use_pre_activation = True
27
+ else:
28
+ use_pre_activation = False
29
+
30
+ if timm_architecture == "resnet18":
31
+ stackwise_num_blocks = [2, 2, 2, 2]
32
+ block_type = "basic_block"
33
+ elif timm_architecture == "resnet26":
34
+ stackwise_num_blocks = [2, 2, 2, 2]
35
+ block_type = "bottleneck_block"
36
+ elif timm_architecture == "resnet34":
37
+ stackwise_num_blocks = [3, 4, 6, 3]
38
+ block_type = "basic_block"
39
+ elif timm_architecture in ("resnet50", "resnetv2_50"):
40
+ stackwise_num_blocks = [3, 4, 6, 3]
41
+ block_type = "bottleneck_block"
42
+ elif timm_architecture in ("resnet101", "resnetv2_101"):
43
+ stackwise_num_blocks = [3, 4, 23, 3]
44
+ block_type = "bottleneck_block"
45
+ elif timm_architecture in ("resnet152", "resnetv2_152"):
46
+ stackwise_num_blocks = [3, 8, 36, 3]
47
+ block_type = "bottleneck_block"
48
+ else:
49
+ raise ValueError(
50
+ f"Currently, the architecture {timm_architecture} is not supported."
51
+ )
52
+
53
+ return dict(
54
+ stackwise_num_filters=[64, 128, 256, 512],
55
+ stackwise_num_blocks=stackwise_num_blocks,
56
+ stackwise_num_strides=[1, 2, 2, 2],
57
+ block_type=block_type,
58
+ use_pre_activation=use_pre_activation,
59
+ )
60
+
61
+
62
+ def convert_weights(backbone, loader, timm_config):
63
+ def port_conv2d(keras_layer_name, hf_weight_prefix):
64
+ loader.port_weight(
65
+ backbone.get_layer(keras_layer_name).kernel,
66
+ hf_weight_key=f"{hf_weight_prefix}.weight",
67
+ hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
68
+ )
69
+
70
+ def port_batch_normalization(keras_layer_name, hf_weight_prefix):
71
+ loader.port_weight(
72
+ backbone.get_layer(keras_layer_name).gamma,
73
+ hf_weight_key=f"{hf_weight_prefix}.weight",
74
+ )
75
+ loader.port_weight(
76
+ backbone.get_layer(keras_layer_name).beta,
77
+ hf_weight_key=f"{hf_weight_prefix}.bias",
78
+ )
79
+ loader.port_weight(
80
+ backbone.get_layer(keras_layer_name).moving_mean,
81
+ hf_weight_key=f"{hf_weight_prefix}.running_mean",
82
+ )
83
+ loader.port_weight(
84
+ backbone.get_layer(keras_layer_name).moving_variance,
85
+ hf_weight_key=f"{hf_weight_prefix}.running_var",
86
+ )
87
+
88
+ version = "v1" if not backbone.use_pre_activation else "v2"
89
+ block_type = backbone.block_type
90
+
91
+ # Stem
92
+ if version == "v1":
93
+ port_conv2d("conv1_conv", "conv1")
94
+ port_batch_normalization("conv1_bn", "bn1")
95
+ else:
96
+ port_conv2d("conv1_conv", "stem.conv")
97
+
98
+ # Stages
99
+ num_stacks = len(backbone.stackwise_num_filters)
100
+ for stack_index in range(num_stacks):
101
+ for block_idx in range(backbone.stackwise_num_blocks[stack_index]):
102
+ if version == "v1":
103
+ keras_name = f"v1_stack{stack_index}_block{block_idx}"
104
+ hf_name = f"layer{stack_index+1}.{block_idx}"
105
+ else:
106
+ keras_name = f"v2_stack{stack_index}_block{block_idx}"
107
+ hf_name = f"stages.{stack_index}.blocks.{block_idx}"
108
+
109
+ if version == "v1":
110
+ if block_idx == 0 and (
111
+ block_type == "bottleneck_block" or stack_index > 0
112
+ ):
113
+ port_conv2d(
114
+ f"{keras_name}_0_conv", f"{hf_name}.downsample.0"
115
+ )
116
+ port_batch_normalization(
117
+ f"{keras_name}_0_bn", f"{hf_name}.downsample.1"
118
+ )
119
+ port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1")
120
+ port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.bn1")
121
+ port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2")
122
+ port_batch_normalization(f"{keras_name}_2_bn", f"{hf_name}.bn2")
123
+ if block_type == "bottleneck_block":
124
+ port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3")
125
+ port_batch_normalization(
126
+ f"{keras_name}_3_bn", f"{hf_name}.bn3"
127
+ )
128
+ else:
129
+ if block_idx == 0 and (
130
+ block_type == "bottleneck_block" or stack_index > 0
131
+ ):
132
+ port_conv2d(
133
+ f"{keras_name}_0_conv", f"{hf_name}.downsample.conv"
134
+ )
135
+ port_batch_normalization(
136
+ f"{keras_name}_pre_activation_bn", f"{hf_name}.norm1"
137
+ )
138
+ port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1")
139
+ port_batch_normalization(
140
+ f"{keras_name}_1_bn", f"{hf_name}.norm2"
141
+ )
142
+ port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2")
143
+ if block_type == "bottleneck_block":
144
+ port_batch_normalization(
145
+ f"{keras_name}_2_bn", f"{hf_name}.norm3"
146
+ )
147
+ port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3")
148
+
149
+ # Post
150
+ if version == "v2":
151
+ port_batch_normalization("post_bn", "norm")
152
+
153
+ # Rebuild normalization layer with pretrained mean & std
154
+ mean = timm_config["pretrained_cfg"]["mean"]
155
+ std = timm_config["pretrained_cfg"]["std"]
156
+ normalization_layer = backbone.get_layer("normalization")
157
+ normalization_layer.input_mean = mean
158
+ normalization_layer.input_variance = [s**2 for s in std]
159
+ normalization_layer.build(normalization_layer._build_input_shape)
160
+
161
+
162
+ def load_resnet_backbone(cls, preset, load_weights, **kwargs):
163
+ timm_config = load_config(preset, HF_CONFIG_FILE)
164
+ keras_config = convert_backbone_config(timm_config)
165
+ backbone = cls(**keras_config, **kwargs)
166
+ if load_weights:
167
+ jax_memory_cleanup(backbone)
168
+ # Use prefix="" to avoid using `get_prefixed_key`.
169
+ with SafetensorLoader(preset, prefix="") as loader:
170
+ convert_weights(backbone, loader, timm_config)
171
+ return backbone
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,101 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Convert huggingface models to KerasHub."""
15
+
16
+
17
+ from keras_hub.src.utils.transformers.convert_bert import load_bert_backbone
18
+ from keras_hub.src.utils.transformers.convert_bert import load_bert_tokenizer
19
+ from keras_hub.src.utils.transformers.convert_distilbert import (
20
+ load_distilbert_backbone,
21
+ )
22
+ from keras_hub.src.utils.transformers.convert_distilbert import (
23
+ load_distilbert_tokenizer,
24
+ )
25
+ from keras_hub.src.utils.transformers.convert_gemma import load_gemma_backbone
26
+ from keras_hub.src.utils.transformers.convert_gemma import load_gemma_tokenizer
27
+ from keras_hub.src.utils.transformers.convert_gpt2 import load_gpt2_backbone
28
+ from keras_hub.src.utils.transformers.convert_gpt2 import load_gpt2_tokenizer
29
+ from keras_hub.src.utils.transformers.convert_llama3 import load_llama3_backbone
30
+ from keras_hub.src.utils.transformers.convert_llama3 import (
31
+ load_llama3_tokenizer,
32
+ )
33
+ from keras_hub.src.utils.transformers.convert_pali_gemma import (
34
+ load_pali_gemma_backbone,
35
+ )
36
+ from keras_hub.src.utils.transformers.convert_pali_gemma import (
37
+ load_pali_gemma_tokenizer,
38
+ )
39
+
40
+
41
+ def load_transformers_backbone(cls, preset, load_weights):
42
+ """
43
+ Load a Transformer model config and weights as a KerasHub backbone.
44
+
45
+ Args:
46
+ cls (class): Keras model class.
47
+ preset (str): Preset configuration name.
48
+ load_weights (bool): Whether to load the weights.
49
+
50
+ Returns:
51
+ backbone: Initialized Keras model backbone.
52
+ """
53
+ if cls is None:
54
+ raise ValueError("Backbone class is None")
55
+ if cls.__name__ == "BertBackbone":
56
+ return load_bert_backbone(cls, preset, load_weights)
57
+ if cls.__name__ == "GemmaBackbone":
58
+ return load_gemma_backbone(cls, preset, load_weights)
59
+ if cls.__name__ == "Llama3Backbone":
60
+ return load_llama3_backbone(cls, preset, load_weights)
61
+ if cls.__name__ == "PaliGemmaBackbone":
62
+ return load_pali_gemma_backbone(cls, preset, load_weights)
63
+ if cls.__name__ == "GPT2Backbone":
64
+ return load_gpt2_backbone(cls, preset, load_weights)
65
+ if cls.__name__ == "DistilBertBackbone":
66
+ return load_distilbert_backbone(cls, preset, load_weights)
67
+ raise ValueError(
68
+ f"{cls} has not been ported from the Hugging Face format yet. "
69
+ "Please check Hugging Face Hub for the Keras model. "
70
+ )
71
+
72
+
73
+ def load_transformers_tokenizer(cls, preset):
74
+ """
75
+ Load a Transformer tokenizer assets as a KerasHub tokenizer.
76
+
77
+ Args:
78
+ cls (class): Tokenizer class.
79
+ preset (str): Preset configuration name.
80
+
81
+ Returns:
82
+ tokenizer: Initialized tokenizer.
83
+ """
84
+ if cls is None:
85
+ raise ValueError("Tokenizer class is None")
86
+ if cls.__name__ == "BertTokenizer":
87
+ return load_bert_tokenizer(cls, preset)
88
+ if cls.__name__ == "GemmaTokenizer":
89
+ return load_gemma_tokenizer(cls, preset)
90
+ if cls.__name__ == "Llama3Tokenizer":
91
+ return load_llama3_tokenizer(cls, preset)
92
+ if cls.__name__ == "PaliGemmaTokenizer":
93
+ return load_pali_gemma_tokenizer(cls, preset)
94
+ if cls.__name__ == "GPT2Tokenizer":
95
+ return load_gpt2_tokenizer(cls, preset)
96
+ if cls.__name__ == "DistilBertTokenizer":
97
+ return load_distilbert_tokenizer(cls, preset)
98
+ raise ValueError(
99
+ f"{cls} has not been ported from the Hugging Face format yet. "
100
+ "Please check Hugging Face Hub for the Keras model. "
101
+ )