keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev20240915160609__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/__init__.py +1 -0
- keras_hub/api/bounding_box/__init__.py +36 -0
- keras_hub/api/layers/__init__.py +14 -0
- keras_hub/api/models/__init__.py +97 -48
- keras_hub/api/tokenizers/__init__.py +30 -0
- keras_hub/src/bounding_box/__init__.py +13 -0
- keras_hub/src/bounding_box/converters.py +529 -0
- keras_hub/src/bounding_box/formats.py +162 -0
- keras_hub/src/bounding_box/iou.py +263 -0
- keras_hub/src/bounding_box/to_dense.py +95 -0
- keras_hub/src/bounding_box/to_ragged.py +99 -0
- keras_hub/src/bounding_box/utils.py +194 -0
- keras_hub/src/bounding_box/validate_format.py +99 -0
- keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
- keras_hub/src/layers/preprocessing/image_converter.py +130 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
- keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
- keras_hub/src/layers/preprocessing/random_swap.py +33 -31
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
- keras_hub/src/models/albert/__init__.py +1 -2
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
- keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
- keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/albert/albert_tokenizer.py +17 -36
- keras_hub/src/models/backbone.py +12 -34
- keras_hub/src/models/bart/__init__.py +1 -2
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
- keras_hub/src/models/bart/bart_tokenizer.py +12 -39
- keras_hub/src/models/bert/__init__.py +1 -5
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
- keras_hub/src/models/bert/bert_presets.py +1 -4
- keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
- keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/bert/bert_tokenizer.py +17 -35
- keras_hub/src/models/bloom/__init__.py +1 -2
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
- keras_hub/src/models/causal_lm.py +10 -29
- keras_hub/src/models/causal_lm_preprocessor.py +195 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
- keras_hub/src/models/deberta_v3/__init__.py +1 -4
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
- keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
- keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
- keras_hub/src/models/densenet/densenet_backbone.py +46 -22
- keras_hub/src/models/distil_bert/__init__.py +1 -4
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
- keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
- keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
- keras_hub/src/models/efficientnet/__init__.py +13 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
- keras_hub/src/models/efficientnet/mbconv.py +238 -0
- keras_hub/src/models/electra/__init__.py +1 -2
- keras_hub/src/models/electra/electra_tokenizer.py +17 -32
- keras_hub/src/models/f_net/__init__.py +1 -2
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
- keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
- keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
- keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
- keras_hub/src/models/falcon/__init__.py +1 -2
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
- keras_hub/src/models/gemma/__init__.py +1 -2
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
- keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
- keras_hub/src/models/gpt2/__init__.py +1 -2
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
- keras_hub/src/models/image_classifier.py +0 -5
- keras_hub/src/models/image_classifier_preprocessor.py +83 -0
- keras_hub/src/models/llama/__init__.py +1 -2
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
- keras_hub/src/models/llama/llama_tokenizer.py +12 -25
- keras_hub/src/models/llama3/__init__.py +1 -2
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
- keras_hub/src/models/masked_lm.py +0 -2
- keras_hub/src/models/masked_lm_preprocessor.py +156 -0
- keras_hub/src/models/mistral/__init__.py +1 -2
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
- keras_hub/src/models/mobilenet/__init__.py +13 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
- keras_hub/src/models/opt/__init__.py +1 -2
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
- keras_hub/src/models/opt/opt_tokenizer.py +12 -41
- keras_hub/src/models/pali_gemma/__init__.py +1 -4
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
- keras_hub/src/models/phi3/__init__.py +1 -2
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
- keras_hub/src/models/preprocessor.py +72 -83
- keras_hub/src/models/resnet/__init__.py +6 -0
- keras_hub/src/models/resnet/resnet_backbone.py +390 -42
- keras_hub/src/models/resnet/resnet_image_classifier.py +24 -3
- keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
- keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
- keras_hub/src/models/resnet/resnet_presets.py +95 -0
- keras_hub/src/models/roberta/__init__.py +1 -2
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
- keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
- keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
- keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
- keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
- keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
- keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
- keras_hub/src/models/t5/__init__.py +1 -2
- keras_hub/src/models/t5/t5_tokenizer.py +13 -23
- keras_hub/src/models/task.py +71 -116
- keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
- keras_hub/src/models/text_classifier_preprocessor.py +138 -0
- keras_hub/src/models/whisper/__init__.py +1 -2
- keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
- keras_hub/src/models/whisper/whisper_backbone.py +0 -3
- keras_hub/src/models/whisper/whisper_presets.py +10 -10
- keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
- keras_hub/src/models/xlm_roberta/__init__.py +1 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
- keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
- keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
- keras_hub/src/tests/test_case.py +38 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +29 -17
- keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +19 -7
- keras_hub/src/tokenizers/tokenizer.py +67 -32
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
- keras_hub/src/tokenizers/word_piece_tokenizer.py +33 -47
- keras_hub/src/utils/keras_utils.py +0 -50
- keras_hub/src/utils/preset_utils.py +220 -67
- keras_hub/src/utils/tensor_utils.py +187 -69
- keras_hub/src/utils/timm/convert_resnet.py +19 -16
- keras_hub/src/utils/timm/preset_loader.py +66 -0
- keras_hub/src/utils/transformers/convert_albert.py +193 -0
- keras_hub/src/utils/transformers/convert_bart.py +373 -0
- keras_hub/src/utils/transformers/convert_bert.py +7 -17
- keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
- keras_hub/src/utils/transformers/convert_gemma.py +5 -19
- keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
- keras_hub/src/utils/transformers/convert_llama3.py +7 -18
- keras_hub/src/utils/transformers/convert_mistral.py +129 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
- keras_hub/src/utils/transformers/preset_loader.py +77 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/METADATA +1 -2
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/RECORD +173 -143
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/WHEEL +1 -1
- keras_hub/src/models/bart/bart_preprocessor.py +0 -276
- keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
- keras_hub/src/models/electra/electra_preprocessor.py +0 -154
- keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
- keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
- keras_hub/src/models/llama/llama_preprocessor.py +0 -189
- keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
- keras_hub/src/models/opt/opt_preprocessor.py +0 -188
- keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
- keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
- keras_hub/src/utils/timm/convert.py +0 -37
- keras_hub/src/utils/transformers/convert.py +0 -101
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,129 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import numpy as np
|
15
|
+
|
16
|
+
from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone
|
17
|
+
from keras_hub.src.utils.preset_utils import get_file
|
18
|
+
|
19
|
+
backbone_cls = MistralBackbone
|
20
|
+
|
21
|
+
|
22
|
+
def convert_backbone_config(transformers_config):
|
23
|
+
return {
|
24
|
+
"vocabulary_size": transformers_config["vocab_size"],
|
25
|
+
"num_layers": transformers_config["num_hidden_layers"],
|
26
|
+
"num_query_heads": transformers_config["num_attention_heads"],
|
27
|
+
"hidden_dim": transformers_config["hidden_size"],
|
28
|
+
"intermediate_dim": transformers_config["intermediate_size"],
|
29
|
+
"num_key_value_heads": transformers_config["num_key_value_heads"],
|
30
|
+
"rope_max_wavelength": transformers_config["rope_theta"],
|
31
|
+
"layer_norm_epsilon": transformers_config["rms_norm_eps"],
|
32
|
+
"sliding_window": transformers_config["sliding_window"],
|
33
|
+
}
|
34
|
+
|
35
|
+
|
36
|
+
def convert_weights(backbone, loader, transformers_config):
|
37
|
+
# Embeddings
|
38
|
+
loader.port_weight(
|
39
|
+
keras_variable=backbone.token_embedding.embeddings,
|
40
|
+
hf_weight_key="model.embed_tokens.weight",
|
41
|
+
hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16),
|
42
|
+
)
|
43
|
+
loader.port_weight(
|
44
|
+
keras_variable=backbone.token_embedding.reverse_embeddings,
|
45
|
+
hf_weight_key="lm_head.weight",
|
46
|
+
hook_fn=lambda hf_tensor, _: np.transpose(
|
47
|
+
hf_tensor.astype(np.float16), axes=(1, 0)
|
48
|
+
),
|
49
|
+
)
|
50
|
+
|
51
|
+
# Attention blocks
|
52
|
+
for index in range(backbone.num_layers):
|
53
|
+
decoder_layer = backbone.transformer_layers[index]
|
54
|
+
|
55
|
+
# Norm layers
|
56
|
+
loader.port_weight(
|
57
|
+
keras_variable=decoder_layer._self_attention_layernorm.scale,
|
58
|
+
hf_weight_key=f"model.layers.{index}.input_layernorm.weight",
|
59
|
+
hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16),
|
60
|
+
)
|
61
|
+
loader.port_weight(
|
62
|
+
keras_variable=decoder_layer._feedforward_layernorm.scale,
|
63
|
+
hf_weight_key=f"model.layers.{index}.post_attention_layernorm.weight",
|
64
|
+
hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16),
|
65
|
+
)
|
66
|
+
|
67
|
+
# Attention layers
|
68
|
+
loader.port_weight(
|
69
|
+
keras_variable=decoder_layer._self_attention_layer._query_dense.kernel,
|
70
|
+
hf_weight_key=f"model.layers.{index}.self_attn.q_proj.weight",
|
71
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
72
|
+
np.transpose(hf_tensor.astype(np.float16)), keras_shape
|
73
|
+
),
|
74
|
+
)
|
75
|
+
loader.port_weight(
|
76
|
+
keras_variable=decoder_layer._self_attention_layer._key_dense.kernel,
|
77
|
+
hf_weight_key=f"model.layers.{index}.self_attn.k_proj.weight",
|
78
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
79
|
+
np.transpose(hf_tensor.astype(np.float16)), keras_shape
|
80
|
+
),
|
81
|
+
)
|
82
|
+
loader.port_weight(
|
83
|
+
keras_variable=decoder_layer._self_attention_layer._value_dense.kernel,
|
84
|
+
hf_weight_key=f"model.layers.{index}.self_attn.v_proj.weight",
|
85
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
86
|
+
np.transpose(hf_tensor.astype(np.float16)), keras_shape
|
87
|
+
),
|
88
|
+
)
|
89
|
+
loader.port_weight(
|
90
|
+
keras_variable=decoder_layer._self_attention_layer._output_dense.kernel,
|
91
|
+
hf_weight_key=f"model.layers.{index}.self_attn.o_proj.weight",
|
92
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
93
|
+
np.transpose(hf_tensor.astype(np.float16)), keras_shape
|
94
|
+
),
|
95
|
+
)
|
96
|
+
|
97
|
+
# MLP layers
|
98
|
+
loader.port_weight(
|
99
|
+
keras_variable=decoder_layer._feedforward_gate_dense.kernel,
|
100
|
+
hf_weight_key=f"model.layers.{index}.mlp.gate_proj.weight",
|
101
|
+
hook_fn=lambda hf_tensor, _: np.transpose(
|
102
|
+
hf_tensor.astype(np.float16), axes=(1, 0)
|
103
|
+
),
|
104
|
+
)
|
105
|
+
loader.port_weight(
|
106
|
+
keras_variable=decoder_layer._feedforward_intermediate_dense.kernel,
|
107
|
+
hf_weight_key=f"model.layers.{index}.mlp.up_proj.weight",
|
108
|
+
hook_fn=lambda hf_tensor, _: np.transpose(
|
109
|
+
hf_tensor.astype(np.float16), axes=(1, 0)
|
110
|
+
),
|
111
|
+
)
|
112
|
+
loader.port_weight(
|
113
|
+
keras_variable=decoder_layer._feedforward_output_dense.kernel,
|
114
|
+
hf_weight_key=f"model.layers.{index}.mlp.down_proj.weight",
|
115
|
+
hook_fn=lambda hf_tensor, _: np.transpose(
|
116
|
+
hf_tensor.astype(np.float16), axes=(1, 0)
|
117
|
+
),
|
118
|
+
)
|
119
|
+
|
120
|
+
# Normalization
|
121
|
+
loader.port_weight(
|
122
|
+
keras_variable=backbone.layer_norm.scale,
|
123
|
+
hf_weight_key="model.norm.weight",
|
124
|
+
hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16),
|
125
|
+
)
|
126
|
+
|
127
|
+
|
128
|
+
def convert_tokenizer(cls, preset, **kwargs):
|
129
|
+
return cls(get_file(preset, "tokenizer.model"), **kwargs)
|
@@ -13,11 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
import numpy as np
|
15
15
|
|
16
|
-
from keras_hub.src.
|
16
|
+
from keras_hub.src.models.pali_gemma.pali_gemma_backbone import (
|
17
|
+
PaliGemmaBackbone,
|
18
|
+
)
|
17
19
|
from keras_hub.src.utils.preset_utils import get_file
|
18
|
-
|
19
|
-
|
20
|
-
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
|
20
|
+
|
21
|
+
backbone_cls = PaliGemmaBackbone
|
21
22
|
|
22
23
|
|
23
24
|
def convert_backbone_config(transformers_config):
|
@@ -275,29 +276,6 @@ def convert_weights(backbone, loader, transformers_config):
|
|
275
276
|
hook_fn=lambda hf_tensor, keras_shape: hf_tensor[: keras_shape[0]],
|
276
277
|
)
|
277
278
|
|
278
|
-
return backbone
|
279
|
-
|
280
|
-
|
281
|
-
def load_pali_gemma_backbone(cls, preset, load_weights):
|
282
|
-
transformers_config = load_config(preset, HF_CONFIG_FILE)
|
283
|
-
keras_config = convert_backbone_config(transformers_config)
|
284
|
-
backbone = cls(**keras_config)
|
285
|
-
if load_weights:
|
286
|
-
jax_memory_cleanup(backbone)
|
287
|
-
with SafetensorLoader(preset) as loader:
|
288
|
-
convert_weights(backbone, loader, transformers_config)
|
289
|
-
return backbone
|
290
|
-
|
291
|
-
|
292
|
-
def load_pali_gemma_tokenizer(cls, preset):
|
293
|
-
"""
|
294
|
-
Load the Gemma tokenizer.
|
295
|
-
|
296
|
-
Args:
|
297
|
-
cls (class): Tokenizer class.
|
298
|
-
preset (str): Preset configuration name.
|
299
279
|
|
300
|
-
|
301
|
-
|
302
|
-
"""
|
303
|
-
return cls(get_file(preset, "tokenizer.model"))
|
280
|
+
def convert_tokenizer(cls, preset, **kwargs):
|
281
|
+
return cls(get_file(preset, "tokenizer.model"), **kwargs)
|
@@ -0,0 +1,77 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Convert huggingface models to KerasHub."""
|
15
|
+
|
16
|
+
|
17
|
+
from keras_hub.src.utils.preset_utils import PresetLoader
|
18
|
+
from keras_hub.src.utils.preset_utils import jax_memory_cleanup
|
19
|
+
from keras_hub.src.utils.transformers import convert_albert
|
20
|
+
from keras_hub.src.utils.transformers import convert_bart
|
21
|
+
from keras_hub.src.utils.transformers import convert_bert
|
22
|
+
from keras_hub.src.utils.transformers import convert_distilbert
|
23
|
+
from keras_hub.src.utils.transformers import convert_gemma
|
24
|
+
from keras_hub.src.utils.transformers import convert_gpt2
|
25
|
+
from keras_hub.src.utils.transformers import convert_llama3
|
26
|
+
from keras_hub.src.utils.transformers import convert_mistral
|
27
|
+
from keras_hub.src.utils.transformers import convert_pali_gemma
|
28
|
+
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
|
29
|
+
|
30
|
+
|
31
|
+
class TransformersPresetLoader(PresetLoader):
|
32
|
+
def __init__(self, preset, config):
|
33
|
+
super().__init__(preset, config)
|
34
|
+
model_type = self.config["model_type"]
|
35
|
+
if model_type == "albert":
|
36
|
+
self.converter = convert_albert
|
37
|
+
elif model_type == "bart":
|
38
|
+
self.converter = convert_bart
|
39
|
+
elif model_type == "bert":
|
40
|
+
self.converter = convert_bert
|
41
|
+
elif model_type == "distilbert":
|
42
|
+
self.converter = convert_distilbert
|
43
|
+
elif model_type == "gemma" or model_type == "gemma2":
|
44
|
+
self.converter = convert_gemma
|
45
|
+
elif model_type == "gpt2":
|
46
|
+
self.converter = convert_gpt2
|
47
|
+
elif model_type == "llama":
|
48
|
+
# TODO: handle other llama versions.
|
49
|
+
self.converter = convert_llama3
|
50
|
+
elif model_type == "mistral":
|
51
|
+
self.converter = convert_mistral
|
52
|
+
elif model_type == "paligemma":
|
53
|
+
self.converter = convert_pali_gemma
|
54
|
+
else:
|
55
|
+
raise ValueError(
|
56
|
+
"KerasHub has no converter for huggingface/transformers models "
|
57
|
+
f"with model type `'{model_type}'`."
|
58
|
+
)
|
59
|
+
|
60
|
+
def check_backbone_class(self):
|
61
|
+
return self.converter.backbone_cls
|
62
|
+
|
63
|
+
def load_backbone(self, cls, load_weights, **kwargs):
|
64
|
+
keras_config = self.converter.convert_backbone_config(self.config)
|
65
|
+
backbone = cls(**{**keras_config, **kwargs})
|
66
|
+
if load_weights:
|
67
|
+
jax_memory_cleanup(backbone)
|
68
|
+
with SafetensorLoader(self.preset) as loader:
|
69
|
+
self.converter.convert_weights(backbone, loader, self.config)
|
70
|
+
return backbone
|
71
|
+
|
72
|
+
def load_tokenizer(self, cls, **kwargs):
|
73
|
+
return self.converter.convert_tokenizer(cls, self.preset, **kwargs)
|
74
|
+
|
75
|
+
def load_image_converter(self, cls, **kwargs):
|
76
|
+
# TODO: set image size for pali gemma checkpoints.
|
77
|
+
return None
|
@@ -17,7 +17,7 @@ from keras_hub.src.utils.preset_utils import SAFETENSOR_CONFIG_FILE
|
|
17
17
|
from keras_hub.src.utils.preset_utils import SAFETENSOR_FILE
|
18
18
|
from keras_hub.src.utils.preset_utils import check_file_exists
|
19
19
|
from keras_hub.src.utils.preset_utils import get_file
|
20
|
-
from keras_hub.src.utils.preset_utils import
|
20
|
+
from keras_hub.src.utils.preset_utils import load_json
|
21
21
|
|
22
22
|
try:
|
23
23
|
import safetensors
|
@@ -38,7 +38,7 @@ class SafetensorLoader(contextlib.ExitStack):
|
|
38
38
|
|
39
39
|
self.preset = preset
|
40
40
|
if check_file_exists(preset, SAFETENSOR_CONFIG_FILE):
|
41
|
-
self.safetensor_config =
|
41
|
+
self.safetensor_config = load_json(preset, SAFETENSOR_CONFIG_FILE)
|
42
42
|
else:
|
43
43
|
self.safetensor_config = None
|
44
44
|
self.safetensor_files = {}
|
keras_hub/src/version_utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: keras-hub-nightly
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.16.0.dev20240915160609
|
4
4
|
Summary: 🚧🚧🚧 Work in progress. 🚧🚧🚧 More details soon!
|
5
5
|
Home-page: https://github.com/keras-team/keras-hub
|
6
6
|
Author: Keras team
|
@@ -8,7 +8,6 @@ Author-email: keras-hub@google.com
|
|
8
8
|
License: Apache License 2.0
|
9
9
|
Classifier: Development Status :: 3 - Alpha
|
10
10
|
Classifier: Programming Language :: Python :: 3
|
11
|
-
Classifier: Programming Language :: Python :: 3.8
|
12
11
|
Classifier: Programming Language :: Python :: 3.9
|
13
12
|
Classifier: Programming Language :: Python :: 3.10
|
14
13
|
Classifier: Programming Language :: Python :: 3.11
|