keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -7,19 +7,10 @@ import re
7
7
 
8
8
  import keras
9
9
  from absl import logging
10
- from packaging.version import parse
11
10
 
12
11
  from keras_hub.src.api_export import keras_hub_export
13
12
  from keras_hub.src.utils.keras_utils import print_msg
14
13
 
15
- try:
16
- import tensorflow as tf
17
- except ImportError:
18
- raise ImportError(
19
- "To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
20
- "The TensorFlow package is required for data preprocessing with any backend."
21
- )
22
-
23
14
  try:
24
15
  import kagglehub
25
16
  from kagglehub.exceptions import KaggleApiHTTPError
@@ -172,26 +163,13 @@ def get_file(preset, path):
172
163
  )
173
164
  else:
174
165
  raise ValueError(message)
175
-
176
- elif scheme in tf.io.gfile.get_registered_schemes():
177
- url = os.path.join(preset, path)
178
- subdir = preset.replace("://", "_").replace("-", "_").replace("/", "_")
179
- filename = os.path.basename(path)
180
- subdir = os.path.join(subdir, os.path.dirname(path))
181
- try:
182
- return copy_gfile_to_cache(
183
- filename,
184
- url,
185
- cache_subdir=os.path.join("models", subdir),
186
- )
187
- except (tf.errors.PermissionDeniedError, tf.errors.NotFoundError) as e:
188
- raise FileNotFoundError(
189
- f"`{path}` doesn't exist in preset directory `{preset}`.",
190
- ) from e
166
+ elif scheme in tf_registered_schemes():
167
+ return tf_copy_gfile_to_cache(preset, path)
191
168
  elif scheme == HF_SCHEME:
192
169
  if huggingface_hub is None:
193
170
  raise ImportError(
194
- f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. "
171
+ "`from_preset()` requires the `huggingface_hub` package to "
172
+ "load from '{preset}'. "
195
173
  "Please install with `pip install huggingface_hub`."
196
174
  )
197
175
  hf_handle = preset.removeprefix(HF_SCHEME + "://")
@@ -225,7 +203,8 @@ def get_file(preset, path):
225
203
  raise ValueError(
226
204
  "Unknown preset identifier. A preset must be a one of:\n"
227
205
  "1) a built-in preset identifier like `'bert_base_en'`\n"
228
- "2) a Kaggle Models handle like `'kaggle://keras/bert/keras/bert_base_en'`\n"
206
+ "2) a Kaggle Models handle like "
207
+ "`'kaggle://keras/bert/keras/bert_base_en'`\n"
229
208
  "3) a Hugging Face handle like `'hf://username/bert_base_en'`\n"
230
209
  "4) a path to a local preset directory like `'./bert_base_en`\n"
231
210
  "Use `print(cls.presets.keys())` to view all built-in presets for "
@@ -234,29 +213,48 @@ def get_file(preset, path):
234
213
  )
235
214
 
236
215
 
237
- def copy_gfile_to_cache(filename, url, cache_subdir):
216
+ def tf_registered_schemes():
217
+ try:
218
+ import tensorflow as tf
219
+
220
+ return tf.io.gfile.get_registered_schemes()
221
+ except ImportError:
222
+ return []
223
+
224
+
225
+ def tf_copy_gfile_to_cache(preset, path):
238
226
  """Much of this is adapted from get_file of keras core."""
239
227
  if "KERAS_HOME" in os.environ:
240
- cachdir_base = os.environ.get("KERAS_HOME")
228
+ base_dir = os.environ.get("KERAS_HOME")
241
229
  else:
242
- cachdir_base = os.path.expanduser(os.path.join("~", ".keras"))
243
- if not os.access(cachdir_base, os.W_OK):
244
- cachdir_base = os.path.join("/tmp", ".keras")
245
- cachedir = os.path.join(cachdir_base, cache_subdir)
246
- os.makedirs(cachedir, exist_ok=True)
247
-
248
- fpath = os.path.join(cachedir, filename)
249
- if not os.path.exists(fpath):
230
+ base_dir = os.path.expanduser(os.path.join("~", ".keras"))
231
+ if not os.access(base_dir, os.W_OK):
232
+ base_dir = os.path.join("/tmp", ".keras")
233
+
234
+ url = os.path.join(preset, path)
235
+ model_dir = preset.replace("://", "_").replace("-", "_").replace("/", "_")
236
+ local_path = os.path.join(base_dir, "models", model_dir, path)
237
+
238
+ if not os.path.exists(local_path):
250
239
  print_msg(f"Downloading data from {url}")
251
240
  try:
252
- tf.io.gfile.copy(url, fpath)
241
+ import tensorflow as tf
242
+
243
+ os.make_dirs(os.path.dirname(local_path), exist_ok=True)
244
+ tf.io.gfile.copy(url, local_path)
253
245
  except Exception as e:
254
246
  # gfile.copy will leave an empty file after an error.
255
247
  # Work around this bug.
256
- os.remove(fpath)
248
+ os.remove(local_path)
249
+ if isinstance(
250
+ e, tf.errors.PermissionDeniedError, tf.errors.NotFoundError
251
+ ):
252
+ raise FileNotFoundError(
253
+ f"`{path}` doesn't exist in preset directory `{preset}`.",
254
+ ) from e
257
255
  raise e
258
256
 
259
- return fpath
257
+ return local_path
260
258
 
261
259
 
262
260
  def check_file_exists(preset, path):
@@ -267,64 +265,6 @@ def check_file_exists(preset, path):
267
265
  return True
268
266
 
269
267
 
270
- def get_tokenizer(layer):
271
- """Get the tokenizer from any KerasHub model or layer."""
272
- # Avoid circular import.
273
- from keras_hub.src.tokenizers.tokenizer import Tokenizer
274
-
275
- if isinstance(layer, Tokenizer):
276
- return layer
277
- if hasattr(layer, "tokenizer"):
278
- return layer.tokenizer
279
- if hasattr(layer, "preprocessor"):
280
- return getattr(layer.preprocessor, "tokenizer", None)
281
- return None
282
-
283
-
284
- def recursive_pop(config, key):
285
- """Remove a key from a nested config object"""
286
- config.pop(key, None)
287
- for value in config.values():
288
- if isinstance(value, dict):
289
- recursive_pop(value, key)
290
-
291
-
292
- # TODO: refactor saving routines into a PresetSaver class?
293
- def make_preset_dir(preset):
294
- os.makedirs(preset, exist_ok=True)
295
-
296
-
297
- def save_serialized_object(
298
- layer,
299
- preset,
300
- config_file=CONFIG_FILE,
301
- config_to_skip=[],
302
- ):
303
- make_preset_dir(preset)
304
- config_path = os.path.join(preset, config_file)
305
- config = keras.saving.serialize_keras_object(layer)
306
- config_to_skip += ["compile_config", "build_config"]
307
- for c in config_to_skip:
308
- recursive_pop(config, c)
309
- with open(config_path, "w") as config_file:
310
- config_file.write(json.dumps(config, indent=4))
311
-
312
-
313
- def save_metadata(layer, preset):
314
- from keras_hub.src.version_utils import __version__ as keras_hub_version
315
-
316
- keras_version = keras.version() if hasattr(keras, "version") else None
317
- metadata = {
318
- "keras_version": keras_version,
319
- "keras_hub_version": keras_hub_version,
320
- "parameter_count": layer.count_params(),
321
- "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
322
- }
323
- metadata_path = os.path.join(preset, METADATA_FILE)
324
- with open(metadata_path, "w") as metadata_file:
325
- metadata_file.write(json.dumps(metadata, indent=4))
326
-
327
-
328
268
  def _validate_backbone(preset):
329
269
  config_path = os.path.join(preset, CONFIG_FILE)
330
270
  if not os.path.exists(config_path):
@@ -400,8 +340,8 @@ def create_model_card(preset):
400
340
  markdown_content += f"* **{k}:** {v}\n"
401
341
  markdown_content += "\n"
402
342
  markdown_content += (
403
- "This model card has been generated automatically and should be completed "
404
- "by the model author. See [Model Cards documentation]"
343
+ "This model card has been generated automatically and should be "
344
+ "completed by the model author. See [Model Cards documentation]"
405
345
  "(https://huggingface.co/docs/hub/model-cards) for more information.\n"
406
346
  )
407
347
 
@@ -446,20 +386,16 @@ def upload_preset(
446
386
  if uri.startswith(KAGGLE_PREFIX):
447
387
  if kagglehub is None:
448
388
  raise ImportError(
449
- "Uploading a model to Kaggle Hub requires the `kagglehub` package. "
450
- "Please install with `pip install kagglehub`."
451
- )
452
- if parse(kagglehub.__version__) < parse("0.2.4"):
453
- raise ImportError(
454
- "Uploading a model to Kaggle Hub requires the `kagglehub` package version `0.2.4` or higher. "
455
- "Please upgrade with `pip install --upgrade kagglehub`."
389
+ "Uploading a model to Kaggle Hub requires the `kagglehub` "
390
+ "package. Please install with `pip install kagglehub`."
456
391
  )
457
392
  kaggle_handle = uri.removeprefix(KAGGLE_PREFIX)
458
393
  kagglehub.model_upload(kaggle_handle, preset)
459
394
  elif uri.startswith(HF_PREFIX):
460
395
  if huggingface_hub is None:
461
396
  raise ImportError(
462
- f"`upload_preset()` requires the `huggingface_hub` package to upload to '{uri}'. "
397
+ f"`upload_preset()` requires the `huggingface_hub` package "
398
+ f"to upload to '{uri}'. "
463
399
  "Please install with `pip install huggingface_hub`."
464
400
  )
465
401
  hf_handle = uri.removeprefix(HF_PREFIX)
@@ -471,14 +407,15 @@ def upload_preset(
471
407
  raise ValueError(
472
408
  "Unexpected Hugging Face URI. Hugging Face model handles "
473
409
  "should have the form 'hf://[{org}/]{model}'. For example, "
474
- "'hf://username/bert_base_en' or 'hf://bert_case_en' to implicitly"
475
- f"upload to your user account. Received: URI={uri}."
410
+ "'hf://username/bert_base_en' or 'hf://bert_case_en' to "
411
+ f"implicitly upload to your user account. Received: URI={uri}."
476
412
  ) from e
477
413
  has_model_card = huggingface_hub.file_exists(
478
414
  repo_id=repo_url.repo_id, filename=README_FILE
479
415
  )
480
416
  if not has_model_card:
481
- # Remote repo doesn't have a model card so a basic model card is automatically generated.
417
+ # Remote repo doesn't have a model card so a basic model card is
418
+ # automatically generated.
482
419
  create_model_card(preset)
483
420
  try:
484
421
  huggingface_hub.upload_folder(
@@ -486,13 +423,14 @@ def upload_preset(
486
423
  )
487
424
  finally:
488
425
  if not has_model_card:
489
- # Clean up the preset directory in case user attempts to upload the
490
- # preset directory into Kaggle hub as well.
426
+ # Clean up the preset directory in case user attempts to upload
427
+ # the preset directory into Kaggle hub as well.
491
428
  delete_model_card(preset)
492
429
  else:
493
430
  raise ValueError(
494
431
  "Unknown URI. An URI must be a one of:\n"
495
- "1) a Kaggle Model handle like `'kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>'`\n"
432
+ "1) a Kaggle Model handle like "
433
+ "`'kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>'`\n"
496
434
  "2) a Hugging Face handle like `'hf://[<HF_USERNAME>/]<MODEL>'`\n"
497
435
  f"Received: uri='{uri}'."
498
436
  )
@@ -505,19 +443,11 @@ def load_json(preset, config_file=CONFIG_FILE):
505
443
  return config
506
444
 
507
445
 
508
- def load_serialized_object(config, **kwargs):
509
- # `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
510
- # Ensure that `dtype` is properly configured.
511
- dtype = kwargs.pop("dtype", None)
512
- config = set_dtype_in_config(config, dtype)
513
-
514
- config["config"] = {**config["config"], **kwargs}
515
- return keras.saving.deserialize_keras_object(config)
516
-
517
-
518
446
  def check_config_class(config):
519
447
  """Validate a preset is being loaded on the correct class."""
520
448
  registered_name = config["registered_name"]
449
+ if registered_name in ("Functional", "Sequential"):
450
+ return keras.Model
521
451
  cls = keras.saving.get_registered_object(registered_name)
522
452
  if cls is None:
523
453
  raise ValueError(
@@ -600,6 +530,13 @@ def get_preset_loader(preset):
600
530
  )
601
531
 
602
532
 
533
+ def get_preset_saver(preset):
534
+ # Unlike loading, we only support one form of saving; Keras serialized
535
+ # configs and saved weights. We keep the rough API structure as loading
536
+ # just for simplicity.
537
+ return KerasPresetSaver(preset)
538
+
539
+
603
540
  class PresetLoader:
604
541
  def __init__(self, preset, config):
605
542
  self.config = config
@@ -612,10 +549,8 @@ class PresetLoader:
612
549
  backbone_kwargs["dtype"] = kwargs.pop("dtype", None)
613
550
 
614
551
  # Forward `height` and `width` to backbone when using `TextToImage`.
615
- if "height" in kwargs:
616
- backbone_kwargs["height"] = kwargs.pop("height", None)
617
- if "width" in kwargs:
618
- backbone_kwargs["width"] = kwargs.pop("width", None)
552
+ if "image_shape" in kwargs:
553
+ backbone_kwargs["image_shape"] = kwargs.pop("image_shape", None)
619
554
 
620
555
  return backbone_kwargs, kwargs
621
556
 
@@ -627,7 +562,7 @@ class PresetLoader:
627
562
  """Load the backbone model from the preset."""
628
563
  raise NotImplementedError
629
564
 
630
- def load_tokenizer(self, cls, config_name=TOKENIZER_CONFIG_FILE, **kwargs):
565
+ def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
631
566
  """Load a tokenizer layer from the preset."""
632
567
  raise NotImplementedError
633
568
 
@@ -658,7 +593,7 @@ class PresetLoader:
658
593
  return cls(**kwargs)
659
594
 
660
595
  def load_preprocessor(
661
- self, cls, config_name=PREPROCESSOR_CONFIG_FILE, **kwargs
596
+ self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs
662
597
  ):
663
598
  """Load a prepocessor layer from the preset.
664
599
 
@@ -675,25 +610,26 @@ class KerasPresetLoader(PresetLoader):
675
610
  return check_config_class(self.config)
676
611
 
677
612
  def load_backbone(self, cls, load_weights, **kwargs):
678
- backbone = load_serialized_object(self.config, **kwargs)
613
+ backbone = self._load_serialized_object(self.config, **kwargs)
679
614
  if load_weights:
680
615
  jax_memory_cleanup(backbone)
681
616
  backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
682
617
  return backbone
683
618
 
684
- def load_tokenizer(self, cls, config_name=TOKENIZER_CONFIG_FILE, **kwargs):
685
- tokenizer_config = load_json(self.preset, config_name)
686
- tokenizer = load_serialized_object(tokenizer_config, **kwargs)
687
- tokenizer.load_preset_assets(self.preset)
619
+ def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
620
+ tokenizer_config = load_json(self.preset, config_file)
621
+ tokenizer = self._load_serialized_object(tokenizer_config, **kwargs)
622
+ if hasattr(tokenizer, "load_preset_assets"):
623
+ tokenizer.load_preset_assets(self.preset)
688
624
  return tokenizer
689
625
 
690
626
  def load_audio_converter(self, cls, **kwargs):
691
627
  converter_config = load_json(self.preset, AUDIO_CONVERTER_CONFIG_FILE)
692
- return load_serialized_object(converter_config, **kwargs)
628
+ return self._load_serialized_object(converter_config, **kwargs)
693
629
 
694
630
  def load_image_converter(self, cls, **kwargs):
695
631
  converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE)
696
- return load_serialized_object(converter_config, **kwargs)
632
+ return self._load_serialized_object(converter_config, **kwargs)
697
633
 
698
634
  def load_task(self, cls, load_weights, load_task_weights, **kwargs):
699
635
  # If there is no `task.json` or it's for the wrong class delegate to the
@@ -708,8 +644,16 @@ class KerasPresetLoader(PresetLoader):
708
644
  cls, load_weights, load_task_weights, **kwargs
709
645
  )
710
646
  # We found a `task.json` with a complete config for our class.
711
- task = load_serialized_object(task_config, **kwargs)
712
- if task.preprocessor:
647
+ # Forward backbone args.
648
+ backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs)
649
+ if "backbone" in task_config["config"]:
650
+ backbone_config = task_config["config"]["backbone"]["config"]
651
+ backbone_config = {**backbone_config, **backbone_kwargs}
652
+ task_config["config"]["backbone"]["config"] = backbone_config
653
+ task = self._load_serialized_object(task_config, **kwargs)
654
+ if task.preprocessor and hasattr(
655
+ task.preprocessor, "load_preset_assets"
656
+ ):
713
657
  task.preprocessor.load_preset_assets(self.preset)
714
658
  if load_weights:
715
659
  has_task_weights = check_file_exists(self.preset, TASK_WEIGHTS_FILE)
@@ -724,16 +668,124 @@ class KerasPresetLoader(PresetLoader):
724
668
  return task
725
669
 
726
670
  def load_preprocessor(
727
- self, cls, config_name=PREPROCESSOR_CONFIG_FILE, **kwargs
671
+ self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs
728
672
  ):
729
673
  # If there is no `preprocessing.json` or it's for the wrong class,
730
674
  # delegate to the super class loader.
731
- if not check_file_exists(self.preset, config_name):
675
+ if not check_file_exists(self.preset, config_file):
732
676
  return super().load_preprocessor(cls, **kwargs)
733
- preprocessor_json = load_json(self.preset, config_name)
677
+ preprocessor_json = load_json(self.preset, config_file)
734
678
  if not issubclass(check_config_class(preprocessor_json), cls):
735
679
  return super().load_preprocessor(cls, **kwargs)
736
680
  # We found a `preprocessing.json` with a complete config for our class.
737
- preprocessor = load_serialized_object(preprocessor_json, **kwargs)
738
- preprocessor.load_preset_assets(self.preset)
681
+ preprocessor = self._load_serialized_object(preprocessor_json, **kwargs)
682
+ if hasattr(preprocessor, "load_preset_assets"):
683
+ preprocessor.load_preset_assets(self.preset)
739
684
  return preprocessor
685
+
686
+ def _load_serialized_object(self, config, **kwargs):
687
+ # `dtype` in config might be a serialized `DTypePolicy` or
688
+ # `DTypePolicyMap`. Ensure that `dtype` is properly configured.
689
+ dtype = kwargs.pop("dtype", None)
690
+ config = set_dtype_in_config(config, dtype)
691
+
692
+ config["config"] = {**config["config"], **kwargs}
693
+ return keras.saving.deserialize_keras_object(config)
694
+
695
+
696
+ class KerasPresetSaver:
697
+ def __init__(self, preset_dir):
698
+ os.makedirs(preset_dir, exist_ok=True)
699
+ self.preset_dir = preset_dir
700
+
701
+ def save_backbone(self, backbone):
702
+ self._save_serialized_object(backbone, config_file=CONFIG_FILE)
703
+ backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE)
704
+ backbone.save_weights(backbone_weight_path)
705
+ self._save_metadata(backbone)
706
+
707
+ def save_tokenizer(self, tokenizer):
708
+ config_file = TOKENIZER_CONFIG_FILE
709
+ if hasattr(tokenizer, "config_file"):
710
+ config_file = tokenizer.config_file
711
+ self._save_serialized_object(tokenizer, config_file)
712
+ # Save assets.
713
+ subdir = config_file.split(".")[0]
714
+ asset_dir = os.path.join(self.preset_dir, ASSET_DIR, subdir)
715
+ os.makedirs(asset_dir, exist_ok=True)
716
+ tokenizer.save_assets(asset_dir)
717
+
718
+ def save_audio_converter(self, converter):
719
+ self._save_serialized_object(converter, AUDIO_CONVERTER_CONFIG_FILE)
720
+
721
+ def save_image_converter(self, converter):
722
+ self._save_serialized_object(converter, IMAGE_CONVERTER_CONFIG_FILE)
723
+
724
+ def save_task(self, task):
725
+ # Save task specific config and weights.
726
+ self._save_serialized_object(task, TASK_CONFIG_FILE)
727
+ if task.has_task_weights():
728
+ task_weight_path = os.path.join(self.preset_dir, TASK_WEIGHTS_FILE)
729
+ task.save_task_weights(task_weight_path)
730
+ # Save backbone.
731
+ if hasattr(task.backbone, "save_to_preset"):
732
+ task.backbone.save_to_preset(self.preset_dir)
733
+ else:
734
+ # Allow saving a `keras.Model` that is not a backbone subclass.
735
+ self.save_backbone(task.backbone)
736
+ # Save preprocessor.
737
+ if task.preprocessor and hasattr(task.preprocessor, "save_to_preset"):
738
+ task.preprocessor.save_to_preset(self.preset_dir)
739
+ else:
740
+ # Allow saving a `keras.Layer` that is not a preprocessor subclass.
741
+ self.save_preprocessor(task.preprocessor)
742
+
743
+ def save_preprocessor(self, preprocessor):
744
+ config_file = PREPROCESSOR_CONFIG_FILE
745
+ if hasattr(preprocessor, "config_file"):
746
+ config_file = preprocessor.config_file
747
+ self._save_serialized_object(preprocessor, config_file)
748
+ for layer in preprocessor._flatten_layers(include_self=False):
749
+ if hasattr(layer, "save_to_preset"):
750
+ layer.save_to_preset(self.preset_dir)
751
+
752
+ def _recursive_pop(self, config, key):
753
+ """Remove a key from a nested config object"""
754
+ config.pop(key, None)
755
+ for value in config.values():
756
+ if isinstance(value, dict):
757
+ self._recursive_pop(value, key)
758
+
759
+ def _save_serialized_object(self, layer, config_file):
760
+ config_path = os.path.join(self.preset_dir, config_file)
761
+ config = keras.saving.serialize_keras_object(layer)
762
+ config_to_skip = ["compile_config", "build_config"]
763
+ for key in config_to_skip:
764
+ self._recursive_pop(config, key)
765
+ with open(config_path, "w") as config_file:
766
+ config_file.write(json.dumps(config, indent=4))
767
+
768
+ def _save_metadata(self, layer):
769
+ from keras_hub.src.models.task import Task
770
+ from keras_hub.src.version_utils import __version__ as keras_hub_version
771
+
772
+ # Find all tasks that are compatible with the backbone.
773
+ # E.g. for `BertBackbone` we would have `TextClassifier` and `MaskedLM`.
774
+ # For `ResNetBackbone` we would have `ImageClassifier`.
775
+ tasks = list_subclasses(Task)
776
+ tasks = filter(lambda x: x.backbone_cls is type(layer), tasks)
777
+ tasks = [task.__base__.__name__ for task in tasks]
778
+ # Keep task list alphabetical.
779
+ tasks = sorted(tasks)
780
+
781
+ keras_version = keras.version() if hasattr(keras, "version") else None
782
+ metadata = {
783
+ "keras_version": keras_version,
784
+ "keras_hub_version": keras_hub_version,
785
+ "parameter_count": layer.count_params(),
786
+ "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
787
+ "tasks": tasks,
788
+ }
789
+ metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
790
+ with open(metadata_path, "w") as metadata_file:
791
+ metadata_file.write(json.dumps(metadata, indent=4))
@@ -293,10 +293,10 @@ def any_equal(inputs, values, padding_mask):
293
293
 
294
294
  Args:
295
295
  inputs: Input tensor.
296
- values: List or iterable of tensors shaped like `inputs` or broadcastable
297
- by bit operators.
298
- padding_mask: Tensor with shape compatible with inputs that will condition
299
- output.
296
+ values: List or iterable of tensors shaped like `inputs` or
297
+ broadcastable by bit operators.
298
+ padding_mask: Tensor with shape compatible with inputs that will
299
+ condition output.
300
300
 
301
301
  Returns:
302
302
  A tensor with `inputs` shape where each position is True if it contains
@@ -59,9 +59,11 @@ def convert_weights(backbone, loader, timm_config):
59
59
  num_stacks = len(backbone.stackwise_num_repeats)
60
60
  for stack_index in range(num_stacks):
61
61
  for block_idx in range(backbone.stackwise_num_repeats[stack_index]):
62
- keras_name = f"stack{stack_index+1}_block{block_idx+1}"
62
+ keras_name = f"stack{stack_index + 1}_block{block_idx + 1}"
63
63
  hf_name = (
64
- f"features.denseblock{stack_index+1}.denselayer{block_idx+1}"
64
+ "features."
65
+ f"denseblock{stack_index + 1}"
66
+ f".denselayer{block_idx + 1}"
65
67
  )
66
68
  port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.norm1")
67
69
  port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1")
@@ -69,8 +71,8 @@ def convert_weights(backbone, loader, timm_config):
69
71
  port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2")
70
72
 
71
73
  for stack_index in range(num_stacks - 1):
72
- keras_transition_name = f"transition{stack_index+1}"
73
- hf_transition_name = f"features.transition{stack_index+1}"
74
+ keras_transition_name = f"transition{stack_index + 1}"
75
+ hf_transition_name = f"features.transition{stack_index + 1}"
74
76
  port_batch_normalization(
75
77
  f"{keras_transition_name}_bn", f"{hf_transition_name}.norm"
76
78
  )