keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,267 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import random
|
16
|
+
|
17
|
+
from keras_hub.src.api_export import keras_hub_export
|
18
|
+
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
|
19
|
+
PreprocessingLayer,
|
20
|
+
)
|
21
|
+
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
|
22
|
+
from keras_hub.src.utils.tensor_utils import is_int_dtype
|
23
|
+
from keras_hub.src.utils.tensor_utils import is_string_dtype
|
24
|
+
|
25
|
+
try:
|
26
|
+
import tensorflow as tf
|
27
|
+
except ImportError:
|
28
|
+
tf = None
|
29
|
+
|
30
|
+
|
31
|
+
@keras_hub_export("keras_hub.layers.RandomSwap")
|
32
|
+
class RandomSwap(PreprocessingLayer):
|
33
|
+
"""Augments input by randomly swapping words.
|
34
|
+
|
35
|
+
This layer comes in handy when you need to generate new data using swap
|
36
|
+
augmentations as described in the paper [EDA: Easy Data Augmentation
|
37
|
+
Techniques for Boosting Performance on Text Classification Tasks]
|
38
|
+
(https://arxiv.org/pdf/1901.11196.pdf). The layer expects the inputs to be
|
39
|
+
pre-split into token level inputs. This allows control over the level of
|
40
|
+
augmentation, you can split by character for character level swaps, or by
|
41
|
+
word for word level swaps.
|
42
|
+
|
43
|
+
Input data should be passed as tensors, `tf.RaggedTensor`s, or lists. For
|
44
|
+
batched input, inputs should be a list of lists or a rank two tensor. For
|
45
|
+
unbatched inputs, each element should be a list or a rank one tensor.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
rate: The probability of a given token being chosen to be swapped
|
49
|
+
with another random token.
|
50
|
+
max_swaps: The maximum number of swaps to be performed.
|
51
|
+
skip_list: A list of token values that should not be considered
|
52
|
+
candidates for deletion.
|
53
|
+
skip_fn: A function that takes as input a scalar tensor token and
|
54
|
+
returns as output a scalar tensor True/False value. A value of
|
55
|
+
True indicates that the token should not be considered a
|
56
|
+
candidate for deletion. This function must be tracable--it
|
57
|
+
should consist of tensorflow operations.
|
58
|
+
skip_py_fn: A function that takes as input a python token value and
|
59
|
+
returns as output `True` or `False`. A value of True
|
60
|
+
indicates that should not be considered a candidate for deletion.
|
61
|
+
Unlike the `skip_fn` argument, this argument need not be
|
62
|
+
tracable--it can be any python function.
|
63
|
+
seed: A seed for the random number generator.
|
64
|
+
|
65
|
+
|
66
|
+
Examples:
|
67
|
+
|
68
|
+
Word level usage.
|
69
|
+
>>> keras.utils.set_random_seed(1337)
|
70
|
+
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
|
71
|
+
>>> augmenter=keras_hub.layers.RandomSwap(rate=0.4, seed=42)
|
72
|
+
>>> augmented=augmenter(inputs)
|
73
|
+
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
|
74
|
+
<tf.Tensor: shape=(2,), dtype=string,
|
75
|
+
numpy=array([b'like I Hey', b'and Keras Tensorflow'], dtype=object)>
|
76
|
+
|
77
|
+
Character level usage.
|
78
|
+
>>> keras.utils.set_random_seed(1337)
|
79
|
+
>>> inputs=tf.strings.unicode_split(["Hey Dude", "Speed Up"], "UTF-8")
|
80
|
+
>>> augmenter=keras_hub.layers.RandomSwap(rate=0.4, seed=42)
|
81
|
+
>>> augmented=augmenter(inputs)
|
82
|
+
>>> tf.strings.reduce_join(augmented, axis=-1)
|
83
|
+
<tf.Tensor: shape=(2,), dtype=string,
|
84
|
+
numpy=array([b'deD yuHe', b'SUede pp'], dtype=object)>
|
85
|
+
|
86
|
+
Usage with skip_list.
|
87
|
+
>>> keras.utils.set_random_seed(1337)
|
88
|
+
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
|
89
|
+
>>> augmenter=keras_hub.layers.RandomSwap(rate=0.4,
|
90
|
+
... skip_list=["Keras"], seed=42)
|
91
|
+
>>> augmented=augmenter(inputs)
|
92
|
+
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
|
93
|
+
<tf.Tensor: shape=(2,), dtype=string,
|
94
|
+
numpy=array([b'like I Hey', b'Keras and Tensorflow'], dtype=object)>
|
95
|
+
|
96
|
+
Usage with skip_fn.
|
97
|
+
>>> def skip_fn(word):
|
98
|
+
... return tf.strings.regex_full_match(word, r"[I, a].*")
|
99
|
+
>>> keras.utils.set_random_seed(1337)
|
100
|
+
>>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
|
101
|
+
>>> augmenter=keras_hub.layers.RandomSwap(rate=0.9, max_swaps=3,
|
102
|
+
... skip_fn=skip_fn, seed=11)
|
103
|
+
>>> augmented=augmenter(inputs)
|
104
|
+
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
|
105
|
+
<tf.Tensor: shape=(2,), dtype=string,
|
106
|
+
numpy=array([b'like I Hey', b'Keras and Tensorflow'], dtype=object)>
|
107
|
+
|
108
|
+
Usage with skip_py_fn.
|
109
|
+
>>> def skip_py_fn(word):
|
110
|
+
... return len(word) < 4
|
111
|
+
>>> keras.utils.set_random_seed(1337)
|
112
|
+
>>> inputs=tf.strings.split(["He was drifting along", "With the wind"])
|
113
|
+
>>> augmenter=keras_hub.layers.RandomSwap(rate=0.8, max_swaps=2,
|
114
|
+
... skip_py_fn=skip_py_fn, seed=15)
|
115
|
+
>>> augmented=augmenter(inputs)
|
116
|
+
>>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
|
117
|
+
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'He was along drifting',
|
118
|
+
b'wind the With'], dtype=object)>
|
119
|
+
"""
|
120
|
+
|
121
|
+
def __init__(
|
122
|
+
self,
|
123
|
+
rate,
|
124
|
+
max_swaps=None,
|
125
|
+
skip_list=None,
|
126
|
+
skip_fn=None,
|
127
|
+
skip_py_fn=None,
|
128
|
+
seed=None,
|
129
|
+
name=None,
|
130
|
+
dtype="int32",
|
131
|
+
**kwargs,
|
132
|
+
):
|
133
|
+
if not is_int_dtype(dtype) and not is_string_dtype(dtype):
|
134
|
+
raise ValueError(
|
135
|
+
"Output dtype must be an integer type or a string. "
|
136
|
+
f"Received: dtype={dtype}"
|
137
|
+
)
|
138
|
+
|
139
|
+
super().__init__(name=name, dtype=dtype, **kwargs)
|
140
|
+
|
141
|
+
self.rate = rate
|
142
|
+
self.max_swaps = max_swaps
|
143
|
+
self.seed = random.randint(1, 1e9) if seed is None else seed
|
144
|
+
self._generator = tf.random.Generator.from_seed(self.seed)
|
145
|
+
self.skip_list = skip_list
|
146
|
+
self.skip_fn = skip_fn
|
147
|
+
self.skip_py_fn = skip_py_fn
|
148
|
+
if self.max_swaps is not None and self.max_swaps < 0:
|
149
|
+
raise ValueError(
|
150
|
+
"max_swaps must be non-negative."
|
151
|
+
f"Received max_swaps={max_swaps}."
|
152
|
+
)
|
153
|
+
|
154
|
+
if [self.skip_list, self.skip_fn, self.skip_py_fn].count(None) < 2:
|
155
|
+
raise ValueError(
|
156
|
+
"Exactly one of skip_list, skip_fn, skip_py_fn must be "
|
157
|
+
"provided."
|
158
|
+
)
|
159
|
+
|
160
|
+
if self.skip_list:
|
161
|
+
self.StaticHashTable = tf.lookup.StaticHashTable(
|
162
|
+
tf.lookup.KeyValueTensorInitializer(
|
163
|
+
tf.convert_to_tensor(self.skip_list),
|
164
|
+
tf.convert_to_tensor([True] * len(self.skip_list)),
|
165
|
+
),
|
166
|
+
default_value=False,
|
167
|
+
)
|
168
|
+
|
169
|
+
def call(self, inputs):
|
170
|
+
inputs, unbatched, _ = convert_to_ragged_batch(inputs)
|
171
|
+
|
172
|
+
skip_masks = None
|
173
|
+
if self.skip_list:
|
174
|
+
skip_masks = self.StaticHashTable.lookup(inputs.flat_values)
|
175
|
+
elif self.skip_fn:
|
176
|
+
skip_masks = tf.map_fn(
|
177
|
+
self.skip_fn, inputs.flat_values, fn_output_signature="bool"
|
178
|
+
)
|
179
|
+
elif self.skip_py_fn:
|
180
|
+
|
181
|
+
def string_fn(token):
|
182
|
+
return self.skip_py_fn(token.numpy().decode("utf-8"))
|
183
|
+
|
184
|
+
def int_fn(token):
|
185
|
+
return self.skip_py_fn(token.numpy())
|
186
|
+
|
187
|
+
py_fn = string_fn if inputs.dtype == tf.string else int_fn
|
188
|
+
|
189
|
+
skip_masks = tf.map_fn(
|
190
|
+
lambda x: tf.py_function(py_fn, [x], "bool"),
|
191
|
+
inputs.flat_values,
|
192
|
+
fn_output_signature="bool",
|
193
|
+
)
|
194
|
+
|
195
|
+
positions = tf.ragged.range(inputs.row_lengths())
|
196
|
+
|
197
|
+
if skip_masks is not None:
|
198
|
+
skip_masks = tf.logical_not(skip_masks)
|
199
|
+
skip_masks.set_shape([None])
|
200
|
+
positions = tf.ragged.boolean_mask(
|
201
|
+
positions, inputs.with_flat_values(skip_masks)
|
202
|
+
)
|
203
|
+
# Figure out how many we are going to select.
|
204
|
+
token_counts = tf.cast(positions.row_lengths(), "float32")
|
205
|
+
num_to_select = tf.random.stateless_binomial(
|
206
|
+
shape=tf.shape(token_counts),
|
207
|
+
seed=self._generator.make_seeds()[:, 0],
|
208
|
+
counts=token_counts,
|
209
|
+
probs=self.rate,
|
210
|
+
)
|
211
|
+
if self.max_swaps is not None:
|
212
|
+
num_to_select = tf.math.minimum(num_to_select, self.max_swaps)
|
213
|
+
num_to_select = tf.math.minimum(
|
214
|
+
num_to_select, tf.cast(positions.row_lengths(), "int32")
|
215
|
+
)
|
216
|
+
num_to_select = tf.cast(num_to_select, "int64")
|
217
|
+
|
218
|
+
def _swap(x):
|
219
|
+
positions, inputs, num_to_select = x
|
220
|
+
for _ in range(num_to_select):
|
221
|
+
index = tf.random.stateless_uniform(
|
222
|
+
shape=[2],
|
223
|
+
minval=0,
|
224
|
+
maxval=tf.size(positions),
|
225
|
+
dtype="int32",
|
226
|
+
seed=self._generator.make_seeds()[:, 0],
|
227
|
+
)
|
228
|
+
index1, index2 = positions[index[0]], positions[index[1]]
|
229
|
+
# swap items at the sampled indices with each other
|
230
|
+
inputs = tf.tensor_scatter_nd_update(
|
231
|
+
inputs,
|
232
|
+
[[index1], [index2]],
|
233
|
+
[inputs[index2], inputs[index1]],
|
234
|
+
)
|
235
|
+
return inputs
|
236
|
+
|
237
|
+
swapped = tf.map_fn(
|
238
|
+
_swap,
|
239
|
+
(positions, inputs, num_to_select),
|
240
|
+
fn_output_signature=tf.RaggedTensorSpec(
|
241
|
+
ragged_rank=positions.ragged_rank - 1, dtype=inputs.dtype
|
242
|
+
),
|
243
|
+
)
|
244
|
+
swapped.flat_values.set_shape([None])
|
245
|
+
|
246
|
+
if unbatched:
|
247
|
+
swapped = tf.squeeze(swapped, axis=0)
|
248
|
+
return swapped
|
249
|
+
|
250
|
+
def get_config(self):
|
251
|
+
config = super().get_config()
|
252
|
+
config.update(
|
253
|
+
{
|
254
|
+
"rate": self.rate,
|
255
|
+
"max_swaps": self.max_swaps,
|
256
|
+
"seed": self.seed,
|
257
|
+
"skip_list": self.skip_list,
|
258
|
+
"skip_fn": self.skip_fn,
|
259
|
+
"skip_py_fn": self.skip_py_fn,
|
260
|
+
}
|
261
|
+
)
|
262
|
+
return config
|
263
|
+
|
264
|
+
def compute_output_shape(self, inputs_shape):
|
265
|
+
inputs_shape = list(inputs_shape)
|
266
|
+
inputs_shape[-1] = None
|
267
|
+
return tuple(inputs_shape)
|
@@ -0,0 +1,219 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from keras_hub.src.api_export import keras_hub_export
|
17
|
+
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
|
18
|
+
PreprocessingLayer,
|
19
|
+
)
|
20
|
+
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
|
21
|
+
|
22
|
+
try:
|
23
|
+
import tensorflow as tf
|
24
|
+
except ImportError:
|
25
|
+
tf = None
|
26
|
+
|
27
|
+
|
28
|
+
@keras_hub_export("keras_hub.layers.StartEndPacker")
|
29
|
+
class StartEndPacker(PreprocessingLayer):
|
30
|
+
"""Adds start and end tokens to a sequence and pads to a fixed length.
|
31
|
+
|
32
|
+
This layer is useful when tokenizing inputs for tasks like translation,
|
33
|
+
where each sequence should include a start and end marker. It should
|
34
|
+
be called after tokenization. The layer will first trim inputs to fit, then
|
35
|
+
add start/end tokens, and finally pad, if necessary, to `sequence_length`.
|
36
|
+
|
37
|
+
Input data should be passed as tensors, `tf.RaggedTensor`s, or lists. For
|
38
|
+
batched input, inputs should be a list of lists or a rank two tensor. For
|
39
|
+
unbatched inputs, each element should be a list or a rank one tensor.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
sequence_length: int. The desired output length.
|
43
|
+
start_value: int/str/list/tuple. The ID(s) or token(s) that are to be
|
44
|
+
placed at the start of each sequence. The dtype must match the dtype
|
45
|
+
of the input tensors to the layer. If `None`, no start value will be
|
46
|
+
added.
|
47
|
+
end_value: int/str/list/tuple. The ID(s) or token(s) that are to be
|
48
|
+
placed at the end of each input segment. The dtype must match the
|
49
|
+
dtype of the input tensors to the layer. If `None`, no end value
|
50
|
+
will be added.
|
51
|
+
pad_value: int/str. The ID or token that is to be placed into the
|
52
|
+
unused positions after the last segment in the sequence. If `None`,
|
53
|
+
0 or "" will be added depending on the dtype of the input tensor.
|
54
|
+
return_padding_mask: bool. Whether to return a boolean padding mask of
|
55
|
+
all locations that are filled in with the `pad_value`.
|
56
|
+
|
57
|
+
Call arguments:
|
58
|
+
inputs: A `tf.Tensor`, `tf.RaggedTensor`, or list of python strings.
|
59
|
+
sequence_length: Pass to override the configured `sequence_length` of
|
60
|
+
the layer.
|
61
|
+
add_start_value: Pass `False` to not append a start value for this
|
62
|
+
input.
|
63
|
+
add_end_value: Pass `False` to not append an end value for this
|
64
|
+
input.
|
65
|
+
|
66
|
+
Examples:
|
67
|
+
|
68
|
+
Unbatched input (int).
|
69
|
+
>>> inputs = [5, 6, 7]
|
70
|
+
>>> start_end_packer = keras_hub.layers.StartEndPacker(
|
71
|
+
... sequence_length=7, start_value=1, end_value=2,
|
72
|
+
... )
|
73
|
+
>>> outputs = start_end_packer(inputs)
|
74
|
+
>>> np.array(outputs)
|
75
|
+
array([1, 5, 6, 7, 2, 0, 0], dtype=int32)
|
76
|
+
|
77
|
+
Batched input (int).
|
78
|
+
>>> inputs = [[5, 6, 7], [8, 9, 10, 11, 12, 13, 14]]
|
79
|
+
>>> start_end_packer = keras_hub.layers.StartEndPacker(
|
80
|
+
... sequence_length=6, start_value=1, end_value=2,
|
81
|
+
... )
|
82
|
+
>>> outputs = start_end_packer(inputs)
|
83
|
+
>>> np.array(outputs)
|
84
|
+
array([[ 1, 5, 6, 7, 2, 0],
|
85
|
+
[ 1, 8, 9, 10, 11, 2]], dtype=int32)
|
86
|
+
|
87
|
+
Unbatched input (str).
|
88
|
+
>>> inputs = tf.constant(["this", "is", "fun"])
|
89
|
+
>>> start_end_packer = keras_hub.layers.StartEndPacker(
|
90
|
+
... sequence_length=6, start_value="<s>", end_value="</s>",
|
91
|
+
... pad_value="<pad>"
|
92
|
+
... )
|
93
|
+
>>> outputs = start_end_packer(inputs)
|
94
|
+
>>> np.array(outputs).astype("U")
|
95
|
+
array(['<s>', 'this', 'is', 'fun', '</s>', '<pad>'], dtype='<U5')
|
96
|
+
|
97
|
+
Batched input (str).
|
98
|
+
>>> inputs = tf.ragged.constant([["this", "is", "fun"], ["awesome"]])
|
99
|
+
>>> start_end_packer = keras_hub.layers.StartEndPacker(
|
100
|
+
... sequence_length=6, start_value="<s>", end_value="</s>",
|
101
|
+
... pad_value="<pad>"
|
102
|
+
... )
|
103
|
+
>>> outputs = start_end_packer(inputs)
|
104
|
+
>>> np.array(outputs).astype("U")
|
105
|
+
array([['<s>', 'this', 'is', 'fun', '</s>', '<pad>'],
|
106
|
+
['<s>', 'awesome', '</s>', '<pad>', '<pad>', '<pad>']], dtype='<U7')
|
107
|
+
|
108
|
+
Multiple start tokens.
|
109
|
+
>>> inputs = tf.ragged.constant([["this", "is", "fun"], ["awesome"]])
|
110
|
+
>>> start_end_packer = keras_hub.layers.StartEndPacker(
|
111
|
+
... sequence_length=6, start_value=["</s>", "<s>"], end_value="</s>",
|
112
|
+
... pad_value="<pad>"
|
113
|
+
... )
|
114
|
+
>>> outputs = start_end_packer(inputs)
|
115
|
+
>>> np.array(outputs).astype("U")
|
116
|
+
array([['</s>', '<s>', 'this', 'is', 'fun', '</s>'],
|
117
|
+
['</s>', '<s>', 'awesome', '</s>', '<pad>', '<pad>']], dtype='<U7')
|
118
|
+
"""
|
119
|
+
|
120
|
+
def __init__(
|
121
|
+
self,
|
122
|
+
sequence_length,
|
123
|
+
start_value=None,
|
124
|
+
end_value=None,
|
125
|
+
pad_value=None,
|
126
|
+
return_padding_mask=False,
|
127
|
+
name=None,
|
128
|
+
**kwargs,
|
129
|
+
):
|
130
|
+
super().__init__(name=name, **kwargs)
|
131
|
+
|
132
|
+
self.sequence_length = sequence_length
|
133
|
+
|
134
|
+
# Maintain private copies for config purposes.
|
135
|
+
self._start_value = start_value
|
136
|
+
self._end_value = end_value
|
137
|
+
|
138
|
+
def check_special_value_type(value, value_name):
|
139
|
+
if isinstance(value, (int, str)):
|
140
|
+
return [value]
|
141
|
+
if value and not isinstance(value, (list, tuple)):
|
142
|
+
raise ValueError(
|
143
|
+
f"{value_name} should be of type int/str/list/tuple."
|
144
|
+
f"Received type: `{type(value)}`."
|
145
|
+
)
|
146
|
+
return value
|
147
|
+
|
148
|
+
start_value = check_special_value_type(start_value, "start_value")
|
149
|
+
end_value = check_special_value_type(end_value, "end_value")
|
150
|
+
|
151
|
+
self.start_value = start_value
|
152
|
+
self.end_value = end_value
|
153
|
+
|
154
|
+
self.pad_value = pad_value
|
155
|
+
self.return_padding_mask = return_padding_mask
|
156
|
+
|
157
|
+
def call(
|
158
|
+
self,
|
159
|
+
inputs,
|
160
|
+
sequence_length=None,
|
161
|
+
add_start_value=True,
|
162
|
+
add_end_value=True,
|
163
|
+
):
|
164
|
+
inputs, unbatched, _ = convert_to_ragged_batch(inputs)
|
165
|
+
|
166
|
+
x = inputs # Intermediate result.
|
167
|
+
|
168
|
+
batch_size = tf.shape(x)[0]
|
169
|
+
sequence_length = sequence_length or self.sequence_length
|
170
|
+
dtype = inputs.dtype
|
171
|
+
|
172
|
+
# Concatenate start and end tokens.
|
173
|
+
if add_start_value and self.start_value is not None:
|
174
|
+
start_value = tf.convert_to_tensor(self.start_value, dtype=dtype)
|
175
|
+
start_token_id_tensor = tf.repeat(
|
176
|
+
start_value[tf.newaxis, :], repeats=batch_size, axis=0
|
177
|
+
)
|
178
|
+
x = tf.concat([start_token_id_tensor, x], axis=-1)
|
179
|
+
if add_end_value and self.end_value is not None:
|
180
|
+
end_value = tf.convert_to_tensor(self.end_value, dtype=dtype)
|
181
|
+
end_token_id_tensor = tf.repeat(
|
182
|
+
end_value[tf.newaxis, :], repeats=batch_size, axis=0
|
183
|
+
)
|
184
|
+
# Trim to leave room for end token.
|
185
|
+
x = x[..., : sequence_length - len(self.end_value)]
|
186
|
+
x = tf.concat([x, end_token_id_tensor], axis=-1)
|
187
|
+
|
188
|
+
# Pad to desired length.
|
189
|
+
outputs = x.to_tensor(
|
190
|
+
default_value=self.pad_value,
|
191
|
+
shape=(batch_size, sequence_length),
|
192
|
+
)
|
193
|
+
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs
|
194
|
+
|
195
|
+
if self.return_padding_mask:
|
196
|
+
mask = tf.ones_like(x, dtype="bool")
|
197
|
+
mask = mask.to_tensor(shape=(batch_size, sequence_length))
|
198
|
+
mask = tf.squeeze(mask, axis=0) if unbatched else mask
|
199
|
+
return outputs, mask
|
200
|
+
|
201
|
+
return outputs
|
202
|
+
|
203
|
+
def get_config(self):
|
204
|
+
config = super().get_config()
|
205
|
+
config.update(
|
206
|
+
{
|
207
|
+
"sequence_length": self.sequence_length,
|
208
|
+
"start_value": self._start_value,
|
209
|
+
"end_value": self._end_value,
|
210
|
+
"pad_value": self.pad_value,
|
211
|
+
"return_padding_mask": self.return_padding_mask,
|
212
|
+
}
|
213
|
+
)
|
214
|
+
return config
|
215
|
+
|
216
|
+
def compute_output_shape(self, inputs_shape):
|
217
|
+
inputs_shape = list(inputs_shape)
|
218
|
+
inputs_shape[-1] = self.sequence_length
|
219
|
+
return tuple(inputs_shape)
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|