keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +0 -6
- keras_hub/api/__init__.py +2 -0
- keras_hub/api/bounding_box/__init__.py +36 -0
- keras_hub/api/layers/__init__.py +14 -0
- keras_hub/api/models/__init__.py +97 -48
- keras_hub/api/tokenizers/__init__.py +30 -0
- keras_hub/api/utils/__init__.py +22 -0
- keras_hub/src/api_export.py +15 -9
- keras_hub/src/bounding_box/__init__.py +13 -0
- keras_hub/src/bounding_box/converters.py +529 -0
- keras_hub/src/bounding_box/formats.py +162 -0
- keras_hub/src/bounding_box/iou.py +263 -0
- keras_hub/src/bounding_box/to_dense.py +95 -0
- keras_hub/src/bounding_box/to_ragged.py +99 -0
- keras_hub/src/bounding_box/utils.py +194 -0
- keras_hub/src/bounding_box/validate_format.py +99 -0
- keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
- keras_hub/src/layers/preprocessing/image_converter.py +130 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
- keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
- keras_hub/src/layers/preprocessing/random_swap.py +33 -31
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
- keras_hub/src/models/albert/__init__.py +1 -2
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
- keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
- keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/albert/albert_tokenizer.py +17 -36
- keras_hub/src/models/backbone.py +12 -34
- keras_hub/src/models/bart/__init__.py +1 -2
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
- keras_hub/src/models/bart/bart_tokenizer.py +12 -39
- keras_hub/src/models/bert/__init__.py +1 -5
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
- keras_hub/src/models/bert/bert_presets.py +1 -4
- keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
- keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/bert/bert_tokenizer.py +17 -35
- keras_hub/src/models/bloom/__init__.py +1 -2
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
- keras_hub/src/models/causal_lm.py +10 -29
- keras_hub/src/models/causal_lm_preprocessor.py +195 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
- keras_hub/src/models/deberta_v3/__init__.py +1 -4
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
- keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
- keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
- keras_hub/src/models/densenet/densenet_backbone.py +46 -22
- keras_hub/src/models/distil_bert/__init__.py +1 -4
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
- keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
- keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
- keras_hub/src/models/efficientnet/__init__.py +13 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
- keras_hub/src/models/efficientnet/mbconv.py +238 -0
- keras_hub/src/models/electra/__init__.py +1 -2
- keras_hub/src/models/electra/electra_tokenizer.py +17 -32
- keras_hub/src/models/f_net/__init__.py +1 -2
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
- keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
- keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
- keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
- keras_hub/src/models/falcon/__init__.py +1 -2
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
- keras_hub/src/models/gemma/__init__.py +1 -2
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
- keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
- keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
- keras_hub/src/models/gpt2/__init__.py +1 -2
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
- keras_hub/src/models/image_classifier.py +0 -5
- keras_hub/src/models/image_classifier_preprocessor.py +83 -0
- keras_hub/src/models/llama/__init__.py +1 -2
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
- keras_hub/src/models/llama/llama_tokenizer.py +12 -25
- keras_hub/src/models/llama3/__init__.py +1 -2
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
- keras_hub/src/models/masked_lm.py +0 -2
- keras_hub/src/models/masked_lm_preprocessor.py +156 -0
- keras_hub/src/models/mistral/__init__.py +1 -2
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
- keras_hub/src/models/mobilenet/__init__.py +13 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
- keras_hub/src/models/opt/__init__.py +1 -2
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
- keras_hub/src/models/opt/opt_tokenizer.py +12 -41
- keras_hub/src/models/pali_gemma/__init__.py +1 -4
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
- keras_hub/src/models/phi3/__init__.py +1 -2
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
- keras_hub/src/models/preprocessor.py +72 -83
- keras_hub/src/models/resnet/__init__.py +6 -0
- keras_hub/src/models/resnet/resnet_backbone.py +390 -42
- keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
- keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
- keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
- keras_hub/src/models/resnet/resnet_presets.py +95 -0
- keras_hub/src/models/retinanet/__init__.py +13 -0
- keras_hub/src/models/retinanet/anchor_generator.py +175 -0
- keras_hub/src/models/retinanet/box_matcher.py +259 -0
- keras_hub/src/models/retinanet/non_max_supression.py +578 -0
- keras_hub/src/models/roberta/__init__.py +1 -2
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
- keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
- keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
- keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
- keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
- keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
- keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
- keras_hub/src/models/t5/__init__.py +1 -2
- keras_hub/src/models/t5/t5_tokenizer.py +13 -23
- keras_hub/src/models/task.py +71 -116
- keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
- keras_hub/src/models/text_classifier_preprocessor.py +138 -0
- keras_hub/src/models/whisper/__init__.py +1 -2
- keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
- keras_hub/src/models/whisper/whisper_backbone.py +0 -3
- keras_hub/src/models/whisper/whisper_presets.py +10 -10
- keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
- keras_hub/src/models/xlm_roberta/__init__.py +1 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
- keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
- keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
- keras_hub/src/tests/test_case.py +46 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
- keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
- keras_hub/src/tokenizers/tokenizer.py +67 -32
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
- keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
- keras_hub/src/utils/imagenet/__init__.py +13 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
- keras_hub/src/utils/keras_utils.py +0 -50
- keras_hub/src/utils/preset_utils.py +230 -68
- keras_hub/src/utils/tensor_utils.py +187 -69
- keras_hub/src/utils/timm/convert_resnet.py +19 -16
- keras_hub/src/utils/timm/preset_loader.py +66 -0
- keras_hub/src/utils/transformers/convert_albert.py +193 -0
- keras_hub/src/utils/transformers/convert_bart.py +373 -0
- keras_hub/src/utils/transformers/convert_bert.py +7 -17
- keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
- keras_hub/src/utils/transformers/convert_gemma.py +5 -19
- keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
- keras_hub/src/utils/transformers/convert_llama3.py +7 -18
- keras_hub/src/utils/transformers/convert_mistral.py +129 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
- keras_hub/src/utils/transformers/preset_loader.py +77 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
- keras_hub/src/version_utils.py +1 -1
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
- keras_hub/src/models/bart/bart_preprocessor.py +0 -276
- keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
- keras_hub/src/models/electra/electra_preprocessor.py +0 -154
- keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
- keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
- keras_hub/src/models/llama/llama_preprocessor.py +0 -189
- keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
- keras_hub/src/models/opt/opt_preprocessor.py +0 -188
- keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
- keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
- keras_hub/src/utils/timm/convert.py +0 -37
- keras_hub/src/utils/transformers/convert.py +0 -101
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -15,11 +15,11 @@ import keras
|
|
15
15
|
from keras import layers
|
16
16
|
|
17
17
|
from keras_hub.src.api_export import keras_hub_export
|
18
|
-
from keras_hub.src.models.
|
18
|
+
from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
|
19
19
|
|
20
20
|
|
21
21
|
@keras_hub_export("keras_hub.models.CSPDarkNetBackbone")
|
22
|
-
class CSPDarkNetBackbone(
|
22
|
+
class CSPDarkNetBackbone(FeaturePyramidBackbone):
|
23
23
|
"""This class represents Keras Backbone of CSPDarkNet model.
|
24
24
|
|
25
25
|
This class implements a CSPDarkNet backbone as described in
|
@@ -65,12 +65,15 @@ class CSPDarkNetBackbone(Backbone):
|
|
65
65
|
self,
|
66
66
|
stackwise_num_filters,
|
67
67
|
stackwise_depth,
|
68
|
-
include_rescaling,
|
68
|
+
include_rescaling=True,
|
69
69
|
block_type="basic_block",
|
70
|
-
image_shape=(
|
70
|
+
image_shape=(None, None, 3),
|
71
71
|
**kwargs,
|
72
72
|
):
|
73
73
|
# === Functional Model ===
|
74
|
+
channel_axis = (
|
75
|
+
-1 if keras.config.image_data_format() == "channels_last" else 1
|
76
|
+
)
|
74
77
|
apply_ConvBlock = (
|
75
78
|
apply_darknet_conv_block_depthwise
|
76
79
|
if block_type == "depthwise_block"
|
@@ -83,15 +86,22 @@ class CSPDarkNetBackbone(Backbone):
|
|
83
86
|
if include_rescaling:
|
84
87
|
x = layers.Rescaling(scale=1 / 255.0)(x)
|
85
88
|
|
86
|
-
x = apply_focus(name="stem_focus")(x)
|
89
|
+
x = apply_focus(channel_axis, name="stem_focus")(x)
|
87
90
|
x = apply_darknet_conv_block(
|
88
|
-
base_channels,
|
91
|
+
base_channels,
|
92
|
+
channel_axis,
|
93
|
+
kernel_size=3,
|
94
|
+
strides=1,
|
95
|
+
name="stem_conv",
|
89
96
|
)(x)
|
97
|
+
|
98
|
+
pyramid_outputs = {}
|
90
99
|
for index, (channels, depth) in enumerate(
|
91
100
|
zip(stackwise_num_filters, stackwise_depth)
|
92
101
|
):
|
93
102
|
x = apply_ConvBlock(
|
94
103
|
channels,
|
104
|
+
channel_axis,
|
95
105
|
kernel_size=3,
|
96
106
|
strides=2,
|
97
107
|
name=f"dark{index + 2}_conv",
|
@@ -100,17 +110,20 @@ class CSPDarkNetBackbone(Backbone):
|
|
100
110
|
if index == len(stackwise_depth) - 1:
|
101
111
|
x = apply_spatial_pyramid_pooling_bottleneck(
|
102
112
|
channels,
|
113
|
+
channel_axis,
|
103
114
|
hidden_filters=channels // 2,
|
104
115
|
name=f"dark{index + 2}_spp",
|
105
116
|
)(x)
|
106
117
|
|
107
118
|
x = apply_cross_stage_partial(
|
108
119
|
channels,
|
120
|
+
channel_axis,
|
109
121
|
num_bottlenecks=depth,
|
110
122
|
block_type="basic_block",
|
111
123
|
residual=(index != len(stackwise_depth) - 1),
|
112
124
|
name=f"dark{index + 2}_csp",
|
113
125
|
)(x)
|
126
|
+
pyramid_outputs[f"P{index + 2}"] = x
|
114
127
|
|
115
128
|
super().__init__(inputs=image_input, outputs=x, **kwargs)
|
116
129
|
|
@@ -120,6 +133,7 @@ class CSPDarkNetBackbone(Backbone):
|
|
120
133
|
self.include_rescaling = include_rescaling
|
121
134
|
self.block_type = block_type
|
122
135
|
self.image_shape = image_shape
|
136
|
+
self.pyramid_outputs = pyramid_outputs
|
123
137
|
|
124
138
|
def get_config(self):
|
125
139
|
config = super().get_config()
|
@@ -135,7 +149,7 @@ class CSPDarkNetBackbone(Backbone):
|
|
135
149
|
return config
|
136
150
|
|
137
151
|
|
138
|
-
def apply_focus(name=None):
|
152
|
+
def apply_focus(channel_axis, name=None):
|
139
153
|
"""A block used in CSPDarknet to focus information into channels of the
|
140
154
|
image.
|
141
155
|
|
@@ -151,7 +165,7 @@ def apply_focus(name=None):
|
|
151
165
|
"""
|
152
166
|
|
153
167
|
def apply(x):
|
154
|
-
return layers.Concatenate(name=name)(
|
168
|
+
return layers.Concatenate(axis=channel_axis, name=name)(
|
155
169
|
[
|
156
170
|
x[..., ::2, ::2, :],
|
157
171
|
x[..., 1::2, ::2, :],
|
@@ -164,7 +178,13 @@ def apply_focus(name=None):
|
|
164
178
|
|
165
179
|
|
166
180
|
def apply_darknet_conv_block(
|
167
|
-
filters,
|
181
|
+
filters,
|
182
|
+
channel_axis,
|
183
|
+
kernel_size,
|
184
|
+
strides,
|
185
|
+
use_bias=False,
|
186
|
+
activation="silu",
|
187
|
+
name=None,
|
168
188
|
):
|
169
189
|
"""
|
170
190
|
The basic conv block used in Darknet. Applies Conv2D followed by a
|
@@ -193,11 +213,12 @@ def apply_darknet_conv_block(
|
|
193
213
|
kernel_size,
|
194
214
|
strides,
|
195
215
|
padding="same",
|
216
|
+
data_format=keras.config.image_data_format(),
|
196
217
|
use_bias=use_bias,
|
197
218
|
name=name + "_conv",
|
198
219
|
)(inputs)
|
199
220
|
|
200
|
-
x = layers.BatchNormalization(name=name + "_bn")(x)
|
221
|
+
x = layers.BatchNormalization(axis=channel_axis, name=name + "_bn")(x)
|
201
222
|
|
202
223
|
if activation == "silu":
|
203
224
|
x = layers.Lambda(lambda x: keras.activations.silu(x))(x)
|
@@ -212,7 +233,7 @@ def apply_darknet_conv_block(
|
|
212
233
|
|
213
234
|
|
214
235
|
def apply_darknet_conv_block_depthwise(
|
215
|
-
filters, kernel_size, strides, activation="silu", name=None
|
236
|
+
filters, channel_axis, kernel_size, strides, activation="silu", name=None
|
216
237
|
):
|
217
238
|
"""
|
218
239
|
The depthwise conv block used in CSPDarknet.
|
@@ -236,9 +257,13 @@ def apply_darknet_conv_block_depthwise(
|
|
236
257
|
|
237
258
|
def apply(inputs):
|
238
259
|
x = layers.DepthwiseConv2D(
|
239
|
-
kernel_size,
|
260
|
+
kernel_size,
|
261
|
+
strides,
|
262
|
+
padding="same",
|
263
|
+
data_format=keras.config.image_data_format(),
|
264
|
+
use_bias=False,
|
240
265
|
)(inputs)
|
241
|
-
x = layers.BatchNormalization()(x)
|
266
|
+
x = layers.BatchNormalization(axis=channel_axis)(x)
|
242
267
|
|
243
268
|
if activation == "silu":
|
244
269
|
x = layers.Lambda(lambda x: keras.activations.swish(x))(x)
|
@@ -248,7 +273,11 @@ def apply_darknet_conv_block_depthwise(
|
|
248
273
|
x = layers.LeakyReLU(0.1)(x)
|
249
274
|
|
250
275
|
x = apply_darknet_conv_block(
|
251
|
-
filters,
|
276
|
+
filters,
|
277
|
+
channel_axis,
|
278
|
+
kernel_size=1,
|
279
|
+
strides=1,
|
280
|
+
activation=activation,
|
252
281
|
)(x)
|
253
282
|
|
254
283
|
return x
|
@@ -258,6 +287,7 @@ def apply_darknet_conv_block_depthwise(
|
|
258
287
|
|
259
288
|
def apply_spatial_pyramid_pooling_bottleneck(
|
260
289
|
filters,
|
290
|
+
channel_axis,
|
261
291
|
hidden_filters=None,
|
262
292
|
kernel_sizes=(5, 9, 13),
|
263
293
|
activation="silu",
|
@@ -291,6 +321,7 @@ def apply_spatial_pyramid_pooling_bottleneck(
|
|
291
321
|
def apply(x):
|
292
322
|
x = apply_darknet_conv_block(
|
293
323
|
hidden_filters,
|
324
|
+
channel_axis,
|
294
325
|
kernel_size=1,
|
295
326
|
strides=1,
|
296
327
|
activation=activation,
|
@@ -304,13 +335,15 @@ def apply_spatial_pyramid_pooling_bottleneck(
|
|
304
335
|
kernel_size,
|
305
336
|
strides=1,
|
306
337
|
padding="same",
|
338
|
+
data_format=keras.config.image_data_format(),
|
307
339
|
name=f"{name}_maxpool_{kernel_size}",
|
308
340
|
)(x[0])
|
309
341
|
)
|
310
342
|
|
311
|
-
x = layers.Concatenate(name=f"{name}_concat")(x)
|
343
|
+
x = layers.Concatenate(axis=channel_axis, name=f"{name}_concat")(x)
|
312
344
|
x = apply_darknet_conv_block(
|
313
345
|
filters,
|
346
|
+
channel_axis,
|
314
347
|
kernel_size=1,
|
315
348
|
strides=1,
|
316
349
|
activation=activation,
|
@@ -324,6 +357,7 @@ def apply_spatial_pyramid_pooling_bottleneck(
|
|
324
357
|
|
325
358
|
def apply_cross_stage_partial(
|
326
359
|
filters,
|
360
|
+
channel_axis,
|
327
361
|
num_bottlenecks,
|
328
362
|
residual=True,
|
329
363
|
block_type="basic_block",
|
@@ -361,6 +395,7 @@ def apply_cross_stage_partial(
|
|
361
395
|
|
362
396
|
x1 = apply_darknet_conv_block(
|
363
397
|
hidden_channels,
|
398
|
+
channel_axis,
|
364
399
|
kernel_size=1,
|
365
400
|
strides=1,
|
366
401
|
activation=activation,
|
@@ -369,6 +404,7 @@ def apply_cross_stage_partial(
|
|
369
404
|
|
370
405
|
x2 = apply_darknet_conv_block(
|
371
406
|
hidden_channels,
|
407
|
+
channel_axis,
|
372
408
|
kernel_size=1,
|
373
409
|
strides=1,
|
374
410
|
activation=activation,
|
@@ -379,6 +415,7 @@ def apply_cross_stage_partial(
|
|
379
415
|
residual_x = x1
|
380
416
|
x1 = apply_darknet_conv_block(
|
381
417
|
hidden_channels,
|
418
|
+
channel_axis,
|
382
419
|
kernel_size=1,
|
383
420
|
strides=1,
|
384
421
|
activation=activation,
|
@@ -386,6 +423,7 @@ def apply_cross_stage_partial(
|
|
386
423
|
)(x1)
|
387
424
|
x1 = ConvBlock(
|
388
425
|
hidden_channels,
|
426
|
+
channel_axis,
|
389
427
|
kernel_size=3,
|
390
428
|
strides=1,
|
391
429
|
activation=activation,
|
@@ -399,6 +437,7 @@ def apply_cross_stage_partial(
|
|
399
437
|
x = layers.Concatenate(name=f"{name}_concat")([x1, x2])
|
400
438
|
x = apply_darknet_conv_block(
|
401
439
|
filters,
|
440
|
+
channel_axis,
|
402
441
|
kernel_size=1,
|
403
442
|
strides=1,
|
404
443
|
activation=activation,
|
@@ -16,9 +16,6 @@ from keras_hub.src.models.deberta_v3.deberta_v3_backbone import (
|
|
16
16
|
DebertaV3Backbone,
|
17
17
|
)
|
18
18
|
from keras_hub.src.models.deberta_v3.deberta_v3_presets import backbone_presets
|
19
|
-
from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import (
|
20
|
-
DebertaV3Tokenizer,
|
21
|
-
)
|
22
19
|
from keras_hub.src.utils.preset_utils import register_presets
|
23
20
|
|
24
|
-
register_presets(backbone_presets,
|
21
|
+
register_presets(backbone_presets, DebertaV3Backbone)
|
@@ -13,19 +13,20 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import keras
|
16
|
-
from absl import logging
|
17
16
|
|
18
17
|
from keras_hub.src.api_export import keras_hub_export
|
19
|
-
from keras_hub.src.
|
20
|
-
|
18
|
+
from keras_hub.src.models.deberta_v3.deberta_v3_backbone import (
|
19
|
+
DebertaV3Backbone,
|
21
20
|
)
|
22
|
-
from keras_hub.src.models.deberta_v3.
|
23
|
-
|
21
|
+
from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import (
|
22
|
+
DebertaV3Tokenizer,
|
24
23
|
)
|
24
|
+
from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor
|
25
|
+
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
25
26
|
|
26
27
|
|
27
28
|
@keras_hub_export("keras_hub.models.DebertaV3MaskedLMPreprocessor")
|
28
|
-
class DebertaV3MaskedLMPreprocessor(
|
29
|
+
class DebertaV3MaskedLMPreprocessor(MaskedLMPreprocessor):
|
29
30
|
"""DeBERTa preprocessing for the masked language modeling task.
|
30
31
|
|
31
32
|
This preprocessing layer will prepare inputs for a masked language modeling
|
@@ -115,77 +116,13 @@ class DebertaV3MaskedLMPreprocessor(DebertaV3Preprocessor):
|
|
115
116
|
```
|
116
117
|
"""
|
117
118
|
|
118
|
-
|
119
|
-
|
120
|
-
tokenizer,
|
121
|
-
sequence_length=512,
|
122
|
-
truncate="round_robin",
|
123
|
-
mask_selection_rate=0.15,
|
124
|
-
mask_selection_length=96,
|
125
|
-
mask_token_rate=0.8,
|
126
|
-
random_token_rate=0.1,
|
127
|
-
**kwargs,
|
128
|
-
):
|
129
|
-
super().__init__(
|
130
|
-
tokenizer,
|
131
|
-
sequence_length=sequence_length,
|
132
|
-
truncate=truncate,
|
133
|
-
**kwargs,
|
134
|
-
)
|
135
|
-
|
136
|
-
self.mask_selection_rate = mask_selection_rate
|
137
|
-
self.mask_selection_length = mask_selection_length
|
138
|
-
self.mask_token_rate = mask_token_rate
|
139
|
-
self.random_token_rate = random_token_rate
|
140
|
-
self.masker = None
|
141
|
-
|
142
|
-
def build(self, input_shape):
|
143
|
-
super().build(input_shape)
|
144
|
-
# Defer masker creation to `build()` so that we can be sure tokenizer
|
145
|
-
# assets have loaded when restoring a saved model.
|
146
|
-
self.masker = MaskedLMMaskGenerator(
|
147
|
-
mask_selection_rate=self.mask_selection_rate,
|
148
|
-
mask_selection_length=self.mask_selection_length,
|
149
|
-
mask_token_rate=self.mask_token_rate,
|
150
|
-
random_token_rate=self.random_token_rate,
|
151
|
-
vocabulary_size=self.tokenizer.vocabulary_size(),
|
152
|
-
mask_token_id=self.tokenizer.mask_token_id,
|
153
|
-
unselectable_token_ids=[
|
154
|
-
self.tokenizer.cls_token_id,
|
155
|
-
self.tokenizer.sep_token_id,
|
156
|
-
self.tokenizer.pad_token_id,
|
157
|
-
],
|
158
|
-
)
|
159
|
-
|
160
|
-
def get_config(self):
|
161
|
-
config = super().get_config()
|
162
|
-
config.update(
|
163
|
-
{
|
164
|
-
"mask_selection_rate": self.mask_selection_rate,
|
165
|
-
"mask_selection_length": self.mask_selection_length,
|
166
|
-
"mask_token_rate": self.mask_token_rate,
|
167
|
-
"random_token_rate": self.random_token_rate,
|
168
|
-
}
|
169
|
-
)
|
170
|
-
return config
|
119
|
+
backbone_cls = DebertaV3Backbone
|
120
|
+
tokenizer_cls = DebertaV3Tokenizer
|
171
121
|
|
122
|
+
@preprocessing_function
|
172
123
|
def call(self, x, y=None, sample_weight=None):
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
"or `sample_weight`. Your `y` and `sample_weight` will be "
|
178
|
-
"ignored."
|
179
|
-
)
|
180
|
-
|
181
|
-
x = super().call(x)
|
182
|
-
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
|
183
|
-
masker_outputs = self.masker(token_ids)
|
184
|
-
x = {
|
185
|
-
"token_ids": masker_outputs["token_ids"],
|
186
|
-
"padding_mask": padding_mask,
|
187
|
-
"mask_positions": masker_outputs["mask_positions"],
|
188
|
-
}
|
189
|
-
y = masker_outputs["mask_ids"]
|
190
|
-
sample_weight = masker_outputs["mask_weights"]
|
124
|
+
output = super().call(x, y=y, sample_weight=sample_weight)
|
125
|
+
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output)
|
126
|
+
# Backbone has no segment ID input.
|
127
|
+
del x["segment_ids"]
|
191
128
|
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
@@ -16,20 +16,25 @@
|
|
16
16
|
import keras
|
17
17
|
|
18
18
|
from keras_hub.src.api_export import keras_hub_export
|
19
|
-
from keras_hub.src.models.classifier import Classifier
|
20
19
|
from keras_hub.src.models.deberta_v3.deberta_v3_backbone import (
|
21
20
|
DebertaV3Backbone,
|
22
21
|
)
|
23
22
|
from keras_hub.src.models.deberta_v3.deberta_v3_backbone import (
|
24
23
|
deberta_kernel_initializer,
|
25
24
|
)
|
26
|
-
from keras_hub.src.models.deberta_v3.
|
27
|
-
|
25
|
+
from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import (
|
26
|
+
DebertaV3TextClassifierPreprocessor,
|
28
27
|
)
|
28
|
+
from keras_hub.src.models.text_classifier import TextClassifier
|
29
29
|
|
30
30
|
|
31
|
-
@keras_hub_export(
|
32
|
-
|
31
|
+
@keras_hub_export(
|
32
|
+
[
|
33
|
+
"keras_hub.models.DebertaV3TextClassifier",
|
34
|
+
"keras_hub.models.DebertaV3Classifier",
|
35
|
+
]
|
36
|
+
)
|
37
|
+
class DebertaV3TextClassifier(TextClassifier):
|
33
38
|
"""An end-to-end DeBERTa model for classification tasks.
|
34
39
|
|
35
40
|
This model attaches a classification head to a
|
@@ -53,7 +58,7 @@ class DebertaV3Classifier(Classifier):
|
|
53
58
|
Args:
|
54
59
|
backbone: A `keras_hub.models.DebertaV3` instance.
|
55
60
|
num_classes: int. Number of classes to predict.
|
56
|
-
preprocessor: A `keras_hub.models.
|
61
|
+
preprocessor: A `keras_hub.models.DebertaV3TextClassifierPreprocessor` or `None`. If
|
57
62
|
`None`, this model will not apply preprocessing, and inputs should
|
58
63
|
be preprocessed before calling the model.
|
59
64
|
activation: Optional `str` or callable. The
|
@@ -72,7 +77,7 @@ class DebertaV3Classifier(Classifier):
|
|
72
77
|
labels = [0, 3]
|
73
78
|
|
74
79
|
# Pretrained classifier.
|
75
|
-
classifier = keras_hub.models.
|
80
|
+
classifier = keras_hub.models.DebertaV3TextClassifier.from_preset(
|
76
81
|
"deberta_v3_base_en",
|
77
82
|
num_classes=4,
|
78
83
|
)
|
@@ -100,7 +105,7 @@ class DebertaV3Classifier(Classifier):
|
|
100
105
|
labels = [0, 3]
|
101
106
|
|
102
107
|
# Pretrained classifier without preprocessing.
|
103
|
-
classifier = keras_hub.models.
|
108
|
+
classifier = keras_hub.models.DebertaV3TextClassifier.from_preset(
|
104
109
|
"deberta_v3_base_en",
|
105
110
|
num_classes=4,
|
106
111
|
preprocessor=None,
|
@@ -132,7 +137,7 @@ class DebertaV3Classifier(Classifier):
|
|
132
137
|
tokenizer = keras_hub.models.DebertaV3Tokenizer(
|
133
138
|
proto=bytes_io.getvalue(),
|
134
139
|
)
|
135
|
-
preprocessor = keras_hub.models.
|
140
|
+
preprocessor = keras_hub.models.DebertaV3TextClassifierPreprocessor(
|
136
141
|
tokenizer=tokenizer,
|
137
142
|
sequence_length=128,
|
138
143
|
)
|
@@ -144,7 +149,7 @@ class DebertaV3Classifier(Classifier):
|
|
144
149
|
intermediate_dim=512,
|
145
150
|
max_sequence_length=128,
|
146
151
|
)
|
147
|
-
classifier = keras_hub.models.
|
152
|
+
classifier = keras_hub.models.DebertaV3TextClassifier(
|
148
153
|
backbone=backbone,
|
149
154
|
preprocessor=preprocessor,
|
150
155
|
num_classes=4,
|
@@ -154,7 +159,7 @@ class DebertaV3Classifier(Classifier):
|
|
154
159
|
"""
|
155
160
|
|
156
161
|
backbone_cls = DebertaV3Backbone
|
157
|
-
preprocessor_cls =
|
162
|
+
preprocessor_cls = DebertaV3TextClassifierPreprocessor
|
158
163
|
|
159
164
|
def __init__(
|
160
165
|
self,
|
@@ -12,24 +12,28 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
16
15
|
import keras
|
17
16
|
|
18
17
|
from keras_hub.src.api_export import keras_hub_export
|
19
|
-
from keras_hub.src.
|
20
|
-
|
18
|
+
from keras_hub.src.models.deberta_v3.deberta_v3_backbone import (
|
19
|
+
DebertaV3Backbone,
|
21
20
|
)
|
22
21
|
from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import (
|
23
22
|
DebertaV3Tokenizer,
|
24
23
|
)
|
25
|
-
from keras_hub.src.models.
|
26
|
-
|
27
|
-
convert_inputs_to_list_of_tensor_segments,
|
24
|
+
from keras_hub.src.models.text_classifier_preprocessor import (
|
25
|
+
TextClassifierPreprocessor,
|
28
26
|
)
|
27
|
+
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
29
28
|
|
30
29
|
|
31
|
-
@keras_hub_export(
|
32
|
-
|
30
|
+
@keras_hub_export(
|
31
|
+
[
|
32
|
+
"keras_hub.models.DebertaV3TextClassifierPreprocessor",
|
33
|
+
"keras_hub.models.DebertaV3Preprocessor",
|
34
|
+
]
|
35
|
+
)
|
36
|
+
class DebertaV3TextClassifierPreprocessor(TextClassifierPreprocessor):
|
33
37
|
"""A DeBERTa preprocessing layer which tokenizes and packs inputs.
|
34
38
|
|
35
39
|
This preprocessing layer will do three things:
|
@@ -74,7 +78,7 @@ class DebertaV3Preprocessor(Preprocessor):
|
|
74
78
|
Examples:
|
75
79
|
Directly calling the layer on data.
|
76
80
|
```python
|
77
|
-
preprocessor = keras_hub.models.
|
81
|
+
preprocessor = keras_hub.models.TextClassifierPreprocessor.from_preset(
|
78
82
|
"deberta_v3_base_en"
|
79
83
|
)
|
80
84
|
|
@@ -110,13 +114,15 @@ class DebertaV3Preprocessor(Preprocessor):
|
|
110
114
|
tokenizer = keras_hub.models.DebertaV3Tokenizer(
|
111
115
|
proto=bytes_io.getvalue(),
|
112
116
|
)
|
113
|
-
preprocessor = keras_hub.models.
|
117
|
+
preprocessor = keras_hub.models.DebertaV3TextClassifierPreprocessor(
|
118
|
+
tokenizer
|
119
|
+
)
|
114
120
|
preprocessor("The quick brown fox jumped.")
|
115
121
|
```
|
116
122
|
|
117
123
|
Mapping with `tf.data.Dataset`.
|
118
124
|
```python
|
119
|
-
preprocessor = keras_hub.models.
|
125
|
+
preprocessor = keras_hub.models.TextClassifierPreprocessor.from_preset(
|
120
126
|
"deberta_v3_base_en"
|
121
127
|
)
|
122
128
|
|
@@ -147,60 +153,13 @@ class DebertaV3Preprocessor(Preprocessor):
|
|
147
153
|
```
|
148
154
|
"""
|
149
155
|
|
156
|
+
backbone_cls = DebertaV3Backbone
|
150
157
|
tokenizer_cls = DebertaV3Tokenizer
|
151
158
|
|
152
|
-
|
153
|
-
self,
|
154
|
-
tokenizer,
|
155
|
-
sequence_length=512,
|
156
|
-
truncate="round_robin",
|
157
|
-
**kwargs,
|
158
|
-
):
|
159
|
-
super().__init__(**kwargs)
|
160
|
-
self.tokenizer = tokenizer
|
161
|
-
self.packer = None
|
162
|
-
self.truncate = truncate
|
163
|
-
self.sequence_length = sequence_length
|
164
|
-
|
165
|
-
def build(self, input_shape):
|
166
|
-
# Defer packer creation to `build()` so that we can be sure tokenizer
|
167
|
-
# assets have loaded when restoring a saved model.
|
168
|
-
self.packer = MultiSegmentPacker(
|
169
|
-
start_value=self.tokenizer.cls_token_id,
|
170
|
-
end_value=self.tokenizer.sep_token_id,
|
171
|
-
pad_value=self.tokenizer.pad_token_id,
|
172
|
-
truncate=self.truncate,
|
173
|
-
sequence_length=self.sequence_length,
|
174
|
-
)
|
175
|
-
self.built = True
|
176
|
-
|
177
|
-
def get_config(self):
|
178
|
-
config = super().get_config()
|
179
|
-
config.update(
|
180
|
-
{
|
181
|
-
"sequence_length": self.sequence_length,
|
182
|
-
"truncate": self.truncate,
|
183
|
-
}
|
184
|
-
)
|
185
|
-
return config
|
186
|
-
|
159
|
+
@preprocessing_function
|
187
160
|
def call(self, x, y=None, sample_weight=None):
|
188
|
-
|
189
|
-
x =
|
190
|
-
|
191
|
-
x
|
192
|
-
"token_ids": token_ids,
|
193
|
-
"padding_mask": token_ids != self.tokenizer.pad_token_id,
|
194
|
-
}
|
161
|
+
output = super().call(x, y=y, sample_weight=sample_weight)
|
162
|
+
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output)
|
163
|
+
# Backbone has no segment ID input.
|
164
|
+
del x["segment_ids"]
|
195
165
|
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
196
|
-
|
197
|
-
@property
|
198
|
-
def sequence_length(self):
|
199
|
-
"""The padded length of model input sequences."""
|
200
|
-
return self._sequence_length
|
201
|
-
|
202
|
-
@sequence_length.setter
|
203
|
-
def sequence_length(self, value):
|
204
|
-
self._sequence_length = value
|
205
|
-
if self.packer is not None:
|
206
|
-
self.packer.sequence_length = value
|
@@ -14,6 +14,9 @@
|
|
14
14
|
|
15
15
|
|
16
16
|
from keras_hub.src.api_export import keras_hub_export
|
17
|
+
from keras_hub.src.models.deberta_v3.deberta_v3_backbone import (
|
18
|
+
DebertaV3Backbone,
|
19
|
+
)
|
17
20
|
from keras_hub.src.tokenizers.sentence_piece_tokenizer import (
|
18
21
|
SentencePieceTokenizer,
|
19
22
|
)
|
@@ -24,7 +27,12 @@ except ImportError:
|
|
24
27
|
tf = None
|
25
28
|
|
26
29
|
|
27
|
-
@keras_hub_export(
|
30
|
+
@keras_hub_export(
|
31
|
+
[
|
32
|
+
"keras_hub.tokenizers.DebertaV3Tokenizer",
|
33
|
+
"keras_hub.models.DebertaV3Tokenizer",
|
34
|
+
]
|
35
|
+
)
|
28
36
|
class DebertaV3Tokenizer(SentencePieceTokenizer):
|
29
37
|
"""DeBERTa tokenizer layer based on SentencePiece.
|
30
38
|
|
@@ -34,10 +42,6 @@ class DebertaV3Tokenizer(SentencePieceTokenizer):
|
|
34
42
|
DeBERTa models and provides a `from_preset()` method to automatically
|
35
43
|
download a matching vocabulary for a DeBERTa preset.
|
36
44
|
|
37
|
-
This tokenizer does not provide truncation or padding of inputs. It can be
|
38
|
-
combined with a `keras_hub.models.DebertaV3Preprocessor` layer for input
|
39
|
-
packing.
|
40
|
-
|
41
45
|
If input is a batch of strings (rank > 0), the layer will output a
|
42
46
|
`tf.RaggedTensor` where the last dimension of the output is ragged.
|
43
47
|
|
@@ -94,38 +98,37 @@ class DebertaV3Tokenizer(SentencePieceTokenizer):
|
|
94
98
|
```
|
95
99
|
"""
|
96
100
|
|
101
|
+
backbone_cls = DebertaV3Backbone
|
102
|
+
|
97
103
|
def __init__(self, proto, **kwargs):
|
98
|
-
self.
|
99
|
-
self.
|
100
|
-
self.
|
104
|
+
self._add_special_token("[CLS]", "cls_token")
|
105
|
+
self._add_special_token("[SEP]", "sep_token")
|
106
|
+
self._add_special_token("[PAD]", "pad_token")
|
107
|
+
# Also add `tokenizer.start_token` and `tokenizer.end_token` for
|
108
|
+
# compatibility with other tokenizers.
|
109
|
+
self._add_special_token("[CLS]", "start_token")
|
110
|
+
self._add_special_token("[SEP]", "end_token")
|
111
|
+
# Handle mask separately as it's not always in the vocab.
|
101
112
|
self.mask_token = "[MASK]"
|
102
|
-
|
113
|
+
self.mask_token_id = None
|
103
114
|
super().__init__(proto=proto, **kwargs)
|
104
115
|
|
116
|
+
@property
|
117
|
+
def special_tokens(self):
|
118
|
+
return super().special_tokens + [self.mask_token]
|
119
|
+
|
120
|
+
@property
|
121
|
+
def special_token_ids(self):
|
122
|
+
return super().special_token_ids + [self.mask_token_id]
|
123
|
+
|
105
124
|
def set_proto(self, proto):
|
106
125
|
super().set_proto(proto)
|
107
126
|
if proto is not None:
|
108
|
-
for token in [self.cls_token, self.pad_token, self.sep_token]:
|
109
|
-
if token not in super().get_vocabulary():
|
110
|
-
raise ValueError(
|
111
|
-
f"Cannot find token `'{token}'` in the provided "
|
112
|
-
f"`vocabulary`. Please provide `'{token}'` in your "
|
113
|
-
"`vocabulary` or use a pretrained `vocabulary` name."
|
114
|
-
)
|
115
|
-
|
116
|
-
self.cls_token_id = self.token_to_id(self.cls_token)
|
117
|
-
self.sep_token_id = self.token_to_id(self.sep_token)
|
118
|
-
self.pad_token_id = self.token_to_id(self.pad_token)
|
119
|
-
# If the mask token is not in the vocabulary, add it to the end of the
|
120
|
-
# vocabulary.
|
121
127
|
if self.mask_token in super().get_vocabulary():
|
122
128
|
self.mask_token_id = super().token_to_id(self.mask_token)
|
123
129
|
else:
|
124
130
|
self.mask_token_id = super().vocabulary_size()
|
125
131
|
else:
|
126
|
-
self.cls_token_id = None
|
127
|
-
self.sep_token_id = None
|
128
|
-
self.pad_token_id = None
|
129
132
|
self.mask_token_id = None
|
130
133
|
|
131
134
|
def vocabulary_size(self):
|
@@ -136,6 +139,8 @@ class DebertaV3Tokenizer(SentencePieceTokenizer):
|
|
136
139
|
|
137
140
|
def get_vocabulary(self):
|
138
141
|
sentence_piece_vocabulary = super().get_vocabulary()
|
142
|
+
if self.mask_token_id is None:
|
143
|
+
return sentence_piece_vocabulary
|
139
144
|
if self.mask_token_id < super().vocabulary_size():
|
140
145
|
return sentence_piece_vocabulary
|
141
146
|
return sentence_piece_vocabulary + ["[MASK]"]
|