keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev20240915160609__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. keras_hub/api/__init__.py +1 -0
  2. keras_hub/api/bounding_box/__init__.py +36 -0
  3. keras_hub/api/layers/__init__.py +14 -0
  4. keras_hub/api/models/__init__.py +97 -48
  5. keras_hub/api/tokenizers/__init__.py +30 -0
  6. keras_hub/src/bounding_box/__init__.py +13 -0
  7. keras_hub/src/bounding_box/converters.py +529 -0
  8. keras_hub/src/bounding_box/formats.py +162 -0
  9. keras_hub/src/bounding_box/iou.py +263 -0
  10. keras_hub/src/bounding_box/to_dense.py +95 -0
  11. keras_hub/src/bounding_box/to_ragged.py +99 -0
  12. keras_hub/src/bounding_box/utils.py +194 -0
  13. keras_hub/src/bounding_box/validate_format.py +99 -0
  14. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  15. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  16. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  17. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  18. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  19. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  20. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  21. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  22. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  23. keras_hub/src/models/albert/__init__.py +1 -2
  24. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  25. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  26. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  27. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  28. keras_hub/src/models/backbone.py +12 -34
  29. keras_hub/src/models/bart/__init__.py +1 -2
  30. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  31. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  32. keras_hub/src/models/bert/__init__.py +1 -5
  33. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  34. keras_hub/src/models/bert/bert_presets.py +1 -4
  35. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  36. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  37. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  38. keras_hub/src/models/bloom/__init__.py +1 -2
  39. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  40. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  41. keras_hub/src/models/causal_lm.py +10 -29
  42. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  43. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  44. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  45. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  46. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  47. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  48. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  49. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  50. keras_hub/src/models/distil_bert/__init__.py +1 -4
  51. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  52. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  53. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  54. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  55. keras_hub/src/models/efficientnet/__init__.py +13 -0
  56. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  57. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  58. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  59. keras_hub/src/models/electra/__init__.py +1 -2
  60. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  61. keras_hub/src/models/f_net/__init__.py +1 -2
  62. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  63. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  64. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  65. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  66. keras_hub/src/models/falcon/__init__.py +1 -2
  67. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  68. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  69. keras_hub/src/models/gemma/__init__.py +1 -2
  70. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  71. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  72. keras_hub/src/models/gpt2/__init__.py +1 -2
  73. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  74. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  75. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  76. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  77. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  78. keras_hub/src/models/image_classifier.py +0 -5
  79. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  80. keras_hub/src/models/llama/__init__.py +1 -2
  81. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  82. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  83. keras_hub/src/models/llama3/__init__.py +1 -2
  84. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  85. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  86. keras_hub/src/models/masked_lm.py +0 -2
  87. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  88. keras_hub/src/models/mistral/__init__.py +1 -2
  89. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  90. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  91. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  92. keras_hub/src/models/mobilenet/__init__.py +13 -0
  93. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  94. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  95. keras_hub/src/models/opt/__init__.py +1 -2
  96. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  97. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  98. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  99. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  100. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  101. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  102. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  103. keras_hub/src/models/phi3/__init__.py +1 -2
  104. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  105. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  106. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  107. keras_hub/src/models/preprocessor.py +72 -83
  108. keras_hub/src/models/resnet/__init__.py +6 -0
  109. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  110. keras_hub/src/models/resnet/resnet_image_classifier.py +24 -3
  111. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  112. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  113. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  114. keras_hub/src/models/roberta/__init__.py +1 -2
  115. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  116. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  117. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  118. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  119. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  120. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  121. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  122. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  123. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  124. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  125. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  126. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  127. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  128. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  129. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  130. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  131. keras_hub/src/models/t5/__init__.py +1 -2
  132. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  133. keras_hub/src/models/task.py +71 -116
  134. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  135. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  136. keras_hub/src/models/whisper/__init__.py +1 -2
  137. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  138. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  139. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  140. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  141. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  142. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  143. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  144. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  145. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  146. keras_hub/src/tests/test_case.py +38 -0
  147. keras_hub/src/tokenizers/byte_pair_tokenizer.py +29 -17
  148. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  149. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +19 -7
  150. keras_hub/src/tokenizers/tokenizer.py +67 -32
  151. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  152. keras_hub/src/tokenizers/word_piece_tokenizer.py +33 -47
  153. keras_hub/src/utils/keras_utils.py +0 -50
  154. keras_hub/src/utils/preset_utils.py +220 -67
  155. keras_hub/src/utils/tensor_utils.py +187 -69
  156. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  157. keras_hub/src/utils/timm/preset_loader.py +66 -0
  158. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  159. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  160. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  161. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  162. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  163. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  164. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  165. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  166. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  167. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  168. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  169. keras_hub/src/version_utils.py +1 -1
  170. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/METADATA +1 -2
  171. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/RECORD +173 -143
  172. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/WHEEL +1 -1
  173. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  174. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  175. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  176. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  177. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  178. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  179. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  180. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  181. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  182. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  183. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  184. keras_hub/src/utils/timm/convert.py +0 -37
  185. keras_hub/src/utils/transformers/convert.py +0 -101
  186. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,167 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
15
+ from keras_hub.src.tokenizers.byte_pair_tokenizer import convert_to_ragged_batch
16
+ from keras_hub.src.tokenizers.byte_pair_tokenizer import split_strings_for_bpe
17
+
18
+ try:
19
+ import tensorflow as tf
20
+ except ImportError:
21
+ tf = None
22
+
23
+
24
+ class CLIPTokenizer(BytePairTokenizer):
25
+ def __init__(self, vocabulary=None, merges=None, **kwargs):
26
+ self.start_token = "<|startoftext|>"
27
+ self.end_token = "<|endoftext|>"
28
+
29
+ super().__init__(
30
+ vocabulary=vocabulary,
31
+ merges=merges,
32
+ unsplittable_tokens=[self.start_token, self.end_token],
33
+ **kwargs,
34
+ )
35
+
36
+ def set_vocabulary_and_merges(self, vocabulary, merges):
37
+ super().set_vocabulary_and_merges(vocabulary, merges)
38
+
39
+ if vocabulary is not None:
40
+ # Check for necessary special tokens.
41
+ if self.end_token not in self.get_vocabulary():
42
+ raise ValueError(
43
+ f"Cannot find token `'{self.end_token}'` in the provided "
44
+ f"`vocabulary`. Please provide `'{self.end_token}'` in "
45
+ "your `vocabulary` or use a pretrained `vocabulary` name."
46
+ )
47
+
48
+ self.start_token_id = self.token_to_id(self.start_token)
49
+ self.end_token_id = self.token_to_id(self.end_token)
50
+ self.pad_token_id = 0
51
+ else:
52
+ self.end_token_id = None
53
+ self.start_token_id = None
54
+ self.pad_token_id = None
55
+
56
+ def _bpe_merge_and_update_cache(self, tokens):
57
+ """Process unseen tokens and add to cache."""
58
+ words = self._transform_bytes(tokens)
59
+
60
+ # In StableDiffusionV3, we need to add `</w>` to the last word.
61
+ words = tf.strings.reduce_join(words, axis=1, separator=" ")
62
+ words = tf.strings.join([words, "</w>"])
63
+ words = tf.strings.split(words, sep=" ")
64
+
65
+ tokenized_words = self._bpe_merge(words)
66
+
67
+ # For each word, join all its token by a whitespace,
68
+ # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
69
+ tokenized_words = tf.strings.reduce_join(
70
+ tokenized_words, axis=1, separator=" "
71
+ )
72
+ self.cache.insert(tokens, tokenized_words)
73
+
74
+ def tokenize(self, inputs):
75
+ self._check_vocabulary()
76
+ if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
77
+ inputs = tf.convert_to_tensor(inputs)
78
+
79
+ if self.add_prefix_space:
80
+ inputs = tf.strings.join([" ", inputs])
81
+
82
+ scalar_input = inputs.shape.rank == 0
83
+ if scalar_input:
84
+ inputs = tf.expand_dims(inputs, 0)
85
+
86
+ raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
87
+
88
+ # Strip and remove empty tokens.
89
+ raw_tokens = tf.strings.strip(raw_tokens)
90
+ raw_tokens = tf.ragged.boolean_mask(raw_tokens, raw_tokens != "")
91
+
92
+ token_row_splits = raw_tokens.row_splits
93
+ flat_tokens = raw_tokens.flat_values
94
+
95
+ # Check cache.
96
+ cache_lookup = self.cache.lookup(flat_tokens)
97
+ cache_mask = cache_lookup == ""
98
+
99
+ has_unseen_words = tf.math.reduce_any(
100
+ (cache_lookup == "") & (flat_tokens != "")
101
+ )
102
+
103
+ def process_unseen_tokens():
104
+ unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
105
+ self._bpe_merge_and_update_cache(unseen_tokens)
106
+ return self.cache.lookup(flat_tokens)
107
+
108
+ # If `has_unseen_words == True`, it means not all tokens are in cache,
109
+ # we will process the unseen tokens. Otherwise return the cache lookup.
110
+ tokenized_words = tf.cond(
111
+ has_unseen_words,
112
+ process_unseen_tokens,
113
+ lambda: cache_lookup,
114
+ )
115
+
116
+ tokens = tf.strings.split(tokenized_words, sep=" ")
117
+ if self.compute_dtype != tf.string:
118
+ # Encode merged tokens.
119
+ tokens = self.token_to_id_map.lookup(tokens)
120
+
121
+ # Unflatten to match input.
122
+ tokens = tf.RaggedTensor.from_row_splits(
123
+ tokens.flat_values,
124
+ tf.gather(tokens.row_splits, token_row_splits),
125
+ )
126
+
127
+ # Convert to a dense output if `sequence_length` is set.
128
+ if self.sequence_length:
129
+ output_shape = tokens.shape.as_list()
130
+ output_shape[-1] = self.sequence_length
131
+ tokens = tokens.to_tensor(shape=output_shape)
132
+
133
+ # Convert to a dense output if input in scalar
134
+ if scalar_input:
135
+ tokens = tf.squeeze(tokens, 0)
136
+ tf.ensure_shape(tokens, shape=[self.sequence_length])
137
+
138
+ return tokens
139
+
140
+ def detokenize(self, inputs):
141
+ self._check_vocabulary()
142
+ inputs, unbatched, _ = convert_to_ragged_batch(inputs)
143
+ inputs = tf.cast(inputs, self.dtype)
144
+ unicode_text = tf.strings.reduce_join(
145
+ self.id_to_token_map.lookup(inputs), axis=-1
146
+ )
147
+
148
+ # When detokenizing, we need to remove </w> and extra whitespace.
149
+ unicode_text = tf.strings.regex_replace(unicode_text, r"</w>", " ")
150
+ unicode_text = tf.strings.strip(unicode_text)
151
+
152
+ split_unicode_text = tf.strings.unicode_split(unicode_text, "UTF-8")
153
+ outputs = tf.strings.reduce_join(
154
+ self.unicode2byte.lookup(split_unicode_text), axis=-1
155
+ )
156
+
157
+ if unbatched:
158
+ outputs = tf.squeeze(outputs, 0)
159
+ return outputs
160
+
161
+ def get_config(self):
162
+ config = super().get_config()
163
+ # In the constructor, we pass the list of special tokens to the
164
+ # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
165
+ # delete it from the config here.
166
+ del config["unsplittable_tokens"]
167
+ return config
@@ -0,0 +1,427 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ import keras
17
+ from keras import layers
18
+ from keras import models
19
+ from keras import ops
20
+
21
+ from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
22
+ from keras_hub.src.models.stable_diffusion_v3.mmdit_block import MMDiTBlock
23
+ from keras_hub.src.utils.keras_utils import standardize_data_format
24
+
25
+
26
+ class PatchEmbedding(layers.Layer):
27
+ def __init__(self, patch_size, hidden_dim, data_format=None, **kwargs):
28
+ super().__init__(**kwargs)
29
+ self.patch_size = int(patch_size)
30
+ self.hidden_dim = int(hidden_dim)
31
+ data_format = standardize_data_format(data_format)
32
+
33
+ self.patch_embedding = layers.Conv2D(
34
+ hidden_dim,
35
+ kernel_size=patch_size,
36
+ strides=patch_size,
37
+ data_format=data_format,
38
+ dtype=self.dtype_policy,
39
+ name="patch_embedding",
40
+ )
41
+
42
+ def build(self, input_shape):
43
+ self.patch_embedding.build(input_shape)
44
+
45
+ def call(self, inputs):
46
+ x = self.patch_embedding(inputs)
47
+ x_shape = ops.shape(x)
48
+ x = ops.reshape(x, (x_shape[0], x_shape[1] * x_shape[2], x_shape[3]))
49
+ return x
50
+
51
+ def get_config(self):
52
+ config = super().get_config()
53
+ config.update(
54
+ {
55
+ "patch_size": self.patch_size,
56
+ "hidden_dim": self.hidden_dim,
57
+ }
58
+ )
59
+ return config
60
+
61
+
62
+ class AdjustablePositionEmbedding(PositionEmbedding):
63
+ def __init__(
64
+ self,
65
+ height,
66
+ width,
67
+ initializer="glorot_uniform",
68
+ **kwargs,
69
+ ):
70
+ height = int(height)
71
+ width = int(width)
72
+ sequence_length = height * width
73
+ super().__init__(sequence_length, initializer, **kwargs)
74
+ self.height = height
75
+ self.width = width
76
+
77
+ def call(self, inputs, height=None, width=None):
78
+ height = height or self.height
79
+ width = width or self.width
80
+ shape = ops.shape(inputs)
81
+ feature_length = shape[-1]
82
+ top = ops.floor_divide(self.height - height, 2)
83
+ left = ops.floor_divide(self.width - width, 2)
84
+ position_embedding = ops.convert_to_tensor(self.position_embeddings)
85
+ position_embedding = ops.reshape(
86
+ position_embedding, (self.height, self.width, feature_length)
87
+ )
88
+ position_embedding = ops.slice(
89
+ position_embedding,
90
+ (top, left, 0),
91
+ (height, width, feature_length),
92
+ )
93
+ position_embedding = ops.reshape(
94
+ position_embedding, (height * width, feature_length)
95
+ )
96
+ position_embedding = ops.expand_dims(position_embedding, axis=0)
97
+ return position_embedding
98
+
99
+ def compute_output_shape(self, input_shape):
100
+ return input_shape
101
+
102
+
103
+ class TimestepEmbedding(layers.Layer):
104
+ def __init__(
105
+ self, embedding_dim, frequency_dim=256, max_period=10000, **kwargs
106
+ ):
107
+ super().__init__(**kwargs)
108
+ self.embedding_dim = int(embedding_dim)
109
+ self.frequency_dim = int(frequency_dim)
110
+ self.max_period = float(max_period)
111
+ self.half_frequency_dim = self.frequency_dim // 2
112
+
113
+ self.mlp = models.Sequential(
114
+ [
115
+ layers.Dense(
116
+ embedding_dim, activation="silu", dtype=self.dtype_policy
117
+ ),
118
+ layers.Dense(
119
+ embedding_dim, activation=None, dtype=self.dtype_policy
120
+ ),
121
+ ],
122
+ name="mlp",
123
+ )
124
+
125
+ def build(self, inputs_shape):
126
+ embedding_shape = list(inputs_shape)[1:]
127
+ embedding_shape.append(self.frequency_dim)
128
+ self.mlp.build(embedding_shape)
129
+
130
+ def _create_timestep_embedding(self, inputs):
131
+ compute_dtype = keras.backend.result_type(self.compute_dtype, "float32")
132
+ x = ops.cast(inputs, compute_dtype)
133
+ freqs = ops.exp(
134
+ ops.divide(
135
+ ops.multiply(
136
+ -math.log(self.max_period),
137
+ ops.arange(0, self.half_frequency_dim, dtype="float32"),
138
+ ),
139
+ self.half_frequency_dim,
140
+ )
141
+ )
142
+ freqs = ops.cast(freqs, compute_dtype)
143
+ x = ops.multiply(x, ops.expand_dims(freqs, axis=0))
144
+ embedding = ops.concatenate([ops.cos(x), ops.sin(x)], axis=-1)
145
+ if self.frequency_dim % 2 != 0:
146
+ embedding = ops.pad(embedding, [[0, 0], [0, 1]])
147
+ return ops.cast(embedding, self.compute_dtype)
148
+
149
+ def call(self, inputs, training=None):
150
+ timestep_embedding = self._create_timestep_embedding(inputs)
151
+ return self.mlp(timestep_embedding, training=training)
152
+
153
+ def get_config(self):
154
+ config = super().get_config()
155
+ config.update(
156
+ {
157
+ "embedding_dim": self.embedding_dim,
158
+ "max_period": self.max_period,
159
+ }
160
+ )
161
+ return config
162
+
163
+ def compute_output_shape(self, inputs_shape):
164
+ output_shape = list(inputs_shape)[1:]
165
+ output_shape.append(self.embedding_dim)
166
+ return output_shape
167
+
168
+
169
+ class OutputLayer(layers.Layer):
170
+ def __init__(self, hidden_dim, output_dim, **kwargs):
171
+ super().__init__(**kwargs)
172
+ self.hidden_dim = hidden_dim
173
+ self.output_dim = output_dim
174
+ num_modulation = 2
175
+
176
+ self.adaptive_norm_modulation = models.Sequential(
177
+ [
178
+ layers.Activation("silu", dtype=self.dtype_policy),
179
+ layers.Dense(
180
+ num_modulation * hidden_dim, dtype=self.dtype_policy
181
+ ),
182
+ ],
183
+ name="adaptive_norm_modulation",
184
+ )
185
+ self.norm = layers.LayerNormalization(
186
+ epsilon=1e-6,
187
+ center=False,
188
+ scale=False,
189
+ dtype=self.dtype_policy,
190
+ name="norm",
191
+ )
192
+ self.output_dense = layers.Dense(
193
+ output_dim, # patch_size ** 2 * input_channels
194
+ use_bias=True,
195
+ dtype=self.dtype_policy,
196
+ name="output_dense",
197
+ )
198
+
199
+ def build(self, inputs_shape, timestep_embedding_shape):
200
+ self.adaptive_norm_modulation.build(timestep_embedding_shape)
201
+ self.norm.build(inputs_shape)
202
+ self.output_dense.build(inputs_shape)
203
+
204
+ def _modulate(self, inputs, shift, scale):
205
+ shift = ops.expand_dims(shift, axis=1)
206
+ scale = ops.expand_dims(scale, axis=1)
207
+ return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
208
+
209
+ def call(self, inputs, timestep_embedding, training=None):
210
+ x = inputs
211
+ modulation = self.adaptive_norm_modulation(
212
+ timestep_embedding, training=training
213
+ )
214
+ modulation = ops.reshape(modulation, (-1, 2, self.hidden_dim))
215
+ shift, scale = ops.unstack(modulation, 2, axis=1)
216
+ x = self._modulate(self.norm(x), shift, scale)
217
+ x = self.output_dense(x, training=training)
218
+ return x
219
+
220
+ def get_config(self):
221
+ config = super().get_config()
222
+ config.update(
223
+ {
224
+ "hidden_dim": self.hidden_dim,
225
+ "output_dim": self.output_dim,
226
+ }
227
+ )
228
+ return config
229
+
230
+
231
+ class Unpatch(layers.Layer):
232
+ def __init__(self, patch_size, output_dim, **kwargs):
233
+ super().__init__(**kwargs)
234
+ self.patch_size = int(patch_size)
235
+ self.output_dim = int(output_dim)
236
+
237
+ def call(self, inputs, height, width):
238
+ patch_size = self.patch_size
239
+ output_dim = self.output_dim
240
+ x = ops.reshape(
241
+ inputs,
242
+ (-1, height, width, patch_size, patch_size, output_dim),
243
+ )
244
+ # (b, h, w, p1, p2, o) -> (b, h, p1, w, p2, o)
245
+ x = ops.transpose(x, (0, 1, 3, 2, 4, 5))
246
+ return ops.reshape(
247
+ x,
248
+ (-1, height * patch_size, width * patch_size, output_dim),
249
+ )
250
+
251
+ def get_config(self):
252
+ config = super().get_config()
253
+ config.update(
254
+ {
255
+ "patch_size": self.patch_size,
256
+ "output_dim": self.output_dim,
257
+ }
258
+ )
259
+ return config
260
+
261
+ def compute_output_shape(self, inputs_shape):
262
+ inputs_shape = list(inputs_shape)
263
+ return [inputs_shape[0], None, None, self.output_dim]
264
+
265
+
266
+ class MMDiT(keras.Model):
267
+ def __init__(
268
+ self,
269
+ patch_size,
270
+ num_heads,
271
+ hidden_dim,
272
+ depth,
273
+ position_size,
274
+ output_dim,
275
+ mlp_ratio=4.0,
276
+ latent_shape=(64, 64, 16),
277
+ context_shape=(1024, 4096),
278
+ pooled_projection_shape=(2048,),
279
+ data_format=None,
280
+ dtype=None,
281
+ **kwargs,
282
+ ):
283
+ if None in latent_shape:
284
+ raise ValueError(
285
+ "`latent_shape` must be fully specified. "
286
+ f"Received: latent_shape={latent_shape}"
287
+ )
288
+ image_height = latent_shape[0] // patch_size
289
+ image_width = latent_shape[1] // patch_size
290
+ output_dim_in_final = patch_size**2 * output_dim
291
+ data_format = standardize_data_format(data_format)
292
+ if data_format != "channels_last":
293
+ raise NotImplementedError(
294
+ "Currently only 'channels_last' is supported."
295
+ )
296
+
297
+ # === Layers ===
298
+ self.patch_embedding = PatchEmbedding(
299
+ patch_size,
300
+ hidden_dim,
301
+ data_format=data_format,
302
+ dtype=dtype,
303
+ name="patch_embedding",
304
+ )
305
+ self.position_embedding_add = layers.Add(
306
+ dtype=dtype, name="position_embedding_add"
307
+ )
308
+ self.position_embedding = AdjustablePositionEmbedding(
309
+ position_size, position_size, dtype=dtype, name="position_embedding"
310
+ )
311
+ self.context_embedding = layers.Dense(
312
+ hidden_dim,
313
+ dtype=dtype,
314
+ name="context_embedding",
315
+ )
316
+ self.vector_embedding = models.Sequential(
317
+ [
318
+ layers.Dense(hidden_dim, activation="silu", dtype=dtype),
319
+ layers.Dense(hidden_dim, activation=None, dtype=dtype),
320
+ ],
321
+ name="vector_embedding",
322
+ )
323
+ self.vector_embedding_add = layers.Add(
324
+ dtype=dtype, name="vector_embedding_add"
325
+ )
326
+ self.timestep_embedding = TimestepEmbedding(
327
+ hidden_dim, dtype=dtype, name="timestep_embedding"
328
+ )
329
+ self.joint_blocks = [
330
+ MMDiTBlock(
331
+ num_heads,
332
+ hidden_dim,
333
+ mlp_ratio,
334
+ use_context_projection=not (i == depth - 1),
335
+ dtype=dtype,
336
+ name=f"joint_block_{i}",
337
+ )
338
+ for i in range(depth)
339
+ ]
340
+ self.output_layer = OutputLayer(
341
+ hidden_dim, output_dim_in_final, dtype=dtype, name="output_layer"
342
+ )
343
+ self.unpatch = Unpatch(
344
+ patch_size, output_dim, dtype=dtype, name="unpatch"
345
+ )
346
+
347
+ # === Functional Model ===
348
+ latent_inputs = layers.Input(shape=latent_shape, name="latent")
349
+ context_inputs = layers.Input(shape=context_shape, name="context")
350
+ pooled_projection_inputs = layers.Input(
351
+ shape=pooled_projection_shape, name="pooled_projection"
352
+ )
353
+ timestep_inputs = layers.Input(shape=(1,), name="timestep")
354
+
355
+ # Embeddings.
356
+ x = self.patch_embedding(latent_inputs)
357
+ position_embedding = self.position_embedding(
358
+ x, height=image_height, width=image_width
359
+ )
360
+ x = self.position_embedding_add([x, position_embedding])
361
+ context = self.context_embedding(context_inputs)
362
+ pooled_projection = self.vector_embedding(pooled_projection_inputs)
363
+ timestep_embedding = self.timestep_embedding(timestep_inputs)
364
+ timestep_embedding = self.vector_embedding_add(
365
+ [timestep_embedding, pooled_projection]
366
+ )
367
+
368
+ # Blocks.
369
+ for block in self.joint_blocks:
370
+ if block.use_context_projection:
371
+ x, context = block(x, context, timestep_embedding)
372
+ else:
373
+ x = block(x, context, timestep_embedding)
374
+
375
+ # Output layer.
376
+ x = self.output_layer(x, timestep_embedding)
377
+ outputs = self.unpatch(x, height=image_height, width=image_width)
378
+
379
+ super().__init__(
380
+ inputs={
381
+ "latent": latent_inputs,
382
+ "context": context_inputs,
383
+ "pooled_projection": pooled_projection_inputs,
384
+ "timestep": timestep_inputs,
385
+ },
386
+ outputs=outputs,
387
+ **kwargs,
388
+ )
389
+
390
+ # === Config ===
391
+ self.patch_size = patch_size
392
+ self.num_heads = num_heads
393
+ self.hidden_dim = hidden_dim
394
+ self.depth = depth
395
+ self.position_size = position_size
396
+ self.output_dim = output_dim
397
+ self.mlp_ratio = mlp_ratio
398
+ self.latent_shape = latent_shape
399
+ self.context_shape = context_shape
400
+ self.pooled_projection_shape = pooled_projection_shape
401
+
402
+ if dtype is not None:
403
+ try:
404
+ self.dtype_policy = keras.dtype_policies.get(dtype)
405
+ # Before Keras 3.2, there is no `keras.dtype_policies.get`.
406
+ except AttributeError:
407
+ if isinstance(dtype, keras.DTypePolicy):
408
+ dtype = dtype.name
409
+ self.dtype_policy = keras.DTypePolicy(dtype)
410
+
411
+ def get_config(self):
412
+ config = super().get_config()
413
+ config.update(
414
+ {
415
+ "patch_size": self.patch_size,
416
+ "num_heads": self.num_heads,
417
+ "hidden_dim": self.hidden_dim,
418
+ "depth": self.depth,
419
+ "position_size": self.position_size,
420
+ "output_dim": self.output_dim,
421
+ "mlp_ratio": self.mlp_ratio,
422
+ "latent_shape": self.latent_shape,
423
+ "context_shape": self.context_shape,
424
+ "pooled_projection_shape": self.pooled_projection_shape,
425
+ }
426
+ )
427
+ return config