keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +0 -6
- keras_hub/api/__init__.py +2 -0
- keras_hub/api/bounding_box/__init__.py +36 -0
- keras_hub/api/layers/__init__.py +14 -0
- keras_hub/api/models/__init__.py +97 -48
- keras_hub/api/tokenizers/__init__.py +30 -0
- keras_hub/api/utils/__init__.py +22 -0
- keras_hub/src/api_export.py +15 -9
- keras_hub/src/bounding_box/__init__.py +13 -0
- keras_hub/src/bounding_box/converters.py +529 -0
- keras_hub/src/bounding_box/formats.py +162 -0
- keras_hub/src/bounding_box/iou.py +263 -0
- keras_hub/src/bounding_box/to_dense.py +95 -0
- keras_hub/src/bounding_box/to_ragged.py +99 -0
- keras_hub/src/bounding_box/utils.py +194 -0
- keras_hub/src/bounding_box/validate_format.py +99 -0
- keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
- keras_hub/src/layers/preprocessing/image_converter.py +130 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
- keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
- keras_hub/src/layers/preprocessing/random_swap.py +33 -31
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
- keras_hub/src/models/albert/__init__.py +1 -2
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
- keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
- keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/albert/albert_tokenizer.py +17 -36
- keras_hub/src/models/backbone.py +12 -34
- keras_hub/src/models/bart/__init__.py +1 -2
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
- keras_hub/src/models/bart/bart_tokenizer.py +12 -39
- keras_hub/src/models/bert/__init__.py +1 -5
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
- keras_hub/src/models/bert/bert_presets.py +1 -4
- keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
- keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/bert/bert_tokenizer.py +17 -35
- keras_hub/src/models/bloom/__init__.py +1 -2
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
- keras_hub/src/models/causal_lm.py +10 -29
- keras_hub/src/models/causal_lm_preprocessor.py +195 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
- keras_hub/src/models/deberta_v3/__init__.py +1 -4
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
- keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
- keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
- keras_hub/src/models/densenet/densenet_backbone.py +46 -22
- keras_hub/src/models/distil_bert/__init__.py +1 -4
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
- keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
- keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
- keras_hub/src/models/efficientnet/__init__.py +13 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
- keras_hub/src/models/efficientnet/mbconv.py +238 -0
- keras_hub/src/models/electra/__init__.py +1 -2
- keras_hub/src/models/electra/electra_tokenizer.py +17 -32
- keras_hub/src/models/f_net/__init__.py +1 -2
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
- keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
- keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
- keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
- keras_hub/src/models/falcon/__init__.py +1 -2
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
- keras_hub/src/models/gemma/__init__.py +1 -2
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
- keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
- keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
- keras_hub/src/models/gpt2/__init__.py +1 -2
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
- keras_hub/src/models/image_classifier.py +0 -5
- keras_hub/src/models/image_classifier_preprocessor.py +83 -0
- keras_hub/src/models/llama/__init__.py +1 -2
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
- keras_hub/src/models/llama/llama_tokenizer.py +12 -25
- keras_hub/src/models/llama3/__init__.py +1 -2
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
- keras_hub/src/models/masked_lm.py +0 -2
- keras_hub/src/models/masked_lm_preprocessor.py +156 -0
- keras_hub/src/models/mistral/__init__.py +1 -2
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
- keras_hub/src/models/mobilenet/__init__.py +13 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
- keras_hub/src/models/opt/__init__.py +1 -2
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
- keras_hub/src/models/opt/opt_tokenizer.py +12 -41
- keras_hub/src/models/pali_gemma/__init__.py +1 -4
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
- keras_hub/src/models/phi3/__init__.py +1 -2
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
- keras_hub/src/models/preprocessor.py +72 -83
- keras_hub/src/models/resnet/__init__.py +6 -0
- keras_hub/src/models/resnet/resnet_backbone.py +390 -42
- keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
- keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
- keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
- keras_hub/src/models/resnet/resnet_presets.py +95 -0
- keras_hub/src/models/retinanet/__init__.py +13 -0
- keras_hub/src/models/retinanet/anchor_generator.py +175 -0
- keras_hub/src/models/retinanet/box_matcher.py +259 -0
- keras_hub/src/models/retinanet/non_max_supression.py +578 -0
- keras_hub/src/models/roberta/__init__.py +1 -2
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
- keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
- keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
- keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
- keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
- keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
- keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
- keras_hub/src/models/t5/__init__.py +1 -2
- keras_hub/src/models/t5/t5_tokenizer.py +13 -23
- keras_hub/src/models/task.py +71 -116
- keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
- keras_hub/src/models/text_classifier_preprocessor.py +138 -0
- keras_hub/src/models/whisper/__init__.py +1 -2
- keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
- keras_hub/src/models/whisper/whisper_backbone.py +0 -3
- keras_hub/src/models/whisper/whisper_presets.py +10 -10
- keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
- keras_hub/src/models/xlm_roberta/__init__.py +1 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
- keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
- keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
- keras_hub/src/tests/test_case.py +46 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
- keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
- keras_hub/src/tokenizers/tokenizer.py +67 -32
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
- keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
- keras_hub/src/utils/imagenet/__init__.py +13 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
- keras_hub/src/utils/keras_utils.py +0 -50
- keras_hub/src/utils/preset_utils.py +230 -68
- keras_hub/src/utils/tensor_utils.py +187 -69
- keras_hub/src/utils/timm/convert_resnet.py +19 -16
- keras_hub/src/utils/timm/preset_loader.py +66 -0
- keras_hub/src/utils/transformers/convert_albert.py +193 -0
- keras_hub/src/utils/transformers/convert_bart.py +373 -0
- keras_hub/src/utils/transformers/convert_bert.py +7 -17
- keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
- keras_hub/src/utils/transformers/convert_gemma.py +5 -19
- keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
- keras_hub/src/utils/transformers/convert_llama3.py +7 -18
- keras_hub/src/utils/transformers/convert_mistral.py +129 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
- keras_hub/src/utils/transformers/preset_loader.py +77 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
- keras_hub/src/version_utils.py +1 -1
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
- keras_hub/src/models/bart/bart_preprocessor.py +0 -276
- keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
- keras_hub/src/models/electra/electra_preprocessor.py +0 -154
- keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
- keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
- keras_hub/src/models/llama/llama_preprocessor.py +0 -189
- keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
- keras_hub/src/models/opt/opt_preprocessor.py +0 -188
- keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
- keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
- keras_hub/src/utils/timm/convert.py +0 -37
- keras_hub/src/utils/transformers/convert.py +0 -101
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,13 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import contextlib
|
16
|
+
import functools
|
17
|
+
import inspect
|
18
|
+
import threading
|
19
|
+
|
15
20
|
import keras
|
21
|
+
import numpy as np
|
16
22
|
from keras import ops
|
17
23
|
|
18
24
|
try:
|
@@ -23,6 +29,181 @@ except ImportError:
|
|
23
29
|
tf_text = None
|
24
30
|
|
25
31
|
|
32
|
+
NO_CONVERT_COUNTER = threading.local()
|
33
|
+
|
34
|
+
|
35
|
+
@contextlib.contextmanager
|
36
|
+
def no_convert_scope():
|
37
|
+
try:
|
38
|
+
NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) + 1
|
39
|
+
yield
|
40
|
+
finally:
|
41
|
+
NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) - 1
|
42
|
+
|
43
|
+
|
44
|
+
def in_no_convert_scope():
|
45
|
+
return getattr(NO_CONVERT_COUNTER, "count", 0) > 0
|
46
|
+
|
47
|
+
|
48
|
+
def preprocessing_function(fn):
|
49
|
+
"""Wraps a preprocessing function to handle tf tensor conversion."""
|
50
|
+
if tf is None:
|
51
|
+
return fn
|
52
|
+
|
53
|
+
params = inspect.signature(fn).parameters
|
54
|
+
accepts_labels = all(k in params for k in ("x", "y", "sample_weight"))
|
55
|
+
if not accepts_labels:
|
56
|
+
|
57
|
+
@functools.wraps(fn)
|
58
|
+
def wrapper(self, x, **kwargs):
|
59
|
+
with tf.device("cpu"):
|
60
|
+
x = convert_preprocessing_inputs(x)
|
61
|
+
with no_convert_scope():
|
62
|
+
x = fn(self, x, **kwargs)
|
63
|
+
return convert_preprocessing_outputs(x)
|
64
|
+
|
65
|
+
else:
|
66
|
+
|
67
|
+
@functools.wraps(fn)
|
68
|
+
def wrapper(self, x, y=None, sample_weight=None, **kwargs):
|
69
|
+
with tf.device("cpu"):
|
70
|
+
x, y, sample_weight = convert_preprocessing_inputs(
|
71
|
+
(x, y, sample_weight)
|
72
|
+
)
|
73
|
+
with no_convert_scope():
|
74
|
+
x = fn(self, x, y=y, sample_weight=sample_weight, **kwargs)
|
75
|
+
return convert_preprocessing_outputs(x)
|
76
|
+
|
77
|
+
return wrapper
|
78
|
+
|
79
|
+
|
80
|
+
def convert_preprocessing_inputs(x):
|
81
|
+
"""Convert raw inputs for preprocessing.
|
82
|
+
|
83
|
+
This function is used to convert raw inputs (strings, lists, `np.ndarray`s,
|
84
|
+
`jax.Array`s, `torch.Tensor`s, etc) to a canonical format for
|
85
|
+
preprocessing layers. All inputs will be converted to backend tensors if
|
86
|
+
possible, except ragged inputs and string inputs which be converted to tf
|
87
|
+
tensors regardless of backend.
|
88
|
+
|
89
|
+
`tuple` and `list` elements are handled differently by this function. A
|
90
|
+
`tuple` is assumed to enumerate separate inputs, and a `list` is assumed to
|
91
|
+
enumerate elements in a single array-like input. This makes it possible to
|
92
|
+
represent ragged and string inputs in a multi-backend format, as shown in
|
93
|
+
the examples below.
|
94
|
+
|
95
|
+
Examples:
|
96
|
+
```python
|
97
|
+
# Two ragged arrays of token ids.
|
98
|
+
x = ([[1, 2, 3], [4, 5]], [[1, 2], [3, 4, 5]])
|
99
|
+
keras_hub.utils.convert_preprocessing_inputs(x)
|
100
|
+
|
101
|
+
# A batch of three samples each with two string segments.
|
102
|
+
x = (["hi", "hello", "hey"], ["bye", "later", "so long"])
|
103
|
+
keras_hub.utils.convert_preprocessing_inputs(x)
|
104
|
+
|
105
|
+
# A batch of features in a dictionary.
|
106
|
+
x = {
|
107
|
+
"text": ["hi", "hello", "hey"],
|
108
|
+
"images": np.ones((3, 64, 64, 3)),
|
109
|
+
"labels": [1, 0, 1],
|
110
|
+
}
|
111
|
+
keras_hub.utils.convert_preprocessing_inputs(x)
|
112
|
+
```
|
113
|
+
"""
|
114
|
+
if not tf.executing_eagerly() or in_no_convert_scope():
|
115
|
+
return x
|
116
|
+
|
117
|
+
if isinstance(x, dict):
|
118
|
+
return {k: convert_preprocessing_inputs(x[k]) for k, v in x.items()}
|
119
|
+
if isinstance(x, tuple):
|
120
|
+
return tuple(convert_preprocessing_inputs(v) for v in x)
|
121
|
+
if isinstance(x, (str, bytes)):
|
122
|
+
return tf.constant(x)
|
123
|
+
if isinstance(x, list):
|
124
|
+
try:
|
125
|
+
numpy_x = np.array(x)
|
126
|
+
except ValueError as e:
|
127
|
+
# If numpy conversion failed, try converting to a ragged array.
|
128
|
+
try:
|
129
|
+
return tf.ragged.constant(x)
|
130
|
+
except ValueError:
|
131
|
+
# If ragged conversion failed return to the numpy error.
|
132
|
+
raise e
|
133
|
+
# If we have a string input, use tf.tensor.
|
134
|
+
if numpy_x.dtype.type is np.str_ or numpy_x.dtype.type is np.bytes_:
|
135
|
+
return tf.convert_to_tensor(x)
|
136
|
+
# Numpy will default to int64, int32 works with more ops.
|
137
|
+
if numpy_x.dtype == np.int64:
|
138
|
+
numpy_x = numpy_x.astype(np.int32)
|
139
|
+
# We have non-ragged, non-string input. Use backbend type.
|
140
|
+
x = ops.convert_to_tensor(numpy_x)
|
141
|
+
# Torch will complain about device placement for GPU tensors.
|
142
|
+
if keras.config.backend() == "torch":
|
143
|
+
x = x.cpu()
|
144
|
+
return x
|
145
|
+
if is_tensor_type(x):
|
146
|
+
# String or ragged types we keep as tf.
|
147
|
+
if isinstance(x, tf.RaggedTensor) or x.dtype == tf.string:
|
148
|
+
return x
|
149
|
+
# If we have a string input, use tf.tensor.
|
150
|
+
if isinstance(x, np.ndarray) and x.dtype.type is np.str_:
|
151
|
+
return tf.convert_to_tensor(x)
|
152
|
+
x = ops.convert_to_tensor(x)
|
153
|
+
# Torch will complain about device placement for GPU tensors.
|
154
|
+
if keras.config.backend() == "torch":
|
155
|
+
x = x.cpu()
|
156
|
+
return x
|
157
|
+
return x
|
158
|
+
|
159
|
+
|
160
|
+
def convert_preprocessing_outputs(x):
|
161
|
+
"""Convert outputs after preprocessing to a backend agnostic format.
|
162
|
+
|
163
|
+
This function is used to convert `tf.Tensor` and `tf.RaggedTensor` output
|
164
|
+
from preprocessing layers to either:
|
165
|
+
|
166
|
+
- The correct tensor type for the Keras backend framework.
|
167
|
+
- Python lists, in the case of ragged and string data.
|
168
|
+
|
169
|
+
This will automatically be called when on the output of preprocessing
|
170
|
+
layers or `keras_hub.models.Task`s with preprocessing included. It could be
|
171
|
+
used directly to convert a `tf.data.Dataset` output to a backend agnostic
|
172
|
+
type.
|
173
|
+
|
174
|
+
Examples:
|
175
|
+
```python
|
176
|
+
# Two ragged arrays of token ids.
|
177
|
+
x = tf.ragged.constant([[1, 2, 3], [4, 5]])
|
178
|
+
keras_hub.utils.convert_preprocessing_outputs(x)
|
179
|
+
|
180
|
+
# A batch of three samples each with two string segments.
|
181
|
+
x = (tf.constant["hi", "yo", "hey"]), tf.constant(["bye", "ciao", ""]))
|
182
|
+
keras_hub.utils.convert_preprocessing_outputs(x)
|
183
|
+
|
184
|
+
# A batch of features in a dictionary.
|
185
|
+
x = {
|
186
|
+
"text": tf.constant(["hi", "hello", "hey"]),
|
187
|
+
"images": tf.ones((3, 64, 64, 3)),
|
188
|
+
"labels": tf.constant([1, 0, 1]),
|
189
|
+
}
|
190
|
+
keras_hub.utils.convert_preprocessing_outputs(x)
|
191
|
+
```
|
192
|
+
"""
|
193
|
+
if not tf.executing_eagerly() or in_no_convert_scope():
|
194
|
+
return x
|
195
|
+
|
196
|
+
def convert(x):
|
197
|
+
if x is None:
|
198
|
+
return x
|
199
|
+
if isinstance(x, tf.RaggedTensor) or x.dtype == tf.string:
|
200
|
+
return tensor_to_list(x)
|
201
|
+
dtype = keras.backend.standardize_dtype(x.dtype)
|
202
|
+
return ops.convert_to_tensor(x, dtype=dtype)
|
203
|
+
|
204
|
+
return keras.tree.map_structure(convert, x)
|
205
|
+
|
206
|
+
|
26
207
|
def _decode_strings_to_utf8(inputs):
|
27
208
|
"""Recursively decodes to list of strings with 'utf-8' encoding."""
|
28
209
|
if isinstance(inputs, bytes):
|
@@ -52,75 +233,15 @@ def tensor_to_list(inputs):
|
|
52
233
|
return list_outputs
|
53
234
|
|
54
235
|
|
55
|
-
def convert_to_backend_tensor_or_python_list(x):
|
56
|
-
"""
|
57
|
-
Convert a tensor to the backend friendly representation of the data.
|
58
|
-
|
59
|
-
This wraps `ops.convert_to_tensor` to account for the fact that torch and
|
60
|
-
jax both lack native types for ragged and string data.
|
61
|
-
|
62
|
-
If we encounter one of these types in torch or jax, we will instead covert
|
63
|
-
the tensor to simple pythonic types (lists of strings).
|
64
|
-
"""
|
65
|
-
if isinstance(x, tf.RaggedTensor) or getattr(x, "dtype", None) == tf.string:
|
66
|
-
return tensor_to_list(x)
|
67
|
-
dtype = getattr(x, "dtype", "float32")
|
68
|
-
dtype = keras.backend.standardize_dtype(dtype)
|
69
|
-
return ops.convert_to_tensor(x, dtype=dtype)
|
70
|
-
|
71
|
-
|
72
236
|
def convert_to_ragged_batch(inputs):
|
73
|
-
"""
|
74
|
-
|
75
|
-
|
76
|
-
tokenized or split text.
|
77
|
-
|
78
|
-
Args:
|
79
|
-
inputs: A pythonic or numpy-like input to covert. This input should
|
80
|
-
represent a possibly batched list of token sequences.
|
81
|
-
|
82
|
-
Returns:
|
83
|
-
An `(inputs, unbatched, rectangular)` tuple, where `inputs` is a
|
84
|
-
2-D `tf.RaggedTensor`, `unbatched` is `True` if the inputs were
|
85
|
-
origianlly rank 1, and `rectangular` is `True` if the inputs rows are
|
86
|
-
all of equal lengths.
|
87
|
-
"""
|
88
|
-
# `tf.keras.layers.Layer` does a weird conversion in __call__, where a list
|
89
|
-
# of lists of ints will become a list of list of scalar tensors. We could
|
90
|
-
# clean this up if we no longer need to care about that case.
|
91
|
-
if isinstance(inputs, (list, tuple)):
|
92
|
-
if isinstance(inputs[0], (list, tuple)):
|
93
|
-
rectangular = len(set([len(row) for row in inputs])) == 1
|
94
|
-
rows = [
|
95
|
-
tf.convert_to_tensor(row, dtype_hint="int32") for row in inputs
|
96
|
-
]
|
97
|
-
inputs = tf.ragged.stack(rows).with_row_splits_dtype("int64")
|
98
|
-
else:
|
99
|
-
inputs = tf.convert_to_tensor(inputs)
|
100
|
-
rectangular = True
|
101
|
-
elif isinstance(inputs, tf.Tensor):
|
102
|
-
rectangular = True
|
103
|
-
elif isinstance(inputs, tf.RaggedTensor):
|
104
|
-
rectangular = False
|
105
|
-
elif hasattr(inputs, "__array__"):
|
106
|
-
inputs = tf.convert_to_tensor(ops.convert_to_numpy(inputs))
|
107
|
-
rectangular = True
|
108
|
-
else:
|
109
|
-
raise ValueError(
|
110
|
-
f"Unknown tensor type. Tensor input can be passed as "
|
111
|
-
"tensors, numpy arrays, or python lists. Received: "
|
112
|
-
f"`type(inputs)={type(inputs)}`"
|
113
|
-
)
|
114
|
-
if inputs.shape.rank < 1 or inputs.shape.rank > 2:
|
115
|
-
raise ValueError(
|
116
|
-
f"Tokenized tensor input should be rank 1 (unbatched) or "
|
117
|
-
f"rank 2 (batched). Received: `inputs.shape={input.shape}`"
|
118
|
-
)
|
237
|
+
"""Ensure a tf.Tensor is a ragged rank 2 tensor."""
|
238
|
+
if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)):
|
239
|
+
inputs = tf.convert_to_tensor(inputs)
|
119
240
|
unbatched = inputs.shape.rank == 1
|
120
|
-
rectangular =
|
241
|
+
rectangular = isinstance(inputs, tf.Tensor)
|
121
242
|
if unbatched:
|
122
243
|
inputs = tf.expand_dims(inputs, 0)
|
123
|
-
if
|
244
|
+
if rectangular:
|
124
245
|
inputs = tf.RaggedTensor.from_tensor(inputs)
|
125
246
|
return inputs, unbatched, rectangular
|
126
247
|
|
@@ -135,10 +256,7 @@ def truncate_at_token(inputs, token, mask):
|
|
135
256
|
|
136
257
|
def strip_to_ragged(token_ids, mask, ids_to_strip):
|
137
258
|
"""Remove masked and special tokens from a sequence before detokenizing."""
|
138
|
-
|
139
|
-
token_ids = token_ids.astype("int32")
|
140
|
-
mask = ops.convert_to_numpy(mask)
|
141
|
-
mask = mask.astype("bool")
|
259
|
+
mask = tf.cast(mask, "bool")
|
142
260
|
for id in ids_to_strip:
|
143
261
|
mask = mask & (token_ids != id)
|
144
262
|
return tf.ragged.boolean_mask(token_ids, mask)
|
@@ -13,10 +13,9 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
import numpy as np
|
15
15
|
|
16
|
-
from keras_hub.src.
|
17
|
-
|
18
|
-
|
19
|
-
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
|
16
|
+
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
|
17
|
+
|
18
|
+
backbone_cls = ResNetBackbone
|
20
19
|
|
21
20
|
|
22
21
|
def convert_backbone_config(timm_config):
|
@@ -56,6 +55,8 @@ def convert_backbone_config(timm_config):
|
|
56
55
|
stackwise_num_strides=[1, 2, 2, 2],
|
57
56
|
block_type=block_type,
|
58
57
|
use_pre_activation=use_pre_activation,
|
58
|
+
input_conv_filters=[64],
|
59
|
+
input_conv_kernel_sizes=[7],
|
59
60
|
)
|
60
61
|
|
61
62
|
|
@@ -100,10 +101,10 @@ def convert_weights(backbone, loader, timm_config):
|
|
100
101
|
for stack_index in range(num_stacks):
|
101
102
|
for block_idx in range(backbone.stackwise_num_blocks[stack_index]):
|
102
103
|
if version == "v1":
|
103
|
-
keras_name = f"
|
104
|
+
keras_name = f"stack{stack_index}_block{block_idx}"
|
104
105
|
hf_name = f"layer{stack_index+1}.{block_idx}"
|
105
106
|
else:
|
106
|
-
keras_name = f"
|
107
|
+
keras_name = f"stack{stack_index}_block{block_idx}"
|
107
108
|
hf_name = f"stages.{stack_index}.blocks.{block_idx}"
|
108
109
|
|
109
110
|
if version == "v1":
|
@@ -159,13 +160,15 @@ def convert_weights(backbone, loader, timm_config):
|
|
159
160
|
normalization_layer.build(normalization_layer._build_input_shape)
|
160
161
|
|
161
162
|
|
162
|
-
def
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
163
|
+
def convert_head(task, loader, timm_config):
|
164
|
+
v2 = "resnetv2_" in timm_config["architecture"]
|
165
|
+
prefix = "head.fc." if v2 else "fc."
|
166
|
+
loader.port_weight(
|
167
|
+
task.output_dense.kernel,
|
168
|
+
hf_weight_key=prefix + "weight",
|
169
|
+
hook_fn=lambda x, _: np.transpose(np.squeeze(x)),
|
170
|
+
)
|
171
|
+
loader.port_weight(
|
172
|
+
task.output_dense.bias,
|
173
|
+
hf_weight_key=prefix + "bias",
|
174
|
+
)
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Convert timm models to KerasHub."""
|
15
|
+
|
16
|
+
from keras_hub.src.models.image_classifier import ImageClassifier
|
17
|
+
from keras_hub.src.utils.preset_utils import PresetLoader
|
18
|
+
from keras_hub.src.utils.preset_utils import jax_memory_cleanup
|
19
|
+
from keras_hub.src.utils.timm import convert_resnet
|
20
|
+
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
|
21
|
+
|
22
|
+
|
23
|
+
class TimmPresetLoader(PresetLoader):
|
24
|
+
def __init__(self, preset, config):
|
25
|
+
super().__init__(preset, config)
|
26
|
+
architecture = self.config["architecture"]
|
27
|
+
if "resnet" in architecture:
|
28
|
+
self.converter = convert_resnet
|
29
|
+
else:
|
30
|
+
raise ValueError(
|
31
|
+
"KerasHub has no converter for timm models "
|
32
|
+
f"with architecture `'{architecture}'`."
|
33
|
+
)
|
34
|
+
|
35
|
+
def check_backbone_class(self):
|
36
|
+
return self.converter.backbone_cls
|
37
|
+
|
38
|
+
def load_backbone(self, cls, load_weights, **kwargs):
|
39
|
+
keras_config = self.converter.convert_backbone_config(self.config)
|
40
|
+
backbone = cls(**{**keras_config, **kwargs})
|
41
|
+
if load_weights:
|
42
|
+
jax_memory_cleanup(backbone)
|
43
|
+
# Use prefix="" to avoid using `get_prefixed_key`.
|
44
|
+
with SafetensorLoader(self.preset, prefix="") as loader:
|
45
|
+
self.converter.convert_weights(backbone, loader, self.config)
|
46
|
+
return backbone
|
47
|
+
|
48
|
+
def load_task(self, cls, load_weights, load_task_weights, **kwargs):
|
49
|
+
if not load_task_weights or not issubclass(cls, ImageClassifier):
|
50
|
+
return super().load_task(
|
51
|
+
cls, load_weights, load_task_weights, **kwargs
|
52
|
+
)
|
53
|
+
# Support loading the classification head for classifier models.
|
54
|
+
kwargs["num_classes"] = self.config["num_classes"]
|
55
|
+
task = super().load_task(cls, load_weights, load_task_weights, **kwargs)
|
56
|
+
if load_task_weights:
|
57
|
+
with SafetensorLoader(self.preset, prefix="") as loader:
|
58
|
+
self.converter.convert_head(task, loader, self.config)
|
59
|
+
return task
|
60
|
+
|
61
|
+
def load_image_converter(self, cls, **kwargs):
|
62
|
+
pretrained_cfg = self.config.get("pretrained_cfg", None)
|
63
|
+
if not pretrained_cfg or "input_size" not in pretrained_cfg:
|
64
|
+
return None
|
65
|
+
input_size = pretrained_cfg["input_size"]
|
66
|
+
return cls(width=input_size[1], height=input_size[2])
|
@@ -0,0 +1,193 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import numpy as np
|
15
|
+
|
16
|
+
from keras_hub.src.models.albert.albert_backbone import AlbertBackbone
|
17
|
+
from keras_hub.src.utils.preset_utils import get_file
|
18
|
+
|
19
|
+
backbone_cls = AlbertBackbone
|
20
|
+
|
21
|
+
|
22
|
+
def convert_backbone_config(transformers_config):
|
23
|
+
return {
|
24
|
+
"vocabulary_size": transformers_config["vocab_size"],
|
25
|
+
"num_layers": transformers_config["num_hidden_layers"],
|
26
|
+
"num_heads": transformers_config["num_attention_heads"],
|
27
|
+
"embedding_dim": transformers_config["embedding_size"],
|
28
|
+
"hidden_dim": transformers_config["hidden_size"],
|
29
|
+
"intermediate_dim": transformers_config["intermediate_size"],
|
30
|
+
"num_groups": transformers_config["num_hidden_groups"],
|
31
|
+
"num_inner_repetitions": transformers_config["inner_group_num"],
|
32
|
+
"dropout": transformers_config["attention_probs_dropout_prob"],
|
33
|
+
"max_sequence_length": transformers_config["max_position_embeddings"],
|
34
|
+
"num_segments": transformers_config["type_vocab_size"],
|
35
|
+
}
|
36
|
+
|
37
|
+
|
38
|
+
def convert_weights(backbone, loader, transformers_config):
|
39
|
+
# Embeddings
|
40
|
+
loader.port_weight(
|
41
|
+
keras_variable=backbone.token_embedding.embeddings,
|
42
|
+
hf_weight_key="albert.embeddings.word_embeddings.weight",
|
43
|
+
)
|
44
|
+
loader.port_weight(
|
45
|
+
keras_variable=backbone.position_embedding.position_embeddings,
|
46
|
+
hf_weight_key="albert.embeddings.position_embeddings.weight",
|
47
|
+
)
|
48
|
+
loader.port_weight(
|
49
|
+
keras_variable=backbone.segment_embedding.embeddings,
|
50
|
+
hf_weight_key="albert.embeddings.token_type_embeddings.weight",
|
51
|
+
)
|
52
|
+
|
53
|
+
# Normalization
|
54
|
+
loader.port_weight(
|
55
|
+
keras_variable=backbone.embeddings_layer_norm.gamma,
|
56
|
+
hf_weight_key="albert.embeddings.LayerNorm.weight",
|
57
|
+
)
|
58
|
+
loader.port_weight(
|
59
|
+
keras_variable=backbone.embeddings_layer_norm.beta,
|
60
|
+
hf_weight_key="albert.embeddings.LayerNorm.bias",
|
61
|
+
)
|
62
|
+
|
63
|
+
# Encoder Embeddings
|
64
|
+
loader.port_weight(
|
65
|
+
keras_variable=backbone.embeddings_projection.kernel,
|
66
|
+
hf_weight_key="albert.encoder.embedding_hidden_mapping_in.weight",
|
67
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
68
|
+
)
|
69
|
+
loader.port_weight(
|
70
|
+
keras_variable=backbone.embeddings_projection.bias,
|
71
|
+
hf_weight_key="albert.encoder.embedding_hidden_mapping_in.bias",
|
72
|
+
)
|
73
|
+
|
74
|
+
# Encoder Group Layers
|
75
|
+
for group_idx in range(backbone.num_groups):
|
76
|
+
for inner_layer_idx in range(backbone.num_inner_repetitions):
|
77
|
+
keras_group = backbone.get_layer(
|
78
|
+
f"group_{group_idx}_inner_layer_{inner_layer_idx}"
|
79
|
+
)
|
80
|
+
hf_group_prefix = (
|
81
|
+
"albert.encoder.albert_layer_groups."
|
82
|
+
f"{group_idx}.albert_layers.{inner_layer_idx}."
|
83
|
+
)
|
84
|
+
|
85
|
+
loader.port_weight(
|
86
|
+
keras_variable=keras_group._self_attention_layer.query_dense.kernel,
|
87
|
+
hf_weight_key=f"{hf_group_prefix}attention.query.weight",
|
88
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
89
|
+
np.transpose(hf_tensor), keras_shape
|
90
|
+
),
|
91
|
+
)
|
92
|
+
loader.port_weight(
|
93
|
+
keras_variable=keras_group._self_attention_layer.query_dense.bias,
|
94
|
+
hf_weight_key=f"{hf_group_prefix}attention.query.bias",
|
95
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
96
|
+
hf_tensor, keras_shape
|
97
|
+
),
|
98
|
+
)
|
99
|
+
loader.port_weight(
|
100
|
+
keras_variable=keras_group._self_attention_layer.key_dense.kernel,
|
101
|
+
hf_weight_key=f"{hf_group_prefix}attention.key.weight",
|
102
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
103
|
+
np.transpose(hf_tensor), keras_shape
|
104
|
+
),
|
105
|
+
)
|
106
|
+
loader.port_weight(
|
107
|
+
keras_variable=keras_group._self_attention_layer.key_dense.bias,
|
108
|
+
hf_weight_key=f"{hf_group_prefix}attention.key.bias",
|
109
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
110
|
+
hf_tensor, keras_shape
|
111
|
+
),
|
112
|
+
)
|
113
|
+
loader.port_weight(
|
114
|
+
keras_variable=keras_group._self_attention_layer.value_dense.kernel,
|
115
|
+
hf_weight_key=f"{hf_group_prefix}attention.value.weight",
|
116
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
117
|
+
np.transpose(hf_tensor), keras_shape
|
118
|
+
),
|
119
|
+
)
|
120
|
+
loader.port_weight(
|
121
|
+
keras_variable=keras_group._self_attention_layer.value_dense.bias,
|
122
|
+
hf_weight_key=f"{hf_group_prefix}attention.value.bias",
|
123
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
124
|
+
hf_tensor, keras_shape
|
125
|
+
),
|
126
|
+
)
|
127
|
+
loader.port_weight(
|
128
|
+
keras_variable=keras_group._self_attention_layer.output_dense.kernel,
|
129
|
+
hf_weight_key=f"{hf_group_prefix}attention.dense.weight",
|
130
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
131
|
+
np.transpose(hf_tensor), keras_shape
|
132
|
+
),
|
133
|
+
)
|
134
|
+
loader.port_weight(
|
135
|
+
keras_variable=keras_group._self_attention_layer.output_dense.bias,
|
136
|
+
hf_weight_key=f"{hf_group_prefix}attention.dense.bias",
|
137
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
138
|
+
hf_tensor, keras_shape
|
139
|
+
),
|
140
|
+
)
|
141
|
+
loader.port_weight(
|
142
|
+
keras_variable=keras_group._self_attention_layer_norm.gamma,
|
143
|
+
hf_weight_key=f"{hf_group_prefix}attention.LayerNorm.weight",
|
144
|
+
)
|
145
|
+
loader.port_weight(
|
146
|
+
keras_variable=keras_group._self_attention_layer_norm.beta,
|
147
|
+
hf_weight_key=f"{hf_group_prefix}attention.LayerNorm.bias",
|
148
|
+
)
|
149
|
+
loader.port_weight(
|
150
|
+
keras_variable=keras_group._feedforward_intermediate_dense.kernel,
|
151
|
+
hf_weight_key=f"{hf_group_prefix}ffn.weight",
|
152
|
+
hook_fn=lambda hf_tensor, _: np.transpose(
|
153
|
+
hf_tensor, axes=(1, 0)
|
154
|
+
),
|
155
|
+
)
|
156
|
+
loader.port_weight(
|
157
|
+
keras_variable=keras_group._feedforward_intermediate_dense.bias,
|
158
|
+
hf_weight_key=f"{hf_group_prefix}ffn.bias",
|
159
|
+
)
|
160
|
+
loader.port_weight(
|
161
|
+
keras_variable=keras_group._feedforward_output_dense.kernel,
|
162
|
+
hf_weight_key=f"{hf_group_prefix}ffn_output.weight",
|
163
|
+
hook_fn=lambda hf_tensor, _: np.transpose(
|
164
|
+
hf_tensor, axes=(1, 0)
|
165
|
+
),
|
166
|
+
)
|
167
|
+
loader.port_weight(
|
168
|
+
keras_variable=keras_group._feedforward_output_dense.bias,
|
169
|
+
hf_weight_key=f"{hf_group_prefix}ffn_output.bias",
|
170
|
+
)
|
171
|
+
loader.port_weight(
|
172
|
+
keras_variable=keras_group._feedforward_layer_norm.gamma,
|
173
|
+
hf_weight_key=f"{hf_group_prefix}full_layer_layer_norm.weight",
|
174
|
+
)
|
175
|
+
loader.port_weight(
|
176
|
+
keras_variable=keras_group._feedforward_layer_norm.beta,
|
177
|
+
hf_weight_key=f"{hf_group_prefix}full_layer_layer_norm.bias",
|
178
|
+
)
|
179
|
+
|
180
|
+
# Pooler
|
181
|
+
loader.port_weight(
|
182
|
+
keras_variable=backbone.pooled_dense.kernel,
|
183
|
+
hf_weight_key="albert.pooler.weight",
|
184
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
185
|
+
)
|
186
|
+
loader.port_weight(
|
187
|
+
keras_variable=backbone.pooled_dense.bias,
|
188
|
+
hf_weight_key="albert.pooler.bias",
|
189
|
+
)
|
190
|
+
|
191
|
+
|
192
|
+
def convert_tokenizer(cls, preset, **kwargs):
|
193
|
+
return cls(get_file(preset, "spiece.model"), **kwargs)
|