keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -7,19 +7,10 @@ import re
|
|
7
7
|
|
8
8
|
import keras
|
9
9
|
from absl import logging
|
10
|
-
from packaging.version import parse
|
11
10
|
|
12
11
|
from keras_hub.src.api_export import keras_hub_export
|
13
12
|
from keras_hub.src.utils.keras_utils import print_msg
|
14
13
|
|
15
|
-
try:
|
16
|
-
import tensorflow as tf
|
17
|
-
except ImportError:
|
18
|
-
raise ImportError(
|
19
|
-
"To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
|
20
|
-
"The TensorFlow package is required for data preprocessing with any backend."
|
21
|
-
)
|
22
|
-
|
23
14
|
try:
|
24
15
|
import kagglehub
|
25
16
|
from kagglehub.exceptions import KaggleApiHTTPError
|
@@ -172,26 +163,13 @@ def get_file(preset, path):
|
|
172
163
|
)
|
173
164
|
else:
|
174
165
|
raise ValueError(message)
|
175
|
-
|
176
|
-
|
177
|
-
url = os.path.join(preset, path)
|
178
|
-
subdir = preset.replace("://", "_").replace("-", "_").replace("/", "_")
|
179
|
-
filename = os.path.basename(path)
|
180
|
-
subdir = os.path.join(subdir, os.path.dirname(path))
|
181
|
-
try:
|
182
|
-
return copy_gfile_to_cache(
|
183
|
-
filename,
|
184
|
-
url,
|
185
|
-
cache_subdir=os.path.join("models", subdir),
|
186
|
-
)
|
187
|
-
except (tf.errors.PermissionDeniedError, tf.errors.NotFoundError) as e:
|
188
|
-
raise FileNotFoundError(
|
189
|
-
f"`{path}` doesn't exist in preset directory `{preset}`.",
|
190
|
-
) from e
|
166
|
+
elif scheme in tf_registered_schemes():
|
167
|
+
return tf_copy_gfile_to_cache(preset, path)
|
191
168
|
elif scheme == HF_SCHEME:
|
192
169
|
if huggingface_hub is None:
|
193
170
|
raise ImportError(
|
194
|
-
|
171
|
+
"`from_preset()` requires the `huggingface_hub` package to "
|
172
|
+
"load from '{preset}'. "
|
195
173
|
"Please install with `pip install huggingface_hub`."
|
196
174
|
)
|
197
175
|
hf_handle = preset.removeprefix(HF_SCHEME + "://")
|
@@ -225,7 +203,8 @@ def get_file(preset, path):
|
|
225
203
|
raise ValueError(
|
226
204
|
"Unknown preset identifier. A preset must be a one of:\n"
|
227
205
|
"1) a built-in preset identifier like `'bert_base_en'`\n"
|
228
|
-
"2) a Kaggle Models handle like
|
206
|
+
"2) a Kaggle Models handle like "
|
207
|
+
"`'kaggle://keras/bert/keras/bert_base_en'`\n"
|
229
208
|
"3) a Hugging Face handle like `'hf://username/bert_base_en'`\n"
|
230
209
|
"4) a path to a local preset directory like `'./bert_base_en`\n"
|
231
210
|
"Use `print(cls.presets.keys())` to view all built-in presets for "
|
@@ -234,29 +213,48 @@ def get_file(preset, path):
|
|
234
213
|
)
|
235
214
|
|
236
215
|
|
237
|
-
def
|
216
|
+
def tf_registered_schemes():
|
217
|
+
try:
|
218
|
+
import tensorflow as tf
|
219
|
+
|
220
|
+
return tf.io.gfile.get_registered_schemes()
|
221
|
+
except ImportError:
|
222
|
+
return []
|
223
|
+
|
224
|
+
|
225
|
+
def tf_copy_gfile_to_cache(preset, path):
|
238
226
|
"""Much of this is adapted from get_file of keras core."""
|
239
227
|
if "KERAS_HOME" in os.environ:
|
240
|
-
|
228
|
+
base_dir = os.environ.get("KERAS_HOME")
|
241
229
|
else:
|
242
|
-
|
243
|
-
if not os.access(
|
244
|
-
|
245
|
-
|
246
|
-
os.
|
247
|
-
|
248
|
-
|
249
|
-
|
230
|
+
base_dir = os.path.expanduser(os.path.join("~", ".keras"))
|
231
|
+
if not os.access(base_dir, os.W_OK):
|
232
|
+
base_dir = os.path.join("/tmp", ".keras")
|
233
|
+
|
234
|
+
url = os.path.join(preset, path)
|
235
|
+
model_dir = preset.replace("://", "_").replace("-", "_").replace("/", "_")
|
236
|
+
local_path = os.path.join(base_dir, "models", model_dir, path)
|
237
|
+
|
238
|
+
if not os.path.exists(local_path):
|
250
239
|
print_msg(f"Downloading data from {url}")
|
251
240
|
try:
|
252
|
-
tf
|
241
|
+
import tensorflow as tf
|
242
|
+
|
243
|
+
os.make_dirs(os.path.dirname(local_path), exist_ok=True)
|
244
|
+
tf.io.gfile.copy(url, local_path)
|
253
245
|
except Exception as e:
|
254
246
|
# gfile.copy will leave an empty file after an error.
|
255
247
|
# Work around this bug.
|
256
|
-
os.remove(
|
248
|
+
os.remove(local_path)
|
249
|
+
if isinstance(
|
250
|
+
e, tf.errors.PermissionDeniedError, tf.errors.NotFoundError
|
251
|
+
):
|
252
|
+
raise FileNotFoundError(
|
253
|
+
f"`{path}` doesn't exist in preset directory `{preset}`.",
|
254
|
+
) from e
|
257
255
|
raise e
|
258
256
|
|
259
|
-
return
|
257
|
+
return local_path
|
260
258
|
|
261
259
|
|
262
260
|
def check_file_exists(preset, path):
|
@@ -267,64 +265,6 @@ def check_file_exists(preset, path):
|
|
267
265
|
return True
|
268
266
|
|
269
267
|
|
270
|
-
def get_tokenizer(layer):
|
271
|
-
"""Get the tokenizer from any KerasHub model or layer."""
|
272
|
-
# Avoid circular import.
|
273
|
-
from keras_hub.src.tokenizers.tokenizer import Tokenizer
|
274
|
-
|
275
|
-
if isinstance(layer, Tokenizer):
|
276
|
-
return layer
|
277
|
-
if hasattr(layer, "tokenizer"):
|
278
|
-
return layer.tokenizer
|
279
|
-
if hasattr(layer, "preprocessor"):
|
280
|
-
return getattr(layer.preprocessor, "tokenizer", None)
|
281
|
-
return None
|
282
|
-
|
283
|
-
|
284
|
-
def recursive_pop(config, key):
|
285
|
-
"""Remove a key from a nested config object"""
|
286
|
-
config.pop(key, None)
|
287
|
-
for value in config.values():
|
288
|
-
if isinstance(value, dict):
|
289
|
-
recursive_pop(value, key)
|
290
|
-
|
291
|
-
|
292
|
-
# TODO: refactor saving routines into a PresetSaver class?
|
293
|
-
def make_preset_dir(preset):
|
294
|
-
os.makedirs(preset, exist_ok=True)
|
295
|
-
|
296
|
-
|
297
|
-
def save_serialized_object(
|
298
|
-
layer,
|
299
|
-
preset,
|
300
|
-
config_file=CONFIG_FILE,
|
301
|
-
config_to_skip=[],
|
302
|
-
):
|
303
|
-
make_preset_dir(preset)
|
304
|
-
config_path = os.path.join(preset, config_file)
|
305
|
-
config = keras.saving.serialize_keras_object(layer)
|
306
|
-
config_to_skip += ["compile_config", "build_config"]
|
307
|
-
for c in config_to_skip:
|
308
|
-
recursive_pop(config, c)
|
309
|
-
with open(config_path, "w") as config_file:
|
310
|
-
config_file.write(json.dumps(config, indent=4))
|
311
|
-
|
312
|
-
|
313
|
-
def save_metadata(layer, preset):
|
314
|
-
from keras_hub.src.version_utils import __version__ as keras_hub_version
|
315
|
-
|
316
|
-
keras_version = keras.version() if hasattr(keras, "version") else None
|
317
|
-
metadata = {
|
318
|
-
"keras_version": keras_version,
|
319
|
-
"keras_hub_version": keras_hub_version,
|
320
|
-
"parameter_count": layer.count_params(),
|
321
|
-
"date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
|
322
|
-
}
|
323
|
-
metadata_path = os.path.join(preset, METADATA_FILE)
|
324
|
-
with open(metadata_path, "w") as metadata_file:
|
325
|
-
metadata_file.write(json.dumps(metadata, indent=4))
|
326
|
-
|
327
|
-
|
328
268
|
def _validate_backbone(preset):
|
329
269
|
config_path = os.path.join(preset, CONFIG_FILE)
|
330
270
|
if not os.path.exists(config_path):
|
@@ -400,8 +340,8 @@ def create_model_card(preset):
|
|
400
340
|
markdown_content += f"* **{k}:** {v}\n"
|
401
341
|
markdown_content += "\n"
|
402
342
|
markdown_content += (
|
403
|
-
"This model card has been generated automatically and should be
|
404
|
-
"by the model author. See [Model Cards documentation]"
|
343
|
+
"This model card has been generated automatically and should be "
|
344
|
+
"completed by the model author. See [Model Cards documentation]"
|
405
345
|
"(https://huggingface.co/docs/hub/model-cards) for more information.\n"
|
406
346
|
)
|
407
347
|
|
@@ -446,20 +386,16 @@ def upload_preset(
|
|
446
386
|
if uri.startswith(KAGGLE_PREFIX):
|
447
387
|
if kagglehub is None:
|
448
388
|
raise ImportError(
|
449
|
-
"Uploading a model to Kaggle Hub requires the `kagglehub`
|
450
|
-
"Please install with `pip install kagglehub`."
|
451
|
-
)
|
452
|
-
if parse(kagglehub.__version__) < parse("0.2.4"):
|
453
|
-
raise ImportError(
|
454
|
-
"Uploading a model to Kaggle Hub requires the `kagglehub` package version `0.2.4` or higher. "
|
455
|
-
"Please upgrade with `pip install --upgrade kagglehub`."
|
389
|
+
"Uploading a model to Kaggle Hub requires the `kagglehub` "
|
390
|
+
"package. Please install with `pip install kagglehub`."
|
456
391
|
)
|
457
392
|
kaggle_handle = uri.removeprefix(KAGGLE_PREFIX)
|
458
393
|
kagglehub.model_upload(kaggle_handle, preset)
|
459
394
|
elif uri.startswith(HF_PREFIX):
|
460
395
|
if huggingface_hub is None:
|
461
396
|
raise ImportError(
|
462
|
-
f"`upload_preset()` requires the `huggingface_hub` package
|
397
|
+
f"`upload_preset()` requires the `huggingface_hub` package "
|
398
|
+
f"to upload to '{uri}'. "
|
463
399
|
"Please install with `pip install huggingface_hub`."
|
464
400
|
)
|
465
401
|
hf_handle = uri.removeprefix(HF_PREFIX)
|
@@ -471,14 +407,15 @@ def upload_preset(
|
|
471
407
|
raise ValueError(
|
472
408
|
"Unexpected Hugging Face URI. Hugging Face model handles "
|
473
409
|
"should have the form 'hf://[{org}/]{model}'. For example, "
|
474
|
-
"'hf://username/bert_base_en' or 'hf://bert_case_en' to
|
475
|
-
f"upload to your user account. Received: URI={uri}."
|
410
|
+
"'hf://username/bert_base_en' or 'hf://bert_case_en' to "
|
411
|
+
f"implicitly upload to your user account. Received: URI={uri}."
|
476
412
|
) from e
|
477
413
|
has_model_card = huggingface_hub.file_exists(
|
478
414
|
repo_id=repo_url.repo_id, filename=README_FILE
|
479
415
|
)
|
480
416
|
if not has_model_card:
|
481
|
-
# Remote repo doesn't have a model card so a basic model card is
|
417
|
+
# Remote repo doesn't have a model card so a basic model card is
|
418
|
+
# automatically generated.
|
482
419
|
create_model_card(preset)
|
483
420
|
try:
|
484
421
|
huggingface_hub.upload_folder(
|
@@ -486,13 +423,14 @@ def upload_preset(
|
|
486
423
|
)
|
487
424
|
finally:
|
488
425
|
if not has_model_card:
|
489
|
-
# Clean up the preset directory in case user attempts to upload
|
490
|
-
# preset directory into Kaggle hub as well.
|
426
|
+
# Clean up the preset directory in case user attempts to upload
|
427
|
+
# the preset directory into Kaggle hub as well.
|
491
428
|
delete_model_card(preset)
|
492
429
|
else:
|
493
430
|
raise ValueError(
|
494
431
|
"Unknown URI. An URI must be a one of:\n"
|
495
|
-
"1) a Kaggle Model handle like
|
432
|
+
"1) a Kaggle Model handle like "
|
433
|
+
"`'kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>'`\n"
|
496
434
|
"2) a Hugging Face handle like `'hf://[<HF_USERNAME>/]<MODEL>'`\n"
|
497
435
|
f"Received: uri='{uri}'."
|
498
436
|
)
|
@@ -505,19 +443,11 @@ def load_json(preset, config_file=CONFIG_FILE):
|
|
505
443
|
return config
|
506
444
|
|
507
445
|
|
508
|
-
def load_serialized_object(config, **kwargs):
|
509
|
-
# `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
|
510
|
-
# Ensure that `dtype` is properly configured.
|
511
|
-
dtype = kwargs.pop("dtype", None)
|
512
|
-
config = set_dtype_in_config(config, dtype)
|
513
|
-
|
514
|
-
config["config"] = {**config["config"], **kwargs}
|
515
|
-
return keras.saving.deserialize_keras_object(config)
|
516
|
-
|
517
|
-
|
518
446
|
def check_config_class(config):
|
519
447
|
"""Validate a preset is being loaded on the correct class."""
|
520
448
|
registered_name = config["registered_name"]
|
449
|
+
if registered_name in ("Functional", "Sequential"):
|
450
|
+
return keras.Model
|
521
451
|
cls = keras.saving.get_registered_object(registered_name)
|
522
452
|
if cls is None:
|
523
453
|
raise ValueError(
|
@@ -600,6 +530,13 @@ def get_preset_loader(preset):
|
|
600
530
|
)
|
601
531
|
|
602
532
|
|
533
|
+
def get_preset_saver(preset):
|
534
|
+
# Unlike loading, we only support one form of saving; Keras serialized
|
535
|
+
# configs and saved weights. We keep the rough API structure as loading
|
536
|
+
# just for simplicity.
|
537
|
+
return KerasPresetSaver(preset)
|
538
|
+
|
539
|
+
|
603
540
|
class PresetLoader:
|
604
541
|
def __init__(self, preset, config):
|
605
542
|
self.config = config
|
@@ -612,10 +549,8 @@ class PresetLoader:
|
|
612
549
|
backbone_kwargs["dtype"] = kwargs.pop("dtype", None)
|
613
550
|
|
614
551
|
# Forward `height` and `width` to backbone when using `TextToImage`.
|
615
|
-
if "
|
616
|
-
backbone_kwargs["
|
617
|
-
if "width" in kwargs:
|
618
|
-
backbone_kwargs["width"] = kwargs.pop("width", None)
|
552
|
+
if "image_shape" in kwargs:
|
553
|
+
backbone_kwargs["image_shape"] = kwargs.pop("image_shape", None)
|
619
554
|
|
620
555
|
return backbone_kwargs, kwargs
|
621
556
|
|
@@ -627,7 +562,7 @@ class PresetLoader:
|
|
627
562
|
"""Load the backbone model from the preset."""
|
628
563
|
raise NotImplementedError
|
629
564
|
|
630
|
-
def load_tokenizer(self, cls,
|
565
|
+
def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
|
631
566
|
"""Load a tokenizer layer from the preset."""
|
632
567
|
raise NotImplementedError
|
633
568
|
|
@@ -658,7 +593,7 @@ class PresetLoader:
|
|
658
593
|
return cls(**kwargs)
|
659
594
|
|
660
595
|
def load_preprocessor(
|
661
|
-
self, cls,
|
596
|
+
self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs
|
662
597
|
):
|
663
598
|
"""Load a prepocessor layer from the preset.
|
664
599
|
|
@@ -675,25 +610,26 @@ class KerasPresetLoader(PresetLoader):
|
|
675
610
|
return check_config_class(self.config)
|
676
611
|
|
677
612
|
def load_backbone(self, cls, load_weights, **kwargs):
|
678
|
-
backbone =
|
613
|
+
backbone = self._load_serialized_object(self.config, **kwargs)
|
679
614
|
if load_weights:
|
680
615
|
jax_memory_cleanup(backbone)
|
681
616
|
backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
|
682
617
|
return backbone
|
683
618
|
|
684
|
-
def load_tokenizer(self, cls,
|
685
|
-
tokenizer_config = load_json(self.preset,
|
686
|
-
tokenizer =
|
687
|
-
tokenizer
|
619
|
+
def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
|
620
|
+
tokenizer_config = load_json(self.preset, config_file)
|
621
|
+
tokenizer = self._load_serialized_object(tokenizer_config, **kwargs)
|
622
|
+
if hasattr(tokenizer, "load_preset_assets"):
|
623
|
+
tokenizer.load_preset_assets(self.preset)
|
688
624
|
return tokenizer
|
689
625
|
|
690
626
|
def load_audio_converter(self, cls, **kwargs):
|
691
627
|
converter_config = load_json(self.preset, AUDIO_CONVERTER_CONFIG_FILE)
|
692
|
-
return
|
628
|
+
return self._load_serialized_object(converter_config, **kwargs)
|
693
629
|
|
694
630
|
def load_image_converter(self, cls, **kwargs):
|
695
631
|
converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE)
|
696
|
-
return
|
632
|
+
return self._load_serialized_object(converter_config, **kwargs)
|
697
633
|
|
698
634
|
def load_task(self, cls, load_weights, load_task_weights, **kwargs):
|
699
635
|
# If there is no `task.json` or it's for the wrong class delegate to the
|
@@ -708,8 +644,16 @@ class KerasPresetLoader(PresetLoader):
|
|
708
644
|
cls, load_weights, load_task_weights, **kwargs
|
709
645
|
)
|
710
646
|
# We found a `task.json` with a complete config for our class.
|
711
|
-
|
712
|
-
|
647
|
+
# Forward backbone args.
|
648
|
+
backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs)
|
649
|
+
if "backbone" in task_config["config"]:
|
650
|
+
backbone_config = task_config["config"]["backbone"]["config"]
|
651
|
+
backbone_config = {**backbone_config, **backbone_kwargs}
|
652
|
+
task_config["config"]["backbone"]["config"] = backbone_config
|
653
|
+
task = self._load_serialized_object(task_config, **kwargs)
|
654
|
+
if task.preprocessor and hasattr(
|
655
|
+
task.preprocessor, "load_preset_assets"
|
656
|
+
):
|
713
657
|
task.preprocessor.load_preset_assets(self.preset)
|
714
658
|
if load_weights:
|
715
659
|
has_task_weights = check_file_exists(self.preset, TASK_WEIGHTS_FILE)
|
@@ -724,16 +668,124 @@ class KerasPresetLoader(PresetLoader):
|
|
724
668
|
return task
|
725
669
|
|
726
670
|
def load_preprocessor(
|
727
|
-
self, cls,
|
671
|
+
self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs
|
728
672
|
):
|
729
673
|
# If there is no `preprocessing.json` or it's for the wrong class,
|
730
674
|
# delegate to the super class loader.
|
731
|
-
if not check_file_exists(self.preset,
|
675
|
+
if not check_file_exists(self.preset, config_file):
|
732
676
|
return super().load_preprocessor(cls, **kwargs)
|
733
|
-
preprocessor_json = load_json(self.preset,
|
677
|
+
preprocessor_json = load_json(self.preset, config_file)
|
734
678
|
if not issubclass(check_config_class(preprocessor_json), cls):
|
735
679
|
return super().load_preprocessor(cls, **kwargs)
|
736
680
|
# We found a `preprocessing.json` with a complete config for our class.
|
737
|
-
preprocessor =
|
738
|
-
preprocessor
|
681
|
+
preprocessor = self._load_serialized_object(preprocessor_json, **kwargs)
|
682
|
+
if hasattr(preprocessor, "load_preset_assets"):
|
683
|
+
preprocessor.load_preset_assets(self.preset)
|
739
684
|
return preprocessor
|
685
|
+
|
686
|
+
def _load_serialized_object(self, config, **kwargs):
|
687
|
+
# `dtype` in config might be a serialized `DTypePolicy` or
|
688
|
+
# `DTypePolicyMap`. Ensure that `dtype` is properly configured.
|
689
|
+
dtype = kwargs.pop("dtype", None)
|
690
|
+
config = set_dtype_in_config(config, dtype)
|
691
|
+
|
692
|
+
config["config"] = {**config["config"], **kwargs}
|
693
|
+
return keras.saving.deserialize_keras_object(config)
|
694
|
+
|
695
|
+
|
696
|
+
class KerasPresetSaver:
|
697
|
+
def __init__(self, preset_dir):
|
698
|
+
os.makedirs(preset_dir, exist_ok=True)
|
699
|
+
self.preset_dir = preset_dir
|
700
|
+
|
701
|
+
def save_backbone(self, backbone):
|
702
|
+
self._save_serialized_object(backbone, config_file=CONFIG_FILE)
|
703
|
+
backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE)
|
704
|
+
backbone.save_weights(backbone_weight_path)
|
705
|
+
self._save_metadata(backbone)
|
706
|
+
|
707
|
+
def save_tokenizer(self, tokenizer):
|
708
|
+
config_file = TOKENIZER_CONFIG_FILE
|
709
|
+
if hasattr(tokenizer, "config_file"):
|
710
|
+
config_file = tokenizer.config_file
|
711
|
+
self._save_serialized_object(tokenizer, config_file)
|
712
|
+
# Save assets.
|
713
|
+
subdir = config_file.split(".")[0]
|
714
|
+
asset_dir = os.path.join(self.preset_dir, ASSET_DIR, subdir)
|
715
|
+
os.makedirs(asset_dir, exist_ok=True)
|
716
|
+
tokenizer.save_assets(asset_dir)
|
717
|
+
|
718
|
+
def save_audio_converter(self, converter):
|
719
|
+
self._save_serialized_object(converter, AUDIO_CONVERTER_CONFIG_FILE)
|
720
|
+
|
721
|
+
def save_image_converter(self, converter):
|
722
|
+
self._save_serialized_object(converter, IMAGE_CONVERTER_CONFIG_FILE)
|
723
|
+
|
724
|
+
def save_task(self, task):
|
725
|
+
# Save task specific config and weights.
|
726
|
+
self._save_serialized_object(task, TASK_CONFIG_FILE)
|
727
|
+
if task.has_task_weights():
|
728
|
+
task_weight_path = os.path.join(self.preset_dir, TASK_WEIGHTS_FILE)
|
729
|
+
task.save_task_weights(task_weight_path)
|
730
|
+
# Save backbone.
|
731
|
+
if hasattr(task.backbone, "save_to_preset"):
|
732
|
+
task.backbone.save_to_preset(self.preset_dir)
|
733
|
+
else:
|
734
|
+
# Allow saving a `keras.Model` that is not a backbone subclass.
|
735
|
+
self.save_backbone(task.backbone)
|
736
|
+
# Save preprocessor.
|
737
|
+
if task.preprocessor and hasattr(task.preprocessor, "save_to_preset"):
|
738
|
+
task.preprocessor.save_to_preset(self.preset_dir)
|
739
|
+
else:
|
740
|
+
# Allow saving a `keras.Layer` that is not a preprocessor subclass.
|
741
|
+
self.save_preprocessor(task.preprocessor)
|
742
|
+
|
743
|
+
def save_preprocessor(self, preprocessor):
|
744
|
+
config_file = PREPROCESSOR_CONFIG_FILE
|
745
|
+
if hasattr(preprocessor, "config_file"):
|
746
|
+
config_file = preprocessor.config_file
|
747
|
+
self._save_serialized_object(preprocessor, config_file)
|
748
|
+
for layer in preprocessor._flatten_layers(include_self=False):
|
749
|
+
if hasattr(layer, "save_to_preset"):
|
750
|
+
layer.save_to_preset(self.preset_dir)
|
751
|
+
|
752
|
+
def _recursive_pop(self, config, key):
|
753
|
+
"""Remove a key from a nested config object"""
|
754
|
+
config.pop(key, None)
|
755
|
+
for value in config.values():
|
756
|
+
if isinstance(value, dict):
|
757
|
+
self._recursive_pop(value, key)
|
758
|
+
|
759
|
+
def _save_serialized_object(self, layer, config_file):
|
760
|
+
config_path = os.path.join(self.preset_dir, config_file)
|
761
|
+
config = keras.saving.serialize_keras_object(layer)
|
762
|
+
config_to_skip = ["compile_config", "build_config"]
|
763
|
+
for key in config_to_skip:
|
764
|
+
self._recursive_pop(config, key)
|
765
|
+
with open(config_path, "w") as config_file:
|
766
|
+
config_file.write(json.dumps(config, indent=4))
|
767
|
+
|
768
|
+
def _save_metadata(self, layer):
|
769
|
+
from keras_hub.src.models.task import Task
|
770
|
+
from keras_hub.src.version_utils import __version__ as keras_hub_version
|
771
|
+
|
772
|
+
# Find all tasks that are compatible with the backbone.
|
773
|
+
# E.g. for `BertBackbone` we would have `TextClassifier` and `MaskedLM`.
|
774
|
+
# For `ResNetBackbone` we would have `ImageClassifier`.
|
775
|
+
tasks = list_subclasses(Task)
|
776
|
+
tasks = filter(lambda x: x.backbone_cls is type(layer), tasks)
|
777
|
+
tasks = [task.__base__.__name__ for task in tasks]
|
778
|
+
# Keep task list alphabetical.
|
779
|
+
tasks = sorted(tasks)
|
780
|
+
|
781
|
+
keras_version = keras.version() if hasattr(keras, "version") else None
|
782
|
+
metadata = {
|
783
|
+
"keras_version": keras_version,
|
784
|
+
"keras_hub_version": keras_hub_version,
|
785
|
+
"parameter_count": layer.count_params(),
|
786
|
+
"date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
|
787
|
+
"tasks": tasks,
|
788
|
+
}
|
789
|
+
metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
|
790
|
+
with open(metadata_path, "w") as metadata_file:
|
791
|
+
metadata_file.write(json.dumps(metadata, indent=4))
|
@@ -293,10 +293,10 @@ def any_equal(inputs, values, padding_mask):
|
|
293
293
|
|
294
294
|
Args:
|
295
295
|
inputs: Input tensor.
|
296
|
-
values: List or iterable of tensors shaped like `inputs` or
|
297
|
-
by bit operators.
|
298
|
-
padding_mask: Tensor with shape compatible with inputs that will
|
299
|
-
output.
|
296
|
+
values: List or iterable of tensors shaped like `inputs` or
|
297
|
+
broadcastable by bit operators.
|
298
|
+
padding_mask: Tensor with shape compatible with inputs that will
|
299
|
+
condition output.
|
300
300
|
|
301
301
|
Returns:
|
302
302
|
A tensor with `inputs` shape where each position is True if it contains
|
@@ -59,9 +59,11 @@ def convert_weights(backbone, loader, timm_config):
|
|
59
59
|
num_stacks = len(backbone.stackwise_num_repeats)
|
60
60
|
for stack_index in range(num_stacks):
|
61
61
|
for block_idx in range(backbone.stackwise_num_repeats[stack_index]):
|
62
|
-
keras_name = f"stack{stack_index+1}_block{block_idx+1}"
|
62
|
+
keras_name = f"stack{stack_index + 1}_block{block_idx + 1}"
|
63
63
|
hf_name = (
|
64
|
-
|
64
|
+
"features."
|
65
|
+
f"denseblock{stack_index + 1}"
|
66
|
+
f".denselayer{block_idx + 1}"
|
65
67
|
)
|
66
68
|
port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.norm1")
|
67
69
|
port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1")
|
@@ -69,8 +71,8 @@ def convert_weights(backbone, loader, timm_config):
|
|
69
71
|
port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2")
|
70
72
|
|
71
73
|
for stack_index in range(num_stacks - 1):
|
72
|
-
keras_transition_name = f"transition{stack_index+1}"
|
73
|
-
hf_transition_name = f"features.transition{stack_index+1}"
|
74
|
+
keras_transition_name = f"transition{stack_index + 1}"
|
75
|
+
hf_transition_name = f"features.transition{stack_index + 1}"
|
74
76
|
port_batch_normalization(
|
75
77
|
f"{keras_transition_name}_bn", f"{hf_transition_name}.norm"
|
76
78
|
)
|