keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,621 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import collections
|
16
|
+
import datetime
|
17
|
+
import inspect
|
18
|
+
import json
|
19
|
+
import os
|
20
|
+
import re
|
21
|
+
|
22
|
+
import keras
|
23
|
+
from absl import logging
|
24
|
+
from packaging.version import parse
|
25
|
+
|
26
|
+
from keras_hub.src.api_export import keras_hub_export
|
27
|
+
from keras_hub.src.utils.keras_utils import print_msg
|
28
|
+
|
29
|
+
try:
|
30
|
+
import tensorflow as tf
|
31
|
+
except ImportError:
|
32
|
+
raise ImportError(
|
33
|
+
"To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
|
34
|
+
"The TensorFlow package is required for data preprocessing with any backend."
|
35
|
+
)
|
36
|
+
|
37
|
+
try:
|
38
|
+
import kagglehub
|
39
|
+
from kagglehub.exceptions import KaggleApiHTTPError
|
40
|
+
except ImportError:
|
41
|
+
kagglehub = None
|
42
|
+
|
43
|
+
try:
|
44
|
+
import huggingface_hub
|
45
|
+
from huggingface_hub.utils import EntryNotFoundError
|
46
|
+
from huggingface_hub.utils import HFValidationError
|
47
|
+
except ImportError:
|
48
|
+
huggingface_hub = None
|
49
|
+
|
50
|
+
KAGGLE_PREFIX = "kaggle://"
|
51
|
+
GS_PREFIX = "gs://"
|
52
|
+
HF_PREFIX = "hf://"
|
53
|
+
|
54
|
+
KAGGLE_SCHEME = "kaggle"
|
55
|
+
GS_SCHEME = "gs"
|
56
|
+
HF_SCHEME = "hf"
|
57
|
+
|
58
|
+
TOKENIZER_ASSET_DIR = "assets/tokenizer"
|
59
|
+
|
60
|
+
# Config file names.
|
61
|
+
CONFIG_FILE = "config.json"
|
62
|
+
TOKENIZER_CONFIG_FILE = "tokenizer.json"
|
63
|
+
TASK_CONFIG_FILE = "task.json"
|
64
|
+
PREPROCESSOR_CONFIG_FILE = "preprocessor.json"
|
65
|
+
METADATA_FILE = "metadata.json"
|
66
|
+
|
67
|
+
# Weight file names.
|
68
|
+
MODEL_WEIGHTS_FILE = "model.weights.h5"
|
69
|
+
TASK_WEIGHTS_FILE = "task.weights.h5"
|
70
|
+
|
71
|
+
# HuggingFace filenames.
|
72
|
+
README_FILE = "README.md"
|
73
|
+
HF_CONFIG_FILE = "config.json"
|
74
|
+
HF_TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
|
75
|
+
SAFETENSOR_CONFIG_FILE = "model.safetensors.index.json"
|
76
|
+
SAFETENSOR_FILE = "model.safetensors"
|
77
|
+
|
78
|
+
# Global state for preset registry.
|
79
|
+
BUILTIN_PRESETS = {}
|
80
|
+
BUILTIN_PRESETS_FOR_CLASS = collections.defaultdict(dict)
|
81
|
+
|
82
|
+
|
83
|
+
def register_presets(presets, classes):
|
84
|
+
"""Register built-in presets for a set of classes.
|
85
|
+
|
86
|
+
Note that this is intended only for models and presets shipped in the
|
87
|
+
library itself.
|
88
|
+
"""
|
89
|
+
for preset in presets:
|
90
|
+
BUILTIN_PRESETS[preset] = presets[preset]
|
91
|
+
for cls in classes:
|
92
|
+
BUILTIN_PRESETS_FOR_CLASS[cls][preset] = presets[preset]
|
93
|
+
|
94
|
+
|
95
|
+
def list_presets(cls):
|
96
|
+
"""Find all registered built-in presets for a class."""
|
97
|
+
return dict(BUILTIN_PRESETS_FOR_CLASS[cls])
|
98
|
+
|
99
|
+
|
100
|
+
def list_subclasses(cls):
|
101
|
+
"""Find all registered subclasses of a class."""
|
102
|
+
custom_objects = keras.saving.get_custom_objects().values()
|
103
|
+
subclasses = []
|
104
|
+
for x in custom_objects:
|
105
|
+
if inspect.isclass(x) and x != cls and issubclass(x, cls):
|
106
|
+
subclasses.append(x)
|
107
|
+
return subclasses
|
108
|
+
|
109
|
+
|
110
|
+
def get_file(preset, path):
|
111
|
+
"""Download a preset file in necessary and return the local path."""
|
112
|
+
# TODO: Add tests for FileNotFound exceptions.
|
113
|
+
if not isinstance(preset, str):
|
114
|
+
raise ValueError(
|
115
|
+
f"A preset identifier must be a string. Received: preset={preset}"
|
116
|
+
)
|
117
|
+
if preset in BUILTIN_PRESETS:
|
118
|
+
preset = BUILTIN_PRESETS[preset]["kaggle_handle"]
|
119
|
+
|
120
|
+
scheme = None
|
121
|
+
if "://" in preset:
|
122
|
+
scheme = preset.split("://")[0].lower()
|
123
|
+
|
124
|
+
if scheme == KAGGLE_SCHEME:
|
125
|
+
if kagglehub is None:
|
126
|
+
raise ImportError(
|
127
|
+
"`from_preset()` requires the `kagglehub` package. "
|
128
|
+
"Please install with `pip install kagglehub`."
|
129
|
+
)
|
130
|
+
kaggle_handle = preset.removeprefix(KAGGLE_SCHEME + "://")
|
131
|
+
num_segments = len(kaggle_handle.split("/"))
|
132
|
+
if num_segments not in (4, 5):
|
133
|
+
raise ValueError(
|
134
|
+
"Unexpected Kaggle preset. Kaggle model handles should have "
|
135
|
+
"the form kaggle://{org}/{model}/keras/{variant}[/{version}]. "
|
136
|
+
"For example, 'kaggle://username/bert/keras/bert_base_en' or "
|
137
|
+
"'kaggle://username/bert/keras/bert_base_en/1' (to specify a "
|
138
|
+
f"version). Received: preset={preset}"
|
139
|
+
)
|
140
|
+
try:
|
141
|
+
return kagglehub.model_download(kaggle_handle, path)
|
142
|
+
except KaggleApiHTTPError as e:
|
143
|
+
message = str(e)
|
144
|
+
if message.find("403 Client Error"):
|
145
|
+
raise FileNotFoundError(
|
146
|
+
f"`{path}` doesn't exist in preset directory `{preset}`."
|
147
|
+
)
|
148
|
+
else:
|
149
|
+
raise ValueError(message)
|
150
|
+
except ValueError as e:
|
151
|
+
message = str(e)
|
152
|
+
if message.find("is not present in the model files"):
|
153
|
+
raise FileNotFoundError(
|
154
|
+
f"`{path}` doesn't exist in preset directory `{preset}`."
|
155
|
+
)
|
156
|
+
else:
|
157
|
+
raise ValueError(message)
|
158
|
+
|
159
|
+
elif scheme in tf.io.gfile.get_registered_schemes():
|
160
|
+
url = os.path.join(preset, path)
|
161
|
+
subdir = preset.replace("://", "_").replace("-", "_").replace("/", "_")
|
162
|
+
filename = os.path.basename(path)
|
163
|
+
subdir = os.path.join(subdir, os.path.dirname(path))
|
164
|
+
try:
|
165
|
+
return copy_gfile_to_cache(
|
166
|
+
filename,
|
167
|
+
url,
|
168
|
+
cache_subdir=os.path.join("models", subdir),
|
169
|
+
)
|
170
|
+
except (tf.errors.PermissionDeniedError, tf.errors.NotFoundError) as e:
|
171
|
+
raise FileNotFoundError(
|
172
|
+
f"`{path}` doesn't exist in preset directory `{preset}`.",
|
173
|
+
) from e
|
174
|
+
elif scheme == HF_SCHEME:
|
175
|
+
if huggingface_hub is None:
|
176
|
+
raise ImportError(
|
177
|
+
f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. "
|
178
|
+
"Please install with `pip install huggingface_hub`."
|
179
|
+
)
|
180
|
+
hf_handle = preset.removeprefix(HF_SCHEME + "://")
|
181
|
+
try:
|
182
|
+
return huggingface_hub.hf_hub_download(
|
183
|
+
repo_id=hf_handle, filename=path
|
184
|
+
)
|
185
|
+
except HFValidationError as e:
|
186
|
+
raise ValueError(
|
187
|
+
"Unexpected Hugging Face preset. Hugging Face model handles "
|
188
|
+
"should have the form 'hf://{org}/{model}'. For example, "
|
189
|
+
f"'hf://username/bert_base_en'. Received: preset={preset}."
|
190
|
+
) from e
|
191
|
+
except EntryNotFoundError as e:
|
192
|
+
message = str(e)
|
193
|
+
if message.find("403 Client Error"):
|
194
|
+
raise FileNotFoundError(
|
195
|
+
f"`{path}` doesn't exist in preset directory `{preset}`."
|
196
|
+
)
|
197
|
+
else:
|
198
|
+
raise ValueError(message)
|
199
|
+
elif os.path.exists(preset):
|
200
|
+
# Assume a local filepath.
|
201
|
+
local_path = os.path.join(preset, path)
|
202
|
+
if not os.path.exists(local_path):
|
203
|
+
raise FileNotFoundError(
|
204
|
+
f"`{path}` doesn't exist in preset directory `{preset}`."
|
205
|
+
)
|
206
|
+
return local_path
|
207
|
+
else:
|
208
|
+
raise ValueError(
|
209
|
+
"Unknown preset identifier. A preset must be a one of:\n"
|
210
|
+
"1) a built-in preset identifier like `'bert_base_en'`\n"
|
211
|
+
"2) a Kaggle Models handle like `'kaggle://keras/bert/keras/bert_base_en'`\n"
|
212
|
+
"3) a Hugging Face handle like `'hf://username/bert_base_en'`\n"
|
213
|
+
"4) a path to a local preset directory like `'./bert_base_en`\n"
|
214
|
+
"Use `print(cls.presets.keys())` to view all built-in presets for "
|
215
|
+
"API symbol `cls`.\n"
|
216
|
+
f"Received: preset='{preset}'"
|
217
|
+
)
|
218
|
+
|
219
|
+
|
220
|
+
def copy_gfile_to_cache(filename, url, cache_subdir):
|
221
|
+
"""Much of this is adapted from get_file of keras core."""
|
222
|
+
if "KERAS_HOME" in os.environ:
|
223
|
+
cachdir_base = os.environ.get("KERAS_HOME")
|
224
|
+
else:
|
225
|
+
cachdir_base = os.path.expanduser(os.path.join("~", ".keras"))
|
226
|
+
if not os.access(cachdir_base, os.W_OK):
|
227
|
+
cachdir_base = os.path.join("/tmp", ".keras")
|
228
|
+
cachedir = os.path.join(cachdir_base, cache_subdir)
|
229
|
+
os.makedirs(cachedir, exist_ok=True)
|
230
|
+
|
231
|
+
fpath = os.path.join(cachedir, filename)
|
232
|
+
if not os.path.exists(fpath):
|
233
|
+
print_msg(f"Downloading data from {url}")
|
234
|
+
try:
|
235
|
+
tf.io.gfile.copy(url, fpath)
|
236
|
+
except Exception as e:
|
237
|
+
# gfile.copy will leave an empty file after an error.
|
238
|
+
# Work around this bug.
|
239
|
+
os.remove(fpath)
|
240
|
+
raise e
|
241
|
+
|
242
|
+
return fpath
|
243
|
+
|
244
|
+
|
245
|
+
def check_file_exists(preset, path):
|
246
|
+
try:
|
247
|
+
get_file(preset, path)
|
248
|
+
except FileNotFoundError:
|
249
|
+
return False
|
250
|
+
return True
|
251
|
+
|
252
|
+
|
253
|
+
def get_tokenizer(layer):
|
254
|
+
"""Get the tokenizer from any KerasHub model or layer."""
|
255
|
+
# Avoid circular import.
|
256
|
+
from keras_hub.src.tokenizers.tokenizer import Tokenizer
|
257
|
+
|
258
|
+
if isinstance(layer, Tokenizer):
|
259
|
+
return layer
|
260
|
+
if hasattr(layer, "tokenizer"):
|
261
|
+
return layer.tokenizer
|
262
|
+
if hasattr(layer, "preprocessor"):
|
263
|
+
return getattr(layer.preprocessor, "tokenizer", None)
|
264
|
+
return None
|
265
|
+
|
266
|
+
|
267
|
+
def recursive_pop(config, key):
|
268
|
+
"""Remove a key from a nested config object"""
|
269
|
+
config.pop(key, None)
|
270
|
+
for value in config.values():
|
271
|
+
if isinstance(value, dict):
|
272
|
+
recursive_pop(value, key)
|
273
|
+
|
274
|
+
|
275
|
+
def make_preset_dir(preset):
|
276
|
+
os.makedirs(preset, exist_ok=True)
|
277
|
+
|
278
|
+
|
279
|
+
def save_tokenizer_assets(tokenizer, preset):
|
280
|
+
if tokenizer:
|
281
|
+
asset_dir = os.path.join(preset, TOKENIZER_ASSET_DIR)
|
282
|
+
os.makedirs(asset_dir, exist_ok=True)
|
283
|
+
tokenizer.save_assets(asset_dir)
|
284
|
+
|
285
|
+
|
286
|
+
def save_serialized_object(
|
287
|
+
layer,
|
288
|
+
preset,
|
289
|
+
config_file=CONFIG_FILE,
|
290
|
+
config_to_skip=[],
|
291
|
+
):
|
292
|
+
make_preset_dir(preset)
|
293
|
+
config_path = os.path.join(preset, config_file)
|
294
|
+
config = keras.saving.serialize_keras_object(layer)
|
295
|
+
config_to_skip += ["compile_config", "build_config"]
|
296
|
+
for c in config_to_skip:
|
297
|
+
recursive_pop(config, c)
|
298
|
+
with open(config_path, "w") as config_file:
|
299
|
+
config_file.write(json.dumps(config, indent=4))
|
300
|
+
|
301
|
+
|
302
|
+
def save_metadata(layer, preset):
|
303
|
+
from keras_hub.src.version_utils import __version__ as keras_hub_version
|
304
|
+
|
305
|
+
keras_version = keras.version() if hasattr(keras, "version") else None
|
306
|
+
metadata = {
|
307
|
+
"keras_version": keras_version,
|
308
|
+
"keras_hub_version": keras_hub_version,
|
309
|
+
"parameter_count": layer.count_params(),
|
310
|
+
"date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
|
311
|
+
}
|
312
|
+
metadata_path = os.path.join(preset, METADATA_FILE)
|
313
|
+
with open(metadata_path, "w") as metadata_file:
|
314
|
+
metadata_file.write(json.dumps(metadata, indent=4))
|
315
|
+
|
316
|
+
|
317
|
+
def _validate_tokenizer(preset, allow_incomplete=False):
|
318
|
+
if not check_file_exists(preset, TOKENIZER_CONFIG_FILE):
|
319
|
+
if allow_incomplete:
|
320
|
+
logging.warning(
|
321
|
+
f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`."
|
322
|
+
)
|
323
|
+
return
|
324
|
+
else:
|
325
|
+
raise FileNotFoundError(
|
326
|
+
f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`. "
|
327
|
+
"To upload the model without a tokenizer, "
|
328
|
+
"set `allow_incomplete=True`."
|
329
|
+
)
|
330
|
+
config_path = get_file(preset, TOKENIZER_CONFIG_FILE)
|
331
|
+
try:
|
332
|
+
with open(config_path, encoding="utf-8") as config_file:
|
333
|
+
config = json.load(config_file)
|
334
|
+
except Exception as e:
|
335
|
+
raise ValueError(
|
336
|
+
f"Tokenizer config file `{config_path}` is an invalid json file. "
|
337
|
+
f"Error message: {e}"
|
338
|
+
)
|
339
|
+
layer = keras.saving.deserialize_keras_object(config)
|
340
|
+
|
341
|
+
for asset in layer.file_assets:
|
342
|
+
asset_path = get_file(preset, os.path.join(TOKENIZER_ASSET_DIR, asset))
|
343
|
+
if not os.path.exists(asset_path):
|
344
|
+
tokenizer_asset_dir = os.path.dirname(asset_path)
|
345
|
+
raise FileNotFoundError(
|
346
|
+
f"Asset `{asset}` doesn't exist in the tokenizer asset direcotry"
|
347
|
+
f" `{tokenizer_asset_dir}`."
|
348
|
+
)
|
349
|
+
config_dir = os.path.dirname(config_path)
|
350
|
+
asset_dir = os.path.join(config_dir, TOKENIZER_ASSET_DIR)
|
351
|
+
|
352
|
+
tokenizer = get_tokenizer(layer)
|
353
|
+
if not tokenizer:
|
354
|
+
raise ValueError(f"Model or layer `{layer}` is missing tokenizer.")
|
355
|
+
tokenizer.load_assets(asset_dir)
|
356
|
+
|
357
|
+
|
358
|
+
def _validate_backbone(preset):
|
359
|
+
config_path = os.path.join(preset, CONFIG_FILE)
|
360
|
+
if not os.path.exists(config_path):
|
361
|
+
raise FileNotFoundError(
|
362
|
+
f"`{CONFIG_FILE}` is missing from the preset directory `{preset}`."
|
363
|
+
)
|
364
|
+
try:
|
365
|
+
with open(config_path, encoding="utf-8") as config_file:
|
366
|
+
json.load(config_file)
|
367
|
+
except Exception as e:
|
368
|
+
raise ValueError(
|
369
|
+
f"Config file `{config_path}` is an invalid json file. "
|
370
|
+
f"Error message: {e}"
|
371
|
+
)
|
372
|
+
|
373
|
+
weights_path = os.path.join(preset, MODEL_WEIGHTS_FILE)
|
374
|
+
if not os.path.exists(weights_path):
|
375
|
+
raise FileNotFoundError(
|
376
|
+
f"The weights file is missing from the preset directory `{preset}`."
|
377
|
+
)
|
378
|
+
|
379
|
+
|
380
|
+
def get_snake_case(name):
|
381
|
+
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
382
|
+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
383
|
+
|
384
|
+
|
385
|
+
def create_model_card(preset):
|
386
|
+
model_card_path = os.path.join(preset, README_FILE)
|
387
|
+
markdown_content = ""
|
388
|
+
|
389
|
+
config = load_config(preset, CONFIG_FILE)
|
390
|
+
model_name = (
|
391
|
+
config["class_name"].replace("Backbone", "")
|
392
|
+
if config["class_name"].endswith("Backbone")
|
393
|
+
else config["class_name"]
|
394
|
+
)
|
395
|
+
|
396
|
+
task_type = None
|
397
|
+
if check_file_exists(preset, TASK_CONFIG_FILE):
|
398
|
+
task_config = load_config(preset, TASK_CONFIG_FILE)
|
399
|
+
task_type = (
|
400
|
+
task_config["class_name"].replace(model_name, "")
|
401
|
+
if task_config["class_name"].startswith(model_name)
|
402
|
+
else task_config["class_name"]
|
403
|
+
)
|
404
|
+
|
405
|
+
# YAML
|
406
|
+
markdown_content += "---\n"
|
407
|
+
markdown_content += "library_name: keras-hub\n"
|
408
|
+
if task_type == "CausalLM":
|
409
|
+
markdown_content += "pipeline_tag: text-generation\n"
|
410
|
+
elif task_type == "Classifier":
|
411
|
+
markdown_content += "pipeline_tag: text-classification\n"
|
412
|
+
markdown_content += "---\n"
|
413
|
+
|
414
|
+
model_link = (
|
415
|
+
f"https://keras.io/api/keras_hub/models/{get_snake_case(model_name)}"
|
416
|
+
)
|
417
|
+
markdown_content += (
|
418
|
+
f"This is a [`{model_name}` model]({model_link}) "
|
419
|
+
"uploaded using the KerasHub library and can be used with JAX, "
|
420
|
+
"TensorFlow, and PyTorch backends.\n"
|
421
|
+
)
|
422
|
+
if task_type:
|
423
|
+
markdown_content += (
|
424
|
+
f"This model is related to a `{task_type}` task.\n\n"
|
425
|
+
)
|
426
|
+
|
427
|
+
backbone_config = config["config"]
|
428
|
+
markdown_content += "Model config:\n"
|
429
|
+
for k, v in backbone_config.items():
|
430
|
+
markdown_content += f"* **{k}:** {v}\n"
|
431
|
+
markdown_content += "\n"
|
432
|
+
markdown_content += (
|
433
|
+
"This model card has been generated automatically and should be completed "
|
434
|
+
"by the model author. See [Model Cards documentation]"
|
435
|
+
"(https://huggingface.co/docs/hub/model-cards) for more information.\n"
|
436
|
+
)
|
437
|
+
|
438
|
+
with open(model_card_path, "w") as md_file:
|
439
|
+
md_file.write(markdown_content)
|
440
|
+
|
441
|
+
|
442
|
+
def delete_model_card(preset):
|
443
|
+
model_card_path = os.path.join(preset, README_FILE)
|
444
|
+
try:
|
445
|
+
os.remove(model_card_path)
|
446
|
+
except FileNotFoundError:
|
447
|
+
logging.warning(
|
448
|
+
f"There was an attempt to delete file `{model_card_path}` but this"
|
449
|
+
" file doesn't exist."
|
450
|
+
)
|
451
|
+
|
452
|
+
|
453
|
+
@keras_hub_export("keras_hub.upload_preset")
|
454
|
+
def upload_preset(
|
455
|
+
uri,
|
456
|
+
preset,
|
457
|
+
allow_incomplete=False,
|
458
|
+
):
|
459
|
+
"""Upload a preset directory to a model hub.
|
460
|
+
|
461
|
+
Args:
|
462
|
+
uri: The URI identifying model to upload to.
|
463
|
+
URIs with format
|
464
|
+
`kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>`
|
465
|
+
will be uploaded to Kaggle Hub while URIs with format
|
466
|
+
`hf://[<HF_USERNAME>/]<MODEL>` will be uploaded to the Hugging
|
467
|
+
Face Hub.
|
468
|
+
preset: The path to the local model preset directory.
|
469
|
+
allow_incomplete: If True, allows the upload of presets without
|
470
|
+
a tokenizer configuration. Otherwise, a tokenizer
|
471
|
+
is required.
|
472
|
+
"""
|
473
|
+
|
474
|
+
# Check if preset directory exists.
|
475
|
+
if not os.path.exists(preset):
|
476
|
+
raise FileNotFoundError(f"The preset directory {preset} doesn't exist.")
|
477
|
+
|
478
|
+
_validate_backbone(preset)
|
479
|
+
_validate_tokenizer(preset, allow_incomplete)
|
480
|
+
|
481
|
+
if uri.startswith(KAGGLE_PREFIX):
|
482
|
+
if kagglehub is None:
|
483
|
+
raise ImportError(
|
484
|
+
"Uploading a model to Kaggle Hub requires the `kagglehub` package. "
|
485
|
+
"Please install with `pip install kagglehub`."
|
486
|
+
)
|
487
|
+
if parse(kagglehub.__version__) < parse("0.2.4"):
|
488
|
+
raise ImportError(
|
489
|
+
"Uploading a model to Kaggle Hub requires the `kagglehub` package version `0.2.4` or higher. "
|
490
|
+
"Please upgrade with `pip install --upgrade kagglehub`."
|
491
|
+
)
|
492
|
+
kaggle_handle = uri.removeprefix(KAGGLE_PREFIX)
|
493
|
+
kagglehub.model_upload(kaggle_handle, preset)
|
494
|
+
elif uri.startswith(HF_PREFIX):
|
495
|
+
if huggingface_hub is None:
|
496
|
+
raise ImportError(
|
497
|
+
f"`upload_preset()` requires the `huggingface_hub` package to upload to '{uri}'. "
|
498
|
+
"Please install with `pip install huggingface_hub`."
|
499
|
+
)
|
500
|
+
hf_handle = uri.removeprefix(HF_PREFIX)
|
501
|
+
try:
|
502
|
+
repo_url = huggingface_hub.create_repo(
|
503
|
+
repo_id=hf_handle, exist_ok=True
|
504
|
+
)
|
505
|
+
except HFValidationError as e:
|
506
|
+
raise ValueError(
|
507
|
+
"Unexpected Hugging Face URI. Hugging Face model handles "
|
508
|
+
"should have the form 'hf://[{org}/]{model}'. For example, "
|
509
|
+
"'hf://username/bert_base_en' or 'hf://bert_case_en' to implicitly"
|
510
|
+
f"upload to your user account. Received: URI={uri}."
|
511
|
+
) from e
|
512
|
+
has_model_card = huggingface_hub.file_exists(
|
513
|
+
repo_id=repo_url.repo_id, filename=README_FILE
|
514
|
+
)
|
515
|
+
if not has_model_card:
|
516
|
+
# Remote repo doesn't have a model card so a basic model card is automatically generated.
|
517
|
+
create_model_card(preset)
|
518
|
+
try:
|
519
|
+
huggingface_hub.upload_folder(
|
520
|
+
repo_id=repo_url.repo_id, folder_path=preset
|
521
|
+
)
|
522
|
+
finally:
|
523
|
+
if not has_model_card:
|
524
|
+
# Clean up the preset directory in case user attempts to upload the
|
525
|
+
# preset directory into Kaggle hub as well.
|
526
|
+
delete_model_card(preset)
|
527
|
+
else:
|
528
|
+
raise ValueError(
|
529
|
+
"Unknown URI. An URI must be a one of:\n"
|
530
|
+
"1) a Kaggle Model handle like `'kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>'`\n"
|
531
|
+
"2) a Hugging Face handle like `'hf://[<HF_USERNAME>/]<MODEL>'`\n"
|
532
|
+
f"Received: uri='{uri}'."
|
533
|
+
)
|
534
|
+
|
535
|
+
|
536
|
+
def load_config(preset, config_file=CONFIG_FILE):
|
537
|
+
config_path = get_file(preset, config_file)
|
538
|
+
with open(config_path, encoding="utf-8") as config_file:
|
539
|
+
config = json.load(config_file)
|
540
|
+
return config
|
541
|
+
|
542
|
+
|
543
|
+
def check_format(preset):
|
544
|
+
if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists(
|
545
|
+
preset, SAFETENSOR_CONFIG_FILE
|
546
|
+
):
|
547
|
+
# Determine the format by parsing the config file.
|
548
|
+
config = load_config(preset, HF_CONFIG_FILE)
|
549
|
+
if "hf://timm" in preset or "architecture" in config:
|
550
|
+
return "timm"
|
551
|
+
return "transformers"
|
552
|
+
|
553
|
+
if not check_file_exists(preset, METADATA_FILE):
|
554
|
+
raise FileNotFoundError(
|
555
|
+
f"The preset directory `{preset}` doesn't have a file named `{METADATA_FILE}`, "
|
556
|
+
"or you do not have access to it. This file is required to load a Keras model "
|
557
|
+
"preset. Please verify that the model you are trying to load is a Keras model."
|
558
|
+
)
|
559
|
+
metadata = load_config(preset, METADATA_FILE)
|
560
|
+
if "keras_version" not in metadata:
|
561
|
+
raise ValueError(
|
562
|
+
f"`{METADATA_FILE}` in the preset directory `{preset}` doesn't have `keras_version`. "
|
563
|
+
"Please verify that the model you are trying to load is a Keras model."
|
564
|
+
)
|
565
|
+
return "keras"
|
566
|
+
|
567
|
+
|
568
|
+
def load_serialized_object(preset, config_file=CONFIG_FILE, **kwargs):
|
569
|
+
kwargs = kwargs or {}
|
570
|
+
config = load_config(preset, config_file)
|
571
|
+
|
572
|
+
# `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
|
573
|
+
# Ensure that `dtype` is properly configured.
|
574
|
+
dtype = kwargs.pop("dtype", None)
|
575
|
+
config = set_dtype_in_config(config, dtype)
|
576
|
+
|
577
|
+
config["config"] = {**config["config"], **kwargs}
|
578
|
+
return keras.saving.deserialize_keras_object(config)
|
579
|
+
|
580
|
+
|
581
|
+
def check_config_class(
|
582
|
+
preset,
|
583
|
+
config_file=CONFIG_FILE,
|
584
|
+
):
|
585
|
+
"""Validate a preset is being loaded on the correct class."""
|
586
|
+
config_path = get_file(preset, config_file)
|
587
|
+
with open(config_path, encoding="utf-8") as config_file:
|
588
|
+
config = json.load(config_file)
|
589
|
+
return keras.saving.get_registered_object(config["registered_name"])
|
590
|
+
|
591
|
+
|
592
|
+
def jax_memory_cleanup(layer):
|
593
|
+
# For jax, delete all previous allocated memory to avoid temporarily
|
594
|
+
# duplicating variable allocations. torch and tensorflow have stateful
|
595
|
+
# variable types and do not need this fix.
|
596
|
+
if keras.config.backend() == "jax":
|
597
|
+
for weight in layer.weights:
|
598
|
+
if getattr(weight, "_value", None) is not None:
|
599
|
+
weight._value.delete()
|
600
|
+
|
601
|
+
|
602
|
+
def set_dtype_in_config(config, dtype=None):
|
603
|
+
if dtype is None:
|
604
|
+
return config
|
605
|
+
|
606
|
+
config = config.copy()
|
607
|
+
if "dtype" not in config["config"]:
|
608
|
+
# Forward `dtype` to the config.
|
609
|
+
config["config"]["dtype"] = dtype
|
610
|
+
elif (
|
611
|
+
"dtype" in config["config"]
|
612
|
+
and isinstance(config["config"]["dtype"], dict)
|
613
|
+
and "DTypePolicyMap" in config["config"]["dtype"]["class_name"]
|
614
|
+
):
|
615
|
+
# If it is `DTypePolicyMap` in `config`, forward `dtype` as its default
|
616
|
+
# policy.
|
617
|
+
policy_map_config = config["config"]["dtype"]["config"]
|
618
|
+
policy_map_config["default_policy"] = dtype
|
619
|
+
for k in policy_map_config["policy_map"].keys():
|
620
|
+
policy_map_config["policy_map"][k]["config"]["source_name"] = dtype
|
621
|
+
return config
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Utilities with miscellaneous python extensions."""
|
15
|
+
|
16
|
+
|
17
|
+
class classproperty(property):
|
18
|
+
"""Define a class level property."""
|
19
|
+
|
20
|
+
def __get__(self, _, owner_cls):
|
21
|
+
return self.fget(owner_cls)
|