keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,181 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
from keras_hub.src.utils.tensor_utils import is_float_dtype
|
20
|
+
|
21
|
+
|
22
|
+
@keras_hub_export("keras_hub.metrics.Perplexity")
|
23
|
+
class Perplexity(keras.metrics.Metric):
|
24
|
+
"""Perplexity metric.
|
25
|
+
|
26
|
+
This class implements the perplexity metric. In short, this class calculates
|
27
|
+
the cross entropy loss and takes its exponent.
|
28
|
+
Note: This implementation is not suitable for fixed-size windows.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
from_logits: bool. If True, `y_pred` (input to `update_state()`) should
|
32
|
+
be the logits as returned by the model. Otherwise, `y_pred` is a
|
33
|
+
tensor of probabilities.
|
34
|
+
mask_token_id: int. ID of the token to be masked. If provided, the mask
|
35
|
+
is computed for this class. Note that if this field is provided, and
|
36
|
+
if the `sample_weight` field in `update_state()` is also provided,
|
37
|
+
we will compute the final `sample_weight` as the element-wise
|
38
|
+
product of the mask and the `sample_weight`.
|
39
|
+
dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
|
40
|
+
not specified, it defaults to `"float32"`.
|
41
|
+
name: string. Name of the metric instance.
|
42
|
+
**kwargs: Other keyword arguments.
|
43
|
+
|
44
|
+
Examples:
|
45
|
+
|
46
|
+
1. Calculate perplexity by calling update_state() and result().
|
47
|
+
1.1. `sample_weight`, and `mask_token_id` are not provided.
|
48
|
+
>>> np.random.seed(42)
|
49
|
+
>>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
|
50
|
+
>>> target = np.random.randint(10, size=[2, 5])
|
51
|
+
>>> logits = np.random.uniform(size=(2, 5, 10))
|
52
|
+
>>> perplexity.update_state(target, logits)
|
53
|
+
>>> perplexity.result()
|
54
|
+
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
|
55
|
+
|
56
|
+
1.2. `sample_weight` specified (masking token with ID 0).
|
57
|
+
>>> np.random.seed(42)
|
58
|
+
>>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
|
59
|
+
>>> target = np.random.randint(10, size=[2, 5])
|
60
|
+
>>> logits = np.random.uniform(size=(2, 5, 10))
|
61
|
+
>>> sample_weight = (target != 0).astype("float32")
|
62
|
+
>>> perplexity.update_state(target, logits, sample_weight)
|
63
|
+
>>> perplexity.result()
|
64
|
+
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
|
65
|
+
|
66
|
+
2. Call perplexity directly.
|
67
|
+
>>> np.random.seed(42)
|
68
|
+
>>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
|
69
|
+
>>> target = np.random.randint(10, size=[2, 5])
|
70
|
+
>>> logits = np.random.uniform(size=(2, 5, 10))
|
71
|
+
>>> perplexity(target, logits)
|
72
|
+
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
|
73
|
+
|
74
|
+
3. Provide the padding token ID and let the class compute the mask on its
|
75
|
+
own.
|
76
|
+
>>> np.random.seed(42)
|
77
|
+
>>> perplexity = keras_hub.metrics.Perplexity(mask_token_id=0)
|
78
|
+
>>> target = np.random.randint(10, size=[2, 5])
|
79
|
+
>>> logits = np.random.uniform(size=(2, 5, 10))
|
80
|
+
>>> perplexity(target, logits)
|
81
|
+
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
|
82
|
+
"""
|
83
|
+
|
84
|
+
def __init__(
|
85
|
+
self,
|
86
|
+
from_logits=False,
|
87
|
+
mask_token_id=None,
|
88
|
+
dtype="float32",
|
89
|
+
name="perplexity",
|
90
|
+
**kwargs,
|
91
|
+
):
|
92
|
+
if not is_float_dtype(dtype):
|
93
|
+
raise ValueError(
|
94
|
+
"`dtype` must be a floating point type. "
|
95
|
+
f"Received: dtype={dtype}"
|
96
|
+
)
|
97
|
+
|
98
|
+
super().__init__(name=name, dtype=dtype, **kwargs)
|
99
|
+
|
100
|
+
self._crossentropy = keras.losses.SparseCategoricalCrossentropy(
|
101
|
+
from_logits=from_logits, reduction="sum"
|
102
|
+
)
|
103
|
+
|
104
|
+
self.from_logits = from_logits
|
105
|
+
self.mask_token_id = mask_token_id
|
106
|
+
|
107
|
+
self._aggregate_crossentropy = self.add_weight(
|
108
|
+
shape=(),
|
109
|
+
initializer="zeros",
|
110
|
+
dtype=self.dtype,
|
111
|
+
name="aggregate_crossentropy",
|
112
|
+
)
|
113
|
+
self._number_of_samples = self.add_weight(
|
114
|
+
shape=(),
|
115
|
+
initializer="zeros",
|
116
|
+
dtype=self.dtype,
|
117
|
+
name="number_of_samples",
|
118
|
+
)
|
119
|
+
|
120
|
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
121
|
+
# y_true shape: (batch_size, seq_len)
|
122
|
+
# y_pred shape: (batch_size, seq_len, vocab_size)
|
123
|
+
y_true = ops.cast(y_true, self.dtype)
|
124
|
+
y_pred = ops.cast(y_pred, self.dtype)
|
125
|
+
|
126
|
+
if sample_weight is not None:
|
127
|
+
sample_weight = ops.cast(sample_weight, self.dtype)
|
128
|
+
|
129
|
+
batch_size = ops.cast(ops.shape(y_true)[0], self.dtype)
|
130
|
+
|
131
|
+
if self.mask_token_id is not None:
|
132
|
+
mask = ops.cast(
|
133
|
+
ops.logical_not(ops.equal(y_true, self.mask_token_id)),
|
134
|
+
self.dtype,
|
135
|
+
)
|
136
|
+
if sample_weight is None:
|
137
|
+
sample_weight = mask
|
138
|
+
else:
|
139
|
+
sample_weight = ops.multiply(mask, sample_weight)
|
140
|
+
|
141
|
+
# Calculate the Cross Entropy Loss.
|
142
|
+
crossentropy_value = ops.cast(
|
143
|
+
self._crossentropy(y_true, y_pred, sample_weight=sample_weight),
|
144
|
+
self.dtype,
|
145
|
+
) # scalar
|
146
|
+
|
147
|
+
# Divide the loss by the number of non-masked tokens
|
148
|
+
if sample_weight is not None:
|
149
|
+
crossentropy_value = crossentropy_value / ops.sum(
|
150
|
+
sample_weight
|
151
|
+
) # scalar
|
152
|
+
else:
|
153
|
+
crossentropy_value = crossentropy_value / (
|
154
|
+
ops.cast(ops.shape(y_true)[0], self.dtype)
|
155
|
+
* ops.cast(ops.shape(y_true)[1], self.dtype)
|
156
|
+
) # scalar
|
157
|
+
|
158
|
+
self._aggregate_crossentropy.assign_add(batch_size * crossentropy_value)
|
159
|
+
self._number_of_samples.assign_add(batch_size)
|
160
|
+
|
161
|
+
def result(self):
|
162
|
+
perplexity_score = ops.where(
|
163
|
+
ops.equal(ops.convert_to_tensor(self._number_of_samples), 0),
|
164
|
+
0,
|
165
|
+
ops.exp(self._aggregate_crossentropy / self._number_of_samples),
|
166
|
+
)
|
167
|
+
return perplexity_score
|
168
|
+
|
169
|
+
def reset_state(self):
|
170
|
+
self._aggregate_crossentropy.assign(0.0)
|
171
|
+
self._number_of_samples.assign(0.0)
|
172
|
+
|
173
|
+
def get_config(self):
|
174
|
+
config = super().get_config()
|
175
|
+
config.update(
|
176
|
+
{
|
177
|
+
"from_logits": self.from_logits,
|
178
|
+
"mask_token_id": self.mask_token_id,
|
179
|
+
}
|
180
|
+
)
|
181
|
+
return config
|
@@ -0,0 +1,204 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.utils.tensor_utils import is_float_dtype
|
19
|
+
from keras_hub.src.utils.tensor_utils import tensor_to_list
|
20
|
+
|
21
|
+
try:
|
22
|
+
import tensorflow as tf
|
23
|
+
except ImportError:
|
24
|
+
tf = None
|
25
|
+
|
26
|
+
try:
|
27
|
+
from rouge_score import rouge_scorer
|
28
|
+
except ImportError:
|
29
|
+
rouge_scorer = None
|
30
|
+
|
31
|
+
|
32
|
+
class RougeBase(keras.metrics.Metric):
|
33
|
+
"""ROUGE metric.
|
34
|
+
|
35
|
+
This class implements two variants of the ROUGE metric - ROUGE-N,
|
36
|
+
and ROUGE-L.
|
37
|
+
|
38
|
+
Note on input shapes:
|
39
|
+
For `y_true` and `y_pred`, this class supports scalar values and batch
|
40
|
+
inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
variant: string. One of "rougeN", "rougeL". For "rougeN", N lies in
|
44
|
+
the range [1, 9]. Defaults to `"rouge2"`.
|
45
|
+
use_stemmer: bool. Whether Porter Stemmer should be used to strip word
|
46
|
+
suffixes to improve matching. Defaults to `False`.
|
47
|
+
dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
|
48
|
+
not specified, it defaults to `"float32"`.
|
49
|
+
name: string. Name of the metric instance.
|
50
|
+
**kwargs: Other keyword arguments.
|
51
|
+
|
52
|
+
References:
|
53
|
+
- [Lin et al., 2004](https://aclanthology.org/W04-1013/)
|
54
|
+
"""
|
55
|
+
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
variant="rouge2",
|
59
|
+
use_stemmer=False,
|
60
|
+
dtype="float32",
|
61
|
+
name="rouge",
|
62
|
+
**kwargs,
|
63
|
+
):
|
64
|
+
super().__init__(name=name, dtype=dtype, **kwargs)
|
65
|
+
|
66
|
+
if rouge_scorer is None:
|
67
|
+
raise ImportError(
|
68
|
+
f"{self.__class__.__name__} requires the `rouge_score` "
|
69
|
+
"package. Please install it with `pip install rouge-score`."
|
70
|
+
)
|
71
|
+
|
72
|
+
if not is_float_dtype(dtype):
|
73
|
+
raise ValueError(
|
74
|
+
"`dtype` must be a floating point type. "
|
75
|
+
f"Received: dtype={dtype}"
|
76
|
+
)
|
77
|
+
|
78
|
+
if variant not in tuple(
|
79
|
+
("rouge" + str(order) for order in range(1, 10))
|
80
|
+
) + ("rougeL",):
|
81
|
+
raise ValueError(
|
82
|
+
"Invalid variant of ROUGE. Should be one of: rougeN, rougeL, "
|
83
|
+
"with N ranging from 1 to 9. Received: "
|
84
|
+
f"variant={variant}"
|
85
|
+
)
|
86
|
+
|
87
|
+
self.variant = variant
|
88
|
+
self.use_stemmer = use_stemmer
|
89
|
+
|
90
|
+
# To-do: Add split_summaries and tokenizer options after the maintainers
|
91
|
+
# of rouge_scorer have released a new version.
|
92
|
+
self._rouge_scorer = rouge_scorer.RougeScorer(
|
93
|
+
rouge_types=[self.variant],
|
94
|
+
use_stemmer=use_stemmer,
|
95
|
+
)
|
96
|
+
|
97
|
+
self._rouge_precision = self.add_weight(
|
98
|
+
shape=(),
|
99
|
+
initializer="zeros",
|
100
|
+
dtype=self.dtype,
|
101
|
+
name="rouge_precision",
|
102
|
+
)
|
103
|
+
self._rouge_recall = self.add_weight(
|
104
|
+
shape=(),
|
105
|
+
initializer="zeros",
|
106
|
+
dtype=self.dtype,
|
107
|
+
name="rouge_recall",
|
108
|
+
)
|
109
|
+
self._rouge_f1_score = self.add_weight(
|
110
|
+
shape=(),
|
111
|
+
initializer="zeros",
|
112
|
+
dtype=self.dtype,
|
113
|
+
name="rouge_f1_score",
|
114
|
+
)
|
115
|
+
|
116
|
+
self._number_of_samples = self.add_weight(
|
117
|
+
shape=(),
|
118
|
+
initializer="zeros",
|
119
|
+
dtype=self.dtype,
|
120
|
+
name="number_of_samples",
|
121
|
+
)
|
122
|
+
|
123
|
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
124
|
+
# Three possible shapes for y_true and y_pred: Python string,
|
125
|
+
# [batch_size] and [batch_size, 1]. In the latter two cases, we have
|
126
|
+
# strings in the tensor/list.
|
127
|
+
|
128
|
+
def validate_and_fix_rank(inputs, tensor_name):
|
129
|
+
if not isinstance(inputs, tf.Tensor):
|
130
|
+
inputs = tf.convert_to_tensor(inputs)
|
131
|
+
|
132
|
+
if inputs.shape.rank == 0:
|
133
|
+
return inputs[tf.newaxis]
|
134
|
+
elif inputs.shape.rank == 1:
|
135
|
+
return inputs
|
136
|
+
elif inputs.shape.rank == 2:
|
137
|
+
if inputs.shape[1] != 1:
|
138
|
+
raise ValueError(
|
139
|
+
f"{tensor_name} must be of shape `[batch_size, 1]`. "
|
140
|
+
f"Found shape: {inputs.shape}"
|
141
|
+
)
|
142
|
+
else:
|
143
|
+
return tf.squeeze(inputs, axis=1)
|
144
|
+
else:
|
145
|
+
raise ValueError(
|
146
|
+
f"{tensor_name} must be of rank 0 (scalar input), 1 or 2. "
|
147
|
+
f"Found rank: {inputs.shape.rank}"
|
148
|
+
)
|
149
|
+
|
150
|
+
y_true = validate_and_fix_rank(y_true, "y_true")
|
151
|
+
y_pred = validate_and_fix_rank(y_pred, "y_pred")
|
152
|
+
|
153
|
+
batch_size = tf.shape(y_true)[0]
|
154
|
+
|
155
|
+
def calculate_rouge_score(reference, hypothesis):
|
156
|
+
reference = tensor_to_list(reference)
|
157
|
+
hypothesis = tensor_to_list(hypothesis)
|
158
|
+
score = self._rouge_scorer.score(reference, hypothesis)[
|
159
|
+
self.variant
|
160
|
+
]
|
161
|
+
return score.precision, score.recall, score.fmeasure
|
162
|
+
|
163
|
+
for batch_idx in range(batch_size):
|
164
|
+
score = calculate_rouge_score(y_true[batch_idx], y_pred[batch_idx])
|
165
|
+
self._rouge_precision.assign_add(score[0])
|
166
|
+
self._rouge_recall.assign_add(score[1])
|
167
|
+
self._rouge_f1_score.assign_add(score[2])
|
168
|
+
|
169
|
+
self._number_of_samples.assign_add(
|
170
|
+
ops.cast(batch_size, dtype=self.dtype)
|
171
|
+
)
|
172
|
+
|
173
|
+
def result(self):
|
174
|
+
if self._number_of_samples == 0:
|
175
|
+
return {
|
176
|
+
"precision": 0.0,
|
177
|
+
"recall": 0.0,
|
178
|
+
"f1_score": 0.0,
|
179
|
+
}
|
180
|
+
|
181
|
+
rouge_precision = self._rouge_precision / self._number_of_samples
|
182
|
+
rouge_recall = self._rouge_recall / self._number_of_samples
|
183
|
+
rouge_f1_score = self._rouge_f1_score / self._number_of_samples
|
184
|
+
return {
|
185
|
+
"precision": rouge_precision,
|
186
|
+
"recall": rouge_recall,
|
187
|
+
"f1_score": rouge_f1_score,
|
188
|
+
}
|
189
|
+
|
190
|
+
def reset_state(self):
|
191
|
+
self._rouge_precision.assign(0.0)
|
192
|
+
self._rouge_recall.assign(0.0)
|
193
|
+
self._rouge_f1_score.assign(0.0)
|
194
|
+
self._number_of_samples.assign(0.0)
|
195
|
+
|
196
|
+
def get_config(self):
|
197
|
+
config = super().get_config()
|
198
|
+
config.update(
|
199
|
+
{
|
200
|
+
"variant": self.variant,
|
201
|
+
"use_stemmer": self.use_stemmer,
|
202
|
+
}
|
203
|
+
)
|
204
|
+
return config
|
@@ -0,0 +1,97 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from keras_hub.src.api_export import keras_hub_export
|
16
|
+
from keras_hub.src.metrics.rouge_base import RougeBase
|
17
|
+
|
18
|
+
|
19
|
+
@keras_hub_export("keras_hub.metrics.RougeL")
|
20
|
+
class RougeL(RougeBase):
|
21
|
+
"""ROUGE-L metric.
|
22
|
+
|
23
|
+
This class implements the ROUGE-L variant of the ROUGE metric. The ROUGE-L
|
24
|
+
metric is traditionally used for evaluating summarisation systems.
|
25
|
+
Succinctly put, ROUGE-L is a score based on the length of the longest
|
26
|
+
common subsequence present in the reference text and the hypothesis text.
|
27
|
+
|
28
|
+
Note on input shapes:
|
29
|
+
For `y_true` and `y_pred`, this class supports scalar values and batch
|
30
|
+
inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
use_stemmer: bool. Whether Porter Stemmer should be used to strip word
|
34
|
+
suffixes to improve matching. Defaults to `False`.
|
35
|
+
dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
|
36
|
+
not specified, it defaults to `"float32"`.
|
37
|
+
name: string. Name of the metric instance.
|
38
|
+
**kwargs: Other keyword arguments.
|
39
|
+
|
40
|
+
References:
|
41
|
+
- [Lin et al., 2004](https://aclanthology.org/W04-1013/)
|
42
|
+
|
43
|
+
Examples:
|
44
|
+
|
45
|
+
1. Python string.
|
46
|
+
>>> rouge_l = keras_hub.metrics.RougeL()
|
47
|
+
>>> y_true = "the tiny little cat was found under the big funny bed"
|
48
|
+
>>> y_pred = "the cat was under the bed"
|
49
|
+
>>> rouge_l(y_true, y_pred)["f1_score"]
|
50
|
+
<tf.Tensor: shape=(), dtype=float32, numpy=0.7058824>
|
51
|
+
|
52
|
+
2. List inputs.
|
53
|
+
a. Python list.
|
54
|
+
>>> rouge_l = keras_hub.metrics.RougeL()
|
55
|
+
>>> y_true = [
|
56
|
+
... "the tiny little cat was found under the big funny bed",
|
57
|
+
... "i really love contributing to KerasHub",
|
58
|
+
... ]
|
59
|
+
>>> y_pred = [
|
60
|
+
... "the cat was under the bed",
|
61
|
+
... "i love contributing to KerasHub",
|
62
|
+
... ]
|
63
|
+
>>> rouge_l(y_true, y_pred)["f1_score"]
|
64
|
+
<tf.Tensor: shape=(), dtype=float32, numpy=0.80748665>
|
65
|
+
|
66
|
+
|
67
|
+
3. 2D inputs.
|
68
|
+
>>> rouge_l = keras_hub.metrics.RougeL()
|
69
|
+
>>> y_true = [
|
70
|
+
... ["the tiny little cat was found under the big funny bed"],
|
71
|
+
... ["i really love contributing to KerasHub"],
|
72
|
+
... ]
|
73
|
+
>>> y_pred = [
|
74
|
+
... ["the cat was under the bed"],
|
75
|
+
... ["i love contributing to KerasHub"],
|
76
|
+
... ]
|
77
|
+
>>> rouge_l(y_true, y_pred)["f1_score"]
|
78
|
+
<tf.Tensor: shape=(), dtype=float32, numpy=0.80748665>
|
79
|
+
"""
|
80
|
+
|
81
|
+
def __init__(
|
82
|
+
self,
|
83
|
+
use_stemmer=False,
|
84
|
+
name="rouge-l",
|
85
|
+
**kwargs,
|
86
|
+
):
|
87
|
+
super().__init__(
|
88
|
+
variant="rougeL",
|
89
|
+
use_stemmer=use_stemmer,
|
90
|
+
name=name,
|
91
|
+
**kwargs,
|
92
|
+
)
|
93
|
+
|
94
|
+
def get_config(self):
|
95
|
+
config = super().get_config()
|
96
|
+
del config["variant"]
|
97
|
+
return config
|
@@ -0,0 +1,125 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from keras_hub.src.api_export import keras_hub_export
|
16
|
+
from keras_hub.src.metrics.rouge_base import RougeBase
|
17
|
+
|
18
|
+
|
19
|
+
@keras_hub_export("keras_hub.metrics.RougeN")
|
20
|
+
class RougeN(RougeBase):
|
21
|
+
"""ROUGE-N metric.
|
22
|
+
|
23
|
+
This class implements the ROUGE-N variant of the ROUGE metric. The ROUGE-N
|
24
|
+
metric is traditionally used for evaluating summarisation systems.
|
25
|
+
Succinctly put, ROUGE-N is a score based on the number of matching n-grams
|
26
|
+
between the reference text and the hypothesis text.
|
27
|
+
|
28
|
+
Note on input shapes:
|
29
|
+
For `y_true` and `y_pred`, this class supports scalar values and batch
|
30
|
+
inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
order: The order of n-grams which are to be matched. It should lie in
|
34
|
+
range [1, 9]. Defaults to `2`.
|
35
|
+
use_stemmer: bool. Whether Porter Stemmer should be used to strip word
|
36
|
+
suffixes to improve matching. Defaults to `False`.
|
37
|
+
dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
|
38
|
+
not specified, it defaults to `"float32"`.
|
39
|
+
name: string. Name of the metric instance.
|
40
|
+
**kwargs: Other keyword arguments.
|
41
|
+
|
42
|
+
References:
|
43
|
+
- [Lin et al., 2004](https://aclanthology.org/W04-1013/)
|
44
|
+
|
45
|
+
Examples:
|
46
|
+
|
47
|
+
1. Python string.
|
48
|
+
>>> rouge_n = keras_hub.metrics.RougeN(order=2)
|
49
|
+
>>> y_true = "the tiny little cat was found under the big funny bed"
|
50
|
+
>>> y_pred = "the cat was under the bed"
|
51
|
+
>>> rouge_n(y_true, y_pred)["f1_score"]
|
52
|
+
<tf.Tensor: shape=(), dtype=float32, numpy=0.26666668>
|
53
|
+
|
54
|
+
2. List inputs.
|
55
|
+
>>> rouge_n = keras_hub.metrics.RougeN(order=2)
|
56
|
+
>>> y_true = [
|
57
|
+
... "the tiny little cat was found under the big funny bed",
|
58
|
+
... "i really love contributing to KerasHub",
|
59
|
+
... ]
|
60
|
+
>>> y_pred = [
|
61
|
+
... "the cat was under the bed",
|
62
|
+
... "i love contributing to KerasHub",
|
63
|
+
... ]
|
64
|
+
>>> rouge_n(y_true, y_pred)["f1_score"]
|
65
|
+
<tf.Tensor: shape=(), dtype=float32, numpy=0.4666667>
|
66
|
+
|
67
|
+
3. 2D inputs.
|
68
|
+
>>> rouge_n = keras_hub.metrics.RougeN(order=2)
|
69
|
+
>>> y_true =[
|
70
|
+
... ["the tiny little cat was found under the big funny bed"],
|
71
|
+
... ["i really love contributing to KerasHub"],
|
72
|
+
... ]
|
73
|
+
>>> y_pred =[
|
74
|
+
... ["the cat was under the bed"],
|
75
|
+
... ["i love contributing to KerasHub"],
|
76
|
+
... ]
|
77
|
+
>>> rouge_n(y_true, y_pred)["f1_score"]
|
78
|
+
<tf.Tensor: shape=(), dtype=float32, numpy=0.4666667>
|
79
|
+
|
80
|
+
4. Trigrams.
|
81
|
+
>>> rouge_n = keras_hub.metrics.RougeN(order=3)
|
82
|
+
>>> y_true = [
|
83
|
+
... "the tiny little cat was found under the big funny bed",
|
84
|
+
... "i really love contributing to KerasHub",
|
85
|
+
... ]
|
86
|
+
>>> y_pred = [
|
87
|
+
... "the cat was under the bed",
|
88
|
+
... "i love contributing to KerasHub",
|
89
|
+
... ]
|
90
|
+
>>> rouge_n(y_true, y_pred)["f1_score"]
|
91
|
+
<tf.Tensor: shape=(), dtype=float32, numpy=0.2857143>
|
92
|
+
"""
|
93
|
+
|
94
|
+
def __init__(
|
95
|
+
self,
|
96
|
+
order=2,
|
97
|
+
use_stemmer=False,
|
98
|
+
name="rouge-n",
|
99
|
+
**kwargs,
|
100
|
+
):
|
101
|
+
if order not in range(1, 10):
|
102
|
+
raise ValueError(
|
103
|
+
"Invalid `order` value. Should lie in the range [1, 9]."
|
104
|
+
f"Received order={order}"
|
105
|
+
)
|
106
|
+
|
107
|
+
super().__init__(
|
108
|
+
variant=f"rouge{order}",
|
109
|
+
use_stemmer=use_stemmer,
|
110
|
+
name=name,
|
111
|
+
**kwargs,
|
112
|
+
)
|
113
|
+
|
114
|
+
self.order = order
|
115
|
+
|
116
|
+
def get_config(self):
|
117
|
+
config = super().get_config()
|
118
|
+
del config["variant"]
|
119
|
+
|
120
|
+
config.update(
|
121
|
+
{
|
122
|
+
"order": self.order,
|
123
|
+
}
|
124
|
+
)
|
125
|
+
return config
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from keras_hub.src.models.albert.albert_backbone import AlbertBackbone
|
16
|
+
from keras_hub.src.models.albert.albert_presets import backbone_presets
|
17
|
+
from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer
|
18
|
+
from keras_hub.src.utils.preset_utils import register_presets
|
19
|
+
|
20
|
+
register_presets(backbone_presets, (AlbertBackbone, AlbertTokenizer))
|