keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,176 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from keras_hub.src.api_export import keras_hub_export
|
17
|
+
from keras_hub.src.tokenizers.word_piece_tokenizer import pretokenize
|
18
|
+
|
19
|
+
try:
|
20
|
+
import tensorflow as tf
|
21
|
+
from tensorflow_text.tools.wordpiece_vocab import (
|
22
|
+
wordpiece_tokenizer_learner_lib as learner,
|
23
|
+
)
|
24
|
+
except ImportError:
|
25
|
+
tf = None
|
26
|
+
learner = None
|
27
|
+
|
28
|
+
|
29
|
+
@keras_hub_export("keras_hub.tokenizers.compute_word_piece_vocabulary")
|
30
|
+
def compute_word_piece_vocabulary(
|
31
|
+
data,
|
32
|
+
vocabulary_size,
|
33
|
+
vocabulary_output_file=None,
|
34
|
+
lowercase=False,
|
35
|
+
strip_accents=False,
|
36
|
+
split=True,
|
37
|
+
split_on_cjk=True,
|
38
|
+
suffix_indicator="##",
|
39
|
+
reserved_tokens=["[PAD]", "[CLS]", "[SEP]", "[UNK]", "[MASK]"],
|
40
|
+
):
|
41
|
+
r"""A utility to train a WordPiece vocabulary.
|
42
|
+
|
43
|
+
Trains a WordPiece vocabulary from an input dataset or a list of filenames.
|
44
|
+
|
45
|
+
For custom data loading and pretokenization (`split=False`), the input
|
46
|
+
`data` should be a `tf.data.Dataset`. If `data` is a list of filenames,
|
47
|
+
the file format is required to be plain text files, and the text would be
|
48
|
+
read in line by line during training.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
data: A `tf.data.Dataset`, or a list of filenames.
|
52
|
+
vocabulary_size: int. The maximum size of a vocabulary to be trained.
|
53
|
+
vocabulary_output_file: str. The location to write a
|
54
|
+
vocabulary file. defaults to `None`.
|
55
|
+
lowercase: bool. If `True`, the input text will be
|
56
|
+
lowercased before tokenization. Defaults to `False`.
|
57
|
+
strip_accents: bool. If `True`, all accent marks will
|
58
|
+
be removed from text before tokenization. Defaults to `False`.
|
59
|
+
split: bool. If `True`, input will be split on
|
60
|
+
whitespace and punctuation marks, and all punctuation marks will be
|
61
|
+
kept as tokens. If `False`, input should be split ("pre-tokenized")
|
62
|
+
before calling the tokenizer, and passed as a dense or ragged tensor
|
63
|
+
of whole words. `split` is required to be `True` when `data` is a
|
64
|
+
list of filenames. Defaults to `True`.
|
65
|
+
split_on_cjk: bool. If `True`, input will be split
|
66
|
+
on CJK characters, i.e., Chinese, Japanese, Korean and Vietnamese
|
67
|
+
characters (https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)).
|
68
|
+
Note that this is applicable only when `split` is `True`.
|
69
|
+
Defaults to `True`.
|
70
|
+
suffix_indicator: str. The characters prepended to a
|
71
|
+
WordPiece to indicate that it is a suffix to another subword.
|
72
|
+
E.g. `"##ing"`. Defaults to `"##"`.
|
73
|
+
reserved_tokens: list of strings. A list of tokens that must be included in the vocabulary.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
Returns a list of vocabulary terms.
|
77
|
+
|
78
|
+
Examples:
|
79
|
+
|
80
|
+
Basic Usage (from Dataset).
|
81
|
+
>>> inputs = tf.data.Dataset.from_tensor_slices(["bat sat pat mat rat"])
|
82
|
+
>>> vocab = compute_word_piece_vocabulary(inputs, 13)
|
83
|
+
>>> vocab
|
84
|
+
['[PAD]', '[CLS]', '[SEP]', '[UNK]', '[MASK]', 'a', 'b', 'm', 'p', 'r', 's', 't', '##at']
|
85
|
+
>>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(vocabulary=vocab, oov_token="[UNK]")
|
86
|
+
>>> outputs = inputs.map(tokenizer.tokenize)
|
87
|
+
>>> for x in outputs:
|
88
|
+
... print(x)
|
89
|
+
tf.Tensor([ 6 12 10 12 8 12 7 12 9 12], shape=(10,), dtype=int32)
|
90
|
+
|
91
|
+
Basic Usage (from filenames).
|
92
|
+
```python
|
93
|
+
with open("test.txt", "w+") as f:
|
94
|
+
f.write("bat sat pat mat rat\n")
|
95
|
+
inputs = ["test.txt"]
|
96
|
+
vocab = keras_hub.tokenizers.compute_word_piece_vocabulary(inputs, 13)
|
97
|
+
```
|
98
|
+
|
99
|
+
Custom Split Usage (from Dataset).
|
100
|
+
>>> def normalize_and_split(x):
|
101
|
+
... "Strip punctuation and split on whitespace."
|
102
|
+
... x = tf.strings.regex_replace(x, r"\p{P}", "")
|
103
|
+
... return tf.strings.split(x)
|
104
|
+
>>> inputs = tf.data.Dataset.from_tensor_slices(["bat sat: pat mat rat.\n"])
|
105
|
+
>>> split_inputs = inputs.map(normalize_and_split)
|
106
|
+
>>> vocab = compute_word_piece_vocabulary(
|
107
|
+
... split_inputs, 13, split=False,
|
108
|
+
... )
|
109
|
+
>>> vocab
|
110
|
+
['[PAD]', '[CLS]', '[SEP]', '[UNK]', '[MASK]', 'a', 'b', 'm', 'p', 'r', 's', 't', '##at']
|
111
|
+
>>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(vocabulary=vocab)
|
112
|
+
>>> inputs.map(tokenizer.tokenize)
|
113
|
+
|
114
|
+
Custom Split Usage (from filenames).
|
115
|
+
```python
|
116
|
+
def normalize_and_split(x):
|
117
|
+
"Strip punctuation and split on whitespace."
|
118
|
+
x = tf.strings.regex_replace(x, r"\p{P}", "")
|
119
|
+
return tf.strings.split(x)
|
120
|
+
with open("test.txt", "w+") as f:
|
121
|
+
f.write("bat sat: pat mat rat.\n")
|
122
|
+
inputs = tf.data.TextLineDataset(["test.txt"])
|
123
|
+
split_inputs = inputs.map(normalize_and_split)
|
124
|
+
vocab = keras_hub.tokenizers.compute_word_piece_vocabulary(
|
125
|
+
split_inputs, 13, split=False
|
126
|
+
)
|
127
|
+
tokenizer = keras_hub.tokenizers.WordPieceTokenizer(vocabulary=vocab)
|
128
|
+
inputs.map(tokenizer.tokenize)
|
129
|
+
```
|
130
|
+
"""
|
131
|
+
# Read data files.
|
132
|
+
if not isinstance(data, (list, tf.data.Dataset)):
|
133
|
+
raise ValueError(
|
134
|
+
"The `data` argument must be either `tf.data.Dataset` or `list`. "
|
135
|
+
f"Received: {type(data)}."
|
136
|
+
)
|
137
|
+
if isinstance(data, list):
|
138
|
+
# Processing list of file paths.
|
139
|
+
if not split:
|
140
|
+
raise ValueError(
|
141
|
+
"When learning a vocab from files, `split` must be `True`. "
|
142
|
+
"To compute a vocabulary with custom split rules, load your "
|
143
|
+
"data as a dataset, split it, and pass it to "
|
144
|
+
"`compute_word_piece_vocabulary()` with split=False."
|
145
|
+
)
|
146
|
+
path_ds = tf.data.Dataset.from_tensor_slices(data)
|
147
|
+
# Uses map to read filepaths.
|
148
|
+
data = path_ds.map(
|
149
|
+
lambda path: tf.io.read_file(path),
|
150
|
+
num_parallel_calls=tf.data.AUTOTUNE,
|
151
|
+
)
|
152
|
+
|
153
|
+
words_data = data.map(
|
154
|
+
lambda text: pretokenize(
|
155
|
+
text, lowercase, strip_accents, split, split_on_cjk
|
156
|
+
),
|
157
|
+
num_parallel_calls=tf.data.AUTOTUNE,
|
158
|
+
)
|
159
|
+
word_counts = learner.count_words(words_data)
|
160
|
+
# Train tokenizer.
|
161
|
+
vocab = learner.learn(
|
162
|
+
word_counts,
|
163
|
+
vocab_size=vocabulary_size,
|
164
|
+
reserved_tokens=reserved_tokens,
|
165
|
+
include_joiner_token=True,
|
166
|
+
joiner=suffix_indicator,
|
167
|
+
)
|
168
|
+
if len(vocab) > vocabulary_size:
|
169
|
+
vocab = vocab[:vocabulary_size]
|
170
|
+
if vocabulary_output_file is not None:
|
171
|
+
vocab_text = "".join([line + "\n" for line in vocab])
|
172
|
+
# Write vocab to file.
|
173
|
+
with open(vocabulary_output_file, "w", encoding="utf-8") as vocab_file:
|
174
|
+
vocab_file.write(vocab_text)
|
175
|
+
else:
|
176
|
+
return vocab
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,130 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import sys
|
16
|
+
|
17
|
+
import keras
|
18
|
+
from absl import logging
|
19
|
+
from packaging.version import parse
|
20
|
+
|
21
|
+
from keras_hub.src.utils.tensor_utils import is_tensor_type
|
22
|
+
|
23
|
+
try:
|
24
|
+
import tensorflow as tf
|
25
|
+
except ImportError:
|
26
|
+
tf = None
|
27
|
+
|
28
|
+
|
29
|
+
def clone_initializer(initializer):
|
30
|
+
"""Clones an initializer to ensure a new seed.
|
31
|
+
|
32
|
+
As of tensorflow 2.10, we need to clone user passed initializers when
|
33
|
+
invoking them twice to avoid creating the same randomized initialization.
|
34
|
+
"""
|
35
|
+
# If we get a string or dict, just return as we cannot and should not clone.
|
36
|
+
if not isinstance(initializer, keras.initializers.Initializer):
|
37
|
+
return initializer
|
38
|
+
config = initializer.get_config()
|
39
|
+
return initializer.__class__.from_config(config)
|
40
|
+
|
41
|
+
|
42
|
+
def convert_inputs_to_list_of_tensor_segments(x):
|
43
|
+
"""Converts user inputs to a list of a tensor segments.
|
44
|
+
|
45
|
+
For models and layers which accept lists of string tensors to pack together,
|
46
|
+
this method converts user inputs to a uniform format in a way that can be
|
47
|
+
considered canonical for the library.
|
48
|
+
|
49
|
+
We handle the following:
|
50
|
+
|
51
|
+
- A single string will be converted to a tensor and wrapped in a list.
|
52
|
+
- A list of strings will be converted to a tensor and wrapped in a list.
|
53
|
+
- A single tensor will be wrapped in a list.
|
54
|
+
- A list of tensors will be passed through unaltered.
|
55
|
+
|
56
|
+
All other inputs will result in an error. This effectively means that users
|
57
|
+
who would like to pack multiple segments together should convert those
|
58
|
+
segments to tensors before calling the layer. This removes any ambiguity
|
59
|
+
in the input for those cases.
|
60
|
+
"""
|
61
|
+
# Check the input type.
|
62
|
+
is_string = isinstance(x, (str, bytes))
|
63
|
+
is_tensor = is_tensor_type(x)
|
64
|
+
is_string_list = (
|
65
|
+
isinstance(x, (list, tuple)) and x and isinstance(x[0], (str, bytes))
|
66
|
+
)
|
67
|
+
is_tensor_list = isinstance(x, (list, tuple)) and x and is_tensor_type(x[0])
|
68
|
+
|
69
|
+
if is_string or is_string_list:
|
70
|
+
# Automatically convert raw strings or string lists to tensors.
|
71
|
+
# Wrap this input as a single (possibly batched) segment.
|
72
|
+
x = [tf.convert_to_tensor(x)]
|
73
|
+
elif is_tensor:
|
74
|
+
# Automatically wrap a single tensor as a single segment.
|
75
|
+
x = [x]
|
76
|
+
elif is_tensor_list:
|
77
|
+
# Pass lists of tensors though unaltered.
|
78
|
+
x = x
|
79
|
+
else:
|
80
|
+
# Error for all other input.
|
81
|
+
raise ValueError(
|
82
|
+
f"Unsupported input for `x`. `x` should be a string, a list of "
|
83
|
+
"strings, or a list of tensors. If passing multiple segments "
|
84
|
+
"which should packed together, please convert your inputs to a "
|
85
|
+
f"list of tensors. Received `x={x}`"
|
86
|
+
)
|
87
|
+
return x
|
88
|
+
|
89
|
+
|
90
|
+
def print_msg(message, line_break=True):
|
91
|
+
"""Print the message to absl logging or stdout."""
|
92
|
+
# Copied from core Keras.
|
93
|
+
if keras.utils.is_interactive_logging_enabled():
|
94
|
+
if line_break:
|
95
|
+
sys.stdout.write(message + "\n")
|
96
|
+
else:
|
97
|
+
sys.stdout.write(message)
|
98
|
+
sys.stdout.flush()
|
99
|
+
else:
|
100
|
+
logging.info(message)
|
101
|
+
|
102
|
+
|
103
|
+
@keras.saving.register_keras_serializable(package="keras_hub")
|
104
|
+
def gelu_approximate(x):
|
105
|
+
return keras.activations.gelu(x, approximate=True)
|
106
|
+
|
107
|
+
|
108
|
+
def has_quantization_support():
|
109
|
+
return False if parse(keras.version()) < parse("3.4.0") else True
|
110
|
+
|
111
|
+
|
112
|
+
def assert_quantization_support():
|
113
|
+
if not has_quantization_support():
|
114
|
+
raise ValueError(
|
115
|
+
"Quantization API requires Keras >= 3.4.0 to function "
|
116
|
+
f"correctly. Received: '{keras.version()}'"
|
117
|
+
)
|
118
|
+
|
119
|
+
|
120
|
+
def standardize_data_format(data_format):
|
121
|
+
if data_format is None:
|
122
|
+
return keras.config.image_data_format()
|
123
|
+
data_format = str(data_format).lower()
|
124
|
+
if data_format not in {"channels_first", "channels_last"}:
|
125
|
+
raise ValueError(
|
126
|
+
"The `data_format` argument must be one of "
|
127
|
+
"{'channels_first', 'channels_last'}. "
|
128
|
+
f"Received: data_format={data_format}"
|
129
|
+
)
|
130
|
+
return data_format
|
@@ -0,0 +1,293 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import functools
|
16
|
+
import math
|
17
|
+
|
18
|
+
import keras
|
19
|
+
from keras import ops
|
20
|
+
from keras import tree
|
21
|
+
|
22
|
+
from keras_hub.src.utils.tensor_utils import is_tensor_type
|
23
|
+
|
24
|
+
try:
|
25
|
+
import tensorflow as tf
|
26
|
+
except ImportError:
|
27
|
+
tf = None
|
28
|
+
|
29
|
+
|
30
|
+
def _convert_inputs_to_dataset(
|
31
|
+
x=None,
|
32
|
+
y=None,
|
33
|
+
sample_weight=None,
|
34
|
+
batch_size=None,
|
35
|
+
):
|
36
|
+
"""Convert inputs to a `tf.data.Dataset`.
|
37
|
+
|
38
|
+
This is a stand in for the `TensorLikeDataAdapter` in core Keras.
|
39
|
+
"""
|
40
|
+
if isinstance(x, tf.data.Dataset):
|
41
|
+
if y is not None:
|
42
|
+
raise ValueError(
|
43
|
+
"When `x` is a `tf.data.Dataset`, please do not provide "
|
44
|
+
f"`y`. Received: `type(y)={type(y)}`."
|
45
|
+
)
|
46
|
+
if sample_weight is not None:
|
47
|
+
raise ValueError(
|
48
|
+
"When `x` is a `tf.data.Dataset`, please do not provide "
|
49
|
+
"`sample_weight`. Received: "
|
50
|
+
f"`type(sample_weight)={type(sample_weight)}`."
|
51
|
+
)
|
52
|
+
if batch_size is not None:
|
53
|
+
raise ValueError(
|
54
|
+
"When `x` is a `tf.data.Dataset`, please do not provide "
|
55
|
+
"`batch_size`. Received: "
|
56
|
+
f"`type(batch_size)={type(batch_size)}`."
|
57
|
+
)
|
58
|
+
return x
|
59
|
+
|
60
|
+
inputs = keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
61
|
+
try:
|
62
|
+
|
63
|
+
def convert(x):
|
64
|
+
if isinstance(x, (tf.Tensor, tf.RaggedTensor)):
|
65
|
+
return x
|
66
|
+
if hasattr(x, "__array__"):
|
67
|
+
return ops.convert_to_numpy(x)
|
68
|
+
return x
|
69
|
+
|
70
|
+
inputs = tree.map_structure(convert, inputs)
|
71
|
+
ds = tf.data.Dataset.from_tensor_slices(inputs)
|
72
|
+
except ValueError as e:
|
73
|
+
# If our inputs are unbatched, re-raise with a more friendly error
|
74
|
+
# message the default from tf.data. We expect this to come up with
|
75
|
+
# some frequency, so it's important to have a good sign post here.
|
76
|
+
if "only supported for rank >= 1" in str(e):
|
77
|
+
raise ValueError(
|
78
|
+
"`x`, `y`, and `sample_weight` must have a batch dimension "
|
79
|
+
"when calling `fit()`, `evaluate()`, and `predict()`. Received "
|
80
|
+
"an input with rank 0. Please add an outer dimension to your "
|
81
|
+
"input, e.g., wrap it in a list."
|
82
|
+
) from e
|
83
|
+
raise e
|
84
|
+
|
85
|
+
return ds.batch(batch_size or 32)
|
86
|
+
|
87
|
+
|
88
|
+
def _train_validation_split(arrays, validation_split):
|
89
|
+
"""Split arrays into train and validation subsets in deterministic order.
|
90
|
+
|
91
|
+
This is copied directly from core Keras.
|
92
|
+
"""
|
93
|
+
|
94
|
+
def _can_split(t):
|
95
|
+
return is_tensor_type(t) or t is None
|
96
|
+
|
97
|
+
flat_arrays = tree.flatten(arrays)
|
98
|
+
unsplitable = [type(t) for t in flat_arrays if not _can_split(t)]
|
99
|
+
if unsplitable:
|
100
|
+
raise ValueError(
|
101
|
+
"`validation_split` is only supported for Tensors or NumPy "
|
102
|
+
"arrays, found following types in the input: {}".format(unsplitable)
|
103
|
+
)
|
104
|
+
|
105
|
+
if all(t is None for t in flat_arrays):
|
106
|
+
return arrays, arrays
|
107
|
+
|
108
|
+
first_non_none = None
|
109
|
+
for t in flat_arrays:
|
110
|
+
if t is not None:
|
111
|
+
first_non_none = t
|
112
|
+
break
|
113
|
+
|
114
|
+
# Assumes all arrays have the same batch shape or are `None`.
|
115
|
+
batch_dim = int(first_non_none.shape[0])
|
116
|
+
split_at = int(math.floor(batch_dim * (1.0 - validation_split)))
|
117
|
+
|
118
|
+
if split_at == 0 or split_at == batch_dim:
|
119
|
+
raise ValueError(
|
120
|
+
"Training data contains {batch_dim} samples, which is not "
|
121
|
+
"sufficient to split it into a validation and training set as "
|
122
|
+
"specified by `validation_split={validation_split}`. Either "
|
123
|
+
"provide more data, or a different value for the "
|
124
|
+
"`validation_split` argument.".format(
|
125
|
+
batch_dim=batch_dim, validation_split=validation_split
|
126
|
+
)
|
127
|
+
)
|
128
|
+
|
129
|
+
def _split(t, start, end):
|
130
|
+
if t is None:
|
131
|
+
return t
|
132
|
+
return t[start:end]
|
133
|
+
|
134
|
+
train_arrays = tree.map_structure(
|
135
|
+
functools.partial(_split, start=0, end=split_at), arrays
|
136
|
+
)
|
137
|
+
val_arrays = tree.map_structure(
|
138
|
+
functools.partial(_split, start=split_at, end=batch_dim), arrays
|
139
|
+
)
|
140
|
+
|
141
|
+
return train_arrays, val_arrays
|
142
|
+
|
143
|
+
|
144
|
+
@keras.saving.register_keras_serializable(package="keras_hub")
|
145
|
+
class PipelineModel(keras.Model):
|
146
|
+
"""A model which allows automatically applying preprocessing."""
|
147
|
+
|
148
|
+
def __init__(self, *args, **kwargs):
|
149
|
+
# Workaround for https://github.com/keras-team/keras/issues/17270
|
150
|
+
# Reset any attempt to overwrite this classes base class to this class
|
151
|
+
# can continue to be used for functional and non-functional models.
|
152
|
+
PipelineModel.__bases__ = (keras.Model,)
|
153
|
+
super().__init__(*args, **kwargs)
|
154
|
+
|
155
|
+
def preprocess_samples(self, x, y=None, sample_weight=None):
|
156
|
+
"""An overridable function which preprocesses entire samples."""
|
157
|
+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
158
|
+
|
159
|
+
# ========================================================================
|
160
|
+
# Below are overrides to keras.Model methods to apply the functions above.
|
161
|
+
# ========================================================================
|
162
|
+
def fit(
|
163
|
+
self,
|
164
|
+
x=None,
|
165
|
+
y=None,
|
166
|
+
batch_size=None,
|
167
|
+
sample_weight=None,
|
168
|
+
validation_data=None,
|
169
|
+
validation_split=None,
|
170
|
+
**kwargs,
|
171
|
+
):
|
172
|
+
if validation_split and validation_data is None:
|
173
|
+
(x, y, sample_weight), validation_data = _train_validation_split(
|
174
|
+
(x, y, sample_weight), validation_split=validation_split
|
175
|
+
)
|
176
|
+
|
177
|
+
x = _convert_inputs_to_dataset(x, y, sample_weight, batch_size)
|
178
|
+
x = x.map(
|
179
|
+
self.preprocess_samples, num_parallel_calls=tf.data.AUTOTUNE
|
180
|
+
).prefetch(tf.data.AUTOTUNE)
|
181
|
+
|
182
|
+
if validation_data is not None:
|
183
|
+
if not isinstance(validation_data, tf.data.Dataset):
|
184
|
+
(vx, vy, vsw) = keras.utils.unpack_x_y_sample_weight(
|
185
|
+
validation_data
|
186
|
+
)
|
187
|
+
validation_data = _convert_inputs_to_dataset(
|
188
|
+
vx, vy, vsw, batch_size
|
189
|
+
)
|
190
|
+
|
191
|
+
return super().fit(
|
192
|
+
x=x,
|
193
|
+
y=None,
|
194
|
+
batch_size=None,
|
195
|
+
sample_weight=None,
|
196
|
+
validation_data=validation_data,
|
197
|
+
**kwargs,
|
198
|
+
)
|
199
|
+
|
200
|
+
def evaluate(
|
201
|
+
self,
|
202
|
+
x=None,
|
203
|
+
y=None,
|
204
|
+
batch_size=None,
|
205
|
+
sample_weight=None,
|
206
|
+
**kwargs,
|
207
|
+
):
|
208
|
+
# During `fit()`, `keras.Model` attempts to cache the validation
|
209
|
+
# dataset and ignores the values for `x`, `y`, and `sample_weight`.
|
210
|
+
# We don't want that behavior here, as the validation dataset still
|
211
|
+
# needs preprocessing.
|
212
|
+
kwargs.pop("_use_cached_eval_dataset", None)
|
213
|
+
x = _convert_inputs_to_dataset(x, y, sample_weight, batch_size)
|
214
|
+
x = x.map(
|
215
|
+
self.preprocess_samples, num_parallel_calls=tf.data.AUTOTUNE
|
216
|
+
).prefetch(tf.data.AUTOTUNE)
|
217
|
+
return super().evaluate(
|
218
|
+
x=x,
|
219
|
+
y=None,
|
220
|
+
batch_size=None,
|
221
|
+
**kwargs,
|
222
|
+
)
|
223
|
+
|
224
|
+
def predict(
|
225
|
+
self,
|
226
|
+
x=None,
|
227
|
+
batch_size=None,
|
228
|
+
**kwargs,
|
229
|
+
):
|
230
|
+
x = _convert_inputs_to_dataset(x, None, None, batch_size)
|
231
|
+
x = x.map(
|
232
|
+
self.preprocess_samples, num_parallel_calls=tf.data.AUTOTUNE
|
233
|
+
).prefetch(tf.data.AUTOTUNE)
|
234
|
+
return super().predict(
|
235
|
+
x=x,
|
236
|
+
batch_size=None,
|
237
|
+
**kwargs,
|
238
|
+
)
|
239
|
+
|
240
|
+
def train_on_batch(
|
241
|
+
self,
|
242
|
+
x,
|
243
|
+
y=None,
|
244
|
+
sample_weight=None,
|
245
|
+
**kwargs,
|
246
|
+
):
|
247
|
+
data = self.preprocess_samples(x, y, sample_weight)
|
248
|
+
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
|
249
|
+
x = ops.convert_to_tensor(x)
|
250
|
+
if y is not None:
|
251
|
+
y = ops.convert_to_tensor(y)
|
252
|
+
if sample_weight is not None:
|
253
|
+
sample_weight = ops.convert_to_tensor(sample_weight)
|
254
|
+
return super().train_on_batch(
|
255
|
+
x=x,
|
256
|
+
y=y,
|
257
|
+
sample_weight=sample_weight,
|
258
|
+
**kwargs,
|
259
|
+
)
|
260
|
+
|
261
|
+
def test_on_batch(
|
262
|
+
self,
|
263
|
+
x,
|
264
|
+
y=None,
|
265
|
+
sample_weight=None,
|
266
|
+
**kwargs,
|
267
|
+
):
|
268
|
+
data = self.preprocess_samples(x, y, sample_weight)
|
269
|
+
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
|
270
|
+
x = ops.convert_to_tensor(x)
|
271
|
+
if y is not None:
|
272
|
+
y = ops.convert_to_tensor(y)
|
273
|
+
if sample_weight is not None:
|
274
|
+
sample_weight = ops.convert_to_tensor(sample_weight)
|
275
|
+
return super().test_on_batch(
|
276
|
+
x=x,
|
277
|
+
y=y,
|
278
|
+
sample_weight=sample_weight,
|
279
|
+
**kwargs,
|
280
|
+
)
|
281
|
+
|
282
|
+
def predict_on_batch(
|
283
|
+
self,
|
284
|
+
x,
|
285
|
+
**kwargs,
|
286
|
+
):
|
287
|
+
data = self.preprocess_samples(x)
|
288
|
+
x, _, _ = keras.utils.unpack_x_y_sample_weight(data)
|
289
|
+
x = ops.convert_to_tensor(x)
|
290
|
+
return super().predict_on_batch(
|
291
|
+
x=x,
|
292
|
+
**kwargs,
|
293
|
+
)
|