keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,608 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import json
|
16
|
+
import os
|
17
|
+
import pathlib
|
18
|
+
import re
|
19
|
+
|
20
|
+
import keras
|
21
|
+
import tensorflow as tf
|
22
|
+
from absl.testing import parameterized
|
23
|
+
from keras import ops
|
24
|
+
from keras import tree
|
25
|
+
|
26
|
+
from keras_hub.src.layers.modeling.reversible_embedding import (
|
27
|
+
ReversibleEmbedding,
|
28
|
+
)
|
29
|
+
from keras_hub.src.tokenizers.tokenizer import Tokenizer
|
30
|
+
from keras_hub.src.utils.keras_utils import has_quantization_support
|
31
|
+
from keras_hub.src.utils.tensor_utils import is_float_dtype
|
32
|
+
|
33
|
+
|
34
|
+
def convert_to_comparible_type(x):
|
35
|
+
"""Convert tensors to comparable types.
|
36
|
+
|
37
|
+
Any string are converted to plain python types. Any jax or torch tensors
|
38
|
+
are converted to numpy.
|
39
|
+
"""
|
40
|
+
if getattr(x, "dtype", None) == tf.string:
|
41
|
+
if isinstance(x, tf.RaggedTensor):
|
42
|
+
x = x.to_list()
|
43
|
+
if isinstance(x, tf.Tensor):
|
44
|
+
x = x.numpy() if x.shape.rank == 0 else x.numpy().tolist()
|
45
|
+
return tree.map_structure(lambda x: x.decode("utf-8"), x)
|
46
|
+
if isinstance(x, (tf.Tensor, tf.RaggedTensor)):
|
47
|
+
return x
|
48
|
+
if hasattr(x, "__array__"):
|
49
|
+
return ops.convert_to_numpy(x)
|
50
|
+
return x
|
51
|
+
|
52
|
+
|
53
|
+
class TestCase(tf.test.TestCase, parameterized.TestCase):
|
54
|
+
"""Base test case class for KerasHub."""
|
55
|
+
|
56
|
+
def assertAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None):
|
57
|
+
# This metric dict hack is only needed for tf.keras, and can be
|
58
|
+
# removed after we fully migrate to keras-core/Keras 3.
|
59
|
+
if x1.__class__.__name__ == "_MetricDict":
|
60
|
+
x1 = dict(x1)
|
61
|
+
if x2.__class__.__name__ == "_MetricDict":
|
62
|
+
x2 = dict(x2)
|
63
|
+
x1 = tree.map_structure(convert_to_comparible_type, x1)
|
64
|
+
x2 = tree.map_structure(convert_to_comparible_type, x2)
|
65
|
+
super().assertAllClose(x1, x2, atol=atol, rtol=rtol, msg=msg)
|
66
|
+
|
67
|
+
def assertEqual(self, x1, x2, msg=None):
|
68
|
+
x1 = tree.map_structure(convert_to_comparible_type, x1)
|
69
|
+
x2 = tree.map_structure(convert_to_comparible_type, x2)
|
70
|
+
super().assertEqual(x1, x2, msg=msg)
|
71
|
+
|
72
|
+
def assertAllEqual(self, x1, x2, msg=None):
|
73
|
+
x1 = tree.map_structure(convert_to_comparible_type, x1)
|
74
|
+
x2 = tree.map_structure(convert_to_comparible_type, x2)
|
75
|
+
super().assertAllEqual(x1, x2, msg=msg)
|
76
|
+
|
77
|
+
def assertDTypeEqual(self, x, expected_dtype, msg=None):
|
78
|
+
input_dtype = keras.backend.standardize_dtype(x.dtype)
|
79
|
+
super().assertEqual(input_dtype, expected_dtype, msg=msg)
|
80
|
+
|
81
|
+
def run_layer_test(
|
82
|
+
self,
|
83
|
+
cls,
|
84
|
+
init_kwargs,
|
85
|
+
input_data,
|
86
|
+
expected_output_shape,
|
87
|
+
expected_output_data=None,
|
88
|
+
expected_num_trainable_weights=0,
|
89
|
+
expected_num_non_trainable_weights=0,
|
90
|
+
expected_num_non_trainable_variables=0,
|
91
|
+
run_training_check=True,
|
92
|
+
run_precision_checks=True,
|
93
|
+
):
|
94
|
+
"""Run basic tests for a modeling layer."""
|
95
|
+
# Serialization test.
|
96
|
+
layer = cls(**init_kwargs)
|
97
|
+
self.run_serialization_test(layer)
|
98
|
+
|
99
|
+
def run_build_asserts(layer):
|
100
|
+
self.assertTrue(layer.built)
|
101
|
+
self.assertLen(
|
102
|
+
layer.trainable_weights,
|
103
|
+
expected_num_trainable_weights,
|
104
|
+
msg="Unexpected number of trainable_weights",
|
105
|
+
)
|
106
|
+
self.assertLen(
|
107
|
+
layer.non_trainable_weights,
|
108
|
+
expected_num_non_trainable_weights,
|
109
|
+
msg="Unexpected number of non_trainable_weights",
|
110
|
+
)
|
111
|
+
self.assertLen(
|
112
|
+
layer.non_trainable_variables,
|
113
|
+
expected_num_non_trainable_variables,
|
114
|
+
msg="Unexpected number of non_trainable_variables",
|
115
|
+
)
|
116
|
+
|
117
|
+
def run_output_asserts(layer, output, eager=False):
|
118
|
+
output_shape = tree.map_structure(
|
119
|
+
lambda x: None if x is None else x.shape, output
|
120
|
+
)
|
121
|
+
self.assertEqual(
|
122
|
+
expected_output_shape,
|
123
|
+
output_shape,
|
124
|
+
msg="Unexpected output shape",
|
125
|
+
)
|
126
|
+
output_dtype = tree.flatten(output)[0].dtype
|
127
|
+
self.assertEqual(
|
128
|
+
keras.backend.standardize_dtype(layer.dtype),
|
129
|
+
keras.backend.standardize_dtype(output_dtype),
|
130
|
+
msg="Unexpected output dtype",
|
131
|
+
)
|
132
|
+
if eager and expected_output_data is not None:
|
133
|
+
self.assertAllClose(expected_output_data, output)
|
134
|
+
|
135
|
+
def run_training_step(layer, input_data, output_data):
|
136
|
+
class TestModel(keras.Model):
|
137
|
+
def __init__(self, layer):
|
138
|
+
super().__init__()
|
139
|
+
self.layer = layer
|
140
|
+
|
141
|
+
def call(self, x):
|
142
|
+
if isinstance(x, dict):
|
143
|
+
return self.layer(**x)
|
144
|
+
else:
|
145
|
+
return self.layer(x)
|
146
|
+
|
147
|
+
input_data = tree.map_structure(
|
148
|
+
lambda x: ops.convert_to_numpy(x), input_data
|
149
|
+
)
|
150
|
+
output_data = tree.map_structure(
|
151
|
+
lambda x: ops.convert_to_numpy(x), output_data
|
152
|
+
)
|
153
|
+
model = TestModel(layer)
|
154
|
+
# Temporarily disable jit compilation on torch backend.
|
155
|
+
jit_compile = keras.config.backend() != "torch"
|
156
|
+
model.compile(optimizer="sgd", loss="mse", jit_compile=jit_compile)
|
157
|
+
model.fit(input_data, output_data, verbose=0)
|
158
|
+
|
159
|
+
# Build test.
|
160
|
+
layer = cls(**init_kwargs)
|
161
|
+
if isinstance(input_data, dict):
|
162
|
+
shapes = {k + "_shape": v.shape for k, v in input_data.items()}
|
163
|
+
layer.build(**shapes)
|
164
|
+
else:
|
165
|
+
layer.build(input_data.shape)
|
166
|
+
run_build_asserts(layer)
|
167
|
+
|
168
|
+
# Symbolic call test.
|
169
|
+
keras_tensor_inputs = tree.map_structure(
|
170
|
+
lambda x: keras.KerasTensor(x.shape, x.dtype), input_data
|
171
|
+
)
|
172
|
+
layer = cls(**init_kwargs)
|
173
|
+
if isinstance(keras_tensor_inputs, dict):
|
174
|
+
keras_tensor_outputs = layer(**keras_tensor_inputs)
|
175
|
+
else:
|
176
|
+
keras_tensor_outputs = layer(keras_tensor_inputs)
|
177
|
+
run_build_asserts(layer)
|
178
|
+
run_output_asserts(layer, keras_tensor_outputs)
|
179
|
+
|
180
|
+
# Eager call test and compiled training test.
|
181
|
+
layer = cls(**init_kwargs)
|
182
|
+
if isinstance(input_data, dict):
|
183
|
+
output_data = layer(**input_data)
|
184
|
+
else:
|
185
|
+
output_data = layer(input_data)
|
186
|
+
run_output_asserts(layer, output_data, eager=True)
|
187
|
+
|
188
|
+
if run_training_check:
|
189
|
+
run_training_step(layer, input_data, output_data)
|
190
|
+
|
191
|
+
if run_precision_checks:
|
192
|
+
self.run_precision_test(cls, init_kwargs, input_data)
|
193
|
+
|
194
|
+
def run_preprocessing_layer_test(
|
195
|
+
self,
|
196
|
+
cls,
|
197
|
+
init_kwargs,
|
198
|
+
input_data,
|
199
|
+
expected_output=None,
|
200
|
+
expected_detokenize_output=None,
|
201
|
+
):
|
202
|
+
"""Run basic tests for a preprocessing layer."""
|
203
|
+
layer = cls(**init_kwargs)
|
204
|
+
# Check serialization (without a full save).
|
205
|
+
self.run_serialization_test(layer)
|
206
|
+
|
207
|
+
ds = tf.data.Dataset.from_tensor_slices(input_data)
|
208
|
+
|
209
|
+
# Run with direct call.
|
210
|
+
if isinstance(input_data, tuple):
|
211
|
+
# Mimic tf.data unpacking behavior for preprocessing layers.
|
212
|
+
output = layer(*input_data)
|
213
|
+
else:
|
214
|
+
output = layer(input_data)
|
215
|
+
|
216
|
+
# For tokenizers only, also check detokenize.
|
217
|
+
if isinstance(layer, Tokenizer):
|
218
|
+
if not expected_detokenize_output:
|
219
|
+
expected_detokenize_output = input_data
|
220
|
+
detokenize_output = layer.detokenize(output)
|
221
|
+
self.assertAllEqual(detokenize_output, expected_detokenize_output)
|
222
|
+
|
223
|
+
# Run with an unbatched dataset.
|
224
|
+
output_ds = ds.map(layer).ragged_batch(1_000)
|
225
|
+
self.assertAllClose(output, output_ds.get_single_element())
|
226
|
+
|
227
|
+
# Run with a batched dataset.
|
228
|
+
output_ds = ds.batch(1_000).map(layer)
|
229
|
+
self.assertAllClose(output, output_ds.get_single_element())
|
230
|
+
|
231
|
+
if expected_output:
|
232
|
+
self.assertAllClose(output, expected_output)
|
233
|
+
|
234
|
+
def run_preprocessor_test(
|
235
|
+
self,
|
236
|
+
cls,
|
237
|
+
init_kwargs,
|
238
|
+
input_data,
|
239
|
+
expected_output=None,
|
240
|
+
expected_detokenize_output=None,
|
241
|
+
token_id_key="token_ids",
|
242
|
+
):
|
243
|
+
"""Run basic tests for a Model Preprocessor layer."""
|
244
|
+
self.run_preprocessing_layer_test(
|
245
|
+
cls,
|
246
|
+
init_kwargs,
|
247
|
+
input_data,
|
248
|
+
expected_output=expected_output,
|
249
|
+
expected_detokenize_output=expected_detokenize_output,
|
250
|
+
)
|
251
|
+
|
252
|
+
layer = cls(**self.init_kwargs)
|
253
|
+
if isinstance(input_data, tuple):
|
254
|
+
output = layer(*input_data)
|
255
|
+
else:
|
256
|
+
output = layer(input_data)
|
257
|
+
output, _, _ = keras.utils.unpack_x_y_sample_weight(output)
|
258
|
+
shape = ops.shape(output[token_id_key])
|
259
|
+
self.assertEqual(shape[-1], layer.sequence_length)
|
260
|
+
# Update the sequence length.
|
261
|
+
layer.sequence_length = 17
|
262
|
+
if isinstance(input_data, tuple):
|
263
|
+
output = layer(*input_data)
|
264
|
+
else:
|
265
|
+
output = layer(input_data)
|
266
|
+
output, _, _ = keras.utils.unpack_x_y_sample_weight(output)
|
267
|
+
shape = ops.shape(output[token_id_key])
|
268
|
+
self.assertEqual(shape[-1], 17)
|
269
|
+
|
270
|
+
def run_serialization_test(self, instance):
|
271
|
+
"""Check idempotency of serialize/deserialize.
|
272
|
+
|
273
|
+
Not this is a much faster test than saving."""
|
274
|
+
run_dir_test = (
|
275
|
+
not keras.config.backend() == "tensorflow"
|
276
|
+
or not isinstance(instance, Tokenizer)
|
277
|
+
)
|
278
|
+
# get_config roundtrip
|
279
|
+
cls = instance.__class__
|
280
|
+
cfg = instance.get_config()
|
281
|
+
cfg_json = json.dumps(cfg, sort_keys=True, indent=4)
|
282
|
+
ref_dir = dir(instance)[:]
|
283
|
+
revived_instance = cls.from_config(cfg)
|
284
|
+
revived_cfg = revived_instance.get_config()
|
285
|
+
revived_cfg_json = json.dumps(revived_cfg, sort_keys=True, indent=4)
|
286
|
+
self.assertEqual(cfg_json, revived_cfg_json)
|
287
|
+
if run_dir_test:
|
288
|
+
self.assertEqual(set(ref_dir), set(dir(revived_instance)))
|
289
|
+
|
290
|
+
# serialization roundtrip
|
291
|
+
serialized = keras.saving.serialize_keras_object(instance)
|
292
|
+
serialized_json = json.dumps(serialized, sort_keys=True, indent=4)
|
293
|
+
revived_instance = keras.saving.deserialize_keras_object(
|
294
|
+
json.loads(serialized_json)
|
295
|
+
)
|
296
|
+
revived_cfg = revived_instance.get_config()
|
297
|
+
revived_cfg_json = json.dumps(revived_cfg, sort_keys=True, indent=4)
|
298
|
+
self.assertEqual(cfg_json, revived_cfg_json)
|
299
|
+
if run_dir_test:
|
300
|
+
new_dir = dir(revived_instance)[:]
|
301
|
+
for lst in [ref_dir, new_dir]:
|
302
|
+
if "__annotations__" in lst:
|
303
|
+
lst.remove("__annotations__")
|
304
|
+
self.assertEqual(set(ref_dir), set(new_dir))
|
305
|
+
|
306
|
+
def run_precision_test(self, cls, init_kwargs, input_data):
|
307
|
+
# Never test mixed precision on torch CPU. Torch lacks support.
|
308
|
+
if keras.config.backend() == "torch":
|
309
|
+
import torch
|
310
|
+
|
311
|
+
if not torch.cuda.is_available():
|
312
|
+
return
|
313
|
+
|
314
|
+
for policy in ["mixed_float16", "mixed_bfloat16", "bfloat16"]:
|
315
|
+
policy = keras.mixed_precision.Policy(policy)
|
316
|
+
layer = cls(**{**init_kwargs, "dtype": policy})
|
317
|
+
if isinstance(layer, keras.Model):
|
318
|
+
output_data = layer(input_data)
|
319
|
+
output_spec = layer.compute_output_spec(input_data)
|
320
|
+
elif isinstance(input_data, dict):
|
321
|
+
output_data = layer(**input_data)
|
322
|
+
output_spec = layer.compute_output_spec(**input_data)
|
323
|
+
else:
|
324
|
+
output_data = layer(input_data)
|
325
|
+
output_spec = layer.compute_output_spec(input_data)
|
326
|
+
for tensor in tree.flatten(output_data):
|
327
|
+
if is_float_dtype(tensor.dtype):
|
328
|
+
self.assertDTypeEqual(tensor, policy.compute_dtype)
|
329
|
+
for spec in tree.flatten(output_spec):
|
330
|
+
if is_float_dtype(spec.dtype):
|
331
|
+
self.assertDTypeEqual(spec, policy.compute_dtype)
|
332
|
+
for weight in layer.weights:
|
333
|
+
if is_float_dtype(weight.dtype):
|
334
|
+
self.assertDTypeEqual(weight, policy.variable_dtype)
|
335
|
+
for sublayer in layer._flatten_layers():
|
336
|
+
if isinstance(sublayer, keras.layers.Softmax):
|
337
|
+
continue
|
338
|
+
if isinstance(sublayer, keras.layers.InputLayer):
|
339
|
+
continue
|
340
|
+
self.assertEqual(policy.compute_dtype, sublayer.compute_dtype)
|
341
|
+
self.assertEqual(policy.variable_dtype, sublayer.variable_dtype)
|
342
|
+
|
343
|
+
def run_quantization_test(self, instance, cls, init_kwargs, input_data):
|
344
|
+
def _get_supported_layers(mode):
|
345
|
+
supported_layers = [keras.layers.Dense, keras.layers.EinsumDense]
|
346
|
+
if mode == "int8":
|
347
|
+
supported_layers.append(keras.layers.Embedding)
|
348
|
+
supported_layers.append(ReversibleEmbedding)
|
349
|
+
return supported_layers
|
350
|
+
|
351
|
+
for mode in ["int8", "float8"]:
|
352
|
+
# Manually configure DTypePolicyMap to avoid intensive computation
|
353
|
+
# in `Model.quantize`.
|
354
|
+
policy_map = keras.dtype_policies.DTypePolicyMap("float32")
|
355
|
+
for layer in instance._flatten_layers():
|
356
|
+
if type(layer) in _get_supported_layers(mode):
|
357
|
+
policy_map[layer.path] = keras.dtype_policies.get(
|
358
|
+
f"{mode}_from_float32"
|
359
|
+
)
|
360
|
+
# Instantiate the layer.
|
361
|
+
model = cls(**{**init_kwargs, "dtype": policy_map})
|
362
|
+
# Call layer eagerly.
|
363
|
+
if isinstance(model, keras.Model):
|
364
|
+
_ = model(input_data)
|
365
|
+
elif isinstance(input_data, dict):
|
366
|
+
_ = model(**input_data)
|
367
|
+
else:
|
368
|
+
_ = model(input_data)
|
369
|
+
# Verify sublayer's dtype policy.
|
370
|
+
for sublayer in model._flatten_layers():
|
371
|
+
if type(sublayer) in _get_supported_layers(mode):
|
372
|
+
self.assertEqual(mode, sublayer.quantization_mode)
|
373
|
+
# `get_config` roundtrip.
|
374
|
+
cfg = model.get_config()
|
375
|
+
revived_model = cls.from_config(cfg)
|
376
|
+
revived_cfg = revived_model.get_config()
|
377
|
+
self.assertEqual(cfg, revived_cfg)
|
378
|
+
# Check weights loading.
|
379
|
+
weights = model.get_weights()
|
380
|
+
revived_model.set_weights(weights)
|
381
|
+
|
382
|
+
def run_model_saving_test(
|
383
|
+
self,
|
384
|
+
cls,
|
385
|
+
init_kwargs,
|
386
|
+
input_data,
|
387
|
+
):
|
388
|
+
"""Save and load a model from disk and assert output is unchanged."""
|
389
|
+
model = cls(**init_kwargs)
|
390
|
+
model_output = model(input_data)
|
391
|
+
path = os.path.join(self.get_temp_dir(), "model.keras")
|
392
|
+
model.save(path, save_format="keras_v3")
|
393
|
+
restored_model = keras.models.load_model(path)
|
394
|
+
|
395
|
+
# Check we got the real object back.
|
396
|
+
self.assertIsInstance(restored_model, cls)
|
397
|
+
|
398
|
+
# Check that output matches.
|
399
|
+
restored_output = restored_model(input_data)
|
400
|
+
self.assertAllClose(model_output, restored_output)
|
401
|
+
|
402
|
+
def run_backbone_test(
|
403
|
+
self,
|
404
|
+
cls,
|
405
|
+
init_kwargs,
|
406
|
+
input_data,
|
407
|
+
expected_output_shape,
|
408
|
+
variable_length_data=None,
|
409
|
+
run_mixed_precision_check=True,
|
410
|
+
run_quantization_check=True,
|
411
|
+
):
|
412
|
+
"""Run basic tests for a backbone, including compilation."""
|
413
|
+
backbone = cls(**init_kwargs)
|
414
|
+
# Check serialization (without a full save).
|
415
|
+
self.run_serialization_test(backbone)
|
416
|
+
|
417
|
+
# Call model eagerly.
|
418
|
+
output = backbone(input_data)
|
419
|
+
if isinstance(expected_output_shape, dict):
|
420
|
+
for key in expected_output_shape:
|
421
|
+
self.assertEqual(output[key].shape, expected_output_shape[key])
|
422
|
+
else:
|
423
|
+
self.assertEqual(output.shape, expected_output_shape)
|
424
|
+
if backbone.token_embedding is not None:
|
425
|
+
# Check we can embed tokens eagerly.
|
426
|
+
output = backbone.token_embedding(ops.zeros((2, 3), dtype="int32"))
|
427
|
+
|
428
|
+
# Check variable length sequences.
|
429
|
+
if variable_length_data is None:
|
430
|
+
# If no variable length data passed, assume the second axis of all
|
431
|
+
# inputs is our sequence axis and create it ourselves.
|
432
|
+
variable_length_data = [
|
433
|
+
tree.map_structure(
|
434
|
+
lambda x: x[:, :seq_length, ...], input_data
|
435
|
+
)
|
436
|
+
for seq_length in (2, 3, 4)
|
437
|
+
]
|
438
|
+
for batch in variable_length_data:
|
439
|
+
backbone(batch)
|
440
|
+
|
441
|
+
# Check compiled predict function.
|
442
|
+
backbone.predict(input_data)
|
443
|
+
# Convert to numpy first, torch GPU tensor -> tf.data will error.
|
444
|
+
numpy_data = tree.map_structure(ops.convert_to_numpy, input_data)
|
445
|
+
# Create a dataset.
|
446
|
+
input_dataset = tf.data.Dataset.from_tensor_slices(numpy_data).batch(2)
|
447
|
+
backbone.predict(input_dataset)
|
448
|
+
|
449
|
+
# Check name maps to classname.
|
450
|
+
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", cls.__name__)
|
451
|
+
name = re.sub("([a-z])([A-Z])", r"\1_\2", name).lower()
|
452
|
+
self.assertRegexpMatches(backbone.name, name)
|
453
|
+
|
454
|
+
# Check mixed precision.
|
455
|
+
if run_mixed_precision_check:
|
456
|
+
self.run_precision_test(cls, init_kwargs, input_data)
|
457
|
+
|
458
|
+
# Check quantization.
|
459
|
+
if run_quantization_check and has_quantization_support():
|
460
|
+
self.run_quantization_test(backbone, cls, init_kwargs, input_data)
|
461
|
+
|
462
|
+
def run_vision_backbone_test(
|
463
|
+
self,
|
464
|
+
cls,
|
465
|
+
init_kwargs,
|
466
|
+
input_data,
|
467
|
+
expected_output_shape,
|
468
|
+
variable_length_data=None,
|
469
|
+
run_mixed_precision_check=True,
|
470
|
+
run_quantization_check=True,
|
471
|
+
run_data_format_check=True,
|
472
|
+
):
|
473
|
+
"""Run basic tests for a vision backbone, including compilation."""
|
474
|
+
can_run_data_format_check = True
|
475
|
+
if (
|
476
|
+
keras.config.backend() == "tensorflow"
|
477
|
+
and not tf.config.list_physical_devices("GPU")
|
478
|
+
):
|
479
|
+
# Never test the "channels_first" format on tensorflow CPU.
|
480
|
+
# Tensorflow lacks support for "channels_first" convolution.
|
481
|
+
can_run_data_format_check = False
|
482
|
+
|
483
|
+
ori_data_format = keras.config.image_data_format()
|
484
|
+
keras.config.set_image_data_format("channels_last")
|
485
|
+
self.run_backbone_test(
|
486
|
+
cls=cls,
|
487
|
+
init_kwargs=init_kwargs,
|
488
|
+
input_data=input_data,
|
489
|
+
expected_output_shape=expected_output_shape,
|
490
|
+
variable_length_data=variable_length_data,
|
491
|
+
run_mixed_precision_check=run_mixed_precision_check,
|
492
|
+
run_quantization_check=run_quantization_check,
|
493
|
+
)
|
494
|
+
|
495
|
+
# Check data_format. We assume that `input_data` is in "channels_last"
|
496
|
+
# format.
|
497
|
+
if run_data_format_check and can_run_data_format_check:
|
498
|
+
keras.config.set_image_data_format("channels_first")
|
499
|
+
input_data_shape = ops.shape(input_data)
|
500
|
+
if len(input_data_shape) == 3:
|
501
|
+
input_data = ops.transpose(input_data, axes=(2, 0, 1))
|
502
|
+
elif len(input_data_shape) == 4:
|
503
|
+
input_data = ops.transpose(input_data, axes=(0, 3, 1, 2))
|
504
|
+
if "image_shape" in init_kwargs:
|
505
|
+
init_kwargs = init_kwargs.copy()
|
506
|
+
init_kwargs["image_shape"] = tuple(
|
507
|
+
reversed(init_kwargs["image_shape"])
|
508
|
+
)
|
509
|
+
self.run_backbone_test(
|
510
|
+
cls=cls,
|
511
|
+
init_kwargs=init_kwargs,
|
512
|
+
input_data=input_data,
|
513
|
+
expected_output_shape=expected_output_shape,
|
514
|
+
variable_length_data=variable_length_data,
|
515
|
+
run_mixed_precision_check=run_mixed_precision_check,
|
516
|
+
run_quantization_check=run_quantization_check,
|
517
|
+
)
|
518
|
+
|
519
|
+
# Restore the original `image_data_format`.
|
520
|
+
keras.config.set_image_data_format(ori_data_format)
|
521
|
+
|
522
|
+
def run_task_test(
|
523
|
+
self,
|
524
|
+
cls,
|
525
|
+
init_kwargs,
|
526
|
+
train_data,
|
527
|
+
expected_output_shape=None,
|
528
|
+
batch_size=2,
|
529
|
+
):
|
530
|
+
"""Run basic tests for a backbone, including compilation."""
|
531
|
+
task = cls(**init_kwargs)
|
532
|
+
# Check serialization (without a full save).
|
533
|
+
self.run_serialization_test(task)
|
534
|
+
preprocessor = task.preprocessor
|
535
|
+
ds = tf.data.Dataset.from_tensor_slices(train_data).batch(batch_size)
|
536
|
+
x, y, sw = keras.utils.unpack_x_y_sample_weight(train_data)
|
537
|
+
|
538
|
+
# Test predict.
|
539
|
+
output = task.predict(x)
|
540
|
+
if expected_output_shape is not None:
|
541
|
+
output_shape = tree.map_structure(lambda x: x.shape, output)
|
542
|
+
self.assertAllClose(output_shape, expected_output_shape)
|
543
|
+
# With a dataset.
|
544
|
+
output_ds = task.predict(ds)
|
545
|
+
self.assertAllClose(output, output_ds)
|
546
|
+
# With split preprocessing.
|
547
|
+
task.preprocessor = None
|
548
|
+
output_split = task.predict(ds.map(preprocessor))
|
549
|
+
task.preprocessor = preprocessor
|
550
|
+
self.assertAllClose(output, output_split)
|
551
|
+
|
552
|
+
# Test fit.
|
553
|
+
task.fit(x, y, sample_weight=sw)
|
554
|
+
# With a dataset.
|
555
|
+
task.fit(ds)
|
556
|
+
# With split preprocessing.
|
557
|
+
task.preprocessor = None
|
558
|
+
task.fit(ds.map(preprocessor))
|
559
|
+
task.preprocessor = preprocessor
|
560
|
+
|
561
|
+
def run_preset_test(
|
562
|
+
self,
|
563
|
+
cls,
|
564
|
+
preset,
|
565
|
+
input_data,
|
566
|
+
init_kwargs={},
|
567
|
+
expected_output=None,
|
568
|
+
expected_output_shape=None,
|
569
|
+
expected_partial_output=None,
|
570
|
+
):
|
571
|
+
"""Run instantiation and a forward pass for a preset."""
|
572
|
+
with self.assertRaises(Exception):
|
573
|
+
cls.from_preset("clowntown", **init_kwargs)
|
574
|
+
|
575
|
+
instance = cls.from_preset(preset, **init_kwargs)
|
576
|
+
|
577
|
+
if isinstance(input_data, tuple):
|
578
|
+
# Mimic tf.data unpacking behavior for preprocessing layers.
|
579
|
+
output = instance(*input_data)
|
580
|
+
else:
|
581
|
+
output = instance(input_data)
|
582
|
+
|
583
|
+
if isinstance(instance, keras.Model):
|
584
|
+
instance = cls.from_preset(
|
585
|
+
preset, load_weights=False, **init_kwargs
|
586
|
+
)
|
587
|
+
instance(input_data)
|
588
|
+
|
589
|
+
if expected_output is not None:
|
590
|
+
self.assertAllClose(output, expected_output)
|
591
|
+
|
592
|
+
if expected_output_shape is not None:
|
593
|
+
output_shape = tree.map_structure(lambda x: x.shape, output)
|
594
|
+
self.assertAllClose(output_shape, expected_output_shape)
|
595
|
+
|
596
|
+
if expected_partial_output is not None:
|
597
|
+
# Allow passing a partial output snippet of the last dimension.
|
598
|
+
# We want check stability, but the full output would be too long.
|
599
|
+
def compare(actual, expected):
|
600
|
+
expected = ops.convert_to_numpy(expected)
|
601
|
+
self.assertEqual(len(expected.shape), 1)
|
602
|
+
actual = ops.reshape(actual, (-1,))[: expected.shape[0]]
|
603
|
+
self.assertAllClose(actual, expected, atol=0.01, rtol=0.01)
|
604
|
+
|
605
|
+
tree.map_structure(compare, output, expected_partial_output)
|
606
|
+
|
607
|
+
def get_test_data_dir(self):
|
608
|
+
return str(pathlib.Path(__file__).parent / "test_data")
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|