keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,383 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import itertools
|
16
|
+
from functools import partial
|
17
|
+
|
18
|
+
import keras
|
19
|
+
from keras import ops
|
20
|
+
from keras import tree
|
21
|
+
|
22
|
+
from keras_hub.src.api_export import keras_hub_export
|
23
|
+
from keras_hub.src.models.task import Task
|
24
|
+
from keras_hub.src.samplers.serialization import get as get_sampler
|
25
|
+
from keras_hub.src.utils.tensor_utils import tensor_to_list
|
26
|
+
|
27
|
+
try:
|
28
|
+
import tensorflow as tf
|
29
|
+
except ImportError:
|
30
|
+
tf = None
|
31
|
+
|
32
|
+
|
33
|
+
@keras_hub_export("keras_hub.models.CausalLM")
|
34
|
+
class CausalLM(Task):
|
35
|
+
"""Base class for generative language modeling tasks.
|
36
|
+
|
37
|
+
`CausalLM` tasks wrap a `keras_hub.models.Backbone` and
|
38
|
+
a `keras_hub.models.Preprocessor` to create a model that can be used for
|
39
|
+
generation and generative fine-tuning.
|
40
|
+
|
41
|
+
`CausalLM` tasks provide an additional, high-level `generate()` function
|
42
|
+
which can be used to auto-regressively sample a model token by token with a
|
43
|
+
string in, string out signature. The `compile()` method of all `CausalLM`
|
44
|
+
classes contains an additional `sampler` argument, which can be used to pass
|
45
|
+
a `keras_hub.samplers.Sampler` to control how the predicted distribution
|
46
|
+
will be sampled.
|
47
|
+
|
48
|
+
When calling `fit()`, the tokenized input will be predicted token-by-token
|
49
|
+
with a causal mask applied, which gives both a pre-training and supervised
|
50
|
+
fine-tuning setup for controlling inference-time generation.
|
51
|
+
|
52
|
+
All `CausalLM` tasks include a `from_preset()` constructor which can be used
|
53
|
+
to load a pre-trained config and weights.
|
54
|
+
|
55
|
+
Example:
|
56
|
+
```python
|
57
|
+
# Load a GPT2 backbone with pre-trained weights.
|
58
|
+
causal_lm = keras_hub.models.CausalLM.from_preset(
|
59
|
+
"gpt2_base_en",
|
60
|
+
)
|
61
|
+
causal_lm.compile(sampler="top_k")
|
62
|
+
causal_lm.generate("Keras is a", max_length=64)
|
63
|
+
|
64
|
+
# Load a Mistral instruction tuned checkpoint at bfloat16 precision.
|
65
|
+
causal_lm = keras_hub.models.CausalLM.from_preset(
|
66
|
+
"mistral_instruct_7b_en",
|
67
|
+
dtype="bfloat16",
|
68
|
+
)
|
69
|
+
causal_lm.compile(sampler="greedy")
|
70
|
+
causal_lm.generate("Keras is a", max_length=64)
|
71
|
+
```
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(self, *args, **kwargs):
|
75
|
+
super().__init__(*args, **kwargs)
|
76
|
+
# Default compilation.
|
77
|
+
self.compile()
|
78
|
+
|
79
|
+
def compile(
|
80
|
+
self,
|
81
|
+
optimizer="auto",
|
82
|
+
loss="auto",
|
83
|
+
*,
|
84
|
+
weighted_metrics="auto",
|
85
|
+
sampler="top_k",
|
86
|
+
**kwargs,
|
87
|
+
):
|
88
|
+
"""Configures the `CausalLM` task for training and generation.
|
89
|
+
|
90
|
+
The `CausalLM` task extends the default compilation signature of
|
91
|
+
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
|
92
|
+
`weighted_metrics`. To override these defaults, pass any value
|
93
|
+
to these arguments during compilation.
|
94
|
+
|
95
|
+
The `CausalLM` task adds a new `sampler` to `compile`, which can be used
|
96
|
+
to control the sampling strategy used with the `generate` function.
|
97
|
+
|
98
|
+
Note that because training inputs include padded tokens which are
|
99
|
+
excluded from the loss, it is almost always a good idea to compile with
|
100
|
+
`weighted_metrics` and not `metrics`.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
|
104
|
+
instance. Defaults to `"auto"`, which uses the default optimizer
|
105
|
+
for the given model and task. See `keras.Model.compile` and
|
106
|
+
`keras.optimizers` for more info on possible `optimizer` values.
|
107
|
+
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
|
108
|
+
Defaults to `"auto"`, where a
|
109
|
+
`keras.losses.SparseCategoricalCrossentropy` loss will be
|
110
|
+
applied for the token classification `CausalLM` task. See
|
111
|
+
`keras.Model.compile` and `keras.losses` for more info on
|
112
|
+
possible `loss` values.
|
113
|
+
weighted_metrics: `"auto"`, or a list of metrics to be evaluated by
|
114
|
+
the model during training and testing. Defaults to `"auto"`,
|
115
|
+
where a `keras.metrics.SparseCategoricalAccuracy` will be
|
116
|
+
applied to track the accuracy of the model at guessing masked
|
117
|
+
token values. See `keras.Model.compile` and `keras.metrics` for
|
118
|
+
more info on possible `weighted_metrics` values.
|
119
|
+
sampler: A sampler name, or a `keras_hub.samplers.Sampler` instance.
|
120
|
+
Configures the sampling method used during `generate()` calls.
|
121
|
+
See `keras_hub.samplers` for a full list of built-in sampling
|
122
|
+
strategies.
|
123
|
+
**kwargs: See `keras.Model.compile` for a full list of arguments
|
124
|
+
supported by the compile method.
|
125
|
+
"""
|
126
|
+
if optimizer == "auto":
|
127
|
+
optimizer = keras.optimizers.Adam(2e-5)
|
128
|
+
if loss == "auto":
|
129
|
+
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
130
|
+
if weighted_metrics == "auto":
|
131
|
+
weighted_metrics = [keras.metrics.SparseCategoricalAccuracy()]
|
132
|
+
super().compile(
|
133
|
+
optimizer=optimizer,
|
134
|
+
loss=loss,
|
135
|
+
weighted_metrics=weighted_metrics,
|
136
|
+
**kwargs,
|
137
|
+
)
|
138
|
+
self.sampler = get_sampler(sampler)
|
139
|
+
# Clear the compiled generate function.
|
140
|
+
self.generate_function = None
|
141
|
+
|
142
|
+
def generate_step(self):
|
143
|
+
"""Run generation on a single batch of input."""
|
144
|
+
raise NotImplementedError
|
145
|
+
|
146
|
+
def make_generate_function(self):
|
147
|
+
"""Create or return the compiled generation function."""
|
148
|
+
if self.generate_function is not None:
|
149
|
+
return self.generate_function
|
150
|
+
|
151
|
+
self.generate_function = self.generate_step
|
152
|
+
if keras.config.backend() == "torch":
|
153
|
+
import torch
|
154
|
+
|
155
|
+
def wrapped_generate_function(
|
156
|
+
inputs,
|
157
|
+
stop_token_ids=None,
|
158
|
+
):
|
159
|
+
with torch.no_grad():
|
160
|
+
return self.generate_step(inputs, stop_token_ids)
|
161
|
+
|
162
|
+
self.generate_function = wrapped_generate_function
|
163
|
+
elif keras.config.backend() == "tensorflow" and not self.run_eagerly:
|
164
|
+
# `jit_compile` is a property of keras.Model after TF 2.12.
|
165
|
+
# Use `getattr()` for backwards compatibility.
|
166
|
+
jit_compile = getattr(self, "jit_compile", True)
|
167
|
+
self.generate_function = tf.function(
|
168
|
+
self.generate_step, jit_compile=jit_compile
|
169
|
+
)
|
170
|
+
elif keras.config.backend() == "jax" and not self.run_eagerly:
|
171
|
+
import jax
|
172
|
+
|
173
|
+
@partial(jax.jit, static_argnames=["stop_token_ids"])
|
174
|
+
def compiled_generate_function(inputs, stop_token_ids, state):
|
175
|
+
(
|
176
|
+
sampler_variables,
|
177
|
+
trainable_variables,
|
178
|
+
non_trainable_variables,
|
179
|
+
) = state
|
180
|
+
mapping = itertools.chain(
|
181
|
+
zip(self.sampler.variables, sampler_variables),
|
182
|
+
zip(self.trainable_variables, trainable_variables),
|
183
|
+
zip(self.non_trainable_variables, non_trainable_variables),
|
184
|
+
)
|
185
|
+
|
186
|
+
with keras.StatelessScope(state_mapping=mapping) as scope:
|
187
|
+
outputs = self.generate_step(inputs, stop_token_ids)
|
188
|
+
|
189
|
+
# Get updated sampler variables from the stateless scope.
|
190
|
+
sampler_variables = []
|
191
|
+
for v in self.sampler.variables:
|
192
|
+
new_v = scope.get_current_value(v)
|
193
|
+
sampler_variables.append(new_v if new_v is not None else v)
|
194
|
+
return outputs, sampler_variables
|
195
|
+
|
196
|
+
def wrapped_generate_function(
|
197
|
+
inputs,
|
198
|
+
stop_token_ids=None,
|
199
|
+
):
|
200
|
+
if isinstance(stop_token_ids, list):
|
201
|
+
stop_token_ids = tuple(stop_token_ids)
|
202
|
+
|
203
|
+
# Create an explicit tuple of all variable state.
|
204
|
+
state = (
|
205
|
+
self.sampler.variables,
|
206
|
+
# Use the explicit variable.value to preserve the
|
207
|
+
# sharding spec of distribution.
|
208
|
+
[v.value for v in self.trainable_variables],
|
209
|
+
[v.value for v in self.non_trainable_variables],
|
210
|
+
)
|
211
|
+
inputs = tree.map_structure(ops.convert_to_tensor, inputs)
|
212
|
+
outputs, sampler_variables = compiled_generate_function(
|
213
|
+
inputs,
|
214
|
+
stop_token_ids,
|
215
|
+
state,
|
216
|
+
)
|
217
|
+
# Only assign the sampler variables (random seeds), as other
|
218
|
+
# model variables should never be updated in generation.
|
219
|
+
for ref_v, v in zip(self.sampler.variables, sampler_variables):
|
220
|
+
ref_v.assign(v)
|
221
|
+
return outputs
|
222
|
+
|
223
|
+
self.generate_function = wrapped_generate_function
|
224
|
+
|
225
|
+
return self.generate_function
|
226
|
+
|
227
|
+
def _normalize_generate_inputs(
|
228
|
+
self,
|
229
|
+
inputs,
|
230
|
+
):
|
231
|
+
"""Normalize user input to the generate function.
|
232
|
+
|
233
|
+
This function converts all inputs to tensors, adds a batch dimension if
|
234
|
+
necessary, and returns a iterable "dataset like" object (either an
|
235
|
+
actual `tf.data.Dataset` or a list with a single batch element).
|
236
|
+
"""
|
237
|
+
input_is_scalar = False
|
238
|
+
|
239
|
+
if isinstance(inputs, tf.data.Dataset):
|
240
|
+
return inputs, input_is_scalar
|
241
|
+
|
242
|
+
def normalize(x):
|
243
|
+
x_is_scalar = False
|
244
|
+
if isinstance(x, str) or isinstance(x, list):
|
245
|
+
x = tf.convert_to_tensor(x)
|
246
|
+
|
247
|
+
if isinstance(x, tf.Tensor) and x.shape.rank == 0:
|
248
|
+
x_is_scalar = True
|
249
|
+
x = x[tf.newaxis]
|
250
|
+
|
251
|
+
return x, x_is_scalar
|
252
|
+
|
253
|
+
if isinstance(inputs, dict):
|
254
|
+
for key in inputs:
|
255
|
+
inputs[key], input_is_scalar = normalize(inputs[key])
|
256
|
+
else:
|
257
|
+
inputs, input_is_scalar = normalize(inputs)
|
258
|
+
|
259
|
+
# We avoid converting to a dataset purely for speed, for a single batch
|
260
|
+
# of input, creating a dataset would add significant overhead.
|
261
|
+
return [inputs], input_is_scalar
|
262
|
+
|
263
|
+
def _normalize_generate_outputs(
|
264
|
+
self,
|
265
|
+
outputs,
|
266
|
+
input_is_scalar,
|
267
|
+
):
|
268
|
+
"""Normalize user output from the generate function.
|
269
|
+
|
270
|
+
This function converts all output to numpy (for integer output), or
|
271
|
+
python strings (for string output). If a batch dimension was added to
|
272
|
+
the input, it is removed from the output (so generate can be string in,
|
273
|
+
string out).
|
274
|
+
"""
|
275
|
+
|
276
|
+
def normalize(x):
|
277
|
+
if isinstance(x[0], list):
|
278
|
+
outputs = []
|
279
|
+
for batch in x:
|
280
|
+
for e in batch:
|
281
|
+
outputs.append(e)
|
282
|
+
return outputs[0] if input_is_scalar else outputs
|
283
|
+
if isinstance(x[0], tf.Tensor) and x[0].dtype == tf.string:
|
284
|
+
outputs = tf.concat(x, axis=0)
|
285
|
+
outputs = tf.squeeze(outputs, 0) if input_is_scalar else outputs
|
286
|
+
return tensor_to_list(outputs)
|
287
|
+
outputs = ops.concatenate(x, axis=0)
|
288
|
+
outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs
|
289
|
+
return ops.convert_to_numpy(outputs)
|
290
|
+
|
291
|
+
if isinstance(outputs[0], dict):
|
292
|
+
normalized = {}
|
293
|
+
for key in outputs[0]:
|
294
|
+
normalized[key] = normalize([x[key] for x in outputs])
|
295
|
+
return normalized
|
296
|
+
return normalize([x for x in outputs])
|
297
|
+
|
298
|
+
def generate(
|
299
|
+
self,
|
300
|
+
inputs,
|
301
|
+
max_length=None,
|
302
|
+
stop_token_ids="auto",
|
303
|
+
):
|
304
|
+
"""Generate text given prompt `inputs`.
|
305
|
+
|
306
|
+
This method generates text based on given `inputs`. The sampling method
|
307
|
+
used for generation can be set via the `compile()` method.
|
308
|
+
|
309
|
+
If `inputs` are a `tf.data.Dataset`, outputs will be generated
|
310
|
+
"batch-by-batch" and concatenated. Otherwise, all inputs will be handled
|
311
|
+
as a single batch.
|
312
|
+
|
313
|
+
If a `preprocessor` is attached to the model, `inputs` will be
|
314
|
+
preprocessed inside the `generate()` function and should match the
|
315
|
+
structure expected by the `preprocessor` layer (usually raw strings).
|
316
|
+
If a `preprocessor` is not attached, inputs should match the structure
|
317
|
+
expected by the `backbone`. See the example usage above for a
|
318
|
+
demonstration of each.
|
319
|
+
|
320
|
+
Args:
|
321
|
+
inputs: python data, tensor data, or a `tf.data.Dataset`. If a
|
322
|
+
`preprocessor` is attached to the model, `inputs` should match
|
323
|
+
the structure expected by the `preprocessor` layer. If a
|
324
|
+
`preprocessor` is not attached, `inputs` should match the
|
325
|
+
structure expected the `backbone` model.
|
326
|
+
max_length: Optional. int. The max length of the generated sequence.
|
327
|
+
Will default to the max configured `sequence_length` of the
|
328
|
+
`preprocessor`. If `preprocessor` is `None`, `inputs` should be
|
329
|
+
should be padded to the desired maximum length and this argument
|
330
|
+
will be ignored.
|
331
|
+
stop_token_ids: Optional. `None`, "auto", or tuple of token ids. Defaults
|
332
|
+
to "auto" which uses the `preprocessor.tokenizer.end_token_id`.
|
333
|
+
Not specifying a processor will produce an error. None stops
|
334
|
+
generation after generating `max_length` tokens. You may also
|
335
|
+
specify a list of token id's the model should stop on. Note that
|
336
|
+
sequences of tokens will each be interpreted as a stop token,
|
337
|
+
multi-token stop sequences are not supported.
|
338
|
+
"""
|
339
|
+
# Setup our three main passes.
|
340
|
+
# 1. Optionally preprocessing strings to dense integer tensors.
|
341
|
+
# 2. Generate new tokens via a compiled function on dense tensors.
|
342
|
+
# 3. Optionally postprocess dense integer tensors back to string.
|
343
|
+
generate_function = self.make_generate_function()
|
344
|
+
|
345
|
+
if self.preprocessor is None and stop_token_ids == "auto":
|
346
|
+
raise ValueError(
|
347
|
+
'A `preprocessor` must be attached to the model if `stop_token_ids="auto"`. '
|
348
|
+
"Currently `preprocessor=None`. To call `generate()` with preprocessing "
|
349
|
+
"detached, either pass `stop_token_ids=None` to always generate until "
|
350
|
+
"`max_length` or pass a tuple of token ids that should terminate generation "
|
351
|
+
"as `stop_token_ids`."
|
352
|
+
)
|
353
|
+
elif stop_token_ids == "auto":
|
354
|
+
stop_token_ids = [self.preprocessor.tokenizer.end_token_id]
|
355
|
+
|
356
|
+
def preprocess(x):
|
357
|
+
return self.preprocessor.generate_preprocess(
|
358
|
+
x, sequence_length=max_length
|
359
|
+
)
|
360
|
+
|
361
|
+
def generate(x):
|
362
|
+
return generate_function(x, stop_token_ids=stop_token_ids)
|
363
|
+
|
364
|
+
def postprocess(x):
|
365
|
+
return self.preprocessor.generate_postprocess(x)
|
366
|
+
|
367
|
+
# Normalize inputs, apply our three passes, and normalize outputs.
|
368
|
+
inputs, input_is_scalar = self._normalize_generate_inputs(inputs)
|
369
|
+
|
370
|
+
if self.preprocessor is not None:
|
371
|
+
if isinstance(inputs, tf.data.Dataset):
|
372
|
+
inputs = inputs.map(preprocess, tf.data.AUTOTUNE)
|
373
|
+
inputs = inputs.prefetch(tf.data.AUTOTUNE)
|
374
|
+
else:
|
375
|
+
# Fast path for non-dataset, single-batch input.
|
376
|
+
inputs = [preprocess(x) for x in inputs]
|
377
|
+
|
378
|
+
outputs = [generate(x) for x in inputs]
|
379
|
+
|
380
|
+
if self.preprocessor is not None:
|
381
|
+
outputs = [postprocess(x) for x in outputs]
|
382
|
+
|
383
|
+
return self._normalize_generate_outputs(outputs, input_is_scalar)
|
@@ -0,0 +1,109 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import keras
|
15
|
+
|
16
|
+
from keras_hub.src.api_export import keras_hub_export
|
17
|
+
from keras_hub.src.models.task import Task
|
18
|
+
|
19
|
+
|
20
|
+
@keras_hub_export("keras_hub.models.Classifier")
|
21
|
+
class Classifier(Task):
|
22
|
+
"""Base class for all classification tasks.
|
23
|
+
|
24
|
+
`Classifier` tasks wrap a `keras_hub.models.Backbone` and
|
25
|
+
a `keras_hub.models.Preprocessor` to create a model that can be used for
|
26
|
+
sequence classification. `Classifier` tasks take an additional
|
27
|
+
`num_classes` argument, controlling the number of predicted output classes.
|
28
|
+
|
29
|
+
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
|
30
|
+
labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
|
31
|
+
|
32
|
+
All `Classifier` tasks include a `from_preset()` constructor which can be
|
33
|
+
used to load a pre-trained config and weights.
|
34
|
+
|
35
|
+
Example:
|
36
|
+
```python
|
37
|
+
# Load a BERT classifier with pre-trained weights.
|
38
|
+
classifier = keras_hub.models.Classifier.from_preset(
|
39
|
+
"bert_base_en",
|
40
|
+
num_classes=2,
|
41
|
+
)
|
42
|
+
# Fine-tune on IMDb movie reviews (or any dataset).
|
43
|
+
imdb_train, imdb_test = tfds.load(
|
44
|
+
"imdb_reviews",
|
45
|
+
split=["train", "test"],
|
46
|
+
as_supervised=True,
|
47
|
+
batch_size=16,
|
48
|
+
)
|
49
|
+
classifier.fit(imdb_train, validation_data=imdb_test)
|
50
|
+
# Predict two new examples.
|
51
|
+
classifier.predict(["What an amazing movie!", "A total waste of my time."])
|
52
|
+
```
|
53
|
+
"""
|
54
|
+
|
55
|
+
def __init__(self, *args, **kwargs):
|
56
|
+
super().__init__(*args, **kwargs)
|
57
|
+
# Default compilation.
|
58
|
+
self.compile()
|
59
|
+
|
60
|
+
def compile(
|
61
|
+
self,
|
62
|
+
optimizer="auto",
|
63
|
+
loss="auto",
|
64
|
+
*,
|
65
|
+
metrics="auto",
|
66
|
+
**kwargs,
|
67
|
+
):
|
68
|
+
"""Configures the `Classifier` task for training.
|
69
|
+
|
70
|
+
The `Classifier` task extends the default compilation signature of
|
71
|
+
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
|
72
|
+
`metrics`. To override these defaults, pass any value
|
73
|
+
to these arguments during compilation.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
|
77
|
+
instance. Defaults to `"auto"`, which uses the default optimizer
|
78
|
+
for the given model and task. See `keras.Model.compile` and
|
79
|
+
`keras.optimizers` for more info on possible `optimizer` values.
|
80
|
+
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
|
81
|
+
Defaults to `"auto"`, where a
|
82
|
+
`keras.losses.SparseCategoricalCrossentropy` loss will be
|
83
|
+
applied for the classification task. See
|
84
|
+
`keras.Model.compile` and `keras.losses` for more info on
|
85
|
+
possible `loss` values.
|
86
|
+
metrics: `"auto"`, or a list of metrics to be evaluated by
|
87
|
+
the model during training and testing. Defaults to `"auto"`,
|
88
|
+
where a `keras.metrics.SparseCategoricalAccuracy` will be
|
89
|
+
applied to track the accuracy of the model during training.
|
90
|
+
See `keras.Model.compile` and `keras.metrics` for
|
91
|
+
more info on possible `metrics` values.
|
92
|
+
**kwargs: See `keras.Model.compile` for a full list of arguments
|
93
|
+
supported by the compile method.
|
94
|
+
"""
|
95
|
+
if optimizer == "auto":
|
96
|
+
optimizer = keras.optimizers.Adam(5e-5)
|
97
|
+
if loss == "auto":
|
98
|
+
activation = getattr(self, "activation", None)
|
99
|
+
activation = keras.activations.get(activation)
|
100
|
+
from_logits = activation != keras.activations.softmax
|
101
|
+
loss = keras.losses.SparseCategoricalCrossentropy(from_logits)
|
102
|
+
if metrics == "auto":
|
103
|
+
metrics = [keras.metrics.SparseCategoricalAccuracy()]
|
104
|
+
super().compile(
|
105
|
+
optimizer=optimizer,
|
106
|
+
loss=loss,
|
107
|
+
metrics=metrics,
|
108
|
+
**kwargs,
|
109
|
+
)
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|