keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,237 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
from keras import random
|
18
|
+
|
19
|
+
from keras_hub.src.api_export import keras_hub_export
|
20
|
+
from keras_hub.src.utils.tensor_utils import any_equal
|
21
|
+
|
22
|
+
|
23
|
+
@keras_hub_export("keras_hub.samplers.Sampler")
|
24
|
+
class Sampler:
|
25
|
+
"""Base sampler class.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
temperature: float. optional. Used to control the
|
29
|
+
randomness of the sampling. The higher the temperature, the
|
30
|
+
more diverse the samples. Defaults to `1.0`.
|
31
|
+
|
32
|
+
Call arguments:
|
33
|
+
{{call_args}}
|
34
|
+
|
35
|
+
This base class can be extended to implement different auto-regressive
|
36
|
+
sampling methods. To do so, override the `get_next_token()` method, which
|
37
|
+
computes the next token based on a probability distribution over all
|
38
|
+
possible vocab entries.
|
39
|
+
|
40
|
+
Example:
|
41
|
+
|
42
|
+
```python
|
43
|
+
causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
|
44
|
+
|
45
|
+
# Greedy search with some tokens forbidden.
|
46
|
+
class CustomSampler(keras_hub.samplers.Sampler):
|
47
|
+
def __init__(self, forbidden_tokens, **kwargs):
|
48
|
+
super().__init__(**kwargs)
|
49
|
+
self.forbidden_tokens = forbidden_tokens
|
50
|
+
|
51
|
+
def get_next_token(self, probs):
|
52
|
+
batch_size, vocab_size = keras.ops.shape(probs)
|
53
|
+
for id in self.forbidden_tokens:
|
54
|
+
update = keras.ops.zeros((batch_size, 1))
|
55
|
+
probs = keras.ops.slice_update(probs, (0, id), update)
|
56
|
+
return keras.ops.argmax(probs, axis=-1)
|
57
|
+
|
58
|
+
# 257 = "a" with a leading space, 262 = "the" with a leading space.
|
59
|
+
causal_lm.compile(sampler=CustomSampler(forbidden_tokens=[257, 262]))
|
60
|
+
causal_lm.summary()
|
61
|
+
causal_lm.generate(["That's strange"])
|
62
|
+
```
|
63
|
+
"""
|
64
|
+
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
temperature=1.0,
|
68
|
+
):
|
69
|
+
self.temperature = temperature
|
70
|
+
self._seed_generators = []
|
71
|
+
|
72
|
+
def __setattr__(self, name, value):
|
73
|
+
# We could update to the `Tracker` class from keras-core if our needs
|
74
|
+
# become more advanced (e.g. list assignment, nested trackables). For
|
75
|
+
# now, we only track `SeedGenerator` instances directly on the sampler.
|
76
|
+
if isinstance(value, random.SeedGenerator):
|
77
|
+
self._seed_generators.append(value)
|
78
|
+
return super().__setattr__(name, value)
|
79
|
+
|
80
|
+
@property
|
81
|
+
def variables(self):
|
82
|
+
variables = []
|
83
|
+
for sg in self._seed_generators:
|
84
|
+
variables.append(sg.state)
|
85
|
+
return variables
|
86
|
+
|
87
|
+
def __call__(
|
88
|
+
self,
|
89
|
+
next,
|
90
|
+
prompt,
|
91
|
+
cache=None,
|
92
|
+
index=0,
|
93
|
+
mask=None,
|
94
|
+
stop_token_ids=None,
|
95
|
+
hidden_states=None,
|
96
|
+
model=None,
|
97
|
+
):
|
98
|
+
max_length = ops.shape(prompt)[-1]
|
99
|
+
# Make sure `max_length` and `index` are the same dtype.
|
100
|
+
index = ops.cast(index, "int32")
|
101
|
+
max_length = ops.cast(max_length, "int32")
|
102
|
+
if mask is None:
|
103
|
+
mask = ops.zeros_like(prompt, dtype="bool")
|
104
|
+
else:
|
105
|
+
mask = ops.cast(mask, dtype="bool")
|
106
|
+
# `ops.while_loop` will not accept `None` as a value for `loop_vars`.
|
107
|
+
cache = () if cache is None else cache
|
108
|
+
|
109
|
+
def cond(prompt, cache, index):
|
110
|
+
if stop_token_ids is None:
|
111
|
+
return True
|
112
|
+
# Stop if all sequences have produced a *new* id from stop_token_ids.
|
113
|
+
end_tokens = any_equal(prompt, stop_token_ids, ~mask)
|
114
|
+
prompt_done = ops.any(end_tokens, axis=-1)
|
115
|
+
return ops.logical_not(ops.all(prompt_done))
|
116
|
+
|
117
|
+
def body(prompt, cache, index):
|
118
|
+
# Compute the softmax distribution for the next token.
|
119
|
+
logits, _, cache = next(prompt, cache, index)
|
120
|
+
probabilities = self.compute_probabilities(logits)
|
121
|
+
# Compute the next token.
|
122
|
+
next_token = self.get_next_token(probabilities)
|
123
|
+
# Don't overwrite anywhere mask is True.
|
124
|
+
next_token = ops.cast(next_token, prompt.dtype)
|
125
|
+
next_token = ops.where(mask[:, index], prompt[:, index], next_token)
|
126
|
+
# Update the prompt with the next token.
|
127
|
+
next_token = next_token[:, None]
|
128
|
+
prompt = ops.slice_update(prompt, [0, index], next_token)
|
129
|
+
|
130
|
+
# Return the next prompt, cache and incremented index.
|
131
|
+
return (prompt, cache, index + 1)
|
132
|
+
|
133
|
+
prompt, _, _ = self.run_loop(
|
134
|
+
cond,
|
135
|
+
body,
|
136
|
+
loop_vars=(prompt, cache, index),
|
137
|
+
maximum_iterations=(max_length - index),
|
138
|
+
model=model,
|
139
|
+
)
|
140
|
+
return prompt
|
141
|
+
|
142
|
+
def compute_probabilities(self, logits):
|
143
|
+
"""Compute token probabilities from logits.
|
144
|
+
|
145
|
+
This will always be done in full precision, regardless of dtype, and
|
146
|
+
scale by `temperature`.
|
147
|
+
"""
|
148
|
+
logits = ops.cast(logits, "float32")
|
149
|
+
return keras.activations.softmax(logits / self.temperature)
|
150
|
+
|
151
|
+
def run_loop(
|
152
|
+
self, cond, body, model=None, loop_vars=None, maximum_iterations=None
|
153
|
+
):
|
154
|
+
"""Run ops.while_loops with a `StatelessScope` if necessary."""
|
155
|
+
if keras.config.backend() == "jax":
|
156
|
+
import itertools
|
157
|
+
|
158
|
+
if model:
|
159
|
+
model_trainable_variables = model.trainable_variables
|
160
|
+
model_non_trainable_variables = model.non_trainable_variables
|
161
|
+
else:
|
162
|
+
model_trainable_variables = []
|
163
|
+
model_non_trainable_variables = []
|
164
|
+
|
165
|
+
def stateless_cond(state, *loop_vars):
|
166
|
+
return cond(*loop_vars)
|
167
|
+
|
168
|
+
def stateless_body(state, *loop_vars):
|
169
|
+
(
|
170
|
+
sampler_variables,
|
171
|
+
trainable_variables,
|
172
|
+
non_trainable_variables,
|
173
|
+
) = state
|
174
|
+
mapping = itertools.chain(
|
175
|
+
zip(self.variables, sampler_variables),
|
176
|
+
zip(model_trainable_variables, trainable_variables),
|
177
|
+
zip(model_non_trainable_variables, non_trainable_variables),
|
178
|
+
)
|
179
|
+
with keras.StatelessScope(state_mapping=mapping) as scope:
|
180
|
+
loop_vars = body(*loop_vars)
|
181
|
+
|
182
|
+
sampler_variables = []
|
183
|
+
for v in self.variables:
|
184
|
+
new_v = scope.get_current_value(v)
|
185
|
+
sampler_variables.append(new_v if new_v is not None else v)
|
186
|
+
state = (
|
187
|
+
sampler_variables,
|
188
|
+
trainable_variables,
|
189
|
+
non_trainable_variables,
|
190
|
+
)
|
191
|
+
return state, *loop_vars
|
192
|
+
|
193
|
+
variables = [ops.convert_to_tensor(v) for v in self.variables]
|
194
|
+
trainable_variables = [
|
195
|
+
ops.convert_to_tensor(v) for v in model_trainable_variables
|
196
|
+
]
|
197
|
+
non_trainable_variables = [
|
198
|
+
ops.convert_to_tensor(v) for v in model_non_trainable_variables
|
199
|
+
]
|
200
|
+
state = (
|
201
|
+
variables,
|
202
|
+
trainable_variables,
|
203
|
+
non_trainable_variables,
|
204
|
+
)
|
205
|
+
state, *loop_vars = ops.while_loop(
|
206
|
+
cond=stateless_cond,
|
207
|
+
body=stateless_body,
|
208
|
+
loop_vars=(state, *loop_vars),
|
209
|
+
maximum_iterations=maximum_iterations,
|
210
|
+
)
|
211
|
+
for ref_v, v in zip(self.variables, state[0]):
|
212
|
+
ref_v.assign(v)
|
213
|
+
else:
|
214
|
+
loop_vars = ops.while_loop(
|
215
|
+
cond=cond,
|
216
|
+
body=body,
|
217
|
+
loop_vars=(loop_vars),
|
218
|
+
maximum_iterations=maximum_iterations,
|
219
|
+
)
|
220
|
+
return loop_vars
|
221
|
+
|
222
|
+
def get_next_token(self, probabilities):
|
223
|
+
"""Get the next token.
|
224
|
+
Args:
|
225
|
+
probabilities: a Tensor, the probability distribution for next
|
226
|
+
token over all vocab tokens.
|
227
|
+
Get the next token based on given probability distribution over tokens.
|
228
|
+
Subclasses must implement this method.
|
229
|
+
"""
|
230
|
+
raise NotImplementedError
|
231
|
+
|
232
|
+
@classmethod
|
233
|
+
def from_config(cls, config):
|
234
|
+
return cls(**config)
|
235
|
+
|
236
|
+
def get_config(self):
|
237
|
+
return {"temperature": self.temperature}
|
@@ -0,0 +1,97 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
|
17
|
+
from keras_hub.src.api_export import keras_hub_export
|
18
|
+
from keras_hub.src.samplers.beam_sampler import BeamSampler
|
19
|
+
from keras_hub.src.samplers.contrastive_sampler import ContrastiveSampler
|
20
|
+
from keras_hub.src.samplers.greedy_sampler import GreedySampler
|
21
|
+
from keras_hub.src.samplers.random_sampler import RandomSampler
|
22
|
+
from keras_hub.src.samplers.top_k_sampler import TopKSampler
|
23
|
+
from keras_hub.src.samplers.top_p_sampler import TopPSampler
|
24
|
+
|
25
|
+
|
26
|
+
@keras_hub_export("keras_hub.samplers.serialize")
|
27
|
+
def serialize(sampler):
|
28
|
+
return keras.saving.serialize_keras_object(sampler)
|
29
|
+
|
30
|
+
|
31
|
+
@keras_hub_export("keras_hub.samplers.deserialize")
|
32
|
+
def deserialize(config, custom_objects=None):
|
33
|
+
"""Return a `Sampler` object from its config."""
|
34
|
+
all_classes = {
|
35
|
+
"beam": BeamSampler,
|
36
|
+
"contrastive": ContrastiveSampler,
|
37
|
+
"greedy": GreedySampler,
|
38
|
+
"random": RandomSampler,
|
39
|
+
"top_k": TopKSampler,
|
40
|
+
"top_p": TopPSampler,
|
41
|
+
}
|
42
|
+
return keras.saving.deserialize_keras_object(
|
43
|
+
config,
|
44
|
+
module_objects=all_classes,
|
45
|
+
custom_objects=custom_objects,
|
46
|
+
printable_module_name="samplers",
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
@keras_hub_export("keras_hub.samplers.get")
|
51
|
+
def get(identifier):
|
52
|
+
"""Retrieve a KerasHub sampler by the identifier.
|
53
|
+
|
54
|
+
The `identifier` may be the string name of a sampler class or class.
|
55
|
+
|
56
|
+
>>> identifier = 'greedy'
|
57
|
+
>>> sampler = keras_hub.samplers.get(identifier)
|
58
|
+
|
59
|
+
You can also specify `config` of the sampler to this function by passing
|
60
|
+
dict containing `class_name` and `config` as an identifier. Also note that
|
61
|
+
the `class_name` must map to a `Sampler` class.
|
62
|
+
|
63
|
+
>>> cfg = {'class_name': 'keras_hub>GreedySampler', 'config': {}}
|
64
|
+
>>> sampler = keras_hub.samplers.get(cfg)
|
65
|
+
|
66
|
+
In the case that the `identifier` is a class, this method will return a new
|
67
|
+
instance of the class by its constructor.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
identifier: String or dict that contains the sampler name or
|
71
|
+
configurations.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
Sampler instance base on the input identifier.
|
75
|
+
|
76
|
+
Raises:
|
77
|
+
ValueError: If the input identifier is not a supported type or in a bad
|
78
|
+
format.
|
79
|
+
"""
|
80
|
+
|
81
|
+
if identifier is None:
|
82
|
+
return None
|
83
|
+
if isinstance(identifier, dict):
|
84
|
+
return deserialize(identifier)
|
85
|
+
elif isinstance(identifier, str):
|
86
|
+
if not identifier.islower():
|
87
|
+
raise KeyError(
|
88
|
+
"`keras_hub.samplers.get()` must take a lowercase string "
|
89
|
+
f"identifier, but received: {identifier}."
|
90
|
+
)
|
91
|
+
return deserialize(identifier)
|
92
|
+
elif callable(identifier):
|
93
|
+
return identifier
|
94
|
+
else:
|
95
|
+
raise ValueError(
|
96
|
+
"Could not interpret sampler identifier: " + str(identifier)
|
97
|
+
)
|
@@ -0,0 +1,92 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from keras import ops
|
16
|
+
from keras import random
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
from keras_hub.src.samplers.sampler import Sampler
|
20
|
+
|
21
|
+
|
22
|
+
@keras_hub_export("keras_hub.samplers.TopKSampler")
|
23
|
+
class TopKSampler(Sampler):
|
24
|
+
"""Top-K Sampler class.
|
25
|
+
|
26
|
+
This sampler implements top-k search algorithm. Briefly, top-k algorithm
|
27
|
+
randomly selects a token from the tokens of top K probability, with
|
28
|
+
selection chance determined by the probability.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
k: int, the `k` value of top-k.
|
32
|
+
seed: int. The random seed. Defaults to `None`.
|
33
|
+
|
34
|
+
Call arguments:
|
35
|
+
{{call_args}}
|
36
|
+
|
37
|
+
Examples:
|
38
|
+
```python
|
39
|
+
causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
|
40
|
+
|
41
|
+
# Pass by name to compile.
|
42
|
+
causal_lm.compile(sampler="top_k")
|
43
|
+
causal_lm.generate(["Keras is a"])
|
44
|
+
|
45
|
+
# Pass by object to compile.
|
46
|
+
sampler = keras_hub.samplers.TopKSampler(k=5, temperature=0.7)
|
47
|
+
causal_lm.compile(sampler=sampler)
|
48
|
+
causal_lm.generate(["Keras is a"])
|
49
|
+
```
|
50
|
+
"""
|
51
|
+
|
52
|
+
def __init__(
|
53
|
+
self,
|
54
|
+
k=5,
|
55
|
+
seed=None,
|
56
|
+
**kwargs,
|
57
|
+
):
|
58
|
+
super().__init__(**kwargs)
|
59
|
+
self.k = k
|
60
|
+
self.seed = seed
|
61
|
+
self.seed_generator = random.SeedGenerator(seed)
|
62
|
+
|
63
|
+
def get_next_token(self, probabilities):
|
64
|
+
# Filter out top-k tokens.
|
65
|
+
top_k_pred, top_k_indices = ops.top_k(
|
66
|
+
probabilities,
|
67
|
+
k=self.k,
|
68
|
+
sorted=False,
|
69
|
+
)
|
70
|
+
# Sample the next token from the probability distribution.
|
71
|
+
sample_indices = random.categorical(
|
72
|
+
# tf does not support half precision multinomial sampling, so make
|
73
|
+
# sure we have full precision here.
|
74
|
+
ops.cast(ops.log(top_k_pred), "float32"),
|
75
|
+
1,
|
76
|
+
seed=self.seed_generator,
|
77
|
+
dtype="int32",
|
78
|
+
)
|
79
|
+
|
80
|
+
# Rearrange to get the next token idx from the original order.
|
81
|
+
output = ops.take_along_axis(top_k_indices, sample_indices, axis=-1)
|
82
|
+
return ops.squeeze(output, axis=-1)
|
83
|
+
|
84
|
+
def get_config(self):
|
85
|
+
config = super().get_config()
|
86
|
+
config.update(
|
87
|
+
{
|
88
|
+
"k": self.k,
|
89
|
+
"seed": self.seed,
|
90
|
+
}
|
91
|
+
)
|
92
|
+
return config
|
@@ -0,0 +1,113 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from keras import ops
|
16
|
+
from keras import random
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
from keras_hub.src.samplers.sampler import Sampler
|
20
|
+
|
21
|
+
|
22
|
+
@keras_hub_export("keras_hub.samplers.TopPSampler")
|
23
|
+
class TopPSampler(Sampler):
|
24
|
+
"""Top-P Sampler class.
|
25
|
+
|
26
|
+
This sampler implements top-p search algorithm. Top-p search selects tokens
|
27
|
+
from the smallest subset of output probabilities that sum to greater than
|
28
|
+
`p`. Put in another way, top-p will first order token predictions by
|
29
|
+
likelihood, and ignore all tokens after the cumulative probability of
|
30
|
+
selected tokens exceeds `p`, then select a token from the remaining tokens.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
p: float, the `p` value of top-p.
|
34
|
+
k: int. If set, this argument defines a
|
35
|
+
heuristic "top-k" cutoff applied before the "top-p" sampling. All
|
36
|
+
logits not in the top `k` will be discarded, and the remaining
|
37
|
+
logits will be sorted to find a cutoff point for `p`. Setting this
|
38
|
+
arg can significantly speed sampling up by reducing the number
|
39
|
+
of tokens to sort. Defaults to `None`.
|
40
|
+
seed: int. The random seed. Defaults to `None`.
|
41
|
+
|
42
|
+
Call arguments:
|
43
|
+
{{call_args}}
|
44
|
+
|
45
|
+
Examples:
|
46
|
+
```python
|
47
|
+
causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
|
48
|
+
|
49
|
+
# Pass by name to compile.
|
50
|
+
causal_lm.compile(sampler="top_p")
|
51
|
+
causal_lm.generate(["Keras is a"])
|
52
|
+
|
53
|
+
# Pass by object to compile.
|
54
|
+
sampler = keras_hub.samplers.TopPSampler(p=0.1, k=1_000)
|
55
|
+
causal_lm.compile(sampler=sampler)
|
56
|
+
causal_lm.generate(["Keras is a"])
|
57
|
+
```
|
58
|
+
"""
|
59
|
+
|
60
|
+
def __init__(
|
61
|
+
self,
|
62
|
+
p=0.1,
|
63
|
+
k=None,
|
64
|
+
seed=None,
|
65
|
+
**kwargs,
|
66
|
+
):
|
67
|
+
super().__init__(**kwargs)
|
68
|
+
self.p = p
|
69
|
+
self.k = k
|
70
|
+
self.seed = seed
|
71
|
+
self.seed_generator = random.SeedGenerator(seed)
|
72
|
+
|
73
|
+
def get_next_token(self, probabilities):
|
74
|
+
cutoff = ops.shape(probabilities)[1]
|
75
|
+
if self.k is not None:
|
76
|
+
# If `k` is set, only sample from top `k` tokens.
|
77
|
+
cutoff = self.k
|
78
|
+
sorted_preds, sorted_indices = ops.top_k(
|
79
|
+
probabilities, k=cutoff, sorted=True
|
80
|
+
)
|
81
|
+
# Calculate cumulative probability distribution.
|
82
|
+
cumulative_probabilities = ops.cumsum(sorted_preds, axis=-1)
|
83
|
+
# Create a mask for the tokens to keep.
|
84
|
+
keep_mask = cumulative_probabilities <= self.p
|
85
|
+
# Shift to include the last token that exceed p.
|
86
|
+
shifted_keep_mask = ops.concatenate(
|
87
|
+
[ops.ones_like(keep_mask[:, :1]), keep_mask[:, :-1]], axis=-1
|
88
|
+
)
|
89
|
+
# Filter out unmasked tokens and sample from filtered distribution.
|
90
|
+
probabilities = ops.where(
|
91
|
+
shifted_keep_mask,
|
92
|
+
sorted_preds,
|
93
|
+
ops.zeros(ops.shape(sorted_preds), dtype=sorted_preds.dtype),
|
94
|
+
)
|
95
|
+
sorted_next_token = random.categorical(
|
96
|
+
ops.log(probabilities),
|
97
|
+
1,
|
98
|
+
seed=self.seed_generator,
|
99
|
+
dtype="int32",
|
100
|
+
)
|
101
|
+
output = ops.take_along_axis(sorted_indices, sorted_next_token, axis=-1)
|
102
|
+
return ops.squeeze(output, axis=-1)
|
103
|
+
|
104
|
+
def get_config(self):
|
105
|
+
config = super().get_config()
|
106
|
+
config.update(
|
107
|
+
{
|
108
|
+
"p": self.p,
|
109
|
+
"k": self.k,
|
110
|
+
"seed": self.seed,
|
111
|
+
}
|
112
|
+
)
|
113
|
+
return config
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|