keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +0 -6
- keras_hub/api/__init__.py +2 -0
- keras_hub/api/bounding_box/__init__.py +36 -0
- keras_hub/api/layers/__init__.py +14 -0
- keras_hub/api/models/__init__.py +97 -48
- keras_hub/api/tokenizers/__init__.py +30 -0
- keras_hub/api/utils/__init__.py +22 -0
- keras_hub/src/api_export.py +15 -9
- keras_hub/src/bounding_box/__init__.py +13 -0
- keras_hub/src/bounding_box/converters.py +529 -0
- keras_hub/src/bounding_box/formats.py +162 -0
- keras_hub/src/bounding_box/iou.py +263 -0
- keras_hub/src/bounding_box/to_dense.py +95 -0
- keras_hub/src/bounding_box/to_ragged.py +99 -0
- keras_hub/src/bounding_box/utils.py +194 -0
- keras_hub/src/bounding_box/validate_format.py +99 -0
- keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
- keras_hub/src/layers/preprocessing/image_converter.py +130 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
- keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
- keras_hub/src/layers/preprocessing/random_swap.py +33 -31
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
- keras_hub/src/models/albert/__init__.py +1 -2
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
- keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
- keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/albert/albert_tokenizer.py +17 -36
- keras_hub/src/models/backbone.py +12 -34
- keras_hub/src/models/bart/__init__.py +1 -2
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
- keras_hub/src/models/bart/bart_tokenizer.py +12 -39
- keras_hub/src/models/bert/__init__.py +1 -5
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
- keras_hub/src/models/bert/bert_presets.py +1 -4
- keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
- keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/bert/bert_tokenizer.py +17 -35
- keras_hub/src/models/bloom/__init__.py +1 -2
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
- keras_hub/src/models/causal_lm.py +10 -29
- keras_hub/src/models/causal_lm_preprocessor.py +195 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
- keras_hub/src/models/deberta_v3/__init__.py +1 -4
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
- keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
- keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
- keras_hub/src/models/densenet/densenet_backbone.py +46 -22
- keras_hub/src/models/distil_bert/__init__.py +1 -4
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
- keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
- keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
- keras_hub/src/models/efficientnet/__init__.py +13 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
- keras_hub/src/models/efficientnet/mbconv.py +238 -0
- keras_hub/src/models/electra/__init__.py +1 -2
- keras_hub/src/models/electra/electra_tokenizer.py +17 -32
- keras_hub/src/models/f_net/__init__.py +1 -2
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
- keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
- keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
- keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
- keras_hub/src/models/falcon/__init__.py +1 -2
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
- keras_hub/src/models/gemma/__init__.py +1 -2
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
- keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
- keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
- keras_hub/src/models/gpt2/__init__.py +1 -2
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
- keras_hub/src/models/image_classifier.py +0 -5
- keras_hub/src/models/image_classifier_preprocessor.py +83 -0
- keras_hub/src/models/llama/__init__.py +1 -2
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
- keras_hub/src/models/llama/llama_tokenizer.py +12 -25
- keras_hub/src/models/llama3/__init__.py +1 -2
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
- keras_hub/src/models/masked_lm.py +0 -2
- keras_hub/src/models/masked_lm_preprocessor.py +156 -0
- keras_hub/src/models/mistral/__init__.py +1 -2
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
- keras_hub/src/models/mobilenet/__init__.py +13 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
- keras_hub/src/models/opt/__init__.py +1 -2
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
- keras_hub/src/models/opt/opt_tokenizer.py +12 -41
- keras_hub/src/models/pali_gemma/__init__.py +1 -4
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
- keras_hub/src/models/phi3/__init__.py +1 -2
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
- keras_hub/src/models/preprocessor.py +72 -83
- keras_hub/src/models/resnet/__init__.py +6 -0
- keras_hub/src/models/resnet/resnet_backbone.py +390 -42
- keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
- keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
- keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
- keras_hub/src/models/resnet/resnet_presets.py +95 -0
- keras_hub/src/models/retinanet/__init__.py +13 -0
- keras_hub/src/models/retinanet/anchor_generator.py +175 -0
- keras_hub/src/models/retinanet/box_matcher.py +259 -0
- keras_hub/src/models/retinanet/non_max_supression.py +578 -0
- keras_hub/src/models/roberta/__init__.py +1 -2
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
- keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
- keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
- keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
- keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
- keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
- keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
- keras_hub/src/models/t5/__init__.py +1 -2
- keras_hub/src/models/t5/t5_tokenizer.py +13 -23
- keras_hub/src/models/task.py +71 -116
- keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
- keras_hub/src/models/text_classifier_preprocessor.py +138 -0
- keras_hub/src/models/whisper/__init__.py +1 -2
- keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
- keras_hub/src/models/whisper/whisper_backbone.py +0 -3
- keras_hub/src/models/whisper/whisper_presets.py +10 -10
- keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
- keras_hub/src/models/xlm_roberta/__init__.py +1 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
- keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
- keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
- keras_hub/src/tests/test_case.py +46 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
- keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
- keras_hub/src/tokenizers/tokenizer.py +67 -32
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
- keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
- keras_hub/src/utils/imagenet/__init__.py +13 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
- keras_hub/src/utils/keras_utils.py +0 -50
- keras_hub/src/utils/preset_utils.py +230 -68
- keras_hub/src/utils/tensor_utils.py +187 -69
- keras_hub/src/utils/timm/convert_resnet.py +19 -16
- keras_hub/src/utils/timm/preset_loader.py +66 -0
- keras_hub/src/utils/transformers/convert_albert.py +193 -0
- keras_hub/src/utils/transformers/convert_bart.py +373 -0
- keras_hub/src/utils/transformers/convert_bert.py +7 -17
- keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
- keras_hub/src/utils/transformers/convert_gemma.py +5 -19
- keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
- keras_hub/src/utils/transformers/convert_llama3.py +7 -18
- keras_hub/src/utils/transformers/convert_mistral.py +129 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
- keras_hub/src/utils/transformers/preset_loader.py +77 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
- keras_hub/src/version_utils.py +1 -1
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
- keras_hub/src/models/bart/bart_preprocessor.py +0 -276
- keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
- keras_hub/src/models/electra/electra_preprocessor.py +0 -154
- keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
- keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
- keras_hub/src/models/llama/llama_preprocessor.py +0 -189
- keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
- keras_hub/src/models/opt/opt_preprocessor.py +0 -188
- keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
- keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
- keras_hub/src/utils/timm/convert.py +0 -37
- keras_hub/src/utils/transformers/convert.py +0 -101
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -18,8 +18,6 @@ import keras
|
|
18
18
|
from absl import logging
|
19
19
|
from packaging.version import parse
|
20
20
|
|
21
|
-
from keras_hub.src.utils.tensor_utils import is_tensor_type
|
22
|
-
|
23
21
|
try:
|
24
22
|
import tensorflow as tf
|
25
23
|
except ImportError:
|
@@ -39,54 +37,6 @@ def clone_initializer(initializer):
|
|
39
37
|
return initializer.__class__.from_config(config)
|
40
38
|
|
41
39
|
|
42
|
-
def convert_inputs_to_list_of_tensor_segments(x):
|
43
|
-
"""Converts user inputs to a list of a tensor segments.
|
44
|
-
|
45
|
-
For models and layers which accept lists of string tensors to pack together,
|
46
|
-
this method converts user inputs to a uniform format in a way that can be
|
47
|
-
considered canonical for the library.
|
48
|
-
|
49
|
-
We handle the following:
|
50
|
-
|
51
|
-
- A single string will be converted to a tensor and wrapped in a list.
|
52
|
-
- A list of strings will be converted to a tensor and wrapped in a list.
|
53
|
-
- A single tensor will be wrapped in a list.
|
54
|
-
- A list of tensors will be passed through unaltered.
|
55
|
-
|
56
|
-
All other inputs will result in an error. This effectively means that users
|
57
|
-
who would like to pack multiple segments together should convert those
|
58
|
-
segments to tensors before calling the layer. This removes any ambiguity
|
59
|
-
in the input for those cases.
|
60
|
-
"""
|
61
|
-
# Check the input type.
|
62
|
-
is_string = isinstance(x, (str, bytes))
|
63
|
-
is_tensor = is_tensor_type(x)
|
64
|
-
is_string_list = (
|
65
|
-
isinstance(x, (list, tuple)) and x and isinstance(x[0], (str, bytes))
|
66
|
-
)
|
67
|
-
is_tensor_list = isinstance(x, (list, tuple)) and x and is_tensor_type(x[0])
|
68
|
-
|
69
|
-
if is_string or is_string_list:
|
70
|
-
# Automatically convert raw strings or string lists to tensors.
|
71
|
-
# Wrap this input as a single (possibly batched) segment.
|
72
|
-
x = [tf.convert_to_tensor(x)]
|
73
|
-
elif is_tensor:
|
74
|
-
# Automatically wrap a single tensor as a single segment.
|
75
|
-
x = [x]
|
76
|
-
elif is_tensor_list:
|
77
|
-
# Pass lists of tensors though unaltered.
|
78
|
-
x = x
|
79
|
-
else:
|
80
|
-
# Error for all other input.
|
81
|
-
raise ValueError(
|
82
|
-
f"Unsupported input for `x`. `x` should be a string, a list of "
|
83
|
-
"strings, or a list of tensors. If passing multiple segments "
|
84
|
-
"which should packed together, please convert your inputs to a "
|
85
|
-
f"list of tensors. Received `x={x}`"
|
86
|
-
)
|
87
|
-
return x
|
88
|
-
|
89
|
-
|
90
40
|
def print_msg(message, line_break=True):
|
91
41
|
"""Print the message to absl logging or stdout."""
|
92
42
|
# Copied from core Keras.
|
@@ -60,6 +60,8 @@ TOKENIZER_ASSET_DIR = "assets/tokenizer"
|
|
60
60
|
# Config file names.
|
61
61
|
CONFIG_FILE = "config.json"
|
62
62
|
TOKENIZER_CONFIG_FILE = "tokenizer.json"
|
63
|
+
AUDIO_CONVERTER_CONFIG_FILE = "audio_converter.json"
|
64
|
+
IMAGE_CONVERTER_CONFIG_FILE = "image_converter.json"
|
63
65
|
TASK_CONFIG_FILE = "task.json"
|
64
66
|
PREPROCESSOR_CONFIG_FILE = "preprocessor.json"
|
65
67
|
METADATA_FILE = "metadata.json"
|
@@ -77,10 +79,10 @@ SAFETENSOR_FILE = "model.safetensors"
|
|
77
79
|
|
78
80
|
# Global state for preset registry.
|
79
81
|
BUILTIN_PRESETS = {}
|
80
|
-
|
82
|
+
BUILTIN_PRESETS_FOR_BACKBONE = collections.defaultdict(dict)
|
81
83
|
|
82
84
|
|
83
|
-
def register_presets(presets,
|
85
|
+
def register_presets(presets, backbone_cls):
|
84
86
|
"""Register built-in presets for a set of classes.
|
85
87
|
|
86
88
|
Note that this is intended only for models and presets shipped in the
|
@@ -88,18 +90,26 @@ def register_presets(presets, classes):
|
|
88
90
|
"""
|
89
91
|
for preset in presets:
|
90
92
|
BUILTIN_PRESETS[preset] = presets[preset]
|
91
|
-
|
92
|
-
BUILTIN_PRESETS_FOR_CLASS[cls][preset] = presets[preset]
|
93
|
+
BUILTIN_PRESETS_FOR_BACKBONE[backbone_cls][preset] = presets[preset]
|
93
94
|
|
94
95
|
|
95
|
-
def
|
96
|
+
def builtin_presets(cls):
|
96
97
|
"""Find all registered built-in presets for a class."""
|
97
|
-
|
98
|
+
presets = {}
|
99
|
+
if cls in BUILTIN_PRESETS_FOR_BACKBONE:
|
100
|
+
presets.update(BUILTIN_PRESETS_FOR_BACKBONE[cls])
|
101
|
+
backbone_cls = getattr(cls, "backbone_cls", None)
|
102
|
+
if backbone_cls:
|
103
|
+
presets.update(builtin_presets(backbone_cls))
|
104
|
+
for subclass in list_subclasses(cls):
|
105
|
+
presets.update(builtin_presets(subclass))
|
106
|
+
return presets
|
98
107
|
|
99
108
|
|
100
109
|
def list_subclasses(cls):
|
101
110
|
"""Find all registered subclasses of a class."""
|
102
|
-
|
111
|
+
# Deduplicate the lists, since we have to register object twice for compat.
|
112
|
+
custom_objects = set(keras.saving.get_custom_objects().values())
|
103
113
|
subclasses = []
|
104
114
|
for x in custom_objects:
|
105
115
|
if inspect.isclass(x) and x != cls and issubclass(x, cls):
|
@@ -107,6 +117,26 @@ def list_subclasses(cls):
|
|
107
117
|
return subclasses
|
108
118
|
|
109
119
|
|
120
|
+
def find_subclass(preset, cls, backbone_cls):
|
121
|
+
"""Find a subclass that is compatible with backbone_cls."""
|
122
|
+
subclasses = list_subclasses(cls)
|
123
|
+
subclasses = filter(lambda x: x.backbone_cls == backbone_cls, subclasses)
|
124
|
+
subclasses = list(subclasses)
|
125
|
+
if not subclasses:
|
126
|
+
raise ValueError(
|
127
|
+
f"Unable to find a subclass of {cls.__name__} that is compatible "
|
128
|
+
f"with {backbone_cls.__name__} found in preset '{preset}'."
|
129
|
+
)
|
130
|
+
# If we find multiple subclasses, try to filter to direct subclasses of
|
131
|
+
# the class we are trying to instantiate.
|
132
|
+
if len(subclasses) > 1:
|
133
|
+
directs = list(filter(lambda x: x in cls.__bases__, subclasses))
|
134
|
+
if len(directs) > 1:
|
135
|
+
subclasses = directs
|
136
|
+
# Return the subclass that was registered first (prefer built-in classes).
|
137
|
+
return subclasses[0]
|
138
|
+
|
139
|
+
|
110
140
|
def get_file(preset, path):
|
111
141
|
"""Download a preset file in necessary and return the local path."""
|
112
142
|
# TODO: Add tests for FileNotFound exceptions.
|
@@ -197,7 +227,7 @@ def get_file(preset, path):
|
|
197
227
|
else:
|
198
228
|
raise ValueError(message)
|
199
229
|
elif os.path.exists(preset):
|
200
|
-
# Assume a local filepath.
|
230
|
+
# Assume a local filepath.pyth
|
201
231
|
local_path = os.path.join(preset, path)
|
202
232
|
if not os.path.exists(local_path):
|
203
233
|
raise FileNotFoundError(
|
@@ -272,6 +302,7 @@ def recursive_pop(config, key):
|
|
272
302
|
recursive_pop(value, key)
|
273
303
|
|
274
304
|
|
305
|
+
# TODO: refactor saving routines into a PresetSaver class?
|
275
306
|
def make_preset_dir(preset):
|
276
307
|
os.makedirs(preset, exist_ok=True)
|
277
308
|
|
@@ -314,19 +345,9 @@ def save_metadata(layer, preset):
|
|
314
345
|
metadata_file.write(json.dumps(metadata, indent=4))
|
315
346
|
|
316
347
|
|
317
|
-
def _validate_tokenizer(preset
|
348
|
+
def _validate_tokenizer(preset):
|
318
349
|
if not check_file_exists(preset, TOKENIZER_CONFIG_FILE):
|
319
|
-
|
320
|
-
logging.warning(
|
321
|
-
f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`."
|
322
|
-
)
|
323
|
-
return
|
324
|
-
else:
|
325
|
-
raise FileNotFoundError(
|
326
|
-
f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`. "
|
327
|
-
"To upload the model without a tokenizer, "
|
328
|
-
"set `allow_incomplete=True`."
|
329
|
-
)
|
350
|
+
return
|
330
351
|
config_path = get_file(preset, TOKENIZER_CONFIG_FILE)
|
331
352
|
try:
|
332
353
|
with open(config_path, encoding="utf-8") as config_file:
|
@@ -377,7 +398,7 @@ def _validate_backbone(preset):
|
|
377
398
|
)
|
378
399
|
|
379
400
|
|
380
|
-
def
|
401
|
+
def to_snake_case(name):
|
381
402
|
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
382
403
|
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
383
404
|
|
@@ -386,7 +407,7 @@ def create_model_card(preset):
|
|
386
407
|
model_card_path = os.path.join(preset, README_FILE)
|
387
408
|
markdown_content = ""
|
388
409
|
|
389
|
-
config =
|
410
|
+
config = load_json(preset, CONFIG_FILE)
|
390
411
|
model_name = (
|
391
412
|
config["class_name"].replace("Backbone", "")
|
392
413
|
if config["class_name"].endswith("Backbone")
|
@@ -395,7 +416,7 @@ def create_model_card(preset):
|
|
395
416
|
|
396
417
|
task_type = None
|
397
418
|
if check_file_exists(preset, TASK_CONFIG_FILE):
|
398
|
-
task_config =
|
419
|
+
task_config = load_json(preset, TASK_CONFIG_FILE)
|
399
420
|
task_type = (
|
400
421
|
task_config["class_name"].replace(model_name, "")
|
401
422
|
if task_config["class_name"].startswith(model_name)
|
@@ -407,12 +428,12 @@ def create_model_card(preset):
|
|
407
428
|
markdown_content += "library_name: keras-hub\n"
|
408
429
|
if task_type == "CausalLM":
|
409
430
|
markdown_content += "pipeline_tag: text-generation\n"
|
410
|
-
elif task_type == "
|
431
|
+
elif task_type == "TextClassifier":
|
411
432
|
markdown_content += "pipeline_tag: text-classification\n"
|
412
433
|
markdown_content += "---\n"
|
413
434
|
|
414
435
|
model_link = (
|
415
|
-
f"https://keras.io/api/keras_hub/models/{
|
436
|
+
f"https://keras.io/api/keras_hub/models/{to_snake_case(model_name)}"
|
416
437
|
)
|
417
438
|
markdown_content += (
|
418
439
|
f"This is a [`{model_name}` model]({model_link}) "
|
@@ -454,7 +475,6 @@ def delete_model_card(preset):
|
|
454
475
|
def upload_preset(
|
455
476
|
uri,
|
456
477
|
preset,
|
457
|
-
allow_incomplete=False,
|
458
478
|
):
|
459
479
|
"""Upload a preset directory to a model hub.
|
460
480
|
|
@@ -466,9 +486,6 @@ def upload_preset(
|
|
466
486
|
`hf://[<HF_USERNAME>/]<MODEL>` will be uploaded to the Hugging
|
467
487
|
Face Hub.
|
468
488
|
preset: The path to the local model preset directory.
|
469
|
-
allow_incomplete: If True, allows the upload of presets without
|
470
|
-
a tokenizer configuration. Otherwise, a tokenizer
|
471
|
-
is required.
|
472
489
|
"""
|
473
490
|
|
474
491
|
# Check if preset directory exists.
|
@@ -476,7 +493,7 @@ def upload_preset(
|
|
476
493
|
raise FileNotFoundError(f"The preset directory {preset} doesn't exist.")
|
477
494
|
|
478
495
|
_validate_backbone(preset)
|
479
|
-
_validate_tokenizer(preset
|
496
|
+
_validate_tokenizer(preset)
|
480
497
|
|
481
498
|
if uri.startswith(KAGGLE_PREFIX):
|
482
499
|
if kagglehub is None:
|
@@ -533,42 +550,14 @@ def upload_preset(
|
|
533
550
|
)
|
534
551
|
|
535
552
|
|
536
|
-
def
|
553
|
+
def load_json(preset, config_file=CONFIG_FILE):
|
537
554
|
config_path = get_file(preset, config_file)
|
538
555
|
with open(config_path, encoding="utf-8") as config_file:
|
539
556
|
config = json.load(config_file)
|
540
557
|
return config
|
541
558
|
|
542
559
|
|
543
|
-
def
|
544
|
-
if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists(
|
545
|
-
preset, SAFETENSOR_CONFIG_FILE
|
546
|
-
):
|
547
|
-
# Determine the format by parsing the config file.
|
548
|
-
config = load_config(preset, HF_CONFIG_FILE)
|
549
|
-
if "hf://timm" in preset or "architecture" in config:
|
550
|
-
return "timm"
|
551
|
-
return "transformers"
|
552
|
-
|
553
|
-
if not check_file_exists(preset, METADATA_FILE):
|
554
|
-
raise FileNotFoundError(
|
555
|
-
f"The preset directory `{preset}` doesn't have a file named `{METADATA_FILE}`, "
|
556
|
-
"or you do not have access to it. This file is required to load a Keras model "
|
557
|
-
"preset. Please verify that the model you are trying to load is a Keras model."
|
558
|
-
)
|
559
|
-
metadata = load_config(preset, METADATA_FILE)
|
560
|
-
if "keras_version" not in metadata:
|
561
|
-
raise ValueError(
|
562
|
-
f"`{METADATA_FILE}` in the preset directory `{preset}` doesn't have `keras_version`. "
|
563
|
-
"Please verify that the model you are trying to load is a Keras model."
|
564
|
-
)
|
565
|
-
return "keras"
|
566
|
-
|
567
|
-
|
568
|
-
def load_serialized_object(preset, config_file=CONFIG_FILE, **kwargs):
|
569
|
-
kwargs = kwargs or {}
|
570
|
-
config = load_config(preset, config_file)
|
571
|
-
|
560
|
+
def load_serialized_object(config, **kwargs):
|
572
561
|
# `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
|
573
562
|
# Ensure that `dtype` is properly configured.
|
574
563
|
dtype = kwargs.pop("dtype", None)
|
@@ -578,15 +567,18 @@ def load_serialized_object(preset, config_file=CONFIG_FILE, **kwargs):
|
|
578
567
|
return keras.saving.deserialize_keras_object(config)
|
579
568
|
|
580
569
|
|
581
|
-
def check_config_class(
|
582
|
-
preset,
|
583
|
-
config_file=CONFIG_FILE,
|
584
|
-
):
|
570
|
+
def check_config_class(config):
|
585
571
|
"""Validate a preset is being loaded on the correct class."""
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
572
|
+
registered_name = config["registered_name"]
|
573
|
+
cls = keras.saving.get_registered_object(registered_name)
|
574
|
+
if cls is None:
|
575
|
+
raise ValueError(
|
576
|
+
f"Attempting to load class {registered_name} with "
|
577
|
+
"`from_preset()`, but there is no class registered with Keras "
|
578
|
+
f"for {registered_name}. Make sure to register any custom "
|
579
|
+
"classes with `register_keras_serializable()`."
|
580
|
+
)
|
581
|
+
return cls
|
590
582
|
|
591
583
|
|
592
584
|
def jax_memory_cleanup(layer):
|
@@ -619,3 +611,173 @@ def set_dtype_in_config(config, dtype=None):
|
|
619
611
|
for k in policy_map_config["policy_map"].keys():
|
620
612
|
policy_map_config["policy_map"][k]["config"]["source_name"] = dtype
|
621
613
|
return config
|
614
|
+
|
615
|
+
|
616
|
+
def get_preset_loader(preset):
|
617
|
+
if not check_file_exists(preset, CONFIG_FILE):
|
618
|
+
raise ValueError(
|
619
|
+
f"Preset {preset} has no {CONFIG_FILE}. Make sure the URI or "
|
620
|
+
"directory you are trying to load is a valid KerasHub preset and "
|
621
|
+
"and that you have permissions to read/download from this location."
|
622
|
+
)
|
623
|
+
# We currently assume all formats we support have a `config.json`, this is
|
624
|
+
# true, for Keras, Transformers, and timm. We infer the on disk format by
|
625
|
+
# inspecting the `config.json` file.
|
626
|
+
config = load_json(preset, CONFIG_FILE)
|
627
|
+
if "registered_name" in config:
|
628
|
+
# If we see registered_name, we assume a serialized Keras object.
|
629
|
+
return KerasPresetLoader(preset, config)
|
630
|
+
elif "model_type" in config:
|
631
|
+
# Avoid circular import.
|
632
|
+
from keras_hub.src.utils.transformers.preset_loader import (
|
633
|
+
TransformersPresetLoader,
|
634
|
+
)
|
635
|
+
|
636
|
+
# If we see model_type, we assume a Transformers style config.
|
637
|
+
return TransformersPresetLoader(preset, config)
|
638
|
+
elif "architecture" in config:
|
639
|
+
# Avoid circular import.
|
640
|
+
from keras_hub.src.utils.timm.preset_loader import TimmPresetLoader
|
641
|
+
|
642
|
+
# If we see "architecture", we assume a timm config. We could make this
|
643
|
+
# more robust later on if we need to.
|
644
|
+
return TimmPresetLoader(preset, config)
|
645
|
+
|
646
|
+
else:
|
647
|
+
contents = json.dumps(config, indent=4)
|
648
|
+
raise ValueError(
|
649
|
+
f"Unrecognized format for {CONFIG_FILE} in {preset}. "
|
650
|
+
"Create a preset with the `save_to_preset` utility on KerasHub "
|
651
|
+
f"models. Contents of {CONFIG_FILE}:\n{contents}"
|
652
|
+
)
|
653
|
+
|
654
|
+
|
655
|
+
class PresetLoader:
|
656
|
+
def __init__(self, preset, config):
|
657
|
+
self.config = config
|
658
|
+
self.preset = preset
|
659
|
+
|
660
|
+
def check_backbone_class(self):
|
661
|
+
"""Infer the backbone architecture."""
|
662
|
+
raise NotImplementedError
|
663
|
+
|
664
|
+
def load_backbone(self, cls, load_weights, **kwargs):
|
665
|
+
"""Load the backbone model from the preset."""
|
666
|
+
raise NotImplementedError
|
667
|
+
|
668
|
+
def load_tokenizer(self, cls, **kwargs):
|
669
|
+
"""Load a tokenizer layer from the preset."""
|
670
|
+
raise NotImplementedError
|
671
|
+
|
672
|
+
def load_audio_converter(self, cls, **kwargs):
|
673
|
+
"""Load an audio converter layer from the preset."""
|
674
|
+
raise NotImplementedError
|
675
|
+
|
676
|
+
def load_image_converter(self, cls, **kwargs):
|
677
|
+
"""Load an image converter layer from the preset."""
|
678
|
+
raise NotImplementedError
|
679
|
+
|
680
|
+
def load_task(self, cls, load_weights, load_task_weights, **kwargs):
|
681
|
+
"""Load a task model from the preset.
|
682
|
+
|
683
|
+
By default, we create a task from a backbone and preprocessor with
|
684
|
+
default arguments. This means
|
685
|
+
"""
|
686
|
+
if "backbone" not in kwargs:
|
687
|
+
backbone_class = cls.backbone_cls
|
688
|
+
# Forward dtype to backbone.
|
689
|
+
backbone_kwargs = {"dtype": kwargs.pop("dtype", None)}
|
690
|
+
kwargs["backbone"] = self.load_backbone(
|
691
|
+
backbone_class, load_weights, **backbone_kwargs
|
692
|
+
)
|
693
|
+
if "preprocessor" not in kwargs and cls.preprocessor_cls:
|
694
|
+
kwargs["preprocessor"] = self.load_preprocessor(
|
695
|
+
cls.preprocessor_cls,
|
696
|
+
)
|
697
|
+
return cls(**kwargs)
|
698
|
+
|
699
|
+
def load_preprocessor(self, cls, **kwargs):
|
700
|
+
"""Load a prepocessor layer from the preset.
|
701
|
+
|
702
|
+
By default, we create a preprocessor from a tokenizer with default
|
703
|
+
arguments. This allow us to support transformers checkpoints by
|
704
|
+
only converting the backbone and tokenizer.
|
705
|
+
"""
|
706
|
+
if "tokenizer" not in kwargs and cls.tokenizer_cls:
|
707
|
+
kwargs["tokenizer"] = self.load_tokenizer(cls.tokenizer_cls)
|
708
|
+
if "audio_converter" not in kwargs and cls.audio_converter_cls:
|
709
|
+
kwargs["audio_converter"] = self.load_audio_converter(
|
710
|
+
cls.audio_converter_cls
|
711
|
+
)
|
712
|
+
if "image_converter" not in kwargs and cls.image_converter_cls:
|
713
|
+
kwargs["image_converter"] = self.load_image_converter(
|
714
|
+
cls.image_converter_cls
|
715
|
+
)
|
716
|
+
return cls(**kwargs)
|
717
|
+
|
718
|
+
|
719
|
+
class KerasPresetLoader(PresetLoader):
|
720
|
+
def check_backbone_class(self):
|
721
|
+
return check_config_class(self.config)
|
722
|
+
|
723
|
+
def load_backbone(self, cls, load_weights, **kwargs):
|
724
|
+
backbone = load_serialized_object(self.config, **kwargs)
|
725
|
+
if load_weights:
|
726
|
+
jax_memory_cleanup(backbone)
|
727
|
+
backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
|
728
|
+
return backbone
|
729
|
+
|
730
|
+
def load_tokenizer(self, cls, **kwargs):
|
731
|
+
tokenizer_config = load_json(self.preset, TOKENIZER_CONFIG_FILE)
|
732
|
+
tokenizer = load_serialized_object(tokenizer_config, **kwargs)
|
733
|
+
tokenizer.load_preset_assets(self.preset)
|
734
|
+
return tokenizer
|
735
|
+
|
736
|
+
def load_audio_converter(self, cls, **kwargs):
|
737
|
+
converter_config = load_json(self.preset, AUDIO_CONVERTER_CONFIG_FILE)
|
738
|
+
return load_serialized_object(converter_config, **kwargs)
|
739
|
+
|
740
|
+
def load_image_converter(self, cls, **kwargs):
|
741
|
+
converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE)
|
742
|
+
return load_serialized_object(converter_config, **kwargs)
|
743
|
+
|
744
|
+
def load_task(self, cls, load_weights, load_task_weights, **kwargs):
|
745
|
+
# If there is no `task.json` or it's for the wrong class delegate to the
|
746
|
+
# super class loader.
|
747
|
+
if not check_file_exists(self.preset, TASK_CONFIG_FILE):
|
748
|
+
return super().load_task(
|
749
|
+
cls, load_weights, load_task_weights, **kwargs
|
750
|
+
)
|
751
|
+
task_config = load_json(self.preset, TASK_CONFIG_FILE)
|
752
|
+
if not issubclass(check_config_class(task_config), cls):
|
753
|
+
return super().load_task(
|
754
|
+
cls, load_weights, load_task_weights, **kwargs
|
755
|
+
)
|
756
|
+
# We found a `task.json` with a complete config for our class.
|
757
|
+
task = load_serialized_object(task_config, **kwargs)
|
758
|
+
if task.preprocessor and task.preprocessor.tokenizer:
|
759
|
+
task.preprocessor.tokenizer.load_preset_assets(self.preset)
|
760
|
+
if load_weights:
|
761
|
+
has_task_weights = check_file_exists(self.preset, TASK_WEIGHTS_FILE)
|
762
|
+
if has_task_weights and load_task_weights:
|
763
|
+
jax_memory_cleanup(task)
|
764
|
+
task_weights = get_file(self.preset, TASK_WEIGHTS_FILE)
|
765
|
+
task.load_task_weights(task_weights)
|
766
|
+
else:
|
767
|
+
jax_memory_cleanup(task.backbone)
|
768
|
+
backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE)
|
769
|
+
task.backbone.load_weights(backbone_weights)
|
770
|
+
return task
|
771
|
+
|
772
|
+
def load_preprocessor(self, cls, **kwargs):
|
773
|
+
# If there is no `preprocessing.json` or it's for the wrong class,
|
774
|
+
# delegate to the super class loader.
|
775
|
+
if not check_file_exists(self.preset, PREPROCESSOR_CONFIG_FILE):
|
776
|
+
return super().load_preprocessor(cls, **kwargs)
|
777
|
+
preprocessor_json = load_json(self.preset, PREPROCESSOR_CONFIG_FILE)
|
778
|
+
if not issubclass(check_config_class(preprocessor_json), cls):
|
779
|
+
return super().load_preprocessor(cls, **kwargs)
|
780
|
+
# We found a `preprocessing.json` with a complete config for our class.
|
781
|
+
preprocessor = load_serialized_object(preprocessor_json, **kwargs)
|
782
|
+
preprocessor.tokenizer.load_preset_assets(self.preset)
|
783
|
+
return preprocessor
|