keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +0 -6
- keras_hub/api/__init__.py +2 -0
- keras_hub/api/bounding_box/__init__.py +36 -0
- keras_hub/api/layers/__init__.py +14 -0
- keras_hub/api/models/__init__.py +97 -48
- keras_hub/api/tokenizers/__init__.py +30 -0
- keras_hub/api/utils/__init__.py +22 -0
- keras_hub/src/api_export.py +15 -9
- keras_hub/src/bounding_box/__init__.py +13 -0
- keras_hub/src/bounding_box/converters.py +529 -0
- keras_hub/src/bounding_box/formats.py +162 -0
- keras_hub/src/bounding_box/iou.py +263 -0
- keras_hub/src/bounding_box/to_dense.py +95 -0
- keras_hub/src/bounding_box/to_ragged.py +99 -0
- keras_hub/src/bounding_box/utils.py +194 -0
- keras_hub/src/bounding_box/validate_format.py +99 -0
- keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
- keras_hub/src/layers/preprocessing/image_converter.py +130 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
- keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
- keras_hub/src/layers/preprocessing/random_swap.py +33 -31
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
- keras_hub/src/models/albert/__init__.py +1 -2
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
- keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
- keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/albert/albert_tokenizer.py +17 -36
- keras_hub/src/models/backbone.py +12 -34
- keras_hub/src/models/bart/__init__.py +1 -2
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
- keras_hub/src/models/bart/bart_tokenizer.py +12 -39
- keras_hub/src/models/bert/__init__.py +1 -5
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
- keras_hub/src/models/bert/bert_presets.py +1 -4
- keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
- keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/bert/bert_tokenizer.py +17 -35
- keras_hub/src/models/bloom/__init__.py +1 -2
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
- keras_hub/src/models/causal_lm.py +10 -29
- keras_hub/src/models/causal_lm_preprocessor.py +195 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
- keras_hub/src/models/deberta_v3/__init__.py +1 -4
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
- keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
- keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
- keras_hub/src/models/densenet/densenet_backbone.py +46 -22
- keras_hub/src/models/distil_bert/__init__.py +1 -4
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
- keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
- keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
- keras_hub/src/models/efficientnet/__init__.py +13 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
- keras_hub/src/models/efficientnet/mbconv.py +238 -0
- keras_hub/src/models/electra/__init__.py +1 -2
- keras_hub/src/models/electra/electra_tokenizer.py +17 -32
- keras_hub/src/models/f_net/__init__.py +1 -2
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
- keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
- keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
- keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
- keras_hub/src/models/falcon/__init__.py +1 -2
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
- keras_hub/src/models/gemma/__init__.py +1 -2
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
- keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
- keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
- keras_hub/src/models/gpt2/__init__.py +1 -2
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
- keras_hub/src/models/image_classifier.py +0 -5
- keras_hub/src/models/image_classifier_preprocessor.py +83 -0
- keras_hub/src/models/llama/__init__.py +1 -2
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
- keras_hub/src/models/llama/llama_tokenizer.py +12 -25
- keras_hub/src/models/llama3/__init__.py +1 -2
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
- keras_hub/src/models/masked_lm.py +0 -2
- keras_hub/src/models/masked_lm_preprocessor.py +156 -0
- keras_hub/src/models/mistral/__init__.py +1 -2
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
- keras_hub/src/models/mobilenet/__init__.py +13 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
- keras_hub/src/models/opt/__init__.py +1 -2
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
- keras_hub/src/models/opt/opt_tokenizer.py +12 -41
- keras_hub/src/models/pali_gemma/__init__.py +1 -4
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
- keras_hub/src/models/phi3/__init__.py +1 -2
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
- keras_hub/src/models/preprocessor.py +72 -83
- keras_hub/src/models/resnet/__init__.py +6 -0
- keras_hub/src/models/resnet/resnet_backbone.py +390 -42
- keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
- keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
- keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
- keras_hub/src/models/resnet/resnet_presets.py +95 -0
- keras_hub/src/models/retinanet/__init__.py +13 -0
- keras_hub/src/models/retinanet/anchor_generator.py +175 -0
- keras_hub/src/models/retinanet/box_matcher.py +259 -0
- keras_hub/src/models/retinanet/non_max_supression.py +578 -0
- keras_hub/src/models/roberta/__init__.py +1 -2
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
- keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
- keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
- keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
- keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
- keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
- keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
- keras_hub/src/models/t5/__init__.py +1 -2
- keras_hub/src/models/t5/t5_tokenizer.py +13 -23
- keras_hub/src/models/task.py +71 -116
- keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
- keras_hub/src/models/text_classifier_preprocessor.py +138 -0
- keras_hub/src/models/whisper/__init__.py +1 -2
- keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
- keras_hub/src/models/whisper/whisper_backbone.py +0 -3
- keras_hub/src/models/whisper/whisper_presets.py +10 -10
- keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
- keras_hub/src/models/xlm_roberta/__init__.py +1 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
- keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
- keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
- keras_hub/src/tests/test_case.py +46 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
- keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
- keras_hub/src/tokenizers/tokenizer.py +67 -32
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
- keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
- keras_hub/src/utils/imagenet/__init__.py +13 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
- keras_hub/src/utils/keras_utils.py +0 -50
- keras_hub/src/utils/preset_utils.py +230 -68
- keras_hub/src/utils/tensor_utils.py +187 -69
- keras_hub/src/utils/timm/convert_resnet.py +19 -16
- keras_hub/src/utils/timm/preset_loader.py +66 -0
- keras_hub/src/utils/transformers/convert_albert.py +193 -0
- keras_hub/src/utils/transformers/convert_bart.py +373 -0
- keras_hub/src/utils/transformers/convert_bert.py +7 -17
- keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
- keras_hub/src/utils/transformers/convert_gemma.py +5 -19
- keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
- keras_hub/src/utils/transformers/convert_llama3.py +7 -18
- keras_hub/src/utils/transformers/convert_mistral.py +129 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
- keras_hub/src/utils/transformers/preset_loader.py +77 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
- keras_hub/src/version_utils.py +1 -1
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
- keras_hub/src/models/bart/bart_preprocessor.py +0 -276
- keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
- keras_hub/src/models/electra/electra_preprocessor.py +0 -154
- keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
- keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
- keras_hub/src/models/llama/llama_preprocessor.py +0 -189
- keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
- keras_hub/src/models/opt/opt_preprocessor.py +0 -188
- keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
- keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
- keras_hub/src/utils/timm/convert.py +0 -37
- keras_hub/src/utils/transformers/convert.py +0 -101
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,186 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import keras
|
15
|
+
from keras import layers
|
16
|
+
|
17
|
+
from keras_hub.src.models.stable_diffusion_v3.vae_attention import VAEAttention
|
18
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
19
|
+
|
20
|
+
|
21
|
+
class VAEImageDecoder(keras.Model):
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
stackwise_num_filters,
|
25
|
+
stackwise_num_blocks,
|
26
|
+
output_channels=3,
|
27
|
+
latent_shape=(None, None, 16),
|
28
|
+
data_format=None,
|
29
|
+
dtype=None,
|
30
|
+
**kwargs,
|
31
|
+
):
|
32
|
+
data_format = standardize_data_format(data_format)
|
33
|
+
gn_axis = -1 if data_format == "channels_last" else 1
|
34
|
+
|
35
|
+
# === Functional Model ===
|
36
|
+
latent_inputs = layers.Input(shape=latent_shape)
|
37
|
+
|
38
|
+
x = layers.Conv2D(
|
39
|
+
stackwise_num_filters[0],
|
40
|
+
3,
|
41
|
+
1,
|
42
|
+
padding="same",
|
43
|
+
data_format=data_format,
|
44
|
+
dtype=dtype,
|
45
|
+
name="input_projection",
|
46
|
+
)(latent_inputs)
|
47
|
+
x = apply_resnet_block(
|
48
|
+
x,
|
49
|
+
stackwise_num_filters[0],
|
50
|
+
data_format=data_format,
|
51
|
+
dtype=dtype,
|
52
|
+
name="input_block0",
|
53
|
+
)
|
54
|
+
x = VAEAttention(
|
55
|
+
stackwise_num_filters[0],
|
56
|
+
data_format=data_format,
|
57
|
+
dtype=dtype,
|
58
|
+
name="input_attention",
|
59
|
+
)(x)
|
60
|
+
x = apply_resnet_block(
|
61
|
+
x,
|
62
|
+
stackwise_num_filters[0],
|
63
|
+
data_format=data_format,
|
64
|
+
dtype=dtype,
|
65
|
+
name="input_block1",
|
66
|
+
)
|
67
|
+
|
68
|
+
# Stacks.
|
69
|
+
for i, filters in enumerate(stackwise_num_filters):
|
70
|
+
for j in range(stackwise_num_blocks[i]):
|
71
|
+
x = apply_resnet_block(
|
72
|
+
x,
|
73
|
+
filters,
|
74
|
+
data_format=data_format,
|
75
|
+
dtype=dtype,
|
76
|
+
name=f"block{i}_{j}",
|
77
|
+
)
|
78
|
+
if i != len(stackwise_num_filters) - 1:
|
79
|
+
# No upsamling in the last blcok.
|
80
|
+
x = layers.UpSampling2D(
|
81
|
+
2,
|
82
|
+
data_format=data_format,
|
83
|
+
dtype=dtype,
|
84
|
+
name=f"upsample_{i}",
|
85
|
+
)(x)
|
86
|
+
x = layers.Conv2D(
|
87
|
+
filters,
|
88
|
+
3,
|
89
|
+
1,
|
90
|
+
padding="same",
|
91
|
+
data_format=data_format,
|
92
|
+
dtype=dtype,
|
93
|
+
name=f"upsample_{i}_conv",
|
94
|
+
)(x)
|
95
|
+
|
96
|
+
# Ouput block.
|
97
|
+
x = layers.GroupNormalization(
|
98
|
+
groups=32,
|
99
|
+
axis=gn_axis,
|
100
|
+
epsilon=1e-6,
|
101
|
+
dtype=dtype,
|
102
|
+
name="output_norm",
|
103
|
+
)(x)
|
104
|
+
x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
|
105
|
+
image_outputs = layers.Conv2D(
|
106
|
+
output_channels,
|
107
|
+
3,
|
108
|
+
1,
|
109
|
+
padding="same",
|
110
|
+
data_format=data_format,
|
111
|
+
dtype=dtype,
|
112
|
+
name="output_projection",
|
113
|
+
)(x)
|
114
|
+
super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)
|
115
|
+
|
116
|
+
# === Config ===
|
117
|
+
self.stackwise_num_filters = stackwise_num_filters
|
118
|
+
self.stackwise_num_blocks = stackwise_num_blocks
|
119
|
+
self.output_channels = output_channels
|
120
|
+
self.latent_shape = latent_shape
|
121
|
+
|
122
|
+
if dtype is not None:
|
123
|
+
try:
|
124
|
+
self.dtype_policy = keras.dtype_policies.get(dtype)
|
125
|
+
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
|
126
|
+
except AttributeError:
|
127
|
+
if isinstance(dtype, keras.DTypePolicy):
|
128
|
+
dtype = dtype.name
|
129
|
+
self.dtype_policy = keras.DTypePolicy(dtype)
|
130
|
+
|
131
|
+
def get_config(self):
|
132
|
+
config = super().get_config()
|
133
|
+
config.update(
|
134
|
+
{
|
135
|
+
"stackwise_num_filters": self.stackwise_num_filters,
|
136
|
+
"stackwise_num_blocks": self.stackwise_num_blocks,
|
137
|
+
"output_channels": self.output_channels,
|
138
|
+
"image_shape": self.latent_shape,
|
139
|
+
}
|
140
|
+
)
|
141
|
+
return config
|
142
|
+
|
143
|
+
|
144
|
+
def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
|
145
|
+
data_format = standardize_data_format(data_format)
|
146
|
+
gn_axis = -1 if data_format == "channels_last" else 1
|
147
|
+
input_filters = x.shape[gn_axis]
|
148
|
+
|
149
|
+
residual = x
|
150
|
+
x = layers.GroupNormalization(
|
151
|
+
groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm1"
|
152
|
+
)(x)
|
153
|
+
x = layers.Activation("swish", dtype=dtype)(x)
|
154
|
+
x = layers.Conv2D(
|
155
|
+
filters,
|
156
|
+
3,
|
157
|
+
1,
|
158
|
+
padding="same",
|
159
|
+
data_format=data_format,
|
160
|
+
dtype=dtype,
|
161
|
+
name=f"{name}_conv1",
|
162
|
+
)(x)
|
163
|
+
x = layers.GroupNormalization(
|
164
|
+
groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm2"
|
165
|
+
)(x)
|
166
|
+
x = layers.Activation("swish")(x)
|
167
|
+
x = layers.Conv2D(
|
168
|
+
filters,
|
169
|
+
3,
|
170
|
+
1,
|
171
|
+
padding="same",
|
172
|
+
data_format=data_format,
|
173
|
+
dtype=dtype,
|
174
|
+
name=f"{name}_conv2",
|
175
|
+
)(x)
|
176
|
+
if input_filters != filters:
|
177
|
+
residual = layers.Conv2D(
|
178
|
+
filters,
|
179
|
+
1,
|
180
|
+
1,
|
181
|
+
data_format=data_format,
|
182
|
+
dtype=dtype,
|
183
|
+
name=f"{name}_residual_projection",
|
184
|
+
)(residual)
|
185
|
+
x = layers.Add(dtype=dtype)([residual, x])
|
186
|
+
return x
|
@@ -14,7 +14,6 @@
|
|
14
14
|
|
15
15
|
from keras_hub.src.models.t5.t5_backbone import T5Backbone
|
16
16
|
from keras_hub.src.models.t5.t5_presets import backbone_presets
|
17
|
-
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer
|
18
17
|
from keras_hub.src.utils.preset_utils import register_presets
|
19
18
|
|
20
|
-
register_presets(backbone_presets,
|
19
|
+
register_presets(backbone_presets, T5Backbone)
|
@@ -13,12 +13,18 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from keras_hub.src.api_export import keras_hub_export
|
16
|
+
from keras_hub.src.models.t5.t5_backbone import T5Backbone
|
16
17
|
from keras_hub.src.tokenizers.sentence_piece_tokenizer import (
|
17
18
|
SentencePieceTokenizer,
|
18
19
|
)
|
19
20
|
|
20
21
|
|
21
|
-
@keras_hub_export(
|
22
|
+
@keras_hub_export(
|
23
|
+
[
|
24
|
+
"keras_hub.tokenizers.T5Tokenizer",
|
25
|
+
"keras_hub.models.T5Tokenizer",
|
26
|
+
]
|
27
|
+
)
|
22
28
|
class T5Tokenizer(SentencePieceTokenizer):
|
23
29
|
"""T5 tokenizer layer based on SentencePiece.
|
24
30
|
|
@@ -74,27 +80,11 @@ class T5Tokenizer(SentencePieceTokenizer):
|
|
74
80
|
```
|
75
81
|
"""
|
76
82
|
|
77
|
-
|
78
|
-
self.end_token = "</s>"
|
79
|
-
self.pad_token = "<pad>"
|
83
|
+
backbone_cls = T5Backbone
|
80
84
|
|
85
|
+
def __init__(self, proto, **kwargs):
|
86
|
+
# T5 uses the same start token as end token, i.e., "<\s>".
|
87
|
+
self._add_special_token("</s>", "end_token")
|
88
|
+
self._add_special_token("</s>", "start_token")
|
89
|
+
self._add_special_token("<pad>", "pad_token")
|
81
90
|
super().__init__(proto=proto, **kwargs)
|
82
|
-
|
83
|
-
def set_proto(self, proto):
|
84
|
-
super().set_proto(proto)
|
85
|
-
if proto is not None:
|
86
|
-
for token in [self.end_token, self.pad_token]:
|
87
|
-
if token not in self.get_vocabulary():
|
88
|
-
raise ValueError(
|
89
|
-
f"Cannot find token `'{token}'` in the provided "
|
90
|
-
f"`vocabulary`. Please provide `'{token}'` in your "
|
91
|
-
"`vocabulary` or use a pretrained `vocabulary` name."
|
92
|
-
)
|
93
|
-
self.end_token_id = self.token_to_id(self.end_token)
|
94
|
-
self.pad_token_id = self.token_to_id(self.pad_token)
|
95
|
-
# T5 uses the same start token as end token, i.e., "<\s>".
|
96
|
-
self.start_token_id = self.end_token_id
|
97
|
-
else:
|
98
|
-
self.end_token_id = None
|
99
|
-
self.pad_token_id = None
|
100
|
-
self.start_token_id = None
|
keras_hub/src/models/task.py
CHANGED
@@ -22,18 +22,11 @@ from rich import table as rich_table
|
|
22
22
|
from keras_hub.src.api_export import keras_hub_export
|
23
23
|
from keras_hub.src.utils.keras_utils import print_msg
|
24
24
|
from keras_hub.src.utils.pipeline_model import PipelineModel
|
25
|
-
from keras_hub.src.utils.preset_utils import CONFIG_FILE
|
26
|
-
from keras_hub.src.utils.preset_utils import MODEL_WEIGHTS_FILE
|
27
25
|
from keras_hub.src.utils.preset_utils import TASK_CONFIG_FILE
|
28
26
|
from keras_hub.src.utils.preset_utils import TASK_WEIGHTS_FILE
|
29
|
-
from keras_hub.src.utils.preset_utils import
|
30
|
-
from keras_hub.src.utils.preset_utils import
|
31
|
-
from keras_hub.src.utils.preset_utils import
|
32
|
-
from keras_hub.src.utils.preset_utils import get_file
|
33
|
-
from keras_hub.src.utils.preset_utils import jax_memory_cleanup
|
34
|
-
from keras_hub.src.utils.preset_utils import list_presets
|
35
|
-
from keras_hub.src.utils.preset_utils import list_subclasses
|
36
|
-
from keras_hub.src.utils.preset_utils import load_serialized_object
|
27
|
+
from keras_hub.src.utils.preset_utils import builtin_presets
|
28
|
+
from keras_hub.src.utils.preset_utils import find_subclass
|
29
|
+
from keras_hub.src.utils.preset_utils import get_preset_loader
|
37
30
|
from keras_hub.src.utils.preset_utils import save_serialized_object
|
38
31
|
from keras_hub.src.utils.python_utils import classproperty
|
39
32
|
|
@@ -56,12 +49,17 @@ class Task(PipelineModel):
|
|
56
49
|
to load a pre-trained config and weights. Calling `from_preset()` on a task
|
57
50
|
will automatically instantiate a `keras_hub.models.Backbone` and
|
58
51
|
`keras_hub.models.Preprocessor`.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
compile: boolean, defaults to `True`. If `True` will compile the model
|
55
|
+
with default parameters on construction. Model can still be
|
56
|
+
recompiled with a new loss, optimizer and metrics before training.
|
59
57
|
"""
|
60
58
|
|
61
59
|
backbone_cls = None
|
62
60
|
preprocessor_cls = None
|
63
61
|
|
64
|
-
def __init__(self, *args, **kwargs):
|
62
|
+
def __init__(self, *args, compile=True, **kwargs):
|
65
63
|
super().__init__(*args, **kwargs)
|
66
64
|
self._functional_layer_ids = set(
|
67
65
|
id(layer) for layer in self._flatten_layers()
|
@@ -69,6 +67,9 @@ class Task(PipelineModel):
|
|
69
67
|
self._initialized = True
|
70
68
|
if self.backbone is not None:
|
71
69
|
self.dtype_policy = self._backbone.dtype_policy
|
70
|
+
if compile:
|
71
|
+
# Default compilation.
|
72
|
+
self.compile()
|
72
73
|
|
73
74
|
def preprocess_samples(self, x, y=None, sample_weight=None):
|
74
75
|
if self.preprocessor is not None:
|
@@ -131,13 +132,7 @@ class Task(PipelineModel):
|
|
131
132
|
@classproperty
|
132
133
|
def presets(cls):
|
133
134
|
"""List built-in presets for a `Task` subclass."""
|
134
|
-
|
135
|
-
# We can also load backbone presets.
|
136
|
-
if cls.backbone_cls is not None:
|
137
|
-
presets.update(cls.backbone_cls.presets)
|
138
|
-
for subclass in list_subclasses(cls):
|
139
|
-
presets.update(subclass.presets)
|
140
|
-
return presets
|
135
|
+
return builtin_presets(cls)
|
141
136
|
|
142
137
|
@classmethod
|
143
138
|
def from_preset(
|
@@ -149,10 +144,10 @@ class Task(PipelineModel):
|
|
149
144
|
"""Instantiate a `keras_hub.models.Task` from a model preset.
|
150
145
|
|
151
146
|
A preset is a directory of configs, weights and other file assets used
|
152
|
-
to save and load a pre-trained model. The `preset` can be passed as
|
147
|
+
to save and load a pre-trained model. The `preset` can be passed as
|
153
148
|
one of:
|
154
149
|
|
155
|
-
1. a built
|
150
|
+
1. a built-in preset identifier like `'bert_base_en'`
|
156
151
|
2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'`
|
157
152
|
3. a Hugging Face handle like `'hf://user/bert_base_en'`
|
158
153
|
4. a path to a local preset directory like `'./bert_base_en'`
|
@@ -162,16 +157,16 @@ class Task(PipelineModel):
|
|
162
157
|
|
163
158
|
This constructor can be called in one of two ways. Either from a task
|
164
159
|
specific base class like `keras_hub.models.CausalLM.from_preset()`, or
|
165
|
-
from a model class like `keras_hub.models.
|
160
|
+
from a model class like `keras_hub.models.BertTextClassifier.from_preset()`.
|
166
161
|
If calling from the a base class, the subclass of the returning object
|
167
162
|
will be inferred from the config in the preset directory.
|
168
163
|
|
169
164
|
Args:
|
170
|
-
preset: string. A built
|
165
|
+
preset: string. A built-in preset identifier, a Kaggle Models
|
171
166
|
handle, a Hugging Face handle, or a path to a local directory.
|
172
|
-
load_weights: bool. If `True`,
|
173
|
-
model architecture. If `False`,
|
174
|
-
initialized.
|
167
|
+
load_weights: bool. If `True`, saved weights will be loaded into
|
168
|
+
the model architecture. If `False`, all weights will be
|
169
|
+
randomly initialized.
|
175
170
|
|
176
171
|
Examples:
|
177
172
|
```python
|
@@ -181,100 +176,37 @@ class Task(PipelineModel):
|
|
181
176
|
)
|
182
177
|
|
183
178
|
# Load a Bert classification task.
|
184
|
-
model = keras_hub.models.
|
179
|
+
model = keras_hub.models.TextClassifier.from_preset(
|
185
180
|
"bert_base_en",
|
186
181
|
num_classes=2,
|
187
182
|
)
|
188
183
|
```
|
189
184
|
"""
|
190
|
-
format = check_format(preset)
|
191
|
-
|
192
|
-
if format == "transformers":
|
193
|
-
if cls.backbone_cls is None:
|
194
|
-
raise ValueError("Backbone class is None")
|
195
|
-
if cls.preprocessor_cls is None:
|
196
|
-
raise ValueError("Preprocessor class is None")
|
197
|
-
|
198
|
-
backbone = cls.backbone_cls.from_preset(preset)
|
199
|
-
preprocessor = cls.preprocessor_cls.from_preset(preset)
|
200
|
-
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)
|
201
|
-
|
202
185
|
if cls == Task:
|
203
186
|
raise ValueError(
|
204
187
|
"Do not call `Task.from_preset()` directly. Instead call a "
|
205
188
|
"particular task class, e.g. "
|
206
|
-
"`keras_hub.models.
|
207
|
-
"`keras_hub.models.BertClassifier.from_preset()`."
|
208
|
-
)
|
209
|
-
if "backbone" in kwargs:
|
210
|
-
raise ValueError(
|
211
|
-
"You cannot pass a `backbone` argument to the `from_preset` "
|
212
|
-
f"method. Instead, call the {cls.__name__} default "
|
213
|
-
"constructor with a `backbone` argument. "
|
214
|
-
f"Received: backbone={kwargs['backbone']}."
|
189
|
+
"`keras_hub.models.TextClassifier.from_preset()`."
|
215
190
|
)
|
216
191
|
|
217
|
-
|
218
|
-
|
219
|
-
if
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
jax_memory_cleanup(task)
|
229
|
-
if check_file_exists(preset, TASK_WEIGHTS_FILE):
|
230
|
-
task.load_task_weights(get_file(preset, TASK_WEIGHTS_FILE))
|
231
|
-
task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
|
232
|
-
task.preprocessor.tokenizer.load_preset_assets(preset)
|
233
|
-
return task
|
234
|
-
|
235
|
-
# Backbone case.
|
236
|
-
# If `task.json` doesn't exist or the task preset class is different
|
237
|
-
# from the calling class, create the task based on `config.json`.
|
238
|
-
backbone_preset_cls = check_config_class(preset, CONFIG_FILE)
|
239
|
-
if backbone_preset_cls is not cls.backbone_cls:
|
240
|
-
subclasses = list_subclasses(cls)
|
241
|
-
subclasses = tuple(
|
242
|
-
filter(
|
243
|
-
lambda x: x.backbone_cls == backbone_preset_cls,
|
244
|
-
subclasses,
|
245
|
-
)
|
246
|
-
)
|
247
|
-
if len(subclasses) == 0:
|
248
|
-
raise ValueError(
|
249
|
-
f"No registered subclass of `{cls.__name__}` can load "
|
250
|
-
f"a `{backbone_preset_cls.__name__}`."
|
251
|
-
)
|
252
|
-
if len(subclasses) > 1:
|
253
|
-
names = ", ".join(f"`{x.__name__}`" for x in subclasses)
|
254
|
-
raise ValueError(
|
255
|
-
f"Ambiguous call to `{cls.__name__}.from_preset()`. "
|
256
|
-
f"Found multiple possible subclasses {names}. "
|
257
|
-
"Please call `from_preset` on a subclass directly."
|
258
|
-
)
|
259
|
-
cls = subclasses[0]
|
260
|
-
# Forward dtype to the backbone.
|
261
|
-
backbone_kwargs = {}
|
262
|
-
if "dtype" in kwargs:
|
263
|
-
backbone_kwargs = {"dtype": kwargs.pop("dtype")}
|
264
|
-
backbone = backbone_preset_cls.from_preset(
|
265
|
-
preset, load_weights=load_weights, **backbone_kwargs
|
266
|
-
)
|
267
|
-
if "preprocessor" in kwargs:
|
268
|
-
preprocessor = kwargs.pop("preprocessor")
|
269
|
-
else:
|
270
|
-
preprocessor = cls.preprocessor_cls.from_preset(preset)
|
271
|
-
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)
|
192
|
+
loader = get_preset_loader(preset)
|
193
|
+
backbone_cls = loader.check_backbone_class()
|
194
|
+
# Detect the correct subclass if we need to.
|
195
|
+
if cls.backbone_cls != backbone_cls:
|
196
|
+
cls = find_subclass(preset, cls, backbone_cls)
|
197
|
+
# Specifically for classifiers, we never load task weights if
|
198
|
+
# num_classes is supplied. We handle this in the task base class because
|
199
|
+
# it is the same logic for classifiers regardless of modality (text,
|
200
|
+
# images, audio).
|
201
|
+
load_task_weights = "num_classes" not in kwargs
|
202
|
+
return loader.load_task(cls, load_weights, load_task_weights, **kwargs)
|
272
203
|
|
273
204
|
def load_task_weights(self, filepath):
|
274
205
|
"""Load only the tasks specific weights not in the backbone."""
|
275
206
|
if not str(filepath).endswith(".weights.h5"):
|
276
207
|
raise ValueError(
|
277
|
-
"The filename must end in `.weights.h5`.
|
208
|
+
"The filename must end in `.weights.h5`. "
|
209
|
+
f"Received: filepath={filepath}"
|
278
210
|
)
|
279
211
|
backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers())
|
280
212
|
keras.saving.load_weights(
|
@@ -361,7 +293,9 @@ class Task(PipelineModel):
|
|
361
293
|
print_fn = print_msg
|
362
294
|
|
363
295
|
def highlight_number(x):
|
364
|
-
|
296
|
+
if x is None:
|
297
|
+
f"[color(45)]{x}[/]"
|
298
|
+
return f"[color(34)]{x:,}[/]" # Format number with commas.
|
365
299
|
|
366
300
|
def highlight_symbol(x):
|
367
301
|
return f"[color(33)]{x}[/]"
|
@@ -369,6 +303,10 @@ class Task(PipelineModel):
|
|
369
303
|
def bold_text(x):
|
370
304
|
return f"[bold]{x}[/]"
|
371
305
|
|
306
|
+
def highlight_shape(shape):
|
307
|
+
highlighted = [highlight_number(x) for x in shape]
|
308
|
+
return "(" + ", ".join(highlighted) + ")"
|
309
|
+
|
372
310
|
if self.preprocessor:
|
373
311
|
# Create a rich console for printing. Capture for non-interactive logging.
|
374
312
|
if print_fn:
|
@@ -380,27 +318,44 @@ class Task(PipelineModel):
|
|
380
318
|
console = rich_console.Console(highlight=False)
|
381
319
|
|
382
320
|
column_1 = rich_table.Column(
|
383
|
-
"
|
321
|
+
"Layer (type)",
|
384
322
|
justify="left",
|
385
|
-
width=int(0.
|
323
|
+
width=int(0.6 * line_length),
|
386
324
|
)
|
387
325
|
column_2 = rich_table.Column(
|
388
|
-
"
|
326
|
+
"Config",
|
389
327
|
justify="right",
|
390
|
-
width=int(0.
|
328
|
+
width=int(0.4 * line_length),
|
391
329
|
)
|
392
330
|
table = rich_table.Table(
|
393
331
|
column_1, column_2, width=line_length, show_lines=True
|
394
332
|
)
|
333
|
+
|
334
|
+
def add_layer(layer, info):
|
335
|
+
layer_name = markup.escape(layer.name)
|
336
|
+
layer_class = highlight_symbol(
|
337
|
+
markup.escape(layer.__class__.__name__)
|
338
|
+
)
|
339
|
+
table.add_row(
|
340
|
+
f"{layer_name} ({layer_class})",
|
341
|
+
info,
|
342
|
+
)
|
343
|
+
|
395
344
|
tokenizer = self.preprocessor.tokenizer
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
345
|
+
if tokenizer:
|
346
|
+
info = "Vocab size: "
|
347
|
+
info += highlight_number(tokenizer.vocabulary_size())
|
348
|
+
add_layer(tokenizer, info)
|
349
|
+
image_converter = self.preprocessor.image_converter
|
350
|
+
if image_converter:
|
351
|
+
info = "Image size: "
|
352
|
+
info += highlight_shape(image_converter.image_size())
|
353
|
+
add_layer(image_converter, info)
|
354
|
+
audio_converter = self.preprocessor.audio_converter
|
355
|
+
if audio_converter:
|
356
|
+
info = "Audio shape: "
|
357
|
+
info += highlight_shape(audio_converter.audio_shape())
|
358
|
+
add_layer(audio_converter, info)
|
404
359
|
|
405
360
|
# Print the to the console.
|
406
361
|
preprocessor_name = markup.escape(self.preprocessor.name)
|
@@ -17,25 +17,36 @@ from keras_hub.src.api_export import keras_hub_export
|
|
17
17
|
from keras_hub.src.models.task import Task
|
18
18
|
|
19
19
|
|
20
|
-
@keras_hub_export(
|
21
|
-
|
20
|
+
@keras_hub_export(
|
21
|
+
[
|
22
|
+
"keras_hub.models.TextClassifier",
|
23
|
+
"keras_hub.models.Classifier",
|
24
|
+
]
|
25
|
+
)
|
26
|
+
class TextClassifier(Task):
|
22
27
|
"""Base class for all classification tasks.
|
23
28
|
|
24
|
-
`
|
29
|
+
`TextClassifier` tasks wrap a `keras_hub.models.Backbone` and
|
25
30
|
a `keras_hub.models.Preprocessor` to create a model that can be used for
|
26
|
-
sequence classification. `
|
31
|
+
sequence classification. `TextClassifier` tasks take an additional
|
27
32
|
`num_classes` argument, controlling the number of predicted output classes.
|
28
33
|
|
29
34
|
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
|
30
35
|
labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
|
31
36
|
|
32
|
-
All `
|
37
|
+
All `TextClassifier` tasks include a `from_preset()` constructor which can be
|
33
38
|
used to load a pre-trained config and weights.
|
34
39
|
|
40
|
+
Some, but not all, classification presets include classification head
|
41
|
+
weights in a `task.weights.h5` file. For these presets, you can omit passing
|
42
|
+
`num_classes` to restore the saved classification head. For all presets, if
|
43
|
+
`num_classes` is passed as a kwarg to `from_preset()`, the classification
|
44
|
+
head will be randomly initialized.
|
45
|
+
|
35
46
|
Example:
|
36
47
|
```python
|
37
48
|
# Load a BERT classifier with pre-trained weights.
|
38
|
-
classifier = keras_hub.models.
|
49
|
+
classifier = keras_hub.models.TextClassifier.from_preset(
|
39
50
|
"bert_base_en",
|
40
51
|
num_classes=2,
|
41
52
|
)
|
@@ -52,11 +63,6 @@ class Classifier(Task):
|
|
52
63
|
```
|
53
64
|
"""
|
54
65
|
|
55
|
-
def __init__(self, *args, **kwargs):
|
56
|
-
super().__init__(*args, **kwargs)
|
57
|
-
# Default compilation.
|
58
|
-
self.compile()
|
59
|
-
|
60
66
|
def compile(
|
61
67
|
self,
|
62
68
|
optimizer="auto",
|
@@ -65,9 +71,9 @@ class Classifier(Task):
|
|
65
71
|
metrics="auto",
|
66
72
|
**kwargs,
|
67
73
|
):
|
68
|
-
"""Configures the `
|
74
|
+
"""Configures the `TextClassifier` task for training.
|
69
75
|
|
70
|
-
The `
|
76
|
+
The `TextClassifier` task extends the default compilation signature of
|
71
77
|
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
|
72
78
|
`metrics`. To override these defaults, pass any value
|
73
79
|
to these arguments during compilation.
|