keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,206 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
try:
|
19
|
+
import tensorflow as tf
|
20
|
+
import tensorflow_text as tf_text
|
21
|
+
except ImportError:
|
22
|
+
tf = None
|
23
|
+
tf_text = None
|
24
|
+
|
25
|
+
|
26
|
+
def _decode_strings_to_utf8(inputs):
|
27
|
+
"""Recursively decodes to list of strings with 'utf-8' encoding."""
|
28
|
+
if isinstance(inputs, bytes):
|
29
|
+
# Handles the case when the input is a scalar string.
|
30
|
+
return inputs.decode("utf-8", errors="ignore")
|
31
|
+
else:
|
32
|
+
# Recursively iterate when input is a list.
|
33
|
+
return [_decode_strings_to_utf8(x) for x in inputs]
|
34
|
+
|
35
|
+
|
36
|
+
def tensor_to_list(inputs):
|
37
|
+
"""Converts a tensor to nested lists.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
inputs: Input tensor, or dict/list/tuple of input tensors.
|
41
|
+
"""
|
42
|
+
if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)):
|
43
|
+
inputs = tf.convert_to_tensor(inputs)
|
44
|
+
if isinstance(inputs, tf.RaggedTensor):
|
45
|
+
list_outputs = inputs.to_list()
|
46
|
+
elif isinstance(inputs, tf.Tensor):
|
47
|
+
list_outputs = inputs.numpy()
|
48
|
+
if inputs.shape.rank != 0:
|
49
|
+
list_outputs = list_outputs.tolist()
|
50
|
+
if inputs.dtype == tf.string:
|
51
|
+
list_outputs = _decode_strings_to_utf8(list_outputs)
|
52
|
+
return list_outputs
|
53
|
+
|
54
|
+
|
55
|
+
def convert_to_backend_tensor_or_python_list(x):
|
56
|
+
"""
|
57
|
+
Convert a tensor to the backend friendly representation of the data.
|
58
|
+
|
59
|
+
This wraps `ops.convert_to_tensor` to account for the fact that torch and
|
60
|
+
jax both lack native types for ragged and string data.
|
61
|
+
|
62
|
+
If we encounter one of these types in torch or jax, we will instead covert
|
63
|
+
the tensor to simple pythonic types (lists of strings).
|
64
|
+
"""
|
65
|
+
if isinstance(x, tf.RaggedTensor) or getattr(x, "dtype", None) == tf.string:
|
66
|
+
return tensor_to_list(x)
|
67
|
+
dtype = getattr(x, "dtype", "float32")
|
68
|
+
dtype = keras.backend.standardize_dtype(dtype)
|
69
|
+
return ops.convert_to_tensor(x, dtype=dtype)
|
70
|
+
|
71
|
+
|
72
|
+
def convert_to_ragged_batch(inputs):
|
73
|
+
"""Convert pythonic or numpy-like input to a 2-D `tf.RaggedTensor`.
|
74
|
+
|
75
|
+
This is useful for text preprocessing layers which deal with already
|
76
|
+
tokenized or split text.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
inputs: A pythonic or numpy-like input to covert. This input should
|
80
|
+
represent a possibly batched list of token sequences.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
An `(inputs, unbatched, rectangular)` tuple, where `inputs` is a
|
84
|
+
2-D `tf.RaggedTensor`, `unbatched` is `True` if the inputs were
|
85
|
+
origianlly rank 1, and `rectangular` is `True` if the inputs rows are
|
86
|
+
all of equal lengths.
|
87
|
+
"""
|
88
|
+
# `tf.keras.layers.Layer` does a weird conversion in __call__, where a list
|
89
|
+
# of lists of ints will become a list of list of scalar tensors. We could
|
90
|
+
# clean this up if we no longer need to care about that case.
|
91
|
+
if isinstance(inputs, (list, tuple)):
|
92
|
+
if isinstance(inputs[0], (list, tuple)):
|
93
|
+
rectangular = len(set([len(row) for row in inputs])) == 1
|
94
|
+
rows = [
|
95
|
+
tf.convert_to_tensor(row, dtype_hint="int32") for row in inputs
|
96
|
+
]
|
97
|
+
inputs = tf.ragged.stack(rows).with_row_splits_dtype("int64")
|
98
|
+
else:
|
99
|
+
inputs = tf.convert_to_tensor(inputs)
|
100
|
+
rectangular = True
|
101
|
+
elif isinstance(inputs, tf.Tensor):
|
102
|
+
rectangular = True
|
103
|
+
elif isinstance(inputs, tf.RaggedTensor):
|
104
|
+
rectangular = False
|
105
|
+
elif hasattr(inputs, "__array__"):
|
106
|
+
inputs = tf.convert_to_tensor(ops.convert_to_numpy(inputs))
|
107
|
+
rectangular = True
|
108
|
+
else:
|
109
|
+
raise ValueError(
|
110
|
+
f"Unknown tensor type. Tensor input can be passed as "
|
111
|
+
"tensors, numpy arrays, or python lists. Received: "
|
112
|
+
f"`type(inputs)={type(inputs)}`"
|
113
|
+
)
|
114
|
+
if inputs.shape.rank < 1 or inputs.shape.rank > 2:
|
115
|
+
raise ValueError(
|
116
|
+
f"Tokenized tensor input should be rank 1 (unbatched) or "
|
117
|
+
f"rank 2 (batched). Received: `inputs.shape={input.shape}`"
|
118
|
+
)
|
119
|
+
unbatched = inputs.shape.rank == 1
|
120
|
+
rectangular = rectangular or unbatched
|
121
|
+
if unbatched:
|
122
|
+
inputs = tf.expand_dims(inputs, 0)
|
123
|
+
if isinstance(inputs, tf.Tensor):
|
124
|
+
inputs = tf.RaggedTensor.from_tensor(inputs)
|
125
|
+
return inputs, unbatched, rectangular
|
126
|
+
|
127
|
+
|
128
|
+
def truncate_at_token(inputs, token, mask):
|
129
|
+
"""Truncate at first instance of `token`, ignoring `mask`."""
|
130
|
+
matches = (inputs == token) & (~mask)
|
131
|
+
end_indices = tf.cast(tf.math.argmax(matches, -1), "int32")
|
132
|
+
end_indices = tf.where(end_indices == 0, tf.shape(inputs)[-1], end_indices)
|
133
|
+
return tf.RaggedTensor.from_tensor(inputs, end_indices)
|
134
|
+
|
135
|
+
|
136
|
+
def strip_to_ragged(token_ids, mask, ids_to_strip):
|
137
|
+
"""Remove masked and special tokens from a sequence before detokenizing."""
|
138
|
+
token_ids = ops.convert_to_numpy(token_ids)
|
139
|
+
token_ids = token_ids.astype("int32")
|
140
|
+
mask = ops.convert_to_numpy(mask)
|
141
|
+
mask = mask.astype("bool")
|
142
|
+
for id in ids_to_strip:
|
143
|
+
mask = mask & (token_ids != id)
|
144
|
+
return tf.ragged.boolean_mask(token_ids, mask)
|
145
|
+
|
146
|
+
|
147
|
+
def assert_tf_libs_installed(symbol_name):
|
148
|
+
if tf_text is None or tf is None:
|
149
|
+
raise ImportError(
|
150
|
+
f"{symbol_name} requires `tensorflow` and `tensorflow-text` for "
|
151
|
+
"text processing. Run `pip install tensorflow-text` to install "
|
152
|
+
"both packages or visit https://www.tensorflow.org/install\n\n"
|
153
|
+
"If `tensorflow-text` is already installed, try importing it "
|
154
|
+
"in a clean python session. Your installation may have errors.\n\n"
|
155
|
+
"KerasHub uses `tf.data` and `tensorflow-text` to preprocess text "
|
156
|
+
"on all Keras backends. If you are running on Jax or Torch, this "
|
157
|
+
"installation does not need GPU support."
|
158
|
+
)
|
159
|
+
|
160
|
+
|
161
|
+
def assert_tf_backend(symbol_name):
|
162
|
+
if keras.config.backend() != "tensorflow":
|
163
|
+
raise RuntimeError(
|
164
|
+
f"{symbol_name} requires the `tensorflow` backend. "
|
165
|
+
"Please set `KERAS_BACKEND=tensorflow` when running your program."
|
166
|
+
)
|
167
|
+
|
168
|
+
|
169
|
+
def is_tensor_type(x):
|
170
|
+
return hasattr(x, "__array__")
|
171
|
+
|
172
|
+
|
173
|
+
def is_float_dtype(dtype):
|
174
|
+
return "float" in keras.backend.standardize_dtype(dtype)
|
175
|
+
|
176
|
+
|
177
|
+
def is_int_dtype(dtype):
|
178
|
+
return "int" in keras.backend.standardize_dtype(dtype)
|
179
|
+
|
180
|
+
|
181
|
+
def is_string_dtype(dtype):
|
182
|
+
return "string" in keras.backend.standardize_dtype(dtype)
|
183
|
+
|
184
|
+
|
185
|
+
def any_equal(inputs, values, padding_mask):
|
186
|
+
"""Return a mask that is True anywhere `inputs` has a value in `values`.
|
187
|
+
|
188
|
+
Final mask has `padding_mask` applied.
|
189
|
+
|
190
|
+
Args:
|
191
|
+
inputs: Input tensor.
|
192
|
+
values: List or iterable of tensors shaped like `inputs` or broadcastable
|
193
|
+
by bit operators.
|
194
|
+
padding_mask: Tensor with shape compatible with inputs that will condition
|
195
|
+
output.
|
196
|
+
|
197
|
+
Returns:
|
198
|
+
A tensor with `inputs` shape where each position is True if it contains
|
199
|
+
a value from any `values`. Padding mask will be applied before
|
200
|
+
returning."""
|
201
|
+
output = ops.equal(inputs, values[0])
|
202
|
+
for value in values[1:]:
|
203
|
+
value_equality = ops.equal(inputs, value)
|
204
|
+
output = ops.logical_or(output, value_equality)
|
205
|
+
|
206
|
+
return ops.logical_and(output, padding_mask)
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Convert timm models to KerasHub."""
|
15
|
+
|
16
|
+
from keras_hub.src.utils.timm.convert_resnet import load_resnet_backbone
|
17
|
+
|
18
|
+
|
19
|
+
def load_timm_backbone(cls, preset, load_weights, **kwargs):
|
20
|
+
"""Load a timm model config and weights as a KerasHub backbone.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
cls (class): Keras model class.
|
24
|
+
preset (str): Preset configuration name.
|
25
|
+
load_weights (bool): Whether to load the weights.
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
backbone: Initialized Keras model backbone.
|
29
|
+
"""
|
30
|
+
if cls is None:
|
31
|
+
raise ValueError("Backbone class is None")
|
32
|
+
if cls.__name__ == "ResNetBackbone":
|
33
|
+
return load_resnet_backbone(cls, preset, load_weights, **kwargs)
|
34
|
+
raise ValueError(
|
35
|
+
f"{cls} has not been ported from the Hugging Face format yet. "
|
36
|
+
"Please check Hugging Face Hub for the Keras model. "
|
37
|
+
)
|
@@ -0,0 +1,171 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import numpy as np
|
15
|
+
|
16
|
+
from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
|
17
|
+
from keras_hub.src.utils.preset_utils import jax_memory_cleanup
|
18
|
+
from keras_hub.src.utils.preset_utils import load_config
|
19
|
+
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
|
20
|
+
|
21
|
+
|
22
|
+
def convert_backbone_config(timm_config):
|
23
|
+
timm_architecture = timm_config["architecture"]
|
24
|
+
|
25
|
+
if "resnetv2_" in timm_architecture:
|
26
|
+
use_pre_activation = True
|
27
|
+
else:
|
28
|
+
use_pre_activation = False
|
29
|
+
|
30
|
+
if timm_architecture == "resnet18":
|
31
|
+
stackwise_num_blocks = [2, 2, 2, 2]
|
32
|
+
block_type = "basic_block"
|
33
|
+
elif timm_architecture == "resnet26":
|
34
|
+
stackwise_num_blocks = [2, 2, 2, 2]
|
35
|
+
block_type = "bottleneck_block"
|
36
|
+
elif timm_architecture == "resnet34":
|
37
|
+
stackwise_num_blocks = [3, 4, 6, 3]
|
38
|
+
block_type = "basic_block"
|
39
|
+
elif timm_architecture in ("resnet50", "resnetv2_50"):
|
40
|
+
stackwise_num_blocks = [3, 4, 6, 3]
|
41
|
+
block_type = "bottleneck_block"
|
42
|
+
elif timm_architecture in ("resnet101", "resnetv2_101"):
|
43
|
+
stackwise_num_blocks = [3, 4, 23, 3]
|
44
|
+
block_type = "bottleneck_block"
|
45
|
+
elif timm_architecture in ("resnet152", "resnetv2_152"):
|
46
|
+
stackwise_num_blocks = [3, 8, 36, 3]
|
47
|
+
block_type = "bottleneck_block"
|
48
|
+
else:
|
49
|
+
raise ValueError(
|
50
|
+
f"Currently, the architecture {timm_architecture} is not supported."
|
51
|
+
)
|
52
|
+
|
53
|
+
return dict(
|
54
|
+
stackwise_num_filters=[64, 128, 256, 512],
|
55
|
+
stackwise_num_blocks=stackwise_num_blocks,
|
56
|
+
stackwise_num_strides=[1, 2, 2, 2],
|
57
|
+
block_type=block_type,
|
58
|
+
use_pre_activation=use_pre_activation,
|
59
|
+
)
|
60
|
+
|
61
|
+
|
62
|
+
def convert_weights(backbone, loader, timm_config):
|
63
|
+
def port_conv2d(keras_layer_name, hf_weight_prefix):
|
64
|
+
loader.port_weight(
|
65
|
+
backbone.get_layer(keras_layer_name).kernel,
|
66
|
+
hf_weight_key=f"{hf_weight_prefix}.weight",
|
67
|
+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
|
68
|
+
)
|
69
|
+
|
70
|
+
def port_batch_normalization(keras_layer_name, hf_weight_prefix):
|
71
|
+
loader.port_weight(
|
72
|
+
backbone.get_layer(keras_layer_name).gamma,
|
73
|
+
hf_weight_key=f"{hf_weight_prefix}.weight",
|
74
|
+
)
|
75
|
+
loader.port_weight(
|
76
|
+
backbone.get_layer(keras_layer_name).beta,
|
77
|
+
hf_weight_key=f"{hf_weight_prefix}.bias",
|
78
|
+
)
|
79
|
+
loader.port_weight(
|
80
|
+
backbone.get_layer(keras_layer_name).moving_mean,
|
81
|
+
hf_weight_key=f"{hf_weight_prefix}.running_mean",
|
82
|
+
)
|
83
|
+
loader.port_weight(
|
84
|
+
backbone.get_layer(keras_layer_name).moving_variance,
|
85
|
+
hf_weight_key=f"{hf_weight_prefix}.running_var",
|
86
|
+
)
|
87
|
+
|
88
|
+
version = "v1" if not backbone.use_pre_activation else "v2"
|
89
|
+
block_type = backbone.block_type
|
90
|
+
|
91
|
+
# Stem
|
92
|
+
if version == "v1":
|
93
|
+
port_conv2d("conv1_conv", "conv1")
|
94
|
+
port_batch_normalization("conv1_bn", "bn1")
|
95
|
+
else:
|
96
|
+
port_conv2d("conv1_conv", "stem.conv")
|
97
|
+
|
98
|
+
# Stages
|
99
|
+
num_stacks = len(backbone.stackwise_num_filters)
|
100
|
+
for stack_index in range(num_stacks):
|
101
|
+
for block_idx in range(backbone.stackwise_num_blocks[stack_index]):
|
102
|
+
if version == "v1":
|
103
|
+
keras_name = f"v1_stack{stack_index}_block{block_idx}"
|
104
|
+
hf_name = f"layer{stack_index+1}.{block_idx}"
|
105
|
+
else:
|
106
|
+
keras_name = f"v2_stack{stack_index}_block{block_idx}"
|
107
|
+
hf_name = f"stages.{stack_index}.blocks.{block_idx}"
|
108
|
+
|
109
|
+
if version == "v1":
|
110
|
+
if block_idx == 0 and (
|
111
|
+
block_type == "bottleneck_block" or stack_index > 0
|
112
|
+
):
|
113
|
+
port_conv2d(
|
114
|
+
f"{keras_name}_0_conv", f"{hf_name}.downsample.0"
|
115
|
+
)
|
116
|
+
port_batch_normalization(
|
117
|
+
f"{keras_name}_0_bn", f"{hf_name}.downsample.1"
|
118
|
+
)
|
119
|
+
port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1")
|
120
|
+
port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.bn1")
|
121
|
+
port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2")
|
122
|
+
port_batch_normalization(f"{keras_name}_2_bn", f"{hf_name}.bn2")
|
123
|
+
if block_type == "bottleneck_block":
|
124
|
+
port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3")
|
125
|
+
port_batch_normalization(
|
126
|
+
f"{keras_name}_3_bn", f"{hf_name}.bn3"
|
127
|
+
)
|
128
|
+
else:
|
129
|
+
if block_idx == 0 and (
|
130
|
+
block_type == "bottleneck_block" or stack_index > 0
|
131
|
+
):
|
132
|
+
port_conv2d(
|
133
|
+
f"{keras_name}_0_conv", f"{hf_name}.downsample.conv"
|
134
|
+
)
|
135
|
+
port_batch_normalization(
|
136
|
+
f"{keras_name}_pre_activation_bn", f"{hf_name}.norm1"
|
137
|
+
)
|
138
|
+
port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1")
|
139
|
+
port_batch_normalization(
|
140
|
+
f"{keras_name}_1_bn", f"{hf_name}.norm2"
|
141
|
+
)
|
142
|
+
port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2")
|
143
|
+
if block_type == "bottleneck_block":
|
144
|
+
port_batch_normalization(
|
145
|
+
f"{keras_name}_2_bn", f"{hf_name}.norm3"
|
146
|
+
)
|
147
|
+
port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3")
|
148
|
+
|
149
|
+
# Post
|
150
|
+
if version == "v2":
|
151
|
+
port_batch_normalization("post_bn", "norm")
|
152
|
+
|
153
|
+
# Rebuild normalization layer with pretrained mean & std
|
154
|
+
mean = timm_config["pretrained_cfg"]["mean"]
|
155
|
+
std = timm_config["pretrained_cfg"]["std"]
|
156
|
+
normalization_layer = backbone.get_layer("normalization")
|
157
|
+
normalization_layer.input_mean = mean
|
158
|
+
normalization_layer.input_variance = [s**2 for s in std]
|
159
|
+
normalization_layer.build(normalization_layer._build_input_shape)
|
160
|
+
|
161
|
+
|
162
|
+
def load_resnet_backbone(cls, preset, load_weights, **kwargs):
|
163
|
+
timm_config = load_config(preset, HF_CONFIG_FILE)
|
164
|
+
keras_config = convert_backbone_config(timm_config)
|
165
|
+
backbone = cls(**keras_config, **kwargs)
|
166
|
+
if load_weights:
|
167
|
+
jax_memory_cleanup(backbone)
|
168
|
+
# Use prefix="" to avoid using `get_prefixed_key`.
|
169
|
+
with SafetensorLoader(preset, prefix="") as loader:
|
170
|
+
convert_weights(backbone, loader, timm_config)
|
171
|
+
return backbone
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,101 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Convert huggingface models to KerasHub."""
|
15
|
+
|
16
|
+
|
17
|
+
from keras_hub.src.utils.transformers.convert_bert import load_bert_backbone
|
18
|
+
from keras_hub.src.utils.transformers.convert_bert import load_bert_tokenizer
|
19
|
+
from keras_hub.src.utils.transformers.convert_distilbert import (
|
20
|
+
load_distilbert_backbone,
|
21
|
+
)
|
22
|
+
from keras_hub.src.utils.transformers.convert_distilbert import (
|
23
|
+
load_distilbert_tokenizer,
|
24
|
+
)
|
25
|
+
from keras_hub.src.utils.transformers.convert_gemma import load_gemma_backbone
|
26
|
+
from keras_hub.src.utils.transformers.convert_gemma import load_gemma_tokenizer
|
27
|
+
from keras_hub.src.utils.transformers.convert_gpt2 import load_gpt2_backbone
|
28
|
+
from keras_hub.src.utils.transformers.convert_gpt2 import load_gpt2_tokenizer
|
29
|
+
from keras_hub.src.utils.transformers.convert_llama3 import load_llama3_backbone
|
30
|
+
from keras_hub.src.utils.transformers.convert_llama3 import (
|
31
|
+
load_llama3_tokenizer,
|
32
|
+
)
|
33
|
+
from keras_hub.src.utils.transformers.convert_pali_gemma import (
|
34
|
+
load_pali_gemma_backbone,
|
35
|
+
)
|
36
|
+
from keras_hub.src.utils.transformers.convert_pali_gemma import (
|
37
|
+
load_pali_gemma_tokenizer,
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
def load_transformers_backbone(cls, preset, load_weights):
|
42
|
+
"""
|
43
|
+
Load a Transformer model config and weights as a KerasHub backbone.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
cls (class): Keras model class.
|
47
|
+
preset (str): Preset configuration name.
|
48
|
+
load_weights (bool): Whether to load the weights.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
backbone: Initialized Keras model backbone.
|
52
|
+
"""
|
53
|
+
if cls is None:
|
54
|
+
raise ValueError("Backbone class is None")
|
55
|
+
if cls.__name__ == "BertBackbone":
|
56
|
+
return load_bert_backbone(cls, preset, load_weights)
|
57
|
+
if cls.__name__ == "GemmaBackbone":
|
58
|
+
return load_gemma_backbone(cls, preset, load_weights)
|
59
|
+
if cls.__name__ == "Llama3Backbone":
|
60
|
+
return load_llama3_backbone(cls, preset, load_weights)
|
61
|
+
if cls.__name__ == "PaliGemmaBackbone":
|
62
|
+
return load_pali_gemma_backbone(cls, preset, load_weights)
|
63
|
+
if cls.__name__ == "GPT2Backbone":
|
64
|
+
return load_gpt2_backbone(cls, preset, load_weights)
|
65
|
+
if cls.__name__ == "DistilBertBackbone":
|
66
|
+
return load_distilbert_backbone(cls, preset, load_weights)
|
67
|
+
raise ValueError(
|
68
|
+
f"{cls} has not been ported from the Hugging Face format yet. "
|
69
|
+
"Please check Hugging Face Hub for the Keras model. "
|
70
|
+
)
|
71
|
+
|
72
|
+
|
73
|
+
def load_transformers_tokenizer(cls, preset):
|
74
|
+
"""
|
75
|
+
Load a Transformer tokenizer assets as a KerasHub tokenizer.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
cls (class): Tokenizer class.
|
79
|
+
preset (str): Preset configuration name.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
tokenizer: Initialized tokenizer.
|
83
|
+
"""
|
84
|
+
if cls is None:
|
85
|
+
raise ValueError("Tokenizer class is None")
|
86
|
+
if cls.__name__ == "BertTokenizer":
|
87
|
+
return load_bert_tokenizer(cls, preset)
|
88
|
+
if cls.__name__ == "GemmaTokenizer":
|
89
|
+
return load_gemma_tokenizer(cls, preset)
|
90
|
+
if cls.__name__ == "Llama3Tokenizer":
|
91
|
+
return load_llama3_tokenizer(cls, preset)
|
92
|
+
if cls.__name__ == "PaliGemmaTokenizer":
|
93
|
+
return load_pali_gemma_tokenizer(cls, preset)
|
94
|
+
if cls.__name__ == "GPT2Tokenizer":
|
95
|
+
return load_gpt2_tokenizer(cls, preset)
|
96
|
+
if cls.__name__ == "DistilBertTokenizer":
|
97
|
+
return load_distilbert_tokenizer(cls, preset)
|
98
|
+
raise ValueError(
|
99
|
+
f"{cls} has not been ported from the Hugging Face format yet. "
|
100
|
+
"Please check Hugging Face Hub for the Keras model. "
|
101
|
+
)
|