keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev20240915160609__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/__init__.py +1 -0
- keras_hub/api/bounding_box/__init__.py +36 -0
- keras_hub/api/layers/__init__.py +14 -0
- keras_hub/api/models/__init__.py +97 -48
- keras_hub/api/tokenizers/__init__.py +30 -0
- keras_hub/src/bounding_box/__init__.py +13 -0
- keras_hub/src/bounding_box/converters.py +529 -0
- keras_hub/src/bounding_box/formats.py +162 -0
- keras_hub/src/bounding_box/iou.py +263 -0
- keras_hub/src/bounding_box/to_dense.py +95 -0
- keras_hub/src/bounding_box/to_ragged.py +99 -0
- keras_hub/src/bounding_box/utils.py +194 -0
- keras_hub/src/bounding_box/validate_format.py +99 -0
- keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
- keras_hub/src/layers/preprocessing/image_converter.py +130 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
- keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
- keras_hub/src/layers/preprocessing/random_swap.py +33 -31
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
- keras_hub/src/models/albert/__init__.py +1 -2
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
- keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
- keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/albert/albert_tokenizer.py +17 -36
- keras_hub/src/models/backbone.py +12 -34
- keras_hub/src/models/bart/__init__.py +1 -2
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
- keras_hub/src/models/bart/bart_tokenizer.py +12 -39
- keras_hub/src/models/bert/__init__.py +1 -5
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
- keras_hub/src/models/bert/bert_presets.py +1 -4
- keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
- keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/bert/bert_tokenizer.py +17 -35
- keras_hub/src/models/bloom/__init__.py +1 -2
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
- keras_hub/src/models/causal_lm.py +10 -29
- keras_hub/src/models/causal_lm_preprocessor.py +195 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
- keras_hub/src/models/deberta_v3/__init__.py +1 -4
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
- keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
- keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
- keras_hub/src/models/densenet/densenet_backbone.py +46 -22
- keras_hub/src/models/distil_bert/__init__.py +1 -4
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
- keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
- keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
- keras_hub/src/models/efficientnet/__init__.py +13 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
- keras_hub/src/models/efficientnet/mbconv.py +238 -0
- keras_hub/src/models/electra/__init__.py +1 -2
- keras_hub/src/models/electra/electra_tokenizer.py +17 -32
- keras_hub/src/models/f_net/__init__.py +1 -2
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
- keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
- keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
- keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
- keras_hub/src/models/falcon/__init__.py +1 -2
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
- keras_hub/src/models/gemma/__init__.py +1 -2
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
- keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
- keras_hub/src/models/gpt2/__init__.py +1 -2
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
- keras_hub/src/models/image_classifier.py +0 -5
- keras_hub/src/models/image_classifier_preprocessor.py +83 -0
- keras_hub/src/models/llama/__init__.py +1 -2
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
- keras_hub/src/models/llama/llama_tokenizer.py +12 -25
- keras_hub/src/models/llama3/__init__.py +1 -2
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
- keras_hub/src/models/masked_lm.py +0 -2
- keras_hub/src/models/masked_lm_preprocessor.py +156 -0
- keras_hub/src/models/mistral/__init__.py +1 -2
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
- keras_hub/src/models/mobilenet/__init__.py +13 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
- keras_hub/src/models/opt/__init__.py +1 -2
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
- keras_hub/src/models/opt/opt_tokenizer.py +12 -41
- keras_hub/src/models/pali_gemma/__init__.py +1 -4
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
- keras_hub/src/models/phi3/__init__.py +1 -2
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
- keras_hub/src/models/preprocessor.py +72 -83
- keras_hub/src/models/resnet/__init__.py +6 -0
- keras_hub/src/models/resnet/resnet_backbone.py +390 -42
- keras_hub/src/models/resnet/resnet_image_classifier.py +24 -3
- keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
- keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
- keras_hub/src/models/resnet/resnet_presets.py +95 -0
- keras_hub/src/models/roberta/__init__.py +1 -2
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
- keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
- keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
- keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
- keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
- keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
- keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
- keras_hub/src/models/t5/__init__.py +1 -2
- keras_hub/src/models/t5/t5_tokenizer.py +13 -23
- keras_hub/src/models/task.py +71 -116
- keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
- keras_hub/src/models/text_classifier_preprocessor.py +138 -0
- keras_hub/src/models/whisper/__init__.py +1 -2
- keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
- keras_hub/src/models/whisper/whisper_backbone.py +0 -3
- keras_hub/src/models/whisper/whisper_presets.py +10 -10
- keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
- keras_hub/src/models/xlm_roberta/__init__.py +1 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
- keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
- keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
- keras_hub/src/tests/test_case.py +38 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +29 -17
- keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +19 -7
- keras_hub/src/tokenizers/tokenizer.py +67 -32
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
- keras_hub/src/tokenizers/word_piece_tokenizer.py +33 -47
- keras_hub/src/utils/keras_utils.py +0 -50
- keras_hub/src/utils/preset_utils.py +220 -67
- keras_hub/src/utils/tensor_utils.py +187 -69
- keras_hub/src/utils/timm/convert_resnet.py +19 -16
- keras_hub/src/utils/timm/preset_loader.py +66 -0
- keras_hub/src/utils/transformers/convert_albert.py +193 -0
- keras_hub/src/utils/transformers/convert_bart.py +373 -0
- keras_hub/src/utils/transformers/convert_bert.py +7 -17
- keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
- keras_hub/src/utils/transformers/convert_gemma.py +5 -19
- keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
- keras_hub/src/utils/transformers/convert_llama3.py +7 -18
- keras_hub/src/utils/transformers/convert_mistral.py +129 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
- keras_hub/src/utils/transformers/preset_loader.py +77 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/METADATA +1 -2
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/RECORD +173 -143
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/WHEEL +1 -1
- keras_hub/src/models/bart/bart_preprocessor.py +0 -276
- keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
- keras_hub/src/models/electra/electra_preprocessor.py +0 -154
- keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
- keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
- keras_hub/src/models/llama/llama_preprocessor.py +0 -189
- keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
- keras_hub/src/models/opt/opt_preprocessor.py +0 -188
- keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
- keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
- keras_hub/src/utils/timm/convert.py +0 -37
- keras_hub/src/utils/transformers/convert.py +0 -101
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,167 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
|
15
|
+
from keras_hub.src.tokenizers.byte_pair_tokenizer import convert_to_ragged_batch
|
16
|
+
from keras_hub.src.tokenizers.byte_pair_tokenizer import split_strings_for_bpe
|
17
|
+
|
18
|
+
try:
|
19
|
+
import tensorflow as tf
|
20
|
+
except ImportError:
|
21
|
+
tf = None
|
22
|
+
|
23
|
+
|
24
|
+
class CLIPTokenizer(BytePairTokenizer):
|
25
|
+
def __init__(self, vocabulary=None, merges=None, **kwargs):
|
26
|
+
self.start_token = "<|startoftext|>"
|
27
|
+
self.end_token = "<|endoftext|>"
|
28
|
+
|
29
|
+
super().__init__(
|
30
|
+
vocabulary=vocabulary,
|
31
|
+
merges=merges,
|
32
|
+
unsplittable_tokens=[self.start_token, self.end_token],
|
33
|
+
**kwargs,
|
34
|
+
)
|
35
|
+
|
36
|
+
def set_vocabulary_and_merges(self, vocabulary, merges):
|
37
|
+
super().set_vocabulary_and_merges(vocabulary, merges)
|
38
|
+
|
39
|
+
if vocabulary is not None:
|
40
|
+
# Check for necessary special tokens.
|
41
|
+
if self.end_token not in self.get_vocabulary():
|
42
|
+
raise ValueError(
|
43
|
+
f"Cannot find token `'{self.end_token}'` in the provided "
|
44
|
+
f"`vocabulary`. Please provide `'{self.end_token}'` in "
|
45
|
+
"your `vocabulary` or use a pretrained `vocabulary` name."
|
46
|
+
)
|
47
|
+
|
48
|
+
self.start_token_id = self.token_to_id(self.start_token)
|
49
|
+
self.end_token_id = self.token_to_id(self.end_token)
|
50
|
+
self.pad_token_id = 0
|
51
|
+
else:
|
52
|
+
self.end_token_id = None
|
53
|
+
self.start_token_id = None
|
54
|
+
self.pad_token_id = None
|
55
|
+
|
56
|
+
def _bpe_merge_and_update_cache(self, tokens):
|
57
|
+
"""Process unseen tokens and add to cache."""
|
58
|
+
words = self._transform_bytes(tokens)
|
59
|
+
|
60
|
+
# In StableDiffusionV3, we need to add `</w>` to the last word.
|
61
|
+
words = tf.strings.reduce_join(words, axis=1, separator=" ")
|
62
|
+
words = tf.strings.join([words, "</w>"])
|
63
|
+
words = tf.strings.split(words, sep=" ")
|
64
|
+
|
65
|
+
tokenized_words = self._bpe_merge(words)
|
66
|
+
|
67
|
+
# For each word, join all its token by a whitespace,
|
68
|
+
# e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
|
69
|
+
tokenized_words = tf.strings.reduce_join(
|
70
|
+
tokenized_words, axis=1, separator=" "
|
71
|
+
)
|
72
|
+
self.cache.insert(tokens, tokenized_words)
|
73
|
+
|
74
|
+
def tokenize(self, inputs):
|
75
|
+
self._check_vocabulary()
|
76
|
+
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
|
77
|
+
inputs = tf.convert_to_tensor(inputs)
|
78
|
+
|
79
|
+
if self.add_prefix_space:
|
80
|
+
inputs = tf.strings.join([" ", inputs])
|
81
|
+
|
82
|
+
scalar_input = inputs.shape.rank == 0
|
83
|
+
if scalar_input:
|
84
|
+
inputs = tf.expand_dims(inputs, 0)
|
85
|
+
|
86
|
+
raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
|
87
|
+
|
88
|
+
# Strip and remove empty tokens.
|
89
|
+
raw_tokens = tf.strings.strip(raw_tokens)
|
90
|
+
raw_tokens = tf.ragged.boolean_mask(raw_tokens, raw_tokens != "")
|
91
|
+
|
92
|
+
token_row_splits = raw_tokens.row_splits
|
93
|
+
flat_tokens = raw_tokens.flat_values
|
94
|
+
|
95
|
+
# Check cache.
|
96
|
+
cache_lookup = self.cache.lookup(flat_tokens)
|
97
|
+
cache_mask = cache_lookup == ""
|
98
|
+
|
99
|
+
has_unseen_words = tf.math.reduce_any(
|
100
|
+
(cache_lookup == "") & (flat_tokens != "")
|
101
|
+
)
|
102
|
+
|
103
|
+
def process_unseen_tokens():
|
104
|
+
unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
|
105
|
+
self._bpe_merge_and_update_cache(unseen_tokens)
|
106
|
+
return self.cache.lookup(flat_tokens)
|
107
|
+
|
108
|
+
# If `has_unseen_words == True`, it means not all tokens are in cache,
|
109
|
+
# we will process the unseen tokens. Otherwise return the cache lookup.
|
110
|
+
tokenized_words = tf.cond(
|
111
|
+
has_unseen_words,
|
112
|
+
process_unseen_tokens,
|
113
|
+
lambda: cache_lookup,
|
114
|
+
)
|
115
|
+
|
116
|
+
tokens = tf.strings.split(tokenized_words, sep=" ")
|
117
|
+
if self.compute_dtype != tf.string:
|
118
|
+
# Encode merged tokens.
|
119
|
+
tokens = self.token_to_id_map.lookup(tokens)
|
120
|
+
|
121
|
+
# Unflatten to match input.
|
122
|
+
tokens = tf.RaggedTensor.from_row_splits(
|
123
|
+
tokens.flat_values,
|
124
|
+
tf.gather(tokens.row_splits, token_row_splits),
|
125
|
+
)
|
126
|
+
|
127
|
+
# Convert to a dense output if `sequence_length` is set.
|
128
|
+
if self.sequence_length:
|
129
|
+
output_shape = tokens.shape.as_list()
|
130
|
+
output_shape[-1] = self.sequence_length
|
131
|
+
tokens = tokens.to_tensor(shape=output_shape)
|
132
|
+
|
133
|
+
# Convert to a dense output if input in scalar
|
134
|
+
if scalar_input:
|
135
|
+
tokens = tf.squeeze(tokens, 0)
|
136
|
+
tf.ensure_shape(tokens, shape=[self.sequence_length])
|
137
|
+
|
138
|
+
return tokens
|
139
|
+
|
140
|
+
def detokenize(self, inputs):
|
141
|
+
self._check_vocabulary()
|
142
|
+
inputs, unbatched, _ = convert_to_ragged_batch(inputs)
|
143
|
+
inputs = tf.cast(inputs, self.dtype)
|
144
|
+
unicode_text = tf.strings.reduce_join(
|
145
|
+
self.id_to_token_map.lookup(inputs), axis=-1
|
146
|
+
)
|
147
|
+
|
148
|
+
# When detokenizing, we need to remove </w> and extra whitespace.
|
149
|
+
unicode_text = tf.strings.regex_replace(unicode_text, r"</w>", " ")
|
150
|
+
unicode_text = tf.strings.strip(unicode_text)
|
151
|
+
|
152
|
+
split_unicode_text = tf.strings.unicode_split(unicode_text, "UTF-8")
|
153
|
+
outputs = tf.strings.reduce_join(
|
154
|
+
self.unicode2byte.lookup(split_unicode_text), axis=-1
|
155
|
+
)
|
156
|
+
|
157
|
+
if unbatched:
|
158
|
+
outputs = tf.squeeze(outputs, 0)
|
159
|
+
return outputs
|
160
|
+
|
161
|
+
def get_config(self):
|
162
|
+
config = super().get_config()
|
163
|
+
# In the constructor, we pass the list of special tokens to the
|
164
|
+
# `unsplittable_tokens` arg of the superclass' constructor. Hence, we
|
165
|
+
# delete it from the config here.
|
166
|
+
del config["unsplittable_tokens"]
|
167
|
+
return config
|
@@ -0,0 +1,427 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import math
|
15
|
+
|
16
|
+
import keras
|
17
|
+
from keras import layers
|
18
|
+
from keras import models
|
19
|
+
from keras import ops
|
20
|
+
|
21
|
+
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
22
|
+
from keras_hub.src.models.stable_diffusion_v3.mmdit_block import MMDiTBlock
|
23
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
24
|
+
|
25
|
+
|
26
|
+
class PatchEmbedding(layers.Layer):
|
27
|
+
def __init__(self, patch_size, hidden_dim, data_format=None, **kwargs):
|
28
|
+
super().__init__(**kwargs)
|
29
|
+
self.patch_size = int(patch_size)
|
30
|
+
self.hidden_dim = int(hidden_dim)
|
31
|
+
data_format = standardize_data_format(data_format)
|
32
|
+
|
33
|
+
self.patch_embedding = layers.Conv2D(
|
34
|
+
hidden_dim,
|
35
|
+
kernel_size=patch_size,
|
36
|
+
strides=patch_size,
|
37
|
+
data_format=data_format,
|
38
|
+
dtype=self.dtype_policy,
|
39
|
+
name="patch_embedding",
|
40
|
+
)
|
41
|
+
|
42
|
+
def build(self, input_shape):
|
43
|
+
self.patch_embedding.build(input_shape)
|
44
|
+
|
45
|
+
def call(self, inputs):
|
46
|
+
x = self.patch_embedding(inputs)
|
47
|
+
x_shape = ops.shape(x)
|
48
|
+
x = ops.reshape(x, (x_shape[0], x_shape[1] * x_shape[2], x_shape[3]))
|
49
|
+
return x
|
50
|
+
|
51
|
+
def get_config(self):
|
52
|
+
config = super().get_config()
|
53
|
+
config.update(
|
54
|
+
{
|
55
|
+
"patch_size": self.patch_size,
|
56
|
+
"hidden_dim": self.hidden_dim,
|
57
|
+
}
|
58
|
+
)
|
59
|
+
return config
|
60
|
+
|
61
|
+
|
62
|
+
class AdjustablePositionEmbedding(PositionEmbedding):
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
height,
|
66
|
+
width,
|
67
|
+
initializer="glorot_uniform",
|
68
|
+
**kwargs,
|
69
|
+
):
|
70
|
+
height = int(height)
|
71
|
+
width = int(width)
|
72
|
+
sequence_length = height * width
|
73
|
+
super().__init__(sequence_length, initializer, **kwargs)
|
74
|
+
self.height = height
|
75
|
+
self.width = width
|
76
|
+
|
77
|
+
def call(self, inputs, height=None, width=None):
|
78
|
+
height = height or self.height
|
79
|
+
width = width or self.width
|
80
|
+
shape = ops.shape(inputs)
|
81
|
+
feature_length = shape[-1]
|
82
|
+
top = ops.floor_divide(self.height - height, 2)
|
83
|
+
left = ops.floor_divide(self.width - width, 2)
|
84
|
+
position_embedding = ops.convert_to_tensor(self.position_embeddings)
|
85
|
+
position_embedding = ops.reshape(
|
86
|
+
position_embedding, (self.height, self.width, feature_length)
|
87
|
+
)
|
88
|
+
position_embedding = ops.slice(
|
89
|
+
position_embedding,
|
90
|
+
(top, left, 0),
|
91
|
+
(height, width, feature_length),
|
92
|
+
)
|
93
|
+
position_embedding = ops.reshape(
|
94
|
+
position_embedding, (height * width, feature_length)
|
95
|
+
)
|
96
|
+
position_embedding = ops.expand_dims(position_embedding, axis=0)
|
97
|
+
return position_embedding
|
98
|
+
|
99
|
+
def compute_output_shape(self, input_shape):
|
100
|
+
return input_shape
|
101
|
+
|
102
|
+
|
103
|
+
class TimestepEmbedding(layers.Layer):
|
104
|
+
def __init__(
|
105
|
+
self, embedding_dim, frequency_dim=256, max_period=10000, **kwargs
|
106
|
+
):
|
107
|
+
super().__init__(**kwargs)
|
108
|
+
self.embedding_dim = int(embedding_dim)
|
109
|
+
self.frequency_dim = int(frequency_dim)
|
110
|
+
self.max_period = float(max_period)
|
111
|
+
self.half_frequency_dim = self.frequency_dim // 2
|
112
|
+
|
113
|
+
self.mlp = models.Sequential(
|
114
|
+
[
|
115
|
+
layers.Dense(
|
116
|
+
embedding_dim, activation="silu", dtype=self.dtype_policy
|
117
|
+
),
|
118
|
+
layers.Dense(
|
119
|
+
embedding_dim, activation=None, dtype=self.dtype_policy
|
120
|
+
),
|
121
|
+
],
|
122
|
+
name="mlp",
|
123
|
+
)
|
124
|
+
|
125
|
+
def build(self, inputs_shape):
|
126
|
+
embedding_shape = list(inputs_shape)[1:]
|
127
|
+
embedding_shape.append(self.frequency_dim)
|
128
|
+
self.mlp.build(embedding_shape)
|
129
|
+
|
130
|
+
def _create_timestep_embedding(self, inputs):
|
131
|
+
compute_dtype = keras.backend.result_type(self.compute_dtype, "float32")
|
132
|
+
x = ops.cast(inputs, compute_dtype)
|
133
|
+
freqs = ops.exp(
|
134
|
+
ops.divide(
|
135
|
+
ops.multiply(
|
136
|
+
-math.log(self.max_period),
|
137
|
+
ops.arange(0, self.half_frequency_dim, dtype="float32"),
|
138
|
+
),
|
139
|
+
self.half_frequency_dim,
|
140
|
+
)
|
141
|
+
)
|
142
|
+
freqs = ops.cast(freqs, compute_dtype)
|
143
|
+
x = ops.multiply(x, ops.expand_dims(freqs, axis=0))
|
144
|
+
embedding = ops.concatenate([ops.cos(x), ops.sin(x)], axis=-1)
|
145
|
+
if self.frequency_dim % 2 != 0:
|
146
|
+
embedding = ops.pad(embedding, [[0, 0], [0, 1]])
|
147
|
+
return ops.cast(embedding, self.compute_dtype)
|
148
|
+
|
149
|
+
def call(self, inputs, training=None):
|
150
|
+
timestep_embedding = self._create_timestep_embedding(inputs)
|
151
|
+
return self.mlp(timestep_embedding, training=training)
|
152
|
+
|
153
|
+
def get_config(self):
|
154
|
+
config = super().get_config()
|
155
|
+
config.update(
|
156
|
+
{
|
157
|
+
"embedding_dim": self.embedding_dim,
|
158
|
+
"max_period": self.max_period,
|
159
|
+
}
|
160
|
+
)
|
161
|
+
return config
|
162
|
+
|
163
|
+
def compute_output_shape(self, inputs_shape):
|
164
|
+
output_shape = list(inputs_shape)[1:]
|
165
|
+
output_shape.append(self.embedding_dim)
|
166
|
+
return output_shape
|
167
|
+
|
168
|
+
|
169
|
+
class OutputLayer(layers.Layer):
|
170
|
+
def __init__(self, hidden_dim, output_dim, **kwargs):
|
171
|
+
super().__init__(**kwargs)
|
172
|
+
self.hidden_dim = hidden_dim
|
173
|
+
self.output_dim = output_dim
|
174
|
+
num_modulation = 2
|
175
|
+
|
176
|
+
self.adaptive_norm_modulation = models.Sequential(
|
177
|
+
[
|
178
|
+
layers.Activation("silu", dtype=self.dtype_policy),
|
179
|
+
layers.Dense(
|
180
|
+
num_modulation * hidden_dim, dtype=self.dtype_policy
|
181
|
+
),
|
182
|
+
],
|
183
|
+
name="adaptive_norm_modulation",
|
184
|
+
)
|
185
|
+
self.norm = layers.LayerNormalization(
|
186
|
+
epsilon=1e-6,
|
187
|
+
center=False,
|
188
|
+
scale=False,
|
189
|
+
dtype=self.dtype_policy,
|
190
|
+
name="norm",
|
191
|
+
)
|
192
|
+
self.output_dense = layers.Dense(
|
193
|
+
output_dim, # patch_size ** 2 * input_channels
|
194
|
+
use_bias=True,
|
195
|
+
dtype=self.dtype_policy,
|
196
|
+
name="output_dense",
|
197
|
+
)
|
198
|
+
|
199
|
+
def build(self, inputs_shape, timestep_embedding_shape):
|
200
|
+
self.adaptive_norm_modulation.build(timestep_embedding_shape)
|
201
|
+
self.norm.build(inputs_shape)
|
202
|
+
self.output_dense.build(inputs_shape)
|
203
|
+
|
204
|
+
def _modulate(self, inputs, shift, scale):
|
205
|
+
shift = ops.expand_dims(shift, axis=1)
|
206
|
+
scale = ops.expand_dims(scale, axis=1)
|
207
|
+
return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
|
208
|
+
|
209
|
+
def call(self, inputs, timestep_embedding, training=None):
|
210
|
+
x = inputs
|
211
|
+
modulation = self.adaptive_norm_modulation(
|
212
|
+
timestep_embedding, training=training
|
213
|
+
)
|
214
|
+
modulation = ops.reshape(modulation, (-1, 2, self.hidden_dim))
|
215
|
+
shift, scale = ops.unstack(modulation, 2, axis=1)
|
216
|
+
x = self._modulate(self.norm(x), shift, scale)
|
217
|
+
x = self.output_dense(x, training=training)
|
218
|
+
return x
|
219
|
+
|
220
|
+
def get_config(self):
|
221
|
+
config = super().get_config()
|
222
|
+
config.update(
|
223
|
+
{
|
224
|
+
"hidden_dim": self.hidden_dim,
|
225
|
+
"output_dim": self.output_dim,
|
226
|
+
}
|
227
|
+
)
|
228
|
+
return config
|
229
|
+
|
230
|
+
|
231
|
+
class Unpatch(layers.Layer):
|
232
|
+
def __init__(self, patch_size, output_dim, **kwargs):
|
233
|
+
super().__init__(**kwargs)
|
234
|
+
self.patch_size = int(patch_size)
|
235
|
+
self.output_dim = int(output_dim)
|
236
|
+
|
237
|
+
def call(self, inputs, height, width):
|
238
|
+
patch_size = self.patch_size
|
239
|
+
output_dim = self.output_dim
|
240
|
+
x = ops.reshape(
|
241
|
+
inputs,
|
242
|
+
(-1, height, width, patch_size, patch_size, output_dim),
|
243
|
+
)
|
244
|
+
# (b, h, w, p1, p2, o) -> (b, h, p1, w, p2, o)
|
245
|
+
x = ops.transpose(x, (0, 1, 3, 2, 4, 5))
|
246
|
+
return ops.reshape(
|
247
|
+
x,
|
248
|
+
(-1, height * patch_size, width * patch_size, output_dim),
|
249
|
+
)
|
250
|
+
|
251
|
+
def get_config(self):
|
252
|
+
config = super().get_config()
|
253
|
+
config.update(
|
254
|
+
{
|
255
|
+
"patch_size": self.patch_size,
|
256
|
+
"output_dim": self.output_dim,
|
257
|
+
}
|
258
|
+
)
|
259
|
+
return config
|
260
|
+
|
261
|
+
def compute_output_shape(self, inputs_shape):
|
262
|
+
inputs_shape = list(inputs_shape)
|
263
|
+
return [inputs_shape[0], None, None, self.output_dim]
|
264
|
+
|
265
|
+
|
266
|
+
class MMDiT(keras.Model):
|
267
|
+
def __init__(
|
268
|
+
self,
|
269
|
+
patch_size,
|
270
|
+
num_heads,
|
271
|
+
hidden_dim,
|
272
|
+
depth,
|
273
|
+
position_size,
|
274
|
+
output_dim,
|
275
|
+
mlp_ratio=4.0,
|
276
|
+
latent_shape=(64, 64, 16),
|
277
|
+
context_shape=(1024, 4096),
|
278
|
+
pooled_projection_shape=(2048,),
|
279
|
+
data_format=None,
|
280
|
+
dtype=None,
|
281
|
+
**kwargs,
|
282
|
+
):
|
283
|
+
if None in latent_shape:
|
284
|
+
raise ValueError(
|
285
|
+
"`latent_shape` must be fully specified. "
|
286
|
+
f"Received: latent_shape={latent_shape}"
|
287
|
+
)
|
288
|
+
image_height = latent_shape[0] // patch_size
|
289
|
+
image_width = latent_shape[1] // patch_size
|
290
|
+
output_dim_in_final = patch_size**2 * output_dim
|
291
|
+
data_format = standardize_data_format(data_format)
|
292
|
+
if data_format != "channels_last":
|
293
|
+
raise NotImplementedError(
|
294
|
+
"Currently only 'channels_last' is supported."
|
295
|
+
)
|
296
|
+
|
297
|
+
# === Layers ===
|
298
|
+
self.patch_embedding = PatchEmbedding(
|
299
|
+
patch_size,
|
300
|
+
hidden_dim,
|
301
|
+
data_format=data_format,
|
302
|
+
dtype=dtype,
|
303
|
+
name="patch_embedding",
|
304
|
+
)
|
305
|
+
self.position_embedding_add = layers.Add(
|
306
|
+
dtype=dtype, name="position_embedding_add"
|
307
|
+
)
|
308
|
+
self.position_embedding = AdjustablePositionEmbedding(
|
309
|
+
position_size, position_size, dtype=dtype, name="position_embedding"
|
310
|
+
)
|
311
|
+
self.context_embedding = layers.Dense(
|
312
|
+
hidden_dim,
|
313
|
+
dtype=dtype,
|
314
|
+
name="context_embedding",
|
315
|
+
)
|
316
|
+
self.vector_embedding = models.Sequential(
|
317
|
+
[
|
318
|
+
layers.Dense(hidden_dim, activation="silu", dtype=dtype),
|
319
|
+
layers.Dense(hidden_dim, activation=None, dtype=dtype),
|
320
|
+
],
|
321
|
+
name="vector_embedding",
|
322
|
+
)
|
323
|
+
self.vector_embedding_add = layers.Add(
|
324
|
+
dtype=dtype, name="vector_embedding_add"
|
325
|
+
)
|
326
|
+
self.timestep_embedding = TimestepEmbedding(
|
327
|
+
hidden_dim, dtype=dtype, name="timestep_embedding"
|
328
|
+
)
|
329
|
+
self.joint_blocks = [
|
330
|
+
MMDiTBlock(
|
331
|
+
num_heads,
|
332
|
+
hidden_dim,
|
333
|
+
mlp_ratio,
|
334
|
+
use_context_projection=not (i == depth - 1),
|
335
|
+
dtype=dtype,
|
336
|
+
name=f"joint_block_{i}",
|
337
|
+
)
|
338
|
+
for i in range(depth)
|
339
|
+
]
|
340
|
+
self.output_layer = OutputLayer(
|
341
|
+
hidden_dim, output_dim_in_final, dtype=dtype, name="output_layer"
|
342
|
+
)
|
343
|
+
self.unpatch = Unpatch(
|
344
|
+
patch_size, output_dim, dtype=dtype, name="unpatch"
|
345
|
+
)
|
346
|
+
|
347
|
+
# === Functional Model ===
|
348
|
+
latent_inputs = layers.Input(shape=latent_shape, name="latent")
|
349
|
+
context_inputs = layers.Input(shape=context_shape, name="context")
|
350
|
+
pooled_projection_inputs = layers.Input(
|
351
|
+
shape=pooled_projection_shape, name="pooled_projection"
|
352
|
+
)
|
353
|
+
timestep_inputs = layers.Input(shape=(1,), name="timestep")
|
354
|
+
|
355
|
+
# Embeddings.
|
356
|
+
x = self.patch_embedding(latent_inputs)
|
357
|
+
position_embedding = self.position_embedding(
|
358
|
+
x, height=image_height, width=image_width
|
359
|
+
)
|
360
|
+
x = self.position_embedding_add([x, position_embedding])
|
361
|
+
context = self.context_embedding(context_inputs)
|
362
|
+
pooled_projection = self.vector_embedding(pooled_projection_inputs)
|
363
|
+
timestep_embedding = self.timestep_embedding(timestep_inputs)
|
364
|
+
timestep_embedding = self.vector_embedding_add(
|
365
|
+
[timestep_embedding, pooled_projection]
|
366
|
+
)
|
367
|
+
|
368
|
+
# Blocks.
|
369
|
+
for block in self.joint_blocks:
|
370
|
+
if block.use_context_projection:
|
371
|
+
x, context = block(x, context, timestep_embedding)
|
372
|
+
else:
|
373
|
+
x = block(x, context, timestep_embedding)
|
374
|
+
|
375
|
+
# Output layer.
|
376
|
+
x = self.output_layer(x, timestep_embedding)
|
377
|
+
outputs = self.unpatch(x, height=image_height, width=image_width)
|
378
|
+
|
379
|
+
super().__init__(
|
380
|
+
inputs={
|
381
|
+
"latent": latent_inputs,
|
382
|
+
"context": context_inputs,
|
383
|
+
"pooled_projection": pooled_projection_inputs,
|
384
|
+
"timestep": timestep_inputs,
|
385
|
+
},
|
386
|
+
outputs=outputs,
|
387
|
+
**kwargs,
|
388
|
+
)
|
389
|
+
|
390
|
+
# === Config ===
|
391
|
+
self.patch_size = patch_size
|
392
|
+
self.num_heads = num_heads
|
393
|
+
self.hidden_dim = hidden_dim
|
394
|
+
self.depth = depth
|
395
|
+
self.position_size = position_size
|
396
|
+
self.output_dim = output_dim
|
397
|
+
self.mlp_ratio = mlp_ratio
|
398
|
+
self.latent_shape = latent_shape
|
399
|
+
self.context_shape = context_shape
|
400
|
+
self.pooled_projection_shape = pooled_projection_shape
|
401
|
+
|
402
|
+
if dtype is not None:
|
403
|
+
try:
|
404
|
+
self.dtype_policy = keras.dtype_policies.get(dtype)
|
405
|
+
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
|
406
|
+
except AttributeError:
|
407
|
+
if isinstance(dtype, keras.DTypePolicy):
|
408
|
+
dtype = dtype.name
|
409
|
+
self.dtype_policy = keras.DTypePolicy(dtype)
|
410
|
+
|
411
|
+
def get_config(self):
|
412
|
+
config = super().get_config()
|
413
|
+
config.update(
|
414
|
+
{
|
415
|
+
"patch_size": self.patch_size,
|
416
|
+
"num_heads": self.num_heads,
|
417
|
+
"hidden_dim": self.hidden_dim,
|
418
|
+
"depth": self.depth,
|
419
|
+
"position_size": self.position_size,
|
420
|
+
"output_dim": self.output_dim,
|
421
|
+
"mlp_ratio": self.mlp_ratio,
|
422
|
+
"latent_shape": self.latent_shape,
|
423
|
+
"context_shape": self.context_shape,
|
424
|
+
"pooled_projection_shape": self.pooled_projection_shape,
|
425
|
+
}
|
426
|
+
)
|
427
|
+
return config
|