keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,319 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from keras_hub.src.api_export import keras_hub_export
|
16
|
+
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
|
17
|
+
PreprocessingLayer,
|
18
|
+
)
|
19
|
+
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
|
20
|
+
|
21
|
+
try:
|
22
|
+
import tensorflow as tf
|
23
|
+
import tensorflow_text as tf_text
|
24
|
+
except ImportError:
|
25
|
+
tf = None
|
26
|
+
tf_text = None
|
27
|
+
|
28
|
+
|
29
|
+
@keras_hub_export("keras_hub.layers.MultiSegmentPacker")
|
30
|
+
class MultiSegmentPacker(PreprocessingLayer):
|
31
|
+
"""Packs multiple sequences into a single fixed width model input.
|
32
|
+
|
33
|
+
This layer packs multiple input sequences into a single fixed width sequence
|
34
|
+
containing start and end delimeters, forming a dense input suitable for a
|
35
|
+
classification task for BERT and BERT-like models.
|
36
|
+
|
37
|
+
Takes as input a tuple of token segments. Each tuple element should contain
|
38
|
+
the tokens for a segment, passed as tensors, `tf.RaggedTensor`s, or lists.
|
39
|
+
For batched input, each element in the tuple of segments should be a list of
|
40
|
+
lists or a rank two tensor. For unbatched inputs, each element should be a
|
41
|
+
list or rank one tensor.
|
42
|
+
|
43
|
+
The layer will process inputs as follows:
|
44
|
+
- Truncate all input segments to fit within `sequence_length` according to
|
45
|
+
the `truncate` strategy.
|
46
|
+
- Concatenate all input segments, adding a single `start_value` at the
|
47
|
+
start of the entire sequence, and multiple `end_value`s at the end of
|
48
|
+
each segment.
|
49
|
+
- Pad the resulting sequence to `sequence_length` using `pad_tokens`.
|
50
|
+
- Calculate a separate tensor of "segment ids", with integer type and the
|
51
|
+
same shape as the packed token output, where each integer index of the
|
52
|
+
segment the token originated from. The segment id of the `start_value`
|
53
|
+
is always 0, and the segment id of each `end_value` is the segment that
|
54
|
+
precedes it.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
sequence_length: int. The desired output length.
|
58
|
+
start_value: int/str/list/tuple. The id(s) or token(s) that are to be
|
59
|
+
placed at the start of each sequence (called "[CLS]" for BERT). The
|
60
|
+
dtype must match the dtype of the input tensors to the layer.
|
61
|
+
end_value: int/str/list/tuple. The id(s) or token(s) that are to be
|
62
|
+
placed at the end of the last input segment (called "[SEP]" for
|
63
|
+
BERT). The dtype must match the dtype of the input tensors to the
|
64
|
+
layer.
|
65
|
+
sep_value: int/str/list/tuple. The id(s) or token(s) that are to be
|
66
|
+
placed at the end of every segment, except the last segment (called
|
67
|
+
"[SEP]" for BERT). If `None`, `end_value` is used. The dtype must
|
68
|
+
match the dtype of the input tensors to the layer.
|
69
|
+
pad_value: int/str. The id or token that is to be placed into the unused
|
70
|
+
positions after the last segment in the sequence
|
71
|
+
(called "[PAD]" for BERT).
|
72
|
+
truncate: str. The algorithm to truncate a list of batched segments to
|
73
|
+
fit a per-example length limit. The value can be either
|
74
|
+
`"round_robin"` or `"waterfall"`:
|
75
|
+
- `"round_robin"`: Available space is assigned one token at a
|
76
|
+
time in a round-robin fashion to the inputs that still need
|
77
|
+
some, until the limit is reached.
|
78
|
+
- `"waterfall"`: The allocation of the budget is done using a
|
79
|
+
"waterfall" algorithm that allocates quota in a
|
80
|
+
left-to-right manner and fills up the buckets until we run
|
81
|
+
out of budget. It support arbitrary number of segments.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
A tuple with two elements. The first is the dense, packed token
|
85
|
+
sequence. The second is an integer tensor of the same shape, containing
|
86
|
+
the segment ids.
|
87
|
+
|
88
|
+
Examples:
|
89
|
+
|
90
|
+
*Pack a single input for classification.*
|
91
|
+
>>> seq1 = [1, 2, 3, 4]
|
92
|
+
>>> packer = keras_hub.layers.MultiSegmentPacker(
|
93
|
+
... sequence_length=8, start_value=101, end_value=102
|
94
|
+
... )
|
95
|
+
>>> token_ids, segment_ids = packer((seq1,))
|
96
|
+
>>> np.array(token_ids)
|
97
|
+
array([101, 1, 2, 3, 4, 102, 0, 0], dtype=int32)
|
98
|
+
>>> np.array(segment_ids)
|
99
|
+
array([0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)
|
100
|
+
|
101
|
+
*Pack multiple inputs for classification.*
|
102
|
+
>>> seq1 = [1, 2, 3, 4]
|
103
|
+
>>> seq2 = [11, 12, 13, 14]
|
104
|
+
>>> packer = keras_hub.layers.MultiSegmentPacker(
|
105
|
+
... sequence_length=8, start_value=101, end_value=102
|
106
|
+
... )
|
107
|
+
>>> token_ids, segment_ids = packer((seq1, seq2))
|
108
|
+
>>> np.array(token_ids)
|
109
|
+
array([101, 1, 2, 3, 102, 11, 12, 102], dtype=int32)
|
110
|
+
>>> np.array(segment_ids)
|
111
|
+
array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32)
|
112
|
+
|
113
|
+
*Pack multiple inputs for classification with different sep tokens.*
|
114
|
+
>>> seq1 = [1, 2, 3, 4]
|
115
|
+
>>> seq2 = [11, 12, 13, 14]
|
116
|
+
>>> packer = keras_hub.layers.MultiSegmentPacker(
|
117
|
+
... sequence_length=8,
|
118
|
+
... start_value=101,
|
119
|
+
... end_value=102,
|
120
|
+
... sep_value=[102, 102],
|
121
|
+
... )
|
122
|
+
>>> token_ids, segment_ids = packer((seq1, seq2))
|
123
|
+
>>> np.array(token_ids)
|
124
|
+
array([101, 1, 2, 102, 102, 11, 12, 102], dtype=int32)
|
125
|
+
>>> np.array(segment_ids)
|
126
|
+
array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32)
|
127
|
+
|
128
|
+
Reference:
|
129
|
+
[Devlin et al., 2018](https://arxiv.org/abs/1810.04805).
|
130
|
+
"""
|
131
|
+
|
132
|
+
def __init__(
|
133
|
+
self,
|
134
|
+
sequence_length,
|
135
|
+
start_value,
|
136
|
+
end_value,
|
137
|
+
sep_value=None,
|
138
|
+
pad_value=None,
|
139
|
+
truncate="round_robin",
|
140
|
+
**kwargs,
|
141
|
+
):
|
142
|
+
super().__init__(**kwargs)
|
143
|
+
|
144
|
+
self.sequence_length = sequence_length
|
145
|
+
if truncate not in ("round_robin", "waterfall"):
|
146
|
+
raise ValueError(
|
147
|
+
"Only 'round_robin' and 'waterfall' algorithms are "
|
148
|
+
"supported. Received %s" % truncate
|
149
|
+
)
|
150
|
+
self.truncate = truncate
|
151
|
+
|
152
|
+
# Maintain private copies of start/end values for config purposes.
|
153
|
+
self._start_value = start_value
|
154
|
+
self._sep_value = sep_value
|
155
|
+
self._end_value = end_value
|
156
|
+
|
157
|
+
def check_special_value_type(value, value_name):
|
158
|
+
if isinstance(value, (int, str)):
|
159
|
+
return [value]
|
160
|
+
if value and not isinstance(value, (list, tuple)):
|
161
|
+
raise ValueError(
|
162
|
+
f"{value_name} should be of type int/str/list/tuple."
|
163
|
+
f"Received type: `{type(value)}`."
|
164
|
+
)
|
165
|
+
return value
|
166
|
+
|
167
|
+
start_value = check_special_value_type(start_value, "start_value")
|
168
|
+
if sep_value is None:
|
169
|
+
sep_value = end_value
|
170
|
+
sep_value = check_special_value_type(sep_value, "sep_value")
|
171
|
+
end_value = check_special_value_type(end_value, "end_value")
|
172
|
+
|
173
|
+
self.start_value = start_value
|
174
|
+
self.sep_value = sep_value
|
175
|
+
self.end_value = end_value
|
176
|
+
|
177
|
+
self.pad_value = pad_value
|
178
|
+
|
179
|
+
def get_config(self):
|
180
|
+
config = super().get_config()
|
181
|
+
config.update(
|
182
|
+
{
|
183
|
+
"sequence_length": self.sequence_length,
|
184
|
+
"start_value": self._start_value,
|
185
|
+
"end_value": self._end_value,
|
186
|
+
"sep_value": self._sep_value,
|
187
|
+
"pad_value": self.pad_value,
|
188
|
+
"truncate": self.truncate,
|
189
|
+
}
|
190
|
+
)
|
191
|
+
return config
|
192
|
+
|
193
|
+
def _sanitize_inputs(self, inputs):
|
194
|
+
"""Force inputs to a list of rank 2 ragged tensors."""
|
195
|
+
# Sanitize inputs.
|
196
|
+
if not isinstance(inputs, (list, tuple)):
|
197
|
+
inputs = (inputs,)
|
198
|
+
if not inputs:
|
199
|
+
raise ValueError(
|
200
|
+
"At least one input is required for packing. "
|
201
|
+
f"Received: `inputs={inputs}`"
|
202
|
+
)
|
203
|
+
inputs, unbatched_list, _ = list(
|
204
|
+
zip(*(convert_to_ragged_batch(x) for x in inputs))
|
205
|
+
)
|
206
|
+
if len(set(unbatched_list)) != 1:
|
207
|
+
ranks = [1 if unbatched else 2 for unbatched in unbatched_list]
|
208
|
+
raise ValueError(
|
209
|
+
"All inputs for packing must have the same rank. "
|
210
|
+
f"Received: `inputs={inputs}` with ranks {ranks}"
|
211
|
+
)
|
212
|
+
return inputs, unbatched_list[0]
|
213
|
+
|
214
|
+
def _trim_inputs(self, inputs):
|
215
|
+
"""Trim inputs to desired length."""
|
216
|
+
num_segments = len(inputs)
|
217
|
+
num_special_tokens = (
|
218
|
+
len(self.start_value)
|
219
|
+
+ (num_segments - 1) * len(self.sep_value)
|
220
|
+
+ len(self.end_value)
|
221
|
+
)
|
222
|
+
if self.truncate == "round_robin":
|
223
|
+
return tf_text.RoundRobinTrimmer(
|
224
|
+
self.sequence_length - num_special_tokens
|
225
|
+
).trim(inputs)
|
226
|
+
elif self.truncate == "waterfall":
|
227
|
+
return tf_text.WaterfallTrimmer(
|
228
|
+
self.sequence_length - num_special_tokens
|
229
|
+
).trim(inputs)
|
230
|
+
else:
|
231
|
+
raise ValueError("Unsupported truncate: %s" % self.truncate)
|
232
|
+
|
233
|
+
def _combine_inputs(
|
234
|
+
self,
|
235
|
+
segments,
|
236
|
+
add_start_value=True,
|
237
|
+
add_end_value=True,
|
238
|
+
):
|
239
|
+
"""Combine inputs with start and end values added."""
|
240
|
+
dtype = segments[0].dtype
|
241
|
+
batch_size = segments[0].nrows()
|
242
|
+
start_value = tf.convert_to_tensor(self.start_value, dtype=dtype)
|
243
|
+
sep_value = tf.convert_to_tensor(self.sep_value, dtype=dtype)
|
244
|
+
end_value = tf.convert_to_tensor(self.end_value, dtype=dtype)
|
245
|
+
|
246
|
+
start_columns = tf.repeat(
|
247
|
+
start_value[tf.newaxis, :], repeats=batch_size, axis=0
|
248
|
+
)
|
249
|
+
sep_columns = tf.repeat(
|
250
|
+
sep_value[tf.newaxis, :], repeats=batch_size, axis=0
|
251
|
+
)
|
252
|
+
end_columns = tf.repeat(
|
253
|
+
end_value[tf.newaxis, :], repeats=batch_size, axis=0
|
254
|
+
)
|
255
|
+
ones_sep_columns = tf.ones_like(sep_columns, dtype="int32")
|
256
|
+
ones_end_columns = tf.ones_like(end_columns, dtype="int32")
|
257
|
+
|
258
|
+
segments_to_combine = []
|
259
|
+
segment_ids_to_combine = []
|
260
|
+
if add_start_value:
|
261
|
+
segments_to_combine.append(start_columns)
|
262
|
+
start_segment = tf.zeros_like(start_columns, dtype="int32")
|
263
|
+
segment_ids_to_combine.append(start_segment)
|
264
|
+
|
265
|
+
for i, seg in enumerate(segments):
|
266
|
+
# Combine all segments.
|
267
|
+
segments_to_combine.append(seg)
|
268
|
+
|
269
|
+
# Combine segment ids.
|
270
|
+
segment_ids_to_combine.append(tf.ones_like(seg, dtype="int32") * i)
|
271
|
+
|
272
|
+
# Account for the sep/end tokens here.
|
273
|
+
if i == len(segments) - 1:
|
274
|
+
if add_end_value:
|
275
|
+
segments_to_combine.append(end_columns)
|
276
|
+
segment_ids_to_combine.append(ones_end_columns * i)
|
277
|
+
else:
|
278
|
+
segments_to_combine.append(sep_columns)
|
279
|
+
segment_ids_to_combine.append(ones_sep_columns * i)
|
280
|
+
|
281
|
+
token_ids = tf.concat(segments_to_combine, 1)
|
282
|
+
segment_ids = tf.concat(segment_ids_to_combine, 1)
|
283
|
+
return token_ids, segment_ids
|
284
|
+
|
285
|
+
def call(
|
286
|
+
self,
|
287
|
+
inputs,
|
288
|
+
sequence_length=None,
|
289
|
+
add_start_value=True,
|
290
|
+
add_end_value=True,
|
291
|
+
):
|
292
|
+
inputs, unbatched = self._sanitize_inputs(inputs)
|
293
|
+
|
294
|
+
segments = self._trim_inputs(inputs)
|
295
|
+
token_ids, segment_ids = self._combine_inputs(
|
296
|
+
segments,
|
297
|
+
add_start_value=add_start_value,
|
298
|
+
add_end_value=add_end_value,
|
299
|
+
)
|
300
|
+
# Pad to dense tensor output.
|
301
|
+
sequence_length = sequence_length or self.sequence_length
|
302
|
+
shape = tf.cast([-1, sequence_length], "int64")
|
303
|
+
token_ids = token_ids.to_tensor(
|
304
|
+
shape=shape, default_value=self.pad_value
|
305
|
+
)
|
306
|
+
segment_ids = segment_ids.to_tensor(shape=shape)
|
307
|
+
# Remove the batch dim if added.
|
308
|
+
if unbatched:
|
309
|
+
token_ids = tf.squeeze(token_ids, 0)
|
310
|
+
segment_ids = tf.squeeze(segment_ids, 0)
|
311
|
+
|
312
|
+
return (token_ids, segment_ids)
|
313
|
+
|
314
|
+
def compute_output_shape(self, inputs_shape):
|
315
|
+
if isinstance(inputs_shape[0], tuple):
|
316
|
+
inputs_shape = inputs_shape[0]
|
317
|
+
inputs_shape = list(inputs_shape)
|
318
|
+
inputs_shape[-1] = self.sequence_length
|
319
|
+
return tuple(inputs_shape)
|
@@ -0,0 +1,62 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import tree
|
17
|
+
|
18
|
+
from keras_hub.src.utils.tensor_utils import assert_tf_libs_installed
|
19
|
+
from keras_hub.src.utils.tensor_utils import (
|
20
|
+
convert_to_backend_tensor_or_python_list,
|
21
|
+
)
|
22
|
+
|
23
|
+
try:
|
24
|
+
import tensorflow as tf
|
25
|
+
except ImportError:
|
26
|
+
tf = None
|
27
|
+
|
28
|
+
|
29
|
+
class PreprocessingLayer(keras.layers.Layer):
|
30
|
+
"""Preprocessing layer base class."""
|
31
|
+
|
32
|
+
def __init__(self, **kwargs):
|
33
|
+
assert_tf_libs_installed(self.__class__.__name__)
|
34
|
+
|
35
|
+
super().__init__(**kwargs)
|
36
|
+
self._convert_input_args = False
|
37
|
+
self._allow_non_tensor_positional_args = True
|
38
|
+
# Most pre-preprocessing has no build.
|
39
|
+
if not hasattr(self, "build"):
|
40
|
+
self.built = True
|
41
|
+
|
42
|
+
def get_build_config(self):
|
43
|
+
return None
|
44
|
+
|
45
|
+
def __call__(self, *args, **kwargs):
|
46
|
+
# Always place on CPU for preprocessing, to avoid expensive back and
|
47
|
+
# forth copies to GPU before the trainable model.
|
48
|
+
with tf.device("cpu"):
|
49
|
+
outputs = super().__call__(*args, **kwargs)
|
50
|
+
|
51
|
+
# Jax and Torch lack native string and ragged types.
|
52
|
+
# If we are running on those backends and not running with tf.data
|
53
|
+
# (we are outside a tf.function), we covert all ragged and string
|
54
|
+
# tensor to pythonic types.
|
55
|
+
is_tf_backend = keras.config.backend() == "tensorflow"
|
56
|
+
is_in_tf_graph = not tf.executing_eagerly()
|
57
|
+
if not is_tf_backend and not is_in_tf_graph:
|
58
|
+
outputs = tree.map_structure(
|
59
|
+
convert_to_backend_tensor_or_python_list, outputs
|
60
|
+
)
|
61
|
+
|
62
|
+
return outputs
|
@@ -0,0 +1,271 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import random
|
16
|
+
|
17
|
+
from keras_hub.src.api_export import keras_hub_export
|
18
|
+
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
|
19
|
+
PreprocessingLayer,
|
20
|
+
)
|
21
|
+
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
|
22
|
+
from keras_hub.src.utils.tensor_utils import is_int_dtype
|
23
|
+
from keras_hub.src.utils.tensor_utils import is_string_dtype
|
24
|
+
|
25
|
+
try:
|
26
|
+
import tensorflow as tf
|
27
|
+
except ImportError:
|
28
|
+
tf = None
|
29
|
+
|
30
|
+
|
31
|
+
@keras_hub_export("keras_hub.layers.RandomDeletion")
|
32
|
+
class RandomDeletion(PreprocessingLayer):
|
33
|
+
"""Augments input by randomly deleting tokens.
|
34
|
+
|
35
|
+
This layer comes in handy when you need to generate new data using deletion
|
36
|
+
augmentation as described in the paper [EDA: Easy Data Augmentation
|
37
|
+
Techniques for Boosting Performance on Text Classification Tasks]
|
38
|
+
(https://arxiv.org/pdf/1901.11196.pdf). The layer expects the inputs to be
|
39
|
+
pre-split into token level inputs. This allows control over the level of
|
40
|
+
augmentation, you can split by character for character level swaps, or by
|
41
|
+
word for word level swaps.
|
42
|
+
|
43
|
+
Input data should be passed as tensors, `tf.RaggedTensor`s, or lists. For
|
44
|
+
batched input, inputs should be a list of lists or a rank two tensor. For
|
45
|
+
unbatched inputs, each element should be a list or a rank one tensor.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
rate: The probability of a token being chosen for deletion.
|
49
|
+
max_deletions: The maximum number of tokens to delete.
|
50
|
+
skip_list: A list of token values that should not be considered
|
51
|
+
candidates for deletion.
|
52
|
+
skip_fn: A function that takes as input a scalar tensor token and
|
53
|
+
returns as output a scalar tensor True/False value. A value of
|
54
|
+
True indicates that the token should not be considered a
|
55
|
+
candidate for deletion. This function must be tracable--it
|
56
|
+
should consist of tensorflow operations.
|
57
|
+
skip_py_fn: A function that takes as input a python token value and
|
58
|
+
returns as output `True` or `False`. A value of True
|
59
|
+
indicates that should not be considered a candidate for deletion.
|
60
|
+
Unlike the `skip_fn` argument, this argument need not be
|
61
|
+
tracable--it can be any python function.
|
62
|
+
seed: A seed for the random number generator.
|
63
|
+
|
64
|
+
Examples:
|
65
|
+
|
66
|
+
Word level usage.
|
67
|
+
>>> keras.utils.set_random_seed(1337)
|
68
|
+
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
|
69
|
+
>>> augmenter=keras_hub.layers.RandomDeletion(rate=0.4, seed=42)
|
70
|
+
>>> augmented=augmenter(inputs)
|
71
|
+
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
|
72
|
+
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'I like', b'and'],
|
73
|
+
dtype=object)>
|
74
|
+
|
75
|
+
Character level usage.
|
76
|
+
>>> keras.utils.set_random_seed(1337)
|
77
|
+
>>> inputs=tf.strings.unicode_split(["Hey Dude", "Speed Up"], "UTF-8")
|
78
|
+
>>> augmenter=keras_hub.layers.RandomDeletion(rate=0.4, seed=42)
|
79
|
+
>>> augmented=augmenter(inputs)
|
80
|
+
>>> tf.strings.reduce_join(augmented, axis=-1)
|
81
|
+
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'H Dude', b'pedUp'],
|
82
|
+
dtype=object)>
|
83
|
+
|
84
|
+
Usage with skip_list.
|
85
|
+
>>> keras.utils.set_random_seed(1337)
|
86
|
+
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
|
87
|
+
>>> augmenter=keras_hub.layers.RandomDeletion(rate=0.4,
|
88
|
+
... skip_list=["Keras", "Tensorflow"], seed=42)
|
89
|
+
>>> augmented=augmenter(inputs)
|
90
|
+
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
|
91
|
+
<tf.Tensor: shape=(2,), dtype=string,
|
92
|
+
numpy=array([b'I like', b'Keras Tensorflow'], dtype=object)>
|
93
|
+
|
94
|
+
Usage with skip_fn.
|
95
|
+
>>> def skip_fn(word):
|
96
|
+
... return tf.strings.regex_full_match(word, r"\\pP")
|
97
|
+
>>> keras.utils.set_random_seed(1337)
|
98
|
+
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
|
99
|
+
>>> augmenter=keras_hub.layers.RandomDeletion(rate=0.4,
|
100
|
+
... skip_fn=skip_fn, seed=42)
|
101
|
+
>>> augmented=augmenter(inputs)
|
102
|
+
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
|
103
|
+
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'I like', b'and'],
|
104
|
+
dtype=object)>
|
105
|
+
|
106
|
+
Usage with skip_py_fn.
|
107
|
+
>>> def skip_py_fn(word):
|
108
|
+
... return len(word) < 4
|
109
|
+
>>> keras.utils.set_random_seed(1337)
|
110
|
+
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
|
111
|
+
>>> augmenter=RandomDeletion(rate=0.4,
|
112
|
+
... skip_py_fn=skip_py_fn, seed=42)
|
113
|
+
>>> augmented=augmenter(inputs)
|
114
|
+
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
|
115
|
+
<tf.Tensor: shape=(2,), dtype=string,
|
116
|
+
numpy=array([b'Hey I', b'and Tensorflow'], dtype=object)>
|
117
|
+
"""
|
118
|
+
|
119
|
+
def __init__(
|
120
|
+
self,
|
121
|
+
rate,
|
122
|
+
max_deletions=None,
|
123
|
+
skip_list=None,
|
124
|
+
skip_fn=None,
|
125
|
+
skip_py_fn=None,
|
126
|
+
seed=None,
|
127
|
+
name=None,
|
128
|
+
dtype="int32",
|
129
|
+
**kwargs,
|
130
|
+
):
|
131
|
+
if not is_int_dtype(dtype) and not is_string_dtype(dtype):
|
132
|
+
raise ValueError(
|
133
|
+
"Output dtype must be an integer type or a string. "
|
134
|
+
f"Received: dtype={dtype}"
|
135
|
+
)
|
136
|
+
|
137
|
+
super().__init__(dtype=dtype, name=name, **kwargs)
|
138
|
+
|
139
|
+
self.rate = rate
|
140
|
+
self.max_deletions = max_deletions
|
141
|
+
self.seed = random.randint(1, 1e9) if seed is None else seed
|
142
|
+
self._generator = tf.random.Generator.from_seed(self.seed)
|
143
|
+
self.skip_list = skip_list
|
144
|
+
self.skip_fn = skip_fn
|
145
|
+
self.skip_py_fn = skip_py_fn
|
146
|
+
if self.max_deletions is not None and self.max_deletions < 0:
|
147
|
+
raise ValueError(
|
148
|
+
"max_deletions must be non-negative."
|
149
|
+
f"Received max_deletions={max_deletions}."
|
150
|
+
)
|
151
|
+
|
152
|
+
if self.rate > 1 or self.rate < 0:
|
153
|
+
raise ValueError(
|
154
|
+
"Rate must be between 0 and 1 (both inclusive)."
|
155
|
+
f"Received: rate={rate}"
|
156
|
+
)
|
157
|
+
|
158
|
+
if [self.skip_list, self.skip_fn, self.skip_py_fn].count(None) < 2:
|
159
|
+
raise ValueError(
|
160
|
+
"Exactly one of `skip_list`, `skip_fn`, `skip_py_fn` must be "
|
161
|
+
"provided."
|
162
|
+
)
|
163
|
+
|
164
|
+
if self.skip_list:
|
165
|
+
self.StaticHashTable = tf.lookup.StaticHashTable(
|
166
|
+
tf.lookup.KeyValueTensorInitializer(
|
167
|
+
tf.convert_to_tensor(self.skip_list),
|
168
|
+
tf.convert_to_tensor([True] * len(self.skip_list)),
|
169
|
+
),
|
170
|
+
default_value=False,
|
171
|
+
)
|
172
|
+
|
173
|
+
def call(self, inputs):
|
174
|
+
inputs, unbatched, _ = convert_to_ragged_batch(inputs)
|
175
|
+
|
176
|
+
skip_masks = None
|
177
|
+
if self.skip_list:
|
178
|
+
skip_masks = self.StaticHashTable.lookup(inputs.flat_values)
|
179
|
+
elif self.skip_fn:
|
180
|
+
skip_masks = tf.map_fn(
|
181
|
+
self.skip_fn, inputs.flat_values, fn_output_signature="bool"
|
182
|
+
)
|
183
|
+
elif self.skip_py_fn:
|
184
|
+
|
185
|
+
def string_fn(token):
|
186
|
+
return self.skip_py_fn(token.numpy().decode("utf-8"))
|
187
|
+
|
188
|
+
def int_fn(token):
|
189
|
+
return self.skip_py_fn(token.numpy())
|
190
|
+
|
191
|
+
py_fn = string_fn if inputs.dtype == tf.string else int_fn
|
192
|
+
|
193
|
+
skip_masks = tf.map_fn(
|
194
|
+
lambda x: tf.py_function(py_fn, [x], "bool"),
|
195
|
+
inputs.flat_values,
|
196
|
+
fn_output_signature="bool",
|
197
|
+
)
|
198
|
+
|
199
|
+
positions_flat = tf.range(tf.size(inputs.flat_values))
|
200
|
+
positions = inputs.with_flat_values(positions_flat)
|
201
|
+
if skip_masks is not None:
|
202
|
+
skip_masks = tf.logical_not(skip_masks)
|
203
|
+
skip_masks.set_shape([None])
|
204
|
+
positions = tf.ragged.boolean_mask(
|
205
|
+
positions, inputs.with_flat_values(skip_masks)
|
206
|
+
)
|
207
|
+
|
208
|
+
# Figure out how many we are going to select.
|
209
|
+
token_counts = tf.cast(positions.row_lengths(), "float32")
|
210
|
+
num_to_select = tf.random.stateless_binomial(
|
211
|
+
shape=tf.shape(token_counts),
|
212
|
+
seed=self._generator.make_seeds()[:, 0],
|
213
|
+
counts=token_counts,
|
214
|
+
probs=self.rate,
|
215
|
+
)
|
216
|
+
if self.max_deletions is not None:
|
217
|
+
num_to_select = tf.math.minimum(num_to_select, self.max_deletions)
|
218
|
+
num_to_select = tf.cast(num_to_select, "int64")
|
219
|
+
|
220
|
+
# Shuffle and trim to items that are going to be selected.
|
221
|
+
def _shuffle_and_trim(x):
|
222
|
+
positions, top_n = x
|
223
|
+
shuffled = tf.random.shuffle(positions, seed=self.seed)
|
224
|
+
return shuffled[:top_n]
|
225
|
+
|
226
|
+
selected_for_mask = tf.map_fn(
|
227
|
+
_shuffle_and_trim,
|
228
|
+
(positions, num_to_select),
|
229
|
+
fn_output_signature=tf.RaggedTensorSpec(
|
230
|
+
ragged_rank=positions.ragged_rank - 1, dtype=positions.dtype
|
231
|
+
),
|
232
|
+
)
|
233
|
+
selected_for_mask.flat_values.set_shape([None])
|
234
|
+
|
235
|
+
# Construct the mask which is a boolean RT
|
236
|
+
# Scatter 0's to positions that have been selector for deletion.
|
237
|
+
update_values = tf.zeros_like(selected_for_mask.flat_values, "int32")
|
238
|
+
update_indices = selected_for_mask.flat_values
|
239
|
+
update_indices = tf.expand_dims(update_indices, -1)
|
240
|
+
update_indices = tf.cast(update_indices, "int32")
|
241
|
+
mask_flat = tf.ones_like(inputs.flat_values, dtype="int32")
|
242
|
+
mask_flat = tf.tensor_scatter_nd_update(
|
243
|
+
mask_flat, update_indices, update_values
|
244
|
+
)
|
245
|
+
mask = tf.cast(inputs.with_flat_values(mask_flat), "bool")
|
246
|
+
|
247
|
+
inputs = tf.ragged.boolean_mask(inputs, mask)
|
248
|
+
|
249
|
+
if unbatched:
|
250
|
+
inputs = tf.squeeze(inputs, axis=0)
|
251
|
+
|
252
|
+
return inputs
|
253
|
+
|
254
|
+
def get_config(self):
|
255
|
+
config = super().get_config()
|
256
|
+
config.update(
|
257
|
+
{
|
258
|
+
"rate": self.rate,
|
259
|
+
"max_deletions": self.max_deletions,
|
260
|
+
"seed": self.seed,
|
261
|
+
"skip_list": self.skip_list,
|
262
|
+
"skip_fn": self.skip_fn,
|
263
|
+
"skip_py_fn": self.skip_py_fn,
|
264
|
+
}
|
265
|
+
)
|
266
|
+
return config
|
267
|
+
|
268
|
+
def compute_output_shape(self, inputs_shape):
|
269
|
+
inputs_shape = list(inputs_shape)
|
270
|
+
inputs_shape[-1] = None
|
271
|
+
return tuple(inputs_shape)
|