keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +0 -6
- keras_hub/api/__init__.py +2 -0
- keras_hub/api/bounding_box/__init__.py +36 -0
- keras_hub/api/layers/__init__.py +14 -0
- keras_hub/api/models/__init__.py +97 -48
- keras_hub/api/tokenizers/__init__.py +30 -0
- keras_hub/api/utils/__init__.py +22 -0
- keras_hub/src/api_export.py +15 -9
- keras_hub/src/bounding_box/__init__.py +13 -0
- keras_hub/src/bounding_box/converters.py +529 -0
- keras_hub/src/bounding_box/formats.py +162 -0
- keras_hub/src/bounding_box/iou.py +263 -0
- keras_hub/src/bounding_box/to_dense.py +95 -0
- keras_hub/src/bounding_box/to_ragged.py +99 -0
- keras_hub/src/bounding_box/utils.py +194 -0
- keras_hub/src/bounding_box/validate_format.py +99 -0
- keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
- keras_hub/src/layers/preprocessing/image_converter.py +130 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
- keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
- keras_hub/src/layers/preprocessing/random_swap.py +33 -31
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
- keras_hub/src/models/albert/__init__.py +1 -2
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
- keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
- keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/albert/albert_tokenizer.py +17 -36
- keras_hub/src/models/backbone.py +12 -34
- keras_hub/src/models/bart/__init__.py +1 -2
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
- keras_hub/src/models/bart/bart_tokenizer.py +12 -39
- keras_hub/src/models/bert/__init__.py +1 -5
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
- keras_hub/src/models/bert/bert_presets.py +1 -4
- keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
- keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/bert/bert_tokenizer.py +17 -35
- keras_hub/src/models/bloom/__init__.py +1 -2
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
- keras_hub/src/models/causal_lm.py +10 -29
- keras_hub/src/models/causal_lm_preprocessor.py +195 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
- keras_hub/src/models/deberta_v3/__init__.py +1 -4
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
- keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
- keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
- keras_hub/src/models/densenet/densenet_backbone.py +46 -22
- keras_hub/src/models/distil_bert/__init__.py +1 -4
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
- keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
- keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
- keras_hub/src/models/efficientnet/__init__.py +13 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
- keras_hub/src/models/efficientnet/mbconv.py +238 -0
- keras_hub/src/models/electra/__init__.py +1 -2
- keras_hub/src/models/electra/electra_tokenizer.py +17 -32
- keras_hub/src/models/f_net/__init__.py +1 -2
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
- keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
- keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
- keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
- keras_hub/src/models/falcon/__init__.py +1 -2
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
- keras_hub/src/models/gemma/__init__.py +1 -2
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
- keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
- keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
- keras_hub/src/models/gpt2/__init__.py +1 -2
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
- keras_hub/src/models/image_classifier.py +0 -5
- keras_hub/src/models/image_classifier_preprocessor.py +83 -0
- keras_hub/src/models/llama/__init__.py +1 -2
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
- keras_hub/src/models/llama/llama_tokenizer.py +12 -25
- keras_hub/src/models/llama3/__init__.py +1 -2
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
- keras_hub/src/models/masked_lm.py +0 -2
- keras_hub/src/models/masked_lm_preprocessor.py +156 -0
- keras_hub/src/models/mistral/__init__.py +1 -2
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
- keras_hub/src/models/mobilenet/__init__.py +13 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
- keras_hub/src/models/opt/__init__.py +1 -2
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
- keras_hub/src/models/opt/opt_tokenizer.py +12 -41
- keras_hub/src/models/pali_gemma/__init__.py +1 -4
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
- keras_hub/src/models/phi3/__init__.py +1 -2
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
- keras_hub/src/models/preprocessor.py +72 -83
- keras_hub/src/models/resnet/__init__.py +6 -0
- keras_hub/src/models/resnet/resnet_backbone.py +390 -42
- keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
- keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
- keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
- keras_hub/src/models/resnet/resnet_presets.py +95 -0
- keras_hub/src/models/retinanet/__init__.py +13 -0
- keras_hub/src/models/retinanet/anchor_generator.py +175 -0
- keras_hub/src/models/retinanet/box_matcher.py +259 -0
- keras_hub/src/models/retinanet/non_max_supression.py +578 -0
- keras_hub/src/models/roberta/__init__.py +1 -2
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
- keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
- keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
- keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
- keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
- keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
- keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
- keras_hub/src/models/t5/__init__.py +1 -2
- keras_hub/src/models/t5/t5_tokenizer.py +13 -23
- keras_hub/src/models/task.py +71 -116
- keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
- keras_hub/src/models/text_classifier_preprocessor.py +138 -0
- keras_hub/src/models/whisper/__init__.py +1 -2
- keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
- keras_hub/src/models/whisper/whisper_backbone.py +0 -3
- keras_hub/src/models/whisper/whisper_presets.py +10 -10
- keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
- keras_hub/src/models/xlm_roberta/__init__.py +1 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
- keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
- keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
- keras_hub/src/tests/test_case.py +46 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
- keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
- keras_hub/src/tokenizers/tokenizer.py +67 -32
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
- keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
- keras_hub/src/utils/imagenet/__init__.py +13 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
- keras_hub/src/utils/keras_utils.py +0 -50
- keras_hub/src/utils/preset_utils.py +230 -68
- keras_hub/src/utils/tensor_utils.py +187 -69
- keras_hub/src/utils/timm/convert_resnet.py +19 -16
- keras_hub/src/utils/timm/preset_loader.py +66 -0
- keras_hub/src/utils/transformers/convert_albert.py +193 -0
- keras_hub/src/utils/transformers/convert_bart.py +373 -0
- keras_hub/src/utils/transformers/convert_bert.py +7 -17
- keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
- keras_hub/src/utils/transformers/convert_gemma.py +5 -19
- keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
- keras_hub/src/utils/transformers/convert_llama3.py +7 -18
- keras_hub/src/utils/transformers/convert_mistral.py +129 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
- keras_hub/src/utils/transformers/preset_loader.py +77 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
- keras_hub/src/version_utils.py +1 -1
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
- keras_hub/src/models/bart/bart_preprocessor.py +0 -276
- keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
- keras_hub/src/models/electra/electra_preprocessor.py +0 -154
- keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
- keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
- keras_hub/src/models/llama/llama_preprocessor.py +0 -189
- keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
- keras_hub/src/models/opt/opt_preprocessor.py +0 -188
- keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
- keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
- keras_hub/src/utils/timm/convert.py +0 -37
- keras_hub/src/utils/transformers/convert.py +0 -101
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
keras_hub/src/models/backbone.py
CHANGED
@@ -20,18 +20,11 @@ from keras_hub.src.api_export import keras_hub_export
|
|
20
20
|
from keras_hub.src.utils.keras_utils import assert_quantization_support
|
21
21
|
from keras_hub.src.utils.preset_utils import CONFIG_FILE
|
22
22
|
from keras_hub.src.utils.preset_utils import MODEL_WEIGHTS_FILE
|
23
|
-
from keras_hub.src.utils.preset_utils import
|
24
|
-
from keras_hub.src.utils.preset_utils import
|
25
|
-
from keras_hub.src.utils.preset_utils import get_file
|
26
|
-
from keras_hub.src.utils.preset_utils import jax_memory_cleanup
|
27
|
-
from keras_hub.src.utils.preset_utils import list_presets
|
28
|
-
from keras_hub.src.utils.preset_utils import list_subclasses
|
29
|
-
from keras_hub.src.utils.preset_utils import load_serialized_object
|
23
|
+
from keras_hub.src.utils.preset_utils import builtin_presets
|
24
|
+
from keras_hub.src.utils.preset_utils import get_preset_loader
|
30
25
|
from keras_hub.src.utils.preset_utils import save_metadata
|
31
26
|
from keras_hub.src.utils.preset_utils import save_serialized_object
|
32
27
|
from keras_hub.src.utils.python_utils import classproperty
|
33
|
-
from keras_hub.src.utils.timm.convert import load_timm_backbone
|
34
|
-
from keras_hub.src.utils.transformers.convert import load_transformers_backbone
|
35
28
|
|
36
29
|
|
37
30
|
@keras_hub_export("keras_hub.models.Backbone")
|
@@ -147,11 +140,8 @@ class Backbone(keras.Model):
|
|
147
140
|
|
148
141
|
@classproperty
|
149
142
|
def presets(cls):
|
150
|
-
"""List built-in presets for a `
|
151
|
-
|
152
|
-
for subclass in list_subclasses(cls):
|
153
|
-
presets.update(subclass.presets)
|
154
|
-
return presets
|
143
|
+
"""List built-in presets for a `Backbone` subclass."""
|
144
|
+
return builtin_presets(cls)
|
155
145
|
|
156
146
|
@classmethod
|
157
147
|
def from_preset(
|
@@ -166,7 +156,7 @@ class Backbone(keras.Model):
|
|
166
156
|
to save and load a pre-trained model. The `preset` can be passed as a
|
167
157
|
one of:
|
168
158
|
|
169
|
-
1. a built
|
159
|
+
1. a built-in preset identifier like `'bert_base_en'`
|
170
160
|
2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'`
|
171
161
|
3. a Hugging Face handle like `'hf://user/bert_base_en'`
|
172
162
|
4. a path to a local preset directory like `'./bert_base_en'`
|
@@ -181,7 +171,7 @@ class Backbone(keras.Model):
|
|
181
171
|
all built-in presets available on the class.
|
182
172
|
|
183
173
|
Args:
|
184
|
-
preset: string. A built
|
174
|
+
preset: string. A built-in preset identifier, a Kaggle Models
|
185
175
|
handle, a Hugging Face handle, or a path to a local directory.
|
186
176
|
load_weights: bool. If `True`, the weights will be loaded into the
|
187
177
|
model architecture. If `False`, the weights will be randomly
|
@@ -201,27 +191,15 @@ class Backbone(keras.Model):
|
|
201
191
|
)
|
202
192
|
```
|
203
193
|
"""
|
204
|
-
|
205
|
-
|
206
|
-
if
|
207
|
-
return load_transformers_backbone(cls, preset, load_weights)
|
208
|
-
elif format == "timm":
|
209
|
-
return load_timm_backbone(cls, preset, load_weights, **kwargs)
|
210
|
-
|
211
|
-
preset_cls = check_config_class(preset)
|
212
|
-
if not issubclass(preset_cls, cls):
|
194
|
+
loader = get_preset_loader(preset)
|
195
|
+
backbone_cls = loader.check_backbone_class()
|
196
|
+
if not issubclass(backbone_cls, cls):
|
213
197
|
raise ValueError(
|
214
|
-
f"
|
198
|
+
f"Saved preset has type `{backbone_cls.__name__}` which is not "
|
215
199
|
f"a subclass of calling class `{cls.__name__}`. Call "
|
216
|
-
f"`from_preset` directly on `{
|
200
|
+
f"`from_preset` directly on `{backbone_cls.__name__}` instead."
|
217
201
|
)
|
218
|
-
|
219
|
-
backbone = load_serialized_object(preset, CONFIG_FILE, **kwargs)
|
220
|
-
if load_weights:
|
221
|
-
jax_memory_cleanup(backbone)
|
222
|
-
backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
|
223
|
-
|
224
|
-
return backbone
|
202
|
+
return loader.load_backbone(backbone_cls, load_weights, **kwargs)
|
225
203
|
|
226
204
|
def save_to_preset(self, preset_dir):
|
227
205
|
"""Save backbone to a preset directory.
|
@@ -14,7 +14,6 @@
|
|
14
14
|
|
15
15
|
from keras_hub.src.models.bart.bart_backbone import BartBackbone
|
16
16
|
from keras_hub.src.models.bart.bart_presets import backbone_presets
|
17
|
-
from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer
|
18
17
|
from keras_hub.src.utils.preset_utils import register_presets
|
19
18
|
|
20
|
-
register_presets(backbone_presets,
|
19
|
+
register_presets(backbone_presets, BartBackbone)
|
@@ -13,24 +13,15 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
|
16
|
-
import keras
|
17
|
-
from absl import logging
|
18
|
-
|
19
16
|
from keras_hub.src.api_export import keras_hub_export
|
20
|
-
from keras_hub.src.
|
21
|
-
from keras_hub.src.
|
22
|
-
|
23
|
-
|
24
|
-
from keras_hub.src.utils.tensor_utils import strip_to_ragged
|
25
|
-
|
26
|
-
try:
|
27
|
-
import tensorflow as tf
|
28
|
-
except ImportError:
|
29
|
-
tf = None
|
17
|
+
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
|
18
|
+
from keras_hub.src.models.bart.bart_backbone import BartBackbone
|
19
|
+
from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer
|
20
|
+
from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
|
30
21
|
|
31
22
|
|
32
23
|
@keras_hub_export("keras_hub.models.BartSeq2SeqLMPreprocessor")
|
33
|
-
class BartSeq2SeqLMPreprocessor(
|
24
|
+
class BartSeq2SeqLMPreprocessor(Seq2SeqLMPreprocessor):
|
34
25
|
"""BART Seq2Seq LM preprocessor.
|
35
26
|
|
36
27
|
This layer is used as preprocessor for seq2seq tasks using the BART model.
|
@@ -125,138 +116,20 @@ class BartSeq2SeqLMPreprocessor(BartPreprocessor):
|
|
125
116
|
```
|
126
117
|
"""
|
127
118
|
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
"These values will be ignored."
|
145
|
-
)
|
146
|
-
|
147
|
-
if encoder_sequence_length is None:
|
148
|
-
encoder_sequence_length = self.encoder_sequence_length
|
149
|
-
decoder_sequence_length = decoder_sequence_length or sequence_length
|
150
|
-
if decoder_sequence_length is None:
|
151
|
-
decoder_sequence_length = self.decoder_sequence_length
|
152
|
-
|
153
|
-
x = super().call(
|
154
|
-
x,
|
155
|
-
encoder_sequence_length=encoder_sequence_length,
|
156
|
-
decoder_sequence_length=decoder_sequence_length + 1,
|
157
|
-
)
|
158
|
-
decoder_token_ids = x.pop("decoder_token_ids")
|
159
|
-
decoder_padding_mask = x.pop("decoder_padding_mask")
|
160
|
-
|
161
|
-
# The last token does not have a next token. Hence, we truncate it.
|
162
|
-
x = {
|
163
|
-
**x,
|
164
|
-
"decoder_token_ids": decoder_token_ids[..., :-1],
|
165
|
-
"decoder_padding_mask": decoder_padding_mask[..., :-1],
|
166
|
-
}
|
167
|
-
# Target `y` will be the decoder input sequence shifted one step to the
|
168
|
-
# left (i.e., the next token).
|
169
|
-
y = decoder_token_ids[..., 1:]
|
170
|
-
sample_weight = decoder_padding_mask[..., 1:]
|
171
|
-
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
172
|
-
|
173
|
-
def generate_preprocess(
|
174
|
-
self,
|
175
|
-
x,
|
176
|
-
*,
|
177
|
-
encoder_sequence_length=None,
|
178
|
-
# `sequence_length` is an alias for `decoder_sequence_length`
|
179
|
-
decoder_sequence_length=None,
|
180
|
-
sequence_length=None,
|
181
|
-
):
|
182
|
-
"""Convert encoder and decoder input strings to integer token inputs for generation.
|
183
|
-
|
184
|
-
Similar to calling the layer for training, this method takes in a dict
|
185
|
-
containing `"encoder_text"` and `"decoder_text"`, with strings or tensor
|
186
|
-
strings for values, tokenizes and packs the input, and computes a
|
187
|
-
padding mask masking all inputs not filled in with a padded value.
|
188
|
-
|
189
|
-
Unlike calling the layer for training, this method does not compute
|
190
|
-
labels and will never append a tokenizer.end_token_id to the end of
|
191
|
-
the decoder sequence (as generation is expected to continue at the end
|
192
|
-
of the inputted decoder prompt).
|
193
|
-
"""
|
194
|
-
if not self.built:
|
195
|
-
self.build(None)
|
196
|
-
|
197
|
-
if isinstance(x, dict):
|
198
|
-
encoder_text = x["encoder_text"]
|
199
|
-
decoder_text = x["decoder_text"]
|
200
|
-
else:
|
201
|
-
encoder_text = x
|
202
|
-
# Initialize empty prompt for the decoder.
|
203
|
-
decoder_text = tf.fill((tf.shape(encoder_text)[0],), "")
|
204
|
-
|
205
|
-
if encoder_sequence_length is None:
|
206
|
-
encoder_sequence_length = self.encoder_sequence_length
|
207
|
-
decoder_sequence_length = decoder_sequence_length or sequence_length
|
208
|
-
if decoder_sequence_length is None:
|
209
|
-
decoder_sequence_length = self.decoder_sequence_length
|
210
|
-
|
211
|
-
# Tokenize and pack the encoder inputs.
|
212
|
-
# TODO: Remove `[0]` once we have shifted to `MultiSegmentPacker`.
|
213
|
-
encoder_text = convert_inputs_to_list_of_tensor_segments(encoder_text)[
|
214
|
-
0
|
215
|
-
]
|
216
|
-
encoder_token_ids = self.tokenizer(encoder_text)
|
217
|
-
encoder_token_ids, encoder_padding_mask = self.encoder_packer(
|
218
|
-
encoder_token_ids,
|
219
|
-
sequence_length=encoder_sequence_length,
|
220
|
-
)
|
221
|
-
|
222
|
-
# Tokenize and pack the decoder inputs.
|
223
|
-
decoder_text = convert_inputs_to_list_of_tensor_segments(decoder_text)[
|
224
|
-
0
|
225
|
-
]
|
226
|
-
decoder_token_ids = self.tokenizer(decoder_text)
|
227
|
-
decoder_token_ids, decoder_padding_mask = self.decoder_packer(
|
228
|
-
decoder_token_ids,
|
229
|
-
sequence_length=decoder_sequence_length,
|
230
|
-
add_end_value=False,
|
231
|
-
)
|
232
|
-
|
233
|
-
return {
|
234
|
-
"encoder_token_ids": encoder_token_ids,
|
235
|
-
"encoder_padding_mask": encoder_padding_mask,
|
236
|
-
"decoder_token_ids": decoder_token_ids,
|
237
|
-
"decoder_padding_mask": decoder_padding_mask,
|
238
|
-
}
|
239
|
-
|
240
|
-
def generate_postprocess(
|
241
|
-
self,
|
242
|
-
x,
|
243
|
-
):
|
244
|
-
"""Convert integer token output to strings for generation.
|
245
|
-
|
246
|
-
This method reverses `generate_preprocess()`, by first removing all
|
247
|
-
padding and start/end tokens, and then converting the integer sequence
|
248
|
-
back to a string.
|
249
|
-
"""
|
250
|
-
if not self.built:
|
251
|
-
self.build(None)
|
252
|
-
|
253
|
-
token_ids, padding_mask = (
|
254
|
-
x["decoder_token_ids"],
|
255
|
-
x["decoder_padding_mask"],
|
256
|
-
)
|
257
|
-
ids_to_strip = (
|
258
|
-
self.tokenizer.start_token_id,
|
259
|
-
self.tokenizer.end_token_id,
|
119
|
+
backbone_cls = BartBackbone
|
120
|
+
tokenizer_cls = BartTokenizer
|
121
|
+
|
122
|
+
def build(self, input_shape):
|
123
|
+
super().build(input_shape)
|
124
|
+
# The decoder is packed a bit differently; the format is as follows:
|
125
|
+
# `[end_token_id, start_token_id, tokens..., end_token_id, padding...]`.
|
126
|
+
self.decoder_packer = StartEndPacker(
|
127
|
+
start_value=[
|
128
|
+
self.tokenizer.end_token_id,
|
129
|
+
self.tokenizer.start_token_id,
|
130
|
+
],
|
131
|
+
end_value=self.tokenizer.end_token_id,
|
132
|
+
pad_value=self.tokenizer.pad_token_id,
|
133
|
+
sequence_length=self.decoder_sequence_length,
|
134
|
+
return_padding_mask=True,
|
260
135
|
)
|
261
|
-
token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
|
262
|
-
return self.tokenizer.detokenize(token_ids)
|
@@ -14,10 +14,16 @@
|
|
14
14
|
|
15
15
|
|
16
16
|
from keras_hub.src.api_export import keras_hub_export
|
17
|
+
from keras_hub.src.models.bart.bart_backbone import BartBackbone
|
17
18
|
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
|
18
19
|
|
19
20
|
|
20
|
-
@keras_hub_export(
|
21
|
+
@keras_hub_export(
|
22
|
+
[
|
23
|
+
"keras_hub.tokenizers.BartTokenizer",
|
24
|
+
"keras_hub.models.BartTokenizer",
|
25
|
+
]
|
26
|
+
)
|
21
27
|
class BartTokenizer(BytePairTokenizer):
|
22
28
|
"""A BART tokenizer using Byte-Pair Encoding subword segmentation.
|
23
29
|
|
@@ -73,52 +79,19 @@ class BartTokenizer(BytePairTokenizer):
|
|
73
79
|
```
|
74
80
|
"""
|
75
81
|
|
82
|
+
backbone_cls = BartBackbone
|
83
|
+
|
76
84
|
def __init__(
|
77
85
|
self,
|
78
86
|
vocabulary=None,
|
79
87
|
merges=None,
|
80
88
|
**kwargs,
|
81
89
|
):
|
82
|
-
self.
|
83
|
-
self.
|
84
|
-
self.
|
85
|
-
|
90
|
+
self._add_special_token("<s>", "start_token")
|
91
|
+
self._add_special_token("</s>", "end_token")
|
92
|
+
self._add_special_token("<pad>", "pad_token")
|
86
93
|
super().__init__(
|
87
94
|
vocabulary=vocabulary,
|
88
95
|
merges=merges,
|
89
|
-
unsplittable_tokens=[
|
90
|
-
self.start_token,
|
91
|
-
self.pad_token,
|
92
|
-
self.end_token,
|
93
|
-
],
|
94
96
|
**kwargs,
|
95
97
|
)
|
96
|
-
|
97
|
-
def set_vocabulary_and_merges(self, vocabulary, merges):
|
98
|
-
super().set_vocabulary_and_merges(vocabulary, merges)
|
99
|
-
|
100
|
-
if vocabulary is not None:
|
101
|
-
# Check for necessary special tokens.
|
102
|
-
for token in [self.start_token, self.pad_token, self.end_token]:
|
103
|
-
if token not in self.vocabulary:
|
104
|
-
raise ValueError(
|
105
|
-
f"Cannot find token `'{token}'` in the provided "
|
106
|
-
f"`vocabulary`. Please provide `'{token}'` in your "
|
107
|
-
"`vocabulary` or use a pretrained `vocabulary` name."
|
108
|
-
)
|
109
|
-
|
110
|
-
self.start_token_id = self.token_to_id(self.start_token)
|
111
|
-
self.pad_token_id = self.token_to_id(self.pad_token)
|
112
|
-
self.end_token_id = self.token_to_id(self.end_token)
|
113
|
-
else:
|
114
|
-
self.start_token_id = None
|
115
|
-
self.pad_token_id = None
|
116
|
-
self.end_token_id = None
|
117
|
-
|
118
|
-
def get_config(self):
|
119
|
-
config = super().get_config()
|
120
|
-
# In the constructor, we pass the list of special tokens to the
|
121
|
-
# `unsplittable_tokens` arg of the superclass' constructor. Hence, we
|
122
|
-
# delete it from the config here.
|
123
|
-
del config["unsplittable_tokens"]
|
124
|
-
return config
|
@@ -13,11 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from keras_hub.src.models.bert.bert_backbone import BertBackbone
|
16
|
-
from keras_hub.src.models.bert.bert_classifier import BertClassifier
|
17
16
|
from keras_hub.src.models.bert.bert_presets import backbone_presets
|
18
|
-
from keras_hub.src.models.bert.bert_presets import classifier_presets
|
19
|
-
from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
|
20
17
|
from keras_hub.src.utils.preset_utils import register_presets
|
21
18
|
|
22
|
-
register_presets(backbone_presets,
|
23
|
-
register_presets(classifier_presets, (BertClassifier, BertTokenizer))
|
19
|
+
register_presets(backbone_presets, BertBackbone)
|
@@ -12,18 +12,14 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import keras
|
16
|
-
from absl import logging
|
17
|
-
|
18
15
|
from keras_hub.src.api_export import keras_hub_export
|
19
|
-
from keras_hub.src.
|
20
|
-
|
21
|
-
|
22
|
-
from keras_hub.src.models.bert.bert_preprocessor import BertPreprocessor
|
16
|
+
from keras_hub.src.models.bert.bert_backbone import BertBackbone
|
17
|
+
from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
|
18
|
+
from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor
|
23
19
|
|
24
20
|
|
25
21
|
@keras_hub_export("keras_hub.models.BertMaskedLMPreprocessor")
|
26
|
-
class BertMaskedLMPreprocessor(
|
22
|
+
class BertMaskedLMPreprocessor(MaskedLMPreprocessor):
|
27
23
|
"""BERT preprocessing for the masked language modeling task.
|
28
24
|
|
29
25
|
This preprocessing layer will prepare inputs for a masked language modeling
|
@@ -117,82 +113,5 @@ class BertMaskedLMPreprocessor(BertPreprocessor):
|
|
117
113
|
```
|
118
114
|
"""
|
119
115
|
|
120
|
-
|
121
|
-
|
122
|
-
tokenizer,
|
123
|
-
sequence_length=512,
|
124
|
-
truncate="round_robin",
|
125
|
-
mask_selection_rate=0.15,
|
126
|
-
mask_selection_length=96,
|
127
|
-
mask_token_rate=0.8,
|
128
|
-
random_token_rate=0.1,
|
129
|
-
**kwargs,
|
130
|
-
):
|
131
|
-
super().__init__(
|
132
|
-
tokenizer,
|
133
|
-
sequence_length=sequence_length,
|
134
|
-
truncate=truncate,
|
135
|
-
**kwargs,
|
136
|
-
)
|
137
|
-
self.mask_selection_rate = mask_selection_rate
|
138
|
-
self.mask_selection_length = mask_selection_length
|
139
|
-
self.mask_token_rate = mask_token_rate
|
140
|
-
self.random_token_rate = random_token_rate
|
141
|
-
self.masker = None
|
142
|
-
|
143
|
-
def build(self, input_shape):
|
144
|
-
super().build(input_shape)
|
145
|
-
# Defer masker creation to `build()` so that we can be sure tokenizer
|
146
|
-
# assets have loaded when restoring a saved model.
|
147
|
-
self.masker = MaskedLMMaskGenerator(
|
148
|
-
mask_selection_rate=self.mask_selection_rate,
|
149
|
-
mask_selection_length=self.mask_selection_length,
|
150
|
-
mask_token_rate=self.mask_token_rate,
|
151
|
-
random_token_rate=self.random_token_rate,
|
152
|
-
vocabulary_size=self.tokenizer.vocabulary_size(),
|
153
|
-
mask_token_id=self.tokenizer.mask_token_id,
|
154
|
-
unselectable_token_ids=[
|
155
|
-
self.tokenizer.cls_token_id,
|
156
|
-
self.tokenizer.sep_token_id,
|
157
|
-
self.tokenizer.pad_token_id,
|
158
|
-
],
|
159
|
-
)
|
160
|
-
|
161
|
-
def call(self, x, y=None, sample_weight=None):
|
162
|
-
if y is not None or sample_weight is not None:
|
163
|
-
logging.warning(
|
164
|
-
f"{self.__class__.__name__} generates `y` and `sample_weight` "
|
165
|
-
"based on your input data, but your data already contains `y` "
|
166
|
-
"or `sample_weight`. Your `y` and `sample_weight` will be "
|
167
|
-
"ignored."
|
168
|
-
)
|
169
|
-
|
170
|
-
x = super().call(x)
|
171
|
-
|
172
|
-
token_ids, padding_mask, segment_ids = (
|
173
|
-
x["token_ids"],
|
174
|
-
x["padding_mask"],
|
175
|
-
x["segment_ids"],
|
176
|
-
)
|
177
|
-
masker_outputs = self.masker(token_ids)
|
178
|
-
x = {
|
179
|
-
"token_ids": masker_outputs["token_ids"],
|
180
|
-
"padding_mask": padding_mask,
|
181
|
-
"segment_ids": segment_ids,
|
182
|
-
"mask_positions": masker_outputs["mask_positions"],
|
183
|
-
}
|
184
|
-
y = masker_outputs["mask_ids"]
|
185
|
-
sample_weight = masker_outputs["mask_weights"]
|
186
|
-
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
187
|
-
|
188
|
-
def get_config(self):
|
189
|
-
config = super().get_config()
|
190
|
-
config.update(
|
191
|
-
{
|
192
|
-
"mask_selection_rate": self.mask_selection_rate,
|
193
|
-
"mask_selection_length": self.mask_selection_length,
|
194
|
-
"mask_token_rate": self.mask_token_rate,
|
195
|
-
"random_token_rate": self.random_token_rate,
|
196
|
-
}
|
197
|
-
)
|
198
|
-
return config
|
116
|
+
backbone_cls = BertBackbone
|
117
|
+
tokenizer_cls = BertTokenizer
|
@@ -129,9 +129,6 @@ backbone_presets = {
|
|
129
129
|
},
|
130
130
|
"kaggle_handle": "kaggle://keras/bert/keras/bert_large_en/2",
|
131
131
|
},
|
132
|
-
}
|
133
|
-
|
134
|
-
classifier_presets = {
|
135
132
|
"bert_tiny_en_uncased_sst2": {
|
136
133
|
"metadata": {
|
137
134
|
"description": (
|
@@ -143,5 +140,5 @@ classifier_presets = {
|
|
143
140
|
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
144
141
|
},
|
145
142
|
"kaggle_handle": "kaggle://keras/bert/keras/bert_tiny_en_uncased_sst2/4",
|
146
|
-
}
|
143
|
+
},
|
147
144
|
}
|
@@ -17,12 +17,19 @@ import keras
|
|
17
17
|
from keras_hub.src.api_export import keras_hub_export
|
18
18
|
from keras_hub.src.models.bert.bert_backbone import BertBackbone
|
19
19
|
from keras_hub.src.models.bert.bert_backbone import bert_kernel_initializer
|
20
|
-
from keras_hub.src.models.bert.
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
20
|
+
from keras_hub.src.models.bert.bert_text_classifier_preprocessor import (
|
21
|
+
BertTextClassifierPreprocessor,
|
22
|
+
)
|
23
|
+
from keras_hub.src.models.text_classifier import TextClassifier
|
24
|
+
|
25
|
+
|
26
|
+
@keras_hub_export(
|
27
|
+
[
|
28
|
+
"keras_hub.models.BertTextClassifier",
|
29
|
+
"keras_hub.models.BertClassifier",
|
30
|
+
]
|
31
|
+
)
|
32
|
+
class BertTextClassifier(TextClassifier):
|
26
33
|
"""An end-to-end BERT model for classification tasks.
|
27
34
|
|
28
35
|
This model attaches a classification head to a
|
@@ -41,7 +48,7 @@ class BertClassifier(Classifier):
|
|
41
48
|
Args:
|
42
49
|
backbone: A `keras_hub.models.BertBackbone` instance.
|
43
50
|
num_classes: int. Number of classes to predict.
|
44
|
-
preprocessor: A `keras_hub.models.
|
51
|
+
preprocessor: A `keras_hub.models.BertTextClassifierPreprocessor` or `None`. If
|
45
52
|
`None`, this model will not apply preprocessing, and inputs should
|
46
53
|
be preprocessed before calling the model.
|
47
54
|
activation: Optional `str` or callable. The
|
@@ -59,7 +66,7 @@ class BertClassifier(Classifier):
|
|
59
66
|
labels = [0, 3]
|
60
67
|
|
61
68
|
# Pretrained classifier.
|
62
|
-
classifier = keras_hub.models.
|
69
|
+
classifier = keras_hub.models.BertTextClassifier.from_preset(
|
63
70
|
"bert_base_en_uncased",
|
64
71
|
num_classes=4,
|
65
72
|
)
|
@@ -88,7 +95,7 @@ class BertClassifier(Classifier):
|
|
88
95
|
labels = [0, 3]
|
89
96
|
|
90
97
|
# Pretrained classifier without preprocessing.
|
91
|
-
classifier = keras_hub.models.
|
98
|
+
classifier = keras_hub.models.BertTextClassifier.from_preset(
|
92
99
|
"bert_base_en_uncased",
|
93
100
|
num_classes=4,
|
94
101
|
preprocessor=None,
|
@@ -106,7 +113,7 @@ class BertClassifier(Classifier):
|
|
106
113
|
tokenizer = keras_hub.models.BertTokenizer(
|
107
114
|
vocabulary=vocab,
|
108
115
|
)
|
109
|
-
preprocessor = keras_hub.models.
|
116
|
+
preprocessor = keras_hub.models.BertTextClassifierPreprocessor(
|
110
117
|
tokenizer=tokenizer,
|
111
118
|
sequence_length=128,
|
112
119
|
)
|
@@ -118,7 +125,7 @@ class BertClassifier(Classifier):
|
|
118
125
|
intermediate_dim=512,
|
119
126
|
max_sequence_length=128,
|
120
127
|
)
|
121
|
-
classifier = keras_hub.models.
|
128
|
+
classifier = keras_hub.models.BertTextClassifier(
|
122
129
|
backbone=backbone,
|
123
130
|
preprocessor=preprocessor,
|
124
131
|
num_classes=4,
|
@@ -128,7 +135,7 @@ class BertClassifier(Classifier):
|
|
128
135
|
"""
|
129
136
|
|
130
137
|
backbone_cls = BertBackbone
|
131
|
-
preprocessor_cls =
|
138
|
+
preprocessor_cls = BertTextClassifierPreprocessor
|
132
139
|
|
133
140
|
def __init__(
|
134
141
|
self,
|