keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +0 -6
- keras_hub/api/__init__.py +2 -0
- keras_hub/api/bounding_box/__init__.py +36 -0
- keras_hub/api/layers/__init__.py +14 -0
- keras_hub/api/models/__init__.py +97 -48
- keras_hub/api/tokenizers/__init__.py +30 -0
- keras_hub/api/utils/__init__.py +22 -0
- keras_hub/src/api_export.py +15 -9
- keras_hub/src/bounding_box/__init__.py +13 -0
- keras_hub/src/bounding_box/converters.py +529 -0
- keras_hub/src/bounding_box/formats.py +162 -0
- keras_hub/src/bounding_box/iou.py +263 -0
- keras_hub/src/bounding_box/to_dense.py +95 -0
- keras_hub/src/bounding_box/to_ragged.py +99 -0
- keras_hub/src/bounding_box/utils.py +194 -0
- keras_hub/src/bounding_box/validate_format.py +99 -0
- keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
- keras_hub/src/layers/preprocessing/image_converter.py +130 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
- keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
- keras_hub/src/layers/preprocessing/random_swap.py +33 -31
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
- keras_hub/src/models/albert/__init__.py +1 -2
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
- keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
- keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/albert/albert_tokenizer.py +17 -36
- keras_hub/src/models/backbone.py +12 -34
- keras_hub/src/models/bart/__init__.py +1 -2
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
- keras_hub/src/models/bart/bart_tokenizer.py +12 -39
- keras_hub/src/models/bert/__init__.py +1 -5
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
- keras_hub/src/models/bert/bert_presets.py +1 -4
- keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
- keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/bert/bert_tokenizer.py +17 -35
- keras_hub/src/models/bloom/__init__.py +1 -2
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
- keras_hub/src/models/causal_lm.py +10 -29
- keras_hub/src/models/causal_lm_preprocessor.py +195 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
- keras_hub/src/models/deberta_v3/__init__.py +1 -4
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
- keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
- keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
- keras_hub/src/models/densenet/densenet_backbone.py +46 -22
- keras_hub/src/models/distil_bert/__init__.py +1 -4
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
- keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
- keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
- keras_hub/src/models/efficientnet/__init__.py +13 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
- keras_hub/src/models/efficientnet/mbconv.py +238 -0
- keras_hub/src/models/electra/__init__.py +1 -2
- keras_hub/src/models/electra/electra_tokenizer.py +17 -32
- keras_hub/src/models/f_net/__init__.py +1 -2
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
- keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
- keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
- keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
- keras_hub/src/models/falcon/__init__.py +1 -2
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
- keras_hub/src/models/gemma/__init__.py +1 -2
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
- keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
- keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
- keras_hub/src/models/gpt2/__init__.py +1 -2
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
- keras_hub/src/models/image_classifier.py +0 -5
- keras_hub/src/models/image_classifier_preprocessor.py +83 -0
- keras_hub/src/models/llama/__init__.py +1 -2
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
- keras_hub/src/models/llama/llama_tokenizer.py +12 -25
- keras_hub/src/models/llama3/__init__.py +1 -2
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
- keras_hub/src/models/masked_lm.py +0 -2
- keras_hub/src/models/masked_lm_preprocessor.py +156 -0
- keras_hub/src/models/mistral/__init__.py +1 -2
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
- keras_hub/src/models/mobilenet/__init__.py +13 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
- keras_hub/src/models/opt/__init__.py +1 -2
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
- keras_hub/src/models/opt/opt_tokenizer.py +12 -41
- keras_hub/src/models/pali_gemma/__init__.py +1 -4
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
- keras_hub/src/models/phi3/__init__.py +1 -2
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
- keras_hub/src/models/preprocessor.py +72 -83
- keras_hub/src/models/resnet/__init__.py +6 -0
- keras_hub/src/models/resnet/resnet_backbone.py +390 -42
- keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
- keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
- keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
- keras_hub/src/models/resnet/resnet_presets.py +95 -0
- keras_hub/src/models/retinanet/__init__.py +13 -0
- keras_hub/src/models/retinanet/anchor_generator.py +175 -0
- keras_hub/src/models/retinanet/box_matcher.py +259 -0
- keras_hub/src/models/retinanet/non_max_supression.py +578 -0
- keras_hub/src/models/roberta/__init__.py +1 -2
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
- keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
- keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
- keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
- keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
- keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
- keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
- keras_hub/src/models/t5/__init__.py +1 -2
- keras_hub/src/models/t5/t5_tokenizer.py +13 -23
- keras_hub/src/models/task.py +71 -116
- keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
- keras_hub/src/models/text_classifier_preprocessor.py +138 -0
- keras_hub/src/models/whisper/__init__.py +1 -2
- keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
- keras_hub/src/models/whisper/whisper_backbone.py +0 -3
- keras_hub/src/models/whisper/whisper_presets.py +10 -10
- keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
- keras_hub/src/models/xlm_roberta/__init__.py +1 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
- keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
- keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
- keras_hub/src/tests/test_case.py +46 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
- keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
- keras_hub/src/tokenizers/tokenizer.py +67 -32
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
- keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
- keras_hub/src/utils/imagenet/__init__.py +13 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
- keras_hub/src/utils/keras_utils.py +0 -50
- keras_hub/src/utils/preset_utils.py +230 -68
- keras_hub/src/utils/tensor_utils.py +187 -69
- keras_hub/src/utils/timm/convert_resnet.py +19 -16
- keras_hub/src/utils/timm/preset_loader.py +66 -0
- keras_hub/src/utils/transformers/convert_albert.py +193 -0
- keras_hub/src/utils/transformers/convert_bart.py +373 -0
- keras_hub/src/utils/transformers/convert_bert.py +7 -17
- keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
- keras_hub/src/utils/transformers/convert_gemma.py +5 -19
- keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
- keras_hub/src/utils/transformers/convert_llama3.py +7 -18
- keras_hub/src/utils/transformers/convert_mistral.py +129 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
- keras_hub/src/utils/transformers/preset_loader.py +77 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
- keras_hub/src/version_utils.py +1 -1
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
- keras_hub/src/models/bart/bart_preprocessor.py +0 -276
- keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
- keras_hub/src/models/electra/electra_preprocessor.py +0 -154
- keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
- keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
- keras_hub/src/models/llama/llama_preprocessor.py +0 -189
- keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
- keras_hub/src/models/opt/opt_preprocessor.py +0 -188
- keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
- keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
- keras_hub/src/utils/timm/convert.py +0 -37
- keras_hub/src/utils/transformers/convert.py +0 -101
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,103 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from keras import layers
|
15
|
+
from keras import ops
|
16
|
+
|
17
|
+
|
18
|
+
def quick_gelu(x):
|
19
|
+
return x * ops.sigmoid(1.702 * x)
|
20
|
+
|
21
|
+
|
22
|
+
class CLIPEncoderBlock(layers.Layer):
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
hidden_dim,
|
26
|
+
num_heads,
|
27
|
+
intermediate_dim,
|
28
|
+
intermediate_activation="quick_gelu",
|
29
|
+
**kwargs,
|
30
|
+
):
|
31
|
+
super().__init__(**kwargs)
|
32
|
+
if hidden_dim % num_heads != 0:
|
33
|
+
raise ValueError(
|
34
|
+
"`hidden_dim` must be divisible by `num_heads`. "
|
35
|
+
f"Received: hidden_dim={hidden_dim}, num_heads={num_heads}"
|
36
|
+
)
|
37
|
+
self.hidden_dim = hidden_dim
|
38
|
+
self.num_heads = num_heads
|
39
|
+
self.intermediate_dim = intermediate_dim
|
40
|
+
self.intermediate_activation = intermediate_activation
|
41
|
+
|
42
|
+
if intermediate_activation == "quick_gelu":
|
43
|
+
intermediate_activation = quick_gelu
|
44
|
+
|
45
|
+
self.layer_norm_1 = layers.LayerNormalization(
|
46
|
+
epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_1"
|
47
|
+
)
|
48
|
+
self.attention = layers.MultiHeadAttention(
|
49
|
+
num_heads,
|
50
|
+
hidden_dim // num_heads,
|
51
|
+
dtype=self.dtype_policy,
|
52
|
+
name="attention",
|
53
|
+
)
|
54
|
+
self.layer_norm_2 = layers.LayerNormalization(
|
55
|
+
epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_2"
|
56
|
+
)
|
57
|
+
self.dense_1 = layers.Dense(
|
58
|
+
self.intermediate_dim, dtype=self.dtype_policy, name="dense_1"
|
59
|
+
)
|
60
|
+
self.activation = layers.Activation(
|
61
|
+
intermediate_activation, dtype=self.dtype_policy, name="activation"
|
62
|
+
)
|
63
|
+
self.dense_2 = layers.Dense(
|
64
|
+
self.hidden_dim, dtype=self.dtype_policy, name="dense_2"
|
65
|
+
)
|
66
|
+
|
67
|
+
def build(self, input_shape):
|
68
|
+
self.layer_norm_1.build(input_shape)
|
69
|
+
self.attention.build(input_shape, input_shape, input_shape)
|
70
|
+
self.layer_norm_2.build(input_shape)
|
71
|
+
self.dense_1.build(input_shape)
|
72
|
+
input_shape = self.dense_1.compute_output_shape(input_shape)
|
73
|
+
self.dense_2.build(input_shape)
|
74
|
+
|
75
|
+
def compute_output_shape(self, inputs_shape):
|
76
|
+
outputs_shape = list(inputs_shape)
|
77
|
+
outputs_shape[-1] = self.hidden_dim
|
78
|
+
return outputs_shape
|
79
|
+
|
80
|
+
def call(self, x, training=None):
|
81
|
+
residual = x
|
82
|
+
x = self.layer_norm_1(x)
|
83
|
+
x = self.attention(x, x, x, training=training, use_causal_mask=True)
|
84
|
+
x = ops.add(residual, x)
|
85
|
+
|
86
|
+
residual = x
|
87
|
+
x = self.dense_1(self.layer_norm_2(residual))
|
88
|
+
x = self.activation(x)
|
89
|
+
x = self.dense_2(x)
|
90
|
+
x = ops.add(residual, x)
|
91
|
+
return x
|
92
|
+
|
93
|
+
def get_config(self):
|
94
|
+
config = super().get_config()
|
95
|
+
config.update(
|
96
|
+
{
|
97
|
+
"hidden_dim": self.hidden_dim,
|
98
|
+
"num_heads": self.num_heads,
|
99
|
+
"intermediate_dim": self.intermediate_dim,
|
100
|
+
"intermediate_activation": self.intermediate_activation,
|
101
|
+
}
|
102
|
+
)
|
103
|
+
return config
|
@@ -0,0 +1,93 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import keras
|
15
|
+
|
16
|
+
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
|
17
|
+
from keras_hub.src.models.preprocessor import Preprocessor
|
18
|
+
from keras_hub.src.models.stable_diffusion_v3.clip_tokenizer import (
|
19
|
+
CLIPTokenizer,
|
20
|
+
)
|
21
|
+
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
22
|
+
|
23
|
+
try:
|
24
|
+
import tensorflow as tf
|
25
|
+
except ImportError:
|
26
|
+
tf = None
|
27
|
+
|
28
|
+
|
29
|
+
class CLIPPreprocessor(Preprocessor):
|
30
|
+
tokenizer_cls = CLIPTokenizer
|
31
|
+
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
tokenizer,
|
35
|
+
sequence_length=77,
|
36
|
+
add_start_token=True,
|
37
|
+
add_end_token=False,
|
38
|
+
to_lower=True,
|
39
|
+
pad_with_end_token=True,
|
40
|
+
**kwargs,
|
41
|
+
):
|
42
|
+
super().__init__(**kwargs)
|
43
|
+
self.tokenizer = tokenizer
|
44
|
+
self.sequence_length = sequence_length
|
45
|
+
self.add_start_token = add_start_token
|
46
|
+
self.add_end_token = add_end_token
|
47
|
+
self.to_lower = to_lower
|
48
|
+
self.pad_with_end_token = pad_with_end_token
|
49
|
+
|
50
|
+
def build(self, input_shape):
|
51
|
+
# Defer packer creation to `build()` so that we can be sure tokenizer
|
52
|
+
# assets have loaded when restoring a saved model.
|
53
|
+
pad_value = self.tokenizer.pad_token_id
|
54
|
+
if self.pad_with_end_token:
|
55
|
+
pad_value = self.tokenizer.end_token_id
|
56
|
+
|
57
|
+
self.packer = StartEndPacker(
|
58
|
+
start_value=self.tokenizer.start_token_id,
|
59
|
+
end_value=self.tokenizer.end_token_id,
|
60
|
+
pad_value=pad_value,
|
61
|
+
sequence_length=self.sequence_length,
|
62
|
+
return_padding_mask=True,
|
63
|
+
)
|
64
|
+
self.built = True
|
65
|
+
|
66
|
+
@preprocessing_function
|
67
|
+
def call(self, x, y=None, sample_weight=None, sequence_length=None):
|
68
|
+
if self.to_lower:
|
69
|
+
x = tf.strings.lower(x)
|
70
|
+
token_ids, padding_mask = self.packer(
|
71
|
+
self.tokenizer(x),
|
72
|
+
sequence_length=sequence_length or self.sequence_length,
|
73
|
+
add_start_value=self.add_start_token,
|
74
|
+
add_end_value=self.add_end_token,
|
75
|
+
)
|
76
|
+
x = {
|
77
|
+
"token_ids": token_ids,
|
78
|
+
"padding_mask": padding_mask,
|
79
|
+
}
|
80
|
+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
81
|
+
|
82
|
+
def get_config(self):
|
83
|
+
config = super().get_config()
|
84
|
+
config.update(
|
85
|
+
{
|
86
|
+
"sequence_length": self.sequence_length,
|
87
|
+
"add_start_token": self.add_start_token,
|
88
|
+
"add_end_token": self.add_end_token,
|
89
|
+
"to_lower": self.to_lower,
|
90
|
+
"pad_with_end_token": self.pad_with_end_token,
|
91
|
+
}
|
92
|
+
)
|
93
|
+
return config
|
@@ -0,0 +1,149 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import keras
|
15
|
+
from keras import layers
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.layers.modeling.token_and_position_embedding import (
|
19
|
+
TokenAndPositionEmbedding,
|
20
|
+
)
|
21
|
+
from keras_hub.src.models.stable_diffusion_v3.clip_encoder_block import (
|
22
|
+
CLIPEncoderBlock,
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
class CLIPTextEncoder(keras.Model):
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
embedding_dim,
|
30
|
+
hidden_dim,
|
31
|
+
num_layers,
|
32
|
+
num_heads,
|
33
|
+
intermediate_dim,
|
34
|
+
intermediate_activation="quick_gelu",
|
35
|
+
intermediate_output_index=None,
|
36
|
+
vocabulary_size=49408,
|
37
|
+
sequence_length=77,
|
38
|
+
dtype=None,
|
39
|
+
**kwargs,
|
40
|
+
):
|
41
|
+
if (
|
42
|
+
intermediate_output_index is not None
|
43
|
+
and intermediate_output_index < 0
|
44
|
+
):
|
45
|
+
intermediate_output_index += num_layers
|
46
|
+
|
47
|
+
# === Layers ===
|
48
|
+
self.embedding = TokenAndPositionEmbedding(
|
49
|
+
vocabulary_size=vocabulary_size,
|
50
|
+
sequence_length=sequence_length,
|
51
|
+
embedding_dim=embedding_dim,
|
52
|
+
dtype=dtype,
|
53
|
+
name="embedding",
|
54
|
+
)
|
55
|
+
self.encoder_layers = [
|
56
|
+
CLIPEncoderBlock(
|
57
|
+
hidden_dim,
|
58
|
+
num_heads,
|
59
|
+
intermediate_dim,
|
60
|
+
intermediate_activation,
|
61
|
+
dtype=dtype,
|
62
|
+
)
|
63
|
+
for _ in range(num_layers)
|
64
|
+
]
|
65
|
+
self.layer_norm = layers.LayerNormalization(
|
66
|
+
epsilon=0.00001, dtype=dtype, name="layer_norm"
|
67
|
+
)
|
68
|
+
self.text_projection = layers.Dense(
|
69
|
+
hidden_dim,
|
70
|
+
use_bias=False,
|
71
|
+
dtype=dtype,
|
72
|
+
name="text_projection",
|
73
|
+
)
|
74
|
+
|
75
|
+
# === Functional Model ===
|
76
|
+
encoder_token_ids = layers.Input(
|
77
|
+
shape=(sequence_length,), dtype="int32", name="encoder_token_ids"
|
78
|
+
)
|
79
|
+
x = self.embedding(encoder_token_ids)
|
80
|
+
encoder_intermediate_output = None
|
81
|
+
# Encoder.
|
82
|
+
for i, block in enumerate(self.encoder_layers):
|
83
|
+
x = block(x)
|
84
|
+
if i == intermediate_output_index:
|
85
|
+
encoder_intermediate_output = x
|
86
|
+
x = self.layer_norm(x)
|
87
|
+
encoder_output = x
|
88
|
+
if encoder_intermediate_output is not None:
|
89
|
+
encoder_intermediate_output = self.layer_norm(
|
90
|
+
encoder_intermediate_output
|
91
|
+
)
|
92
|
+
# Projection.
|
93
|
+
indices = ops.expand_dims(
|
94
|
+
ops.cast(ops.argmax(encoder_token_ids, axis=-1), "int32"), axis=-1
|
95
|
+
)
|
96
|
+
pooled_output = ops.take_along_axis(x, indices[:, :, None], axis=1)
|
97
|
+
pooled_output = ops.squeeze(pooled_output, axis=1)
|
98
|
+
projection_output = self.text_projection(pooled_output)
|
99
|
+
|
100
|
+
outputs = {
|
101
|
+
"encoder_sequence_output": encoder_output,
|
102
|
+
"encoder_pooled_output": pooled_output,
|
103
|
+
"encoder_projection_output": projection_output,
|
104
|
+
}
|
105
|
+
if intermediate_output_index is not None:
|
106
|
+
outputs["encoder_intermediate_output"] = encoder_intermediate_output
|
107
|
+
|
108
|
+
super().__init__(
|
109
|
+
inputs={"encoder_token_ids": encoder_token_ids},
|
110
|
+
outputs=outputs,
|
111
|
+
**kwargs,
|
112
|
+
)
|
113
|
+
|
114
|
+
# === Config ===
|
115
|
+
self.embedding_dim = embedding_dim
|
116
|
+
self.hidden_dim = hidden_dim
|
117
|
+
self.num_layers = num_layers
|
118
|
+
self.num_heads = num_heads
|
119
|
+
self.intermediate_dim = intermediate_dim
|
120
|
+
self.intermediate_activation = intermediate_activation
|
121
|
+
self.intermediate_output_index = intermediate_output_index
|
122
|
+
self.vocabulary_size = vocabulary_size
|
123
|
+
self.sequence_length = sequence_length
|
124
|
+
|
125
|
+
if dtype is not None:
|
126
|
+
try:
|
127
|
+
self.dtype_policy = keras.dtype_policies.get(dtype)
|
128
|
+
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
|
129
|
+
except AttributeError:
|
130
|
+
if isinstance(dtype, keras.DTypePolicy):
|
131
|
+
dtype = dtype.name
|
132
|
+
self.dtype_policy = keras.DTypePolicy(dtype)
|
133
|
+
|
134
|
+
def get_config(self):
|
135
|
+
config = super().get_config()
|
136
|
+
config.update(
|
137
|
+
{
|
138
|
+
"embedding_dim": self.embedding_dim,
|
139
|
+
"hidden_dim": self.hidden_dim,
|
140
|
+
"num_layers": self.num_layers,
|
141
|
+
"num_heads": self.num_heads,
|
142
|
+
"intermediate_dim": self.intermediate_dim,
|
143
|
+
"intermediate_activation": self.intermediate_activation,
|
144
|
+
"intermediate_output_index": self.intermediate_output_index,
|
145
|
+
"vocabulary_size": self.vocabulary_size,
|
146
|
+
"sequence_length": self.sequence_length,
|
147
|
+
}
|
148
|
+
)
|
149
|
+
return config
|
@@ -0,0 +1,167 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
|
15
|
+
from keras_hub.src.tokenizers.byte_pair_tokenizer import convert_to_ragged_batch
|
16
|
+
from keras_hub.src.tokenizers.byte_pair_tokenizer import split_strings_for_bpe
|
17
|
+
|
18
|
+
try:
|
19
|
+
import tensorflow as tf
|
20
|
+
except ImportError:
|
21
|
+
tf = None
|
22
|
+
|
23
|
+
|
24
|
+
class CLIPTokenizer(BytePairTokenizer):
|
25
|
+
def __init__(self, vocabulary=None, merges=None, **kwargs):
|
26
|
+
self.start_token = "<|startoftext|>"
|
27
|
+
self.end_token = "<|endoftext|>"
|
28
|
+
|
29
|
+
super().__init__(
|
30
|
+
vocabulary=vocabulary,
|
31
|
+
merges=merges,
|
32
|
+
unsplittable_tokens=[self.start_token, self.end_token],
|
33
|
+
**kwargs,
|
34
|
+
)
|
35
|
+
|
36
|
+
def set_vocabulary_and_merges(self, vocabulary, merges):
|
37
|
+
super().set_vocabulary_and_merges(vocabulary, merges)
|
38
|
+
|
39
|
+
if vocabulary is not None:
|
40
|
+
# Check for necessary special tokens.
|
41
|
+
if self.end_token not in self.get_vocabulary():
|
42
|
+
raise ValueError(
|
43
|
+
f"Cannot find token `'{self.end_token}'` in the provided "
|
44
|
+
f"`vocabulary`. Please provide `'{self.end_token}'` in "
|
45
|
+
"your `vocabulary` or use a pretrained `vocabulary` name."
|
46
|
+
)
|
47
|
+
|
48
|
+
self.start_token_id = self.token_to_id(self.start_token)
|
49
|
+
self.end_token_id = self.token_to_id(self.end_token)
|
50
|
+
self.pad_token_id = 0
|
51
|
+
else:
|
52
|
+
self.end_token_id = None
|
53
|
+
self.start_token_id = None
|
54
|
+
self.pad_token_id = None
|
55
|
+
|
56
|
+
def _bpe_merge_and_update_cache(self, tokens):
|
57
|
+
"""Process unseen tokens and add to cache."""
|
58
|
+
words = self._transform_bytes(tokens)
|
59
|
+
|
60
|
+
# In StableDiffusionV3, we need to add `</w>` to the last word.
|
61
|
+
words = tf.strings.reduce_join(words, axis=1, separator=" ")
|
62
|
+
words = tf.strings.join([words, "</w>"])
|
63
|
+
words = tf.strings.split(words, sep=" ")
|
64
|
+
|
65
|
+
tokenized_words = self._bpe_merge(words)
|
66
|
+
|
67
|
+
# For each word, join all its token by a whitespace,
|
68
|
+
# e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
|
69
|
+
tokenized_words = tf.strings.reduce_join(
|
70
|
+
tokenized_words, axis=1, separator=" "
|
71
|
+
)
|
72
|
+
self.cache.insert(tokens, tokenized_words)
|
73
|
+
|
74
|
+
def tokenize(self, inputs):
|
75
|
+
self._check_vocabulary()
|
76
|
+
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
|
77
|
+
inputs = tf.convert_to_tensor(inputs)
|
78
|
+
|
79
|
+
if self.add_prefix_space:
|
80
|
+
inputs = tf.strings.join([" ", inputs])
|
81
|
+
|
82
|
+
scalar_input = inputs.shape.rank == 0
|
83
|
+
if scalar_input:
|
84
|
+
inputs = tf.expand_dims(inputs, 0)
|
85
|
+
|
86
|
+
raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
|
87
|
+
|
88
|
+
# Strip and remove empty tokens.
|
89
|
+
raw_tokens = tf.strings.strip(raw_tokens)
|
90
|
+
raw_tokens = tf.ragged.boolean_mask(raw_tokens, raw_tokens != "")
|
91
|
+
|
92
|
+
token_row_splits = raw_tokens.row_splits
|
93
|
+
flat_tokens = raw_tokens.flat_values
|
94
|
+
|
95
|
+
# Check cache.
|
96
|
+
cache_lookup = self.cache.lookup(flat_tokens)
|
97
|
+
cache_mask = cache_lookup == ""
|
98
|
+
|
99
|
+
has_unseen_words = tf.math.reduce_any(
|
100
|
+
(cache_lookup == "") & (flat_tokens != "")
|
101
|
+
)
|
102
|
+
|
103
|
+
def process_unseen_tokens():
|
104
|
+
unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
|
105
|
+
self._bpe_merge_and_update_cache(unseen_tokens)
|
106
|
+
return self.cache.lookup(flat_tokens)
|
107
|
+
|
108
|
+
# If `has_unseen_words == True`, it means not all tokens are in cache,
|
109
|
+
# we will process the unseen tokens. Otherwise return the cache lookup.
|
110
|
+
tokenized_words = tf.cond(
|
111
|
+
has_unseen_words,
|
112
|
+
process_unseen_tokens,
|
113
|
+
lambda: cache_lookup,
|
114
|
+
)
|
115
|
+
|
116
|
+
tokens = tf.strings.split(tokenized_words, sep=" ")
|
117
|
+
if self.compute_dtype != tf.string:
|
118
|
+
# Encode merged tokens.
|
119
|
+
tokens = self.token_to_id_map.lookup(tokens)
|
120
|
+
|
121
|
+
# Unflatten to match input.
|
122
|
+
tokens = tf.RaggedTensor.from_row_splits(
|
123
|
+
tokens.flat_values,
|
124
|
+
tf.gather(tokens.row_splits, token_row_splits),
|
125
|
+
)
|
126
|
+
|
127
|
+
# Convert to a dense output if `sequence_length` is set.
|
128
|
+
if self.sequence_length:
|
129
|
+
output_shape = tokens.shape.as_list()
|
130
|
+
output_shape[-1] = self.sequence_length
|
131
|
+
tokens = tokens.to_tensor(shape=output_shape)
|
132
|
+
|
133
|
+
# Convert to a dense output if input in scalar
|
134
|
+
if scalar_input:
|
135
|
+
tokens = tf.squeeze(tokens, 0)
|
136
|
+
tf.ensure_shape(tokens, shape=[self.sequence_length])
|
137
|
+
|
138
|
+
return tokens
|
139
|
+
|
140
|
+
def detokenize(self, inputs):
|
141
|
+
self._check_vocabulary()
|
142
|
+
inputs, unbatched, _ = convert_to_ragged_batch(inputs)
|
143
|
+
inputs = tf.cast(inputs, self.dtype)
|
144
|
+
unicode_text = tf.strings.reduce_join(
|
145
|
+
self.id_to_token_map.lookup(inputs), axis=-1
|
146
|
+
)
|
147
|
+
|
148
|
+
# When detokenizing, we need to remove </w> and extra whitespace.
|
149
|
+
unicode_text = tf.strings.regex_replace(unicode_text, r"</w>", " ")
|
150
|
+
unicode_text = tf.strings.strip(unicode_text)
|
151
|
+
|
152
|
+
split_unicode_text = tf.strings.unicode_split(unicode_text, "UTF-8")
|
153
|
+
outputs = tf.strings.reduce_join(
|
154
|
+
self.unicode2byte.lookup(split_unicode_text), axis=-1
|
155
|
+
)
|
156
|
+
|
157
|
+
if unbatched:
|
158
|
+
outputs = tf.squeeze(outputs, 0)
|
159
|
+
return outputs
|
160
|
+
|
161
|
+
def get_config(self):
|
162
|
+
config = super().get_config()
|
163
|
+
# In the constructor, we pass the list of special tokens to the
|
164
|
+
# `unsplittable_tokens` arg of the superclass' constructor. Hence, we
|
165
|
+
# delete it from the config here.
|
166
|
+
del config["unsplittable_tokens"]
|
167
|
+
return config
|