keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,394 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import collections
|
16
|
+
import math
|
17
|
+
|
18
|
+
import keras
|
19
|
+
from keras import ops
|
20
|
+
|
21
|
+
from keras_hub.src.api_export import keras_hub_export
|
22
|
+
from keras_hub.src.utils.tensor_utils import is_float_dtype
|
23
|
+
from keras_hub.src.utils.tensor_utils import tensor_to_list
|
24
|
+
|
25
|
+
try:
|
26
|
+
import tensorflow as tf
|
27
|
+
except ImportError:
|
28
|
+
tf = None
|
29
|
+
|
30
|
+
|
31
|
+
REPLACE_SUBSTRINGS = [
|
32
|
+
("<skipped>", ""),
|
33
|
+
("-\n", ""),
|
34
|
+
("\n", " "),
|
35
|
+
(""", '"'),
|
36
|
+
("&", "&"),
|
37
|
+
("<", "<"),
|
38
|
+
(">", ">"),
|
39
|
+
]
|
40
|
+
|
41
|
+
|
42
|
+
REGEX_PATTERNS = [
|
43
|
+
# language-dependent part (assuming Western languages)
|
44
|
+
(r"([\{-\~\[-\` -\&\(-\+\:-\@\/])", r" \1 "),
|
45
|
+
# tokenize period and comma unless preceded by a digit
|
46
|
+
(r"([^0-9])([\.,])", r"\1 \2 "),
|
47
|
+
# tokenize period and comma unless followed by a digit
|
48
|
+
(r"([\.,])([^0-9])", r" \1 \2"),
|
49
|
+
# tokenize dash when preceded by a digit
|
50
|
+
(r"([0-9])(-)", r"\1 \2 "),
|
51
|
+
# If last character is "." or ",", add space.
|
52
|
+
(r"[\.,]$", r" \0 \1"),
|
53
|
+
# one space only between words
|
54
|
+
(r"\s+", r" "),
|
55
|
+
]
|
56
|
+
|
57
|
+
|
58
|
+
@keras_hub_export("keras_hub.metrics.Bleu")
|
59
|
+
class Bleu(keras.metrics.Metric):
|
60
|
+
"""BLEU metric.
|
61
|
+
|
62
|
+
This class implements the BLEU metric. BLEU is generally used to evaluate
|
63
|
+
machine translation systems. By default, this implementation replicates
|
64
|
+
SacreBLEU, but user-defined tokenizers can be passed to deal with other
|
65
|
+
languages.
|
66
|
+
|
67
|
+
For BLEU score, we count the number of matching n-grams in the candidate
|
68
|
+
translation and the reference text. We find the "clipped count" of matching
|
69
|
+
n-grams so as to not give a high score to a (reference, prediction) pair
|
70
|
+
with redundant, repeated tokens. Secondly, BLEU score tends to reward
|
71
|
+
shorter predictions more, which is why a brevity penalty is applied to
|
72
|
+
penalise short predictions. For more details, see the following article:
|
73
|
+
https://cloud.google.com/translate/automl/docs/evaluate#bleu.
|
74
|
+
|
75
|
+
Note on input shapes:
|
76
|
+
For unbatched inputs, `y_pred` should be a tensor of shape `()`, and
|
77
|
+
`y_true` should be a tensor of shape `(num_references,)`. For batched
|
78
|
+
inputs, `y_pred` should be a tensor of shape `(batch_size,)`,
|
79
|
+
and `y_true` should be a tensor of shape `(batch_size, num_references)`. In
|
80
|
+
case of batched inputs, `y_true` can also be a ragged tensor of shape
|
81
|
+
`(batch_size, None)` if different samples have different number of
|
82
|
+
references.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
tokenizer: callable. A function that takes a string `tf.RaggedTensor`
|
86
|
+
(of any shape), and tokenizes the strings in the tensor. If the
|
87
|
+
tokenizer is not specified, the default tokenizer is used. The
|
88
|
+
default tokenizer replicates the behaviour of SacreBLEU's
|
89
|
+
`"tokenizer_13a"` tokenizer
|
90
|
+
(https://github.com/mjpost/sacrebleu/blob/v2.1.0/sacrebleu/tokenizers/tokenizer_13a.py).
|
91
|
+
max_order: int. The maximum n-gram order to use. For example, if
|
92
|
+
`max_order` is set to 3, unigrams, bigrams, and trigrams will be
|
93
|
+
considered. Defaults to `4`.
|
94
|
+
smooth: bool. Whether to apply Lin et al. 2004 smoothing to the BLEU
|
95
|
+
score. Adds 1 to the matched n-gram count (i.e., numerator) and 1
|
96
|
+
to the total n-gram count (i.e., denominator) for every order while
|
97
|
+
calculating precision. Defaults to `False`.
|
98
|
+
dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
|
99
|
+
not specified, it defaults to `"float32"`.
|
100
|
+
name: string. Name of the metric instance.
|
101
|
+
**kwargs: Other keyword arguments.
|
102
|
+
|
103
|
+
References:
|
104
|
+
- [Papineni et al., 2002](https://aclanthology.org/P02-1040/)
|
105
|
+
- [SacreBLEU](https://github.com/mjpost/sacrebleu)
|
106
|
+
- [Lin et al., 2004](https://aclanthology.org/P04-1077/)
|
107
|
+
"""
|
108
|
+
|
109
|
+
def __init__(
|
110
|
+
self,
|
111
|
+
tokenizer=None,
|
112
|
+
max_order=4,
|
113
|
+
smooth=False,
|
114
|
+
dtype="float32",
|
115
|
+
name="bleu",
|
116
|
+
**kwargs,
|
117
|
+
):
|
118
|
+
super().__init__(name=name, dtype=dtype, **kwargs)
|
119
|
+
|
120
|
+
if not is_float_dtype(dtype):
|
121
|
+
raise ValueError(
|
122
|
+
"`dtype` must be a floating point type. "
|
123
|
+
f"Received: dtype={dtype}"
|
124
|
+
)
|
125
|
+
|
126
|
+
self.tokenizer = tokenizer
|
127
|
+
self.max_order = max_order
|
128
|
+
self.smooth = smooth
|
129
|
+
|
130
|
+
self._matches = self.add_weight(
|
131
|
+
shape=(self.max_order,),
|
132
|
+
initializer="zeros",
|
133
|
+
dtype=self.dtype,
|
134
|
+
name="bleu_matches",
|
135
|
+
)
|
136
|
+
self._possible_matches = self.add_weight(
|
137
|
+
shape=(self.max_order,),
|
138
|
+
initializer="zeros",
|
139
|
+
dtype=self.dtype,
|
140
|
+
name="bleu_possible_matches",
|
141
|
+
)
|
142
|
+
self._translation_length = self.add_weight(
|
143
|
+
shape=(),
|
144
|
+
initializer="zeros",
|
145
|
+
dtype=self.dtype,
|
146
|
+
name="bleu_translation_length",
|
147
|
+
)
|
148
|
+
self._reference_length = self.add_weight(
|
149
|
+
shape=(),
|
150
|
+
initializer="zeros",
|
151
|
+
dtype=self.dtype,
|
152
|
+
name="bleu_reference_length",
|
153
|
+
)
|
154
|
+
self._bleu = self.add_weight(
|
155
|
+
shape=(),
|
156
|
+
initializer="zeros",
|
157
|
+
dtype=self.dtype,
|
158
|
+
name="bleu",
|
159
|
+
)
|
160
|
+
|
161
|
+
def _tokenizer(self, inputs):
|
162
|
+
"""
|
163
|
+
Tokenizes the input strings. By default, replicates the behaviour of
|
164
|
+
SacreBLEU's default tokenizer, namely, `tokenizer_13a`.
|
165
|
+
"""
|
166
|
+
if self.tokenizer:
|
167
|
+
return self.tokenizer(inputs)
|
168
|
+
|
169
|
+
for pattern, replacement in REPLACE_SUBSTRINGS + REGEX_PATTERNS:
|
170
|
+
inputs = tf.strings.regex_replace(
|
171
|
+
input=inputs,
|
172
|
+
pattern=pattern,
|
173
|
+
rewrite=replacement,
|
174
|
+
replace_global=True,
|
175
|
+
name=None,
|
176
|
+
)
|
177
|
+
inputs = tf.strings.split(inputs)
|
178
|
+
return inputs
|
179
|
+
|
180
|
+
def _get_ngrams(self, segment, max_order):
|
181
|
+
"""Extracts all n-grams up to a given maximum order from an input segment.
|
182
|
+
|
183
|
+
Uses Python ops. Inspired from
|
184
|
+
https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
segment: list. Text segment from which n-grams will be
|
188
|
+
extracted.
|
189
|
+
max_order: int. Maximum length in tokens of the n-grams returned
|
190
|
+
by this method.
|
191
|
+
"""
|
192
|
+
ngram_counts = collections.Counter()
|
193
|
+
for order in range(1, max_order + 1):
|
194
|
+
for i in range(0, len(segment) - order + 1):
|
195
|
+
ngram = tuple(segment[i : i + order])
|
196
|
+
ngram_counts[ngram] += 1
|
197
|
+
return ngram_counts
|
198
|
+
|
199
|
+
def _corpus_bleu(
|
200
|
+
self,
|
201
|
+
reference_corpus,
|
202
|
+
translation_corpus,
|
203
|
+
matches_by_order,
|
204
|
+
possible_matches_by_order,
|
205
|
+
translation_length,
|
206
|
+
reference_length,
|
207
|
+
max_order=4,
|
208
|
+
smooth=False,
|
209
|
+
):
|
210
|
+
"""Corpus BLEU implementation using Python ops.
|
211
|
+
|
212
|
+
Computes BLEU score of translated segments against one or more
|
213
|
+
references. Inspired from
|
214
|
+
https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py.
|
215
|
+
|
216
|
+
Args:
|
217
|
+
reference_corpus: list of lists of references for each
|
218
|
+
translation. Each reference should be tokenized into a list
|
219
|
+
of tokens.
|
220
|
+
translation_corpus: list of translations to score. Each
|
221
|
+
translation should be tokenized into a list of tokens.
|
222
|
+
matches_by_order: list of floats containing the initial number
|
223
|
+
of matches for each order.
|
224
|
+
possible_matches_by_order: list of floats containing the initial
|
225
|
+
number of possible matches for each order.
|
226
|
+
translation_length: float. Initial number of tokens in all the
|
227
|
+
translations.
|
228
|
+
reference_length: float. Initial number of tokens in all the
|
229
|
+
references.
|
230
|
+
max_order: int. Maximum n-gram order to use when computing
|
231
|
+
BLEU score.
|
232
|
+
smooth: boolean. Whether or not to apply Lin et al. 2004
|
233
|
+
smoothing.
|
234
|
+
"""
|
235
|
+
for references, translation in zip(
|
236
|
+
reference_corpus, translation_corpus
|
237
|
+
):
|
238
|
+
reference_length += min(len(r) for r in references)
|
239
|
+
translation_length += len(translation)
|
240
|
+
|
241
|
+
merged_ref_ngram_counts = collections.Counter()
|
242
|
+
for reference in references:
|
243
|
+
merged_ref_ngram_counts |= self._get_ngrams(
|
244
|
+
reference, max_order
|
245
|
+
)
|
246
|
+
translation_ngram_counts = self._get_ngrams(translation, max_order)
|
247
|
+
overlap = translation_ngram_counts & merged_ref_ngram_counts
|
248
|
+
for ngram in overlap:
|
249
|
+
matches_by_order[len(ngram) - 1] += overlap[ngram]
|
250
|
+
for order in range(1, max_order + 1):
|
251
|
+
possible_matches = len(translation) - order + 1
|
252
|
+
if possible_matches > 0:
|
253
|
+
possible_matches_by_order[order - 1] += possible_matches
|
254
|
+
|
255
|
+
precisions = [0] * max_order
|
256
|
+
for i in range(0, max_order):
|
257
|
+
if smooth:
|
258
|
+
precisions[i] = (matches_by_order[i] + 1.0) / (
|
259
|
+
possible_matches_by_order[i] + 1.0
|
260
|
+
)
|
261
|
+
else:
|
262
|
+
if possible_matches_by_order[i] > 0:
|
263
|
+
precisions[i] = (
|
264
|
+
float(matches_by_order[i])
|
265
|
+
/ possible_matches_by_order[i]
|
266
|
+
)
|
267
|
+
else:
|
268
|
+
precisions[i] = 0.0
|
269
|
+
|
270
|
+
if min(precisions) > 0:
|
271
|
+
p_log_sum = sum((1.0 / max_order) * math.log(p) for p in precisions)
|
272
|
+
geo_mean = math.exp(p_log_sum)
|
273
|
+
else:
|
274
|
+
geo_mean = 0
|
275
|
+
|
276
|
+
ratio = float(translation_length) / reference_length
|
277
|
+
|
278
|
+
if ratio > 1.0:
|
279
|
+
bp = 1.0
|
280
|
+
else:
|
281
|
+
bp = math.exp(1 - 1.0 / ratio)
|
282
|
+
|
283
|
+
bleu = geo_mean * bp
|
284
|
+
|
285
|
+
return (
|
286
|
+
bleu,
|
287
|
+
matches_by_order,
|
288
|
+
possible_matches_by_order,
|
289
|
+
translation_length,
|
290
|
+
reference_length,
|
291
|
+
)
|
292
|
+
|
293
|
+
def _calculate_bleu_score(self, references, translation):
|
294
|
+
if isinstance(references, (tf.Tensor, tf.RaggedTensor)):
|
295
|
+
references = tensor_to_list(references)
|
296
|
+
if isinstance(translation, (tf.Tensor, tf.RaggedTensor)):
|
297
|
+
translation = tensor_to_list(translation)
|
298
|
+
|
299
|
+
matches = self._matches.numpy()
|
300
|
+
possible_matches = self._possible_matches.numpy()
|
301
|
+
translation_length = self._translation_length.numpy()
|
302
|
+
reference_length = self._reference_length.numpy()
|
303
|
+
|
304
|
+
(
|
305
|
+
bleu_score,
|
306
|
+
matches,
|
307
|
+
possible_matches,
|
308
|
+
translation_length,
|
309
|
+
reference_length,
|
310
|
+
) = self._corpus_bleu(
|
311
|
+
reference_corpus=references,
|
312
|
+
translation_corpus=translation,
|
313
|
+
matches_by_order=matches,
|
314
|
+
possible_matches_by_order=possible_matches,
|
315
|
+
translation_length=translation_length,
|
316
|
+
reference_length=reference_length,
|
317
|
+
max_order=self.max_order,
|
318
|
+
smooth=self.smooth,
|
319
|
+
)
|
320
|
+
return (
|
321
|
+
bleu_score,
|
322
|
+
matches,
|
323
|
+
possible_matches,
|
324
|
+
translation_length,
|
325
|
+
reference_length,
|
326
|
+
)
|
327
|
+
|
328
|
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
329
|
+
def validate_and_fix_rank(inputs, tensor_name, base_rank=0):
|
330
|
+
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
|
331
|
+
inputs = tf.convert_to_tensor(inputs)
|
332
|
+
|
333
|
+
if inputs.shape.rank == base_rank:
|
334
|
+
return inputs[tf.newaxis]
|
335
|
+
elif inputs.shape.rank == base_rank + 1:
|
336
|
+
return inputs
|
337
|
+
elif inputs.shape.rank == base_rank + 2:
|
338
|
+
if tf.shape(inputs)[-1] != 1:
|
339
|
+
raise ValueError(
|
340
|
+
f"{tensor_name} is of rank {input.shape.rank}. The "
|
341
|
+
f"last dimension must be of size 1."
|
342
|
+
)
|
343
|
+
return tf.squeeze(inputs, axis=-1)
|
344
|
+
else:
|
345
|
+
raise ValueError(
|
346
|
+
f"{tensor_name} must be of rank {base_rank}, {base_rank+1} "
|
347
|
+
f"or {base_rank+2}. Found rank: {inputs.shape.rank}"
|
348
|
+
)
|
349
|
+
|
350
|
+
y_true = validate_and_fix_rank(y_true, "y_true", 1)
|
351
|
+
y_pred = validate_and_fix_rank(y_pred, "y_pred", 0)
|
352
|
+
|
353
|
+
# Tokenize the inputs.
|
354
|
+
y_true = self._tokenizer(y_true)
|
355
|
+
y_pred = self._tokenizer(y_pred)
|
356
|
+
|
357
|
+
(
|
358
|
+
bleu_score,
|
359
|
+
matches,
|
360
|
+
possible_matches,
|
361
|
+
translation_length,
|
362
|
+
reference_length,
|
363
|
+
) = self._calculate_bleu_score(y_true, y_pred)
|
364
|
+
|
365
|
+
self._matches.assign(matches)
|
366
|
+
self._possible_matches.assign(possible_matches)
|
367
|
+
self._translation_length.assign(translation_length)
|
368
|
+
self._reference_length.assign(reference_length)
|
369
|
+
self._bleu.assign(bleu_score)
|
370
|
+
|
371
|
+
def result(self):
|
372
|
+
return self._bleu
|
373
|
+
|
374
|
+
def reset_state(self):
|
375
|
+
self._matches.assign(
|
376
|
+
ops.zeros(shape=(self.max_order,), dtype=self.dtype)
|
377
|
+
)
|
378
|
+
self._possible_matches.assign(
|
379
|
+
ops.zeros(shape=(self.max_order,), dtype=self.dtype)
|
380
|
+
)
|
381
|
+
self._translation_length.assign(0.0)
|
382
|
+
self._reference_length.assign(0.0)
|
383
|
+
self._bleu.assign(0.0)
|
384
|
+
|
385
|
+
def get_config(self):
|
386
|
+
config = super().get_config()
|
387
|
+
config.update(
|
388
|
+
{
|
389
|
+
"tokenizer": self.tokenizer,
|
390
|
+
"max_order": self.max_order,
|
391
|
+
"smooth": self.smooth,
|
392
|
+
}
|
393
|
+
)
|
394
|
+
return config
|
@@ -0,0 +1,197 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
|
17
|
+
from keras_hub.src.api_export import keras_hub_export
|
18
|
+
from keras_hub.src.utils.tensor_utils import is_float_dtype
|
19
|
+
|
20
|
+
try:
|
21
|
+
import tensorflow as tf
|
22
|
+
except ImportError:
|
23
|
+
tf = None
|
24
|
+
|
25
|
+
|
26
|
+
@keras_hub_export("keras_hub.metrics.EditDistance")
|
27
|
+
class EditDistance(keras.metrics.Metric):
|
28
|
+
"""Edit Distance metric.
|
29
|
+
|
30
|
+
This class implements the edit distance metric, sometimes called
|
31
|
+
Levenshtein Distance, as a `keras.metrics.Metric`. Essentially, edit
|
32
|
+
distance is the least number of operations required to convert one string to
|
33
|
+
another, where an operation can be one of substitution, deletion or
|
34
|
+
insertion. By default, this metric will compute the normalized score, where
|
35
|
+
the unnormalized edit distance score is divided by the number of tokens in
|
36
|
+
the reference text.
|
37
|
+
|
38
|
+
This class can be used to compute character error rate (CER) and word error
|
39
|
+
rate (WER). You simply have to pass the appropriate tokenized text, and set
|
40
|
+
`normalize` to True.
|
41
|
+
|
42
|
+
Note on input shapes:
|
43
|
+
`y_true` and `y_pred` can either be tensors of rank 1 or ragged tensors of
|
44
|
+
rank 2. These tensors contain tokenized text.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
normalize: bool. If True, the computed number of operations
|
48
|
+
(substitutions + deletions + insertions) across all samples is
|
49
|
+
divided by the aggregate number of tokens in all reference texts. If
|
50
|
+
False, number of operations are calculated for every sample, and
|
51
|
+
averaged over all the samples.
|
52
|
+
dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
|
53
|
+
not specified, it defaults to `"float32"`.
|
54
|
+
name: string. Name of the metric instance.
|
55
|
+
**kwargs: Other keyword arguments.
|
56
|
+
|
57
|
+
References:
|
58
|
+
- [Morris et al.](https://www.researchgate.net/publication/221478089)
|
59
|
+
|
60
|
+
Examples:
|
61
|
+
|
62
|
+
Various Input Types.
|
63
|
+
|
64
|
+
Single-level Python list.
|
65
|
+
>>> edit_distance = keras_hub.metrics.EditDistance()
|
66
|
+
>>> y_true = "the tiny little cat was found under the big funny bed".split()
|
67
|
+
>>> y_pred = "the cat was found under the bed".split()
|
68
|
+
>>> edit_distance(y_true, y_pred)
|
69
|
+
<tf.Tensor: shape=(), dtype=float32, numpy=0.36363637>
|
70
|
+
|
71
|
+
Nested Python list.
|
72
|
+
>>> edit_distance = keras_hub.metrics.EditDistance()
|
73
|
+
>>> y_true = [
|
74
|
+
... "the tiny little cat was found under the big funny bed".split(),
|
75
|
+
... "it is sunny today".split(),
|
76
|
+
... ]
|
77
|
+
>>> y_pred = [
|
78
|
+
... "the cat was found under the bed".split(),
|
79
|
+
... "it is sunny but with a hint of cloud cover".split(),
|
80
|
+
... ]
|
81
|
+
>>> edit_distance(y_true, y_pred)
|
82
|
+
<tf.Tensor: shape=(), dtype=float32, numpy=0.73333335>
|
83
|
+
"""
|
84
|
+
|
85
|
+
def __init__(
|
86
|
+
self,
|
87
|
+
normalize=True,
|
88
|
+
dtype="float32",
|
89
|
+
name="edit_distance",
|
90
|
+
**kwargs,
|
91
|
+
):
|
92
|
+
super().__init__(name=name, dtype=dtype, **kwargs)
|
93
|
+
|
94
|
+
if not is_float_dtype(dtype):
|
95
|
+
raise ValueError(
|
96
|
+
"`dtype` must be a floating point type. "
|
97
|
+
f"Received: dtype={dtype}"
|
98
|
+
)
|
99
|
+
|
100
|
+
self.normalize = normalize
|
101
|
+
|
102
|
+
self._aggregate_unnormalized_edit_distance = self.add_weight(
|
103
|
+
shape=(),
|
104
|
+
initializer="zeros",
|
105
|
+
dtype=self.dtype,
|
106
|
+
name="aggregate_unnormalized_edit_distance",
|
107
|
+
)
|
108
|
+
if normalize:
|
109
|
+
self._aggregate_reference_length = self.add_weight(
|
110
|
+
shape=(),
|
111
|
+
initializer="zeros",
|
112
|
+
dtype=self.dtype,
|
113
|
+
name="aggregate_reference_length",
|
114
|
+
)
|
115
|
+
else:
|
116
|
+
self._number_of_samples = self.add_weight(
|
117
|
+
shape=(),
|
118
|
+
initializer="zeros",
|
119
|
+
dtype=self.dtype,
|
120
|
+
name="number_of_samples",
|
121
|
+
)
|
122
|
+
|
123
|
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
124
|
+
def validate_and_fix_rank(inputs, tensor_name):
|
125
|
+
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
|
126
|
+
inputs = tf.ragged.constant(inputs)
|
127
|
+
|
128
|
+
if inputs.shape.rank == 1:
|
129
|
+
return tf.RaggedTensor.from_tensor(inputs[tf.newaxis])
|
130
|
+
elif inputs.shape.rank == 2:
|
131
|
+
return inputs
|
132
|
+
else:
|
133
|
+
raise ValueError(
|
134
|
+
f"{tensor_name} must be of rank 1 or 2. "
|
135
|
+
f"Found rank: {inputs.shape.rank}"
|
136
|
+
)
|
137
|
+
|
138
|
+
y_true = validate_and_fix_rank(y_true, "y_true")
|
139
|
+
y_pred = validate_and_fix_rank(y_pred, "y_pred")
|
140
|
+
|
141
|
+
if self.normalize:
|
142
|
+
self._aggregate_reference_length.assign_add(
|
143
|
+
tf.cast(tf.size(y_true.flat_values), dtype=self.dtype)
|
144
|
+
)
|
145
|
+
|
146
|
+
def calculate_edit_distance(args):
|
147
|
+
reference, hypothesis = args
|
148
|
+
|
149
|
+
reference = tf.sparse.from_dense([reference])
|
150
|
+
hypothesis = tf.sparse.from_dense([hypothesis])
|
151
|
+
|
152
|
+
edit_distance = tf.squeeze(
|
153
|
+
tf.edit_distance(
|
154
|
+
hypothesis=hypothesis,
|
155
|
+
truth=reference,
|
156
|
+
normalize=False,
|
157
|
+
)
|
158
|
+
)
|
159
|
+
|
160
|
+
self._aggregate_unnormalized_edit_distance.assign_add(
|
161
|
+
tf.cast(edit_distance, dtype=self.dtype)
|
162
|
+
)
|
163
|
+
if not self.normalize:
|
164
|
+
self._number_of_samples.assign_add(tf.cast(1, dtype=self.dtype))
|
165
|
+
return 0
|
166
|
+
|
167
|
+
_ = tf.map_fn(
|
168
|
+
fn=calculate_edit_distance,
|
169
|
+
elems=(y_true, y_pred),
|
170
|
+
fn_output_signature="int8",
|
171
|
+
)
|
172
|
+
|
173
|
+
def result(self):
|
174
|
+
if self.normalize:
|
175
|
+
if self._aggregate_reference_length == 0:
|
176
|
+
return 0.0
|
177
|
+
return (
|
178
|
+
self._aggregate_unnormalized_edit_distance
|
179
|
+
/ self._aggregate_reference_length
|
180
|
+
)
|
181
|
+
if self._number_of_samples == 0:
|
182
|
+
return 0.0
|
183
|
+
return (
|
184
|
+
self._aggregate_unnormalized_edit_distance / self._number_of_samples
|
185
|
+
)
|
186
|
+
|
187
|
+
def reset_state(self):
|
188
|
+
self._aggregate_unnormalized_edit_distance.assign(0.0)
|
189
|
+
if self.normalize:
|
190
|
+
self._aggregate_reference_length.assign(0.0)
|
191
|
+
else:
|
192
|
+
self._number_of_samples.assign(0.0)
|
193
|
+
|
194
|
+
def get_config(self):
|
195
|
+
config = super().get_config()
|
196
|
+
config.update({"normalize": self.normalize})
|
197
|
+
return config
|