keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,621 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import collections
16
+ import datetime
17
+ import inspect
18
+ import json
19
+ import os
20
+ import re
21
+
22
+ import keras
23
+ from absl import logging
24
+ from packaging.version import parse
25
+
26
+ from keras_hub.src.api_export import keras_hub_export
27
+ from keras_hub.src.utils.keras_utils import print_msg
28
+
29
+ try:
30
+ import tensorflow as tf
31
+ except ImportError:
32
+ raise ImportError(
33
+ "To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
34
+ "The TensorFlow package is required for data preprocessing with any backend."
35
+ )
36
+
37
+ try:
38
+ import kagglehub
39
+ from kagglehub.exceptions import KaggleApiHTTPError
40
+ except ImportError:
41
+ kagglehub = None
42
+
43
+ try:
44
+ import huggingface_hub
45
+ from huggingface_hub.utils import EntryNotFoundError
46
+ from huggingface_hub.utils import HFValidationError
47
+ except ImportError:
48
+ huggingface_hub = None
49
+
50
+ KAGGLE_PREFIX = "kaggle://"
51
+ GS_PREFIX = "gs://"
52
+ HF_PREFIX = "hf://"
53
+
54
+ KAGGLE_SCHEME = "kaggle"
55
+ GS_SCHEME = "gs"
56
+ HF_SCHEME = "hf"
57
+
58
+ TOKENIZER_ASSET_DIR = "assets/tokenizer"
59
+
60
+ # Config file names.
61
+ CONFIG_FILE = "config.json"
62
+ TOKENIZER_CONFIG_FILE = "tokenizer.json"
63
+ TASK_CONFIG_FILE = "task.json"
64
+ PREPROCESSOR_CONFIG_FILE = "preprocessor.json"
65
+ METADATA_FILE = "metadata.json"
66
+
67
+ # Weight file names.
68
+ MODEL_WEIGHTS_FILE = "model.weights.h5"
69
+ TASK_WEIGHTS_FILE = "task.weights.h5"
70
+
71
+ # HuggingFace filenames.
72
+ README_FILE = "README.md"
73
+ HF_CONFIG_FILE = "config.json"
74
+ HF_TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
75
+ SAFETENSOR_CONFIG_FILE = "model.safetensors.index.json"
76
+ SAFETENSOR_FILE = "model.safetensors"
77
+
78
+ # Global state for preset registry.
79
+ BUILTIN_PRESETS = {}
80
+ BUILTIN_PRESETS_FOR_CLASS = collections.defaultdict(dict)
81
+
82
+
83
+ def register_presets(presets, classes):
84
+ """Register built-in presets for a set of classes.
85
+
86
+ Note that this is intended only for models and presets shipped in the
87
+ library itself.
88
+ """
89
+ for preset in presets:
90
+ BUILTIN_PRESETS[preset] = presets[preset]
91
+ for cls in classes:
92
+ BUILTIN_PRESETS_FOR_CLASS[cls][preset] = presets[preset]
93
+
94
+
95
+ def list_presets(cls):
96
+ """Find all registered built-in presets for a class."""
97
+ return dict(BUILTIN_PRESETS_FOR_CLASS[cls])
98
+
99
+
100
+ def list_subclasses(cls):
101
+ """Find all registered subclasses of a class."""
102
+ custom_objects = keras.saving.get_custom_objects().values()
103
+ subclasses = []
104
+ for x in custom_objects:
105
+ if inspect.isclass(x) and x != cls and issubclass(x, cls):
106
+ subclasses.append(x)
107
+ return subclasses
108
+
109
+
110
+ def get_file(preset, path):
111
+ """Download a preset file in necessary and return the local path."""
112
+ # TODO: Add tests for FileNotFound exceptions.
113
+ if not isinstance(preset, str):
114
+ raise ValueError(
115
+ f"A preset identifier must be a string. Received: preset={preset}"
116
+ )
117
+ if preset in BUILTIN_PRESETS:
118
+ preset = BUILTIN_PRESETS[preset]["kaggle_handle"]
119
+
120
+ scheme = None
121
+ if "://" in preset:
122
+ scheme = preset.split("://")[0].lower()
123
+
124
+ if scheme == KAGGLE_SCHEME:
125
+ if kagglehub is None:
126
+ raise ImportError(
127
+ "`from_preset()` requires the `kagglehub` package. "
128
+ "Please install with `pip install kagglehub`."
129
+ )
130
+ kaggle_handle = preset.removeprefix(KAGGLE_SCHEME + "://")
131
+ num_segments = len(kaggle_handle.split("/"))
132
+ if num_segments not in (4, 5):
133
+ raise ValueError(
134
+ "Unexpected Kaggle preset. Kaggle model handles should have "
135
+ "the form kaggle://{org}/{model}/keras/{variant}[/{version}]. "
136
+ "For example, 'kaggle://username/bert/keras/bert_base_en' or "
137
+ "'kaggle://username/bert/keras/bert_base_en/1' (to specify a "
138
+ f"version). Received: preset={preset}"
139
+ )
140
+ try:
141
+ return kagglehub.model_download(kaggle_handle, path)
142
+ except KaggleApiHTTPError as e:
143
+ message = str(e)
144
+ if message.find("403 Client Error"):
145
+ raise FileNotFoundError(
146
+ f"`{path}` doesn't exist in preset directory `{preset}`."
147
+ )
148
+ else:
149
+ raise ValueError(message)
150
+ except ValueError as e:
151
+ message = str(e)
152
+ if message.find("is not present in the model files"):
153
+ raise FileNotFoundError(
154
+ f"`{path}` doesn't exist in preset directory `{preset}`."
155
+ )
156
+ else:
157
+ raise ValueError(message)
158
+
159
+ elif scheme in tf.io.gfile.get_registered_schemes():
160
+ url = os.path.join(preset, path)
161
+ subdir = preset.replace("://", "_").replace("-", "_").replace("/", "_")
162
+ filename = os.path.basename(path)
163
+ subdir = os.path.join(subdir, os.path.dirname(path))
164
+ try:
165
+ return copy_gfile_to_cache(
166
+ filename,
167
+ url,
168
+ cache_subdir=os.path.join("models", subdir),
169
+ )
170
+ except (tf.errors.PermissionDeniedError, tf.errors.NotFoundError) as e:
171
+ raise FileNotFoundError(
172
+ f"`{path}` doesn't exist in preset directory `{preset}`.",
173
+ ) from e
174
+ elif scheme == HF_SCHEME:
175
+ if huggingface_hub is None:
176
+ raise ImportError(
177
+ f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. "
178
+ "Please install with `pip install huggingface_hub`."
179
+ )
180
+ hf_handle = preset.removeprefix(HF_SCHEME + "://")
181
+ try:
182
+ return huggingface_hub.hf_hub_download(
183
+ repo_id=hf_handle, filename=path
184
+ )
185
+ except HFValidationError as e:
186
+ raise ValueError(
187
+ "Unexpected Hugging Face preset. Hugging Face model handles "
188
+ "should have the form 'hf://{org}/{model}'. For example, "
189
+ f"'hf://username/bert_base_en'. Received: preset={preset}."
190
+ ) from e
191
+ except EntryNotFoundError as e:
192
+ message = str(e)
193
+ if message.find("403 Client Error"):
194
+ raise FileNotFoundError(
195
+ f"`{path}` doesn't exist in preset directory `{preset}`."
196
+ )
197
+ else:
198
+ raise ValueError(message)
199
+ elif os.path.exists(preset):
200
+ # Assume a local filepath.
201
+ local_path = os.path.join(preset, path)
202
+ if not os.path.exists(local_path):
203
+ raise FileNotFoundError(
204
+ f"`{path}` doesn't exist in preset directory `{preset}`."
205
+ )
206
+ return local_path
207
+ else:
208
+ raise ValueError(
209
+ "Unknown preset identifier. A preset must be a one of:\n"
210
+ "1) a built-in preset identifier like `'bert_base_en'`\n"
211
+ "2) a Kaggle Models handle like `'kaggle://keras/bert/keras/bert_base_en'`\n"
212
+ "3) a Hugging Face handle like `'hf://username/bert_base_en'`\n"
213
+ "4) a path to a local preset directory like `'./bert_base_en`\n"
214
+ "Use `print(cls.presets.keys())` to view all built-in presets for "
215
+ "API symbol `cls`.\n"
216
+ f"Received: preset='{preset}'"
217
+ )
218
+
219
+
220
+ def copy_gfile_to_cache(filename, url, cache_subdir):
221
+ """Much of this is adapted from get_file of keras core."""
222
+ if "KERAS_HOME" in os.environ:
223
+ cachdir_base = os.environ.get("KERAS_HOME")
224
+ else:
225
+ cachdir_base = os.path.expanduser(os.path.join("~", ".keras"))
226
+ if not os.access(cachdir_base, os.W_OK):
227
+ cachdir_base = os.path.join("/tmp", ".keras")
228
+ cachedir = os.path.join(cachdir_base, cache_subdir)
229
+ os.makedirs(cachedir, exist_ok=True)
230
+
231
+ fpath = os.path.join(cachedir, filename)
232
+ if not os.path.exists(fpath):
233
+ print_msg(f"Downloading data from {url}")
234
+ try:
235
+ tf.io.gfile.copy(url, fpath)
236
+ except Exception as e:
237
+ # gfile.copy will leave an empty file after an error.
238
+ # Work around this bug.
239
+ os.remove(fpath)
240
+ raise e
241
+
242
+ return fpath
243
+
244
+
245
+ def check_file_exists(preset, path):
246
+ try:
247
+ get_file(preset, path)
248
+ except FileNotFoundError:
249
+ return False
250
+ return True
251
+
252
+
253
+ def get_tokenizer(layer):
254
+ """Get the tokenizer from any KerasHub model or layer."""
255
+ # Avoid circular import.
256
+ from keras_hub.src.tokenizers.tokenizer import Tokenizer
257
+
258
+ if isinstance(layer, Tokenizer):
259
+ return layer
260
+ if hasattr(layer, "tokenizer"):
261
+ return layer.tokenizer
262
+ if hasattr(layer, "preprocessor"):
263
+ return getattr(layer.preprocessor, "tokenizer", None)
264
+ return None
265
+
266
+
267
+ def recursive_pop(config, key):
268
+ """Remove a key from a nested config object"""
269
+ config.pop(key, None)
270
+ for value in config.values():
271
+ if isinstance(value, dict):
272
+ recursive_pop(value, key)
273
+
274
+
275
+ def make_preset_dir(preset):
276
+ os.makedirs(preset, exist_ok=True)
277
+
278
+
279
+ def save_tokenizer_assets(tokenizer, preset):
280
+ if tokenizer:
281
+ asset_dir = os.path.join(preset, TOKENIZER_ASSET_DIR)
282
+ os.makedirs(asset_dir, exist_ok=True)
283
+ tokenizer.save_assets(asset_dir)
284
+
285
+
286
+ def save_serialized_object(
287
+ layer,
288
+ preset,
289
+ config_file=CONFIG_FILE,
290
+ config_to_skip=[],
291
+ ):
292
+ make_preset_dir(preset)
293
+ config_path = os.path.join(preset, config_file)
294
+ config = keras.saving.serialize_keras_object(layer)
295
+ config_to_skip += ["compile_config", "build_config"]
296
+ for c in config_to_skip:
297
+ recursive_pop(config, c)
298
+ with open(config_path, "w") as config_file:
299
+ config_file.write(json.dumps(config, indent=4))
300
+
301
+
302
+ def save_metadata(layer, preset):
303
+ from keras_hub.src.version_utils import __version__ as keras_hub_version
304
+
305
+ keras_version = keras.version() if hasattr(keras, "version") else None
306
+ metadata = {
307
+ "keras_version": keras_version,
308
+ "keras_hub_version": keras_hub_version,
309
+ "parameter_count": layer.count_params(),
310
+ "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
311
+ }
312
+ metadata_path = os.path.join(preset, METADATA_FILE)
313
+ with open(metadata_path, "w") as metadata_file:
314
+ metadata_file.write(json.dumps(metadata, indent=4))
315
+
316
+
317
+ def _validate_tokenizer(preset, allow_incomplete=False):
318
+ if not check_file_exists(preset, TOKENIZER_CONFIG_FILE):
319
+ if allow_incomplete:
320
+ logging.warning(
321
+ f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`."
322
+ )
323
+ return
324
+ else:
325
+ raise FileNotFoundError(
326
+ f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`. "
327
+ "To upload the model without a tokenizer, "
328
+ "set `allow_incomplete=True`."
329
+ )
330
+ config_path = get_file(preset, TOKENIZER_CONFIG_FILE)
331
+ try:
332
+ with open(config_path, encoding="utf-8") as config_file:
333
+ config = json.load(config_file)
334
+ except Exception as e:
335
+ raise ValueError(
336
+ f"Tokenizer config file `{config_path}` is an invalid json file. "
337
+ f"Error message: {e}"
338
+ )
339
+ layer = keras.saving.deserialize_keras_object(config)
340
+
341
+ for asset in layer.file_assets:
342
+ asset_path = get_file(preset, os.path.join(TOKENIZER_ASSET_DIR, asset))
343
+ if not os.path.exists(asset_path):
344
+ tokenizer_asset_dir = os.path.dirname(asset_path)
345
+ raise FileNotFoundError(
346
+ f"Asset `{asset}` doesn't exist in the tokenizer asset direcotry"
347
+ f" `{tokenizer_asset_dir}`."
348
+ )
349
+ config_dir = os.path.dirname(config_path)
350
+ asset_dir = os.path.join(config_dir, TOKENIZER_ASSET_DIR)
351
+
352
+ tokenizer = get_tokenizer(layer)
353
+ if not tokenizer:
354
+ raise ValueError(f"Model or layer `{layer}` is missing tokenizer.")
355
+ tokenizer.load_assets(asset_dir)
356
+
357
+
358
+ def _validate_backbone(preset):
359
+ config_path = os.path.join(preset, CONFIG_FILE)
360
+ if not os.path.exists(config_path):
361
+ raise FileNotFoundError(
362
+ f"`{CONFIG_FILE}` is missing from the preset directory `{preset}`."
363
+ )
364
+ try:
365
+ with open(config_path, encoding="utf-8") as config_file:
366
+ json.load(config_file)
367
+ except Exception as e:
368
+ raise ValueError(
369
+ f"Config file `{config_path}` is an invalid json file. "
370
+ f"Error message: {e}"
371
+ )
372
+
373
+ weights_path = os.path.join(preset, MODEL_WEIGHTS_FILE)
374
+ if not os.path.exists(weights_path):
375
+ raise FileNotFoundError(
376
+ f"The weights file is missing from the preset directory `{preset}`."
377
+ )
378
+
379
+
380
+ def get_snake_case(name):
381
+ name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
382
+ return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
383
+
384
+
385
+ def create_model_card(preset):
386
+ model_card_path = os.path.join(preset, README_FILE)
387
+ markdown_content = ""
388
+
389
+ config = load_config(preset, CONFIG_FILE)
390
+ model_name = (
391
+ config["class_name"].replace("Backbone", "")
392
+ if config["class_name"].endswith("Backbone")
393
+ else config["class_name"]
394
+ )
395
+
396
+ task_type = None
397
+ if check_file_exists(preset, TASK_CONFIG_FILE):
398
+ task_config = load_config(preset, TASK_CONFIG_FILE)
399
+ task_type = (
400
+ task_config["class_name"].replace(model_name, "")
401
+ if task_config["class_name"].startswith(model_name)
402
+ else task_config["class_name"]
403
+ )
404
+
405
+ # YAML
406
+ markdown_content += "---\n"
407
+ markdown_content += "library_name: keras-hub\n"
408
+ if task_type == "CausalLM":
409
+ markdown_content += "pipeline_tag: text-generation\n"
410
+ elif task_type == "Classifier":
411
+ markdown_content += "pipeline_tag: text-classification\n"
412
+ markdown_content += "---\n"
413
+
414
+ model_link = (
415
+ f"https://keras.io/api/keras_hub/models/{get_snake_case(model_name)}"
416
+ )
417
+ markdown_content += (
418
+ f"This is a [`{model_name}` model]({model_link}) "
419
+ "uploaded using the KerasHub library and can be used with JAX, "
420
+ "TensorFlow, and PyTorch backends.\n"
421
+ )
422
+ if task_type:
423
+ markdown_content += (
424
+ f"This model is related to a `{task_type}` task.\n\n"
425
+ )
426
+
427
+ backbone_config = config["config"]
428
+ markdown_content += "Model config:\n"
429
+ for k, v in backbone_config.items():
430
+ markdown_content += f"* **{k}:** {v}\n"
431
+ markdown_content += "\n"
432
+ markdown_content += (
433
+ "This model card has been generated automatically and should be completed "
434
+ "by the model author. See [Model Cards documentation]"
435
+ "(https://huggingface.co/docs/hub/model-cards) for more information.\n"
436
+ )
437
+
438
+ with open(model_card_path, "w") as md_file:
439
+ md_file.write(markdown_content)
440
+
441
+
442
+ def delete_model_card(preset):
443
+ model_card_path = os.path.join(preset, README_FILE)
444
+ try:
445
+ os.remove(model_card_path)
446
+ except FileNotFoundError:
447
+ logging.warning(
448
+ f"There was an attempt to delete file `{model_card_path}` but this"
449
+ " file doesn't exist."
450
+ )
451
+
452
+
453
+ @keras_hub_export("keras_hub.upload_preset")
454
+ def upload_preset(
455
+ uri,
456
+ preset,
457
+ allow_incomplete=False,
458
+ ):
459
+ """Upload a preset directory to a model hub.
460
+
461
+ Args:
462
+ uri: The URI identifying model to upload to.
463
+ URIs with format
464
+ `kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>`
465
+ will be uploaded to Kaggle Hub while URIs with format
466
+ `hf://[<HF_USERNAME>/]<MODEL>` will be uploaded to the Hugging
467
+ Face Hub.
468
+ preset: The path to the local model preset directory.
469
+ allow_incomplete: If True, allows the upload of presets without
470
+ a tokenizer configuration. Otherwise, a tokenizer
471
+ is required.
472
+ """
473
+
474
+ # Check if preset directory exists.
475
+ if not os.path.exists(preset):
476
+ raise FileNotFoundError(f"The preset directory {preset} doesn't exist.")
477
+
478
+ _validate_backbone(preset)
479
+ _validate_tokenizer(preset, allow_incomplete)
480
+
481
+ if uri.startswith(KAGGLE_PREFIX):
482
+ if kagglehub is None:
483
+ raise ImportError(
484
+ "Uploading a model to Kaggle Hub requires the `kagglehub` package. "
485
+ "Please install with `pip install kagglehub`."
486
+ )
487
+ if parse(kagglehub.__version__) < parse("0.2.4"):
488
+ raise ImportError(
489
+ "Uploading a model to Kaggle Hub requires the `kagglehub` package version `0.2.4` or higher. "
490
+ "Please upgrade with `pip install --upgrade kagglehub`."
491
+ )
492
+ kaggle_handle = uri.removeprefix(KAGGLE_PREFIX)
493
+ kagglehub.model_upload(kaggle_handle, preset)
494
+ elif uri.startswith(HF_PREFIX):
495
+ if huggingface_hub is None:
496
+ raise ImportError(
497
+ f"`upload_preset()` requires the `huggingface_hub` package to upload to '{uri}'. "
498
+ "Please install with `pip install huggingface_hub`."
499
+ )
500
+ hf_handle = uri.removeprefix(HF_PREFIX)
501
+ try:
502
+ repo_url = huggingface_hub.create_repo(
503
+ repo_id=hf_handle, exist_ok=True
504
+ )
505
+ except HFValidationError as e:
506
+ raise ValueError(
507
+ "Unexpected Hugging Face URI. Hugging Face model handles "
508
+ "should have the form 'hf://[{org}/]{model}'. For example, "
509
+ "'hf://username/bert_base_en' or 'hf://bert_case_en' to implicitly"
510
+ f"upload to your user account. Received: URI={uri}."
511
+ ) from e
512
+ has_model_card = huggingface_hub.file_exists(
513
+ repo_id=repo_url.repo_id, filename=README_FILE
514
+ )
515
+ if not has_model_card:
516
+ # Remote repo doesn't have a model card so a basic model card is automatically generated.
517
+ create_model_card(preset)
518
+ try:
519
+ huggingface_hub.upload_folder(
520
+ repo_id=repo_url.repo_id, folder_path=preset
521
+ )
522
+ finally:
523
+ if not has_model_card:
524
+ # Clean up the preset directory in case user attempts to upload the
525
+ # preset directory into Kaggle hub as well.
526
+ delete_model_card(preset)
527
+ else:
528
+ raise ValueError(
529
+ "Unknown URI. An URI must be a one of:\n"
530
+ "1) a Kaggle Model handle like `'kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>'`\n"
531
+ "2) a Hugging Face handle like `'hf://[<HF_USERNAME>/]<MODEL>'`\n"
532
+ f"Received: uri='{uri}'."
533
+ )
534
+
535
+
536
+ def load_config(preset, config_file=CONFIG_FILE):
537
+ config_path = get_file(preset, config_file)
538
+ with open(config_path, encoding="utf-8") as config_file:
539
+ config = json.load(config_file)
540
+ return config
541
+
542
+
543
+ def check_format(preset):
544
+ if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists(
545
+ preset, SAFETENSOR_CONFIG_FILE
546
+ ):
547
+ # Determine the format by parsing the config file.
548
+ config = load_config(preset, HF_CONFIG_FILE)
549
+ if "hf://timm" in preset or "architecture" in config:
550
+ return "timm"
551
+ return "transformers"
552
+
553
+ if not check_file_exists(preset, METADATA_FILE):
554
+ raise FileNotFoundError(
555
+ f"The preset directory `{preset}` doesn't have a file named `{METADATA_FILE}`, "
556
+ "or you do not have access to it. This file is required to load a Keras model "
557
+ "preset. Please verify that the model you are trying to load is a Keras model."
558
+ )
559
+ metadata = load_config(preset, METADATA_FILE)
560
+ if "keras_version" not in metadata:
561
+ raise ValueError(
562
+ f"`{METADATA_FILE}` in the preset directory `{preset}` doesn't have `keras_version`. "
563
+ "Please verify that the model you are trying to load is a Keras model."
564
+ )
565
+ return "keras"
566
+
567
+
568
+ def load_serialized_object(preset, config_file=CONFIG_FILE, **kwargs):
569
+ kwargs = kwargs or {}
570
+ config = load_config(preset, config_file)
571
+
572
+ # `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
573
+ # Ensure that `dtype` is properly configured.
574
+ dtype = kwargs.pop("dtype", None)
575
+ config = set_dtype_in_config(config, dtype)
576
+
577
+ config["config"] = {**config["config"], **kwargs}
578
+ return keras.saving.deserialize_keras_object(config)
579
+
580
+
581
+ def check_config_class(
582
+ preset,
583
+ config_file=CONFIG_FILE,
584
+ ):
585
+ """Validate a preset is being loaded on the correct class."""
586
+ config_path = get_file(preset, config_file)
587
+ with open(config_path, encoding="utf-8") as config_file:
588
+ config = json.load(config_file)
589
+ return keras.saving.get_registered_object(config["registered_name"])
590
+
591
+
592
+ def jax_memory_cleanup(layer):
593
+ # For jax, delete all previous allocated memory to avoid temporarily
594
+ # duplicating variable allocations. torch and tensorflow have stateful
595
+ # variable types and do not need this fix.
596
+ if keras.config.backend() == "jax":
597
+ for weight in layer.weights:
598
+ if getattr(weight, "_value", None) is not None:
599
+ weight._value.delete()
600
+
601
+
602
+ def set_dtype_in_config(config, dtype=None):
603
+ if dtype is None:
604
+ return config
605
+
606
+ config = config.copy()
607
+ if "dtype" not in config["config"]:
608
+ # Forward `dtype` to the config.
609
+ config["config"]["dtype"] = dtype
610
+ elif (
611
+ "dtype" in config["config"]
612
+ and isinstance(config["config"]["dtype"], dict)
613
+ and "DTypePolicyMap" in config["config"]["dtype"]["class_name"]
614
+ ):
615
+ # If it is `DTypePolicyMap` in `config`, forward `dtype` as its default
616
+ # policy.
617
+ policy_map_config = config["config"]["dtype"]["config"]
618
+ policy_map_config["default_policy"] = dtype
619
+ for k in policy_map_config["policy_map"].keys():
620
+ policy_map_config["policy_map"][k]["config"]["source_name"] = dtype
621
+ return config
@@ -0,0 +1,21 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Utilities with miscellaneous python extensions."""
15
+
16
+
17
+ class classproperty(property):
18
+ """Define a class level property."""
19
+
20
+ def __get__(self, _, owner_cls):
21
+ return self.fget(owner_cls)