keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,427 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ import keras
17
+ from keras import layers
18
+ from keras import models
19
+ from keras import ops
20
+
21
+ from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
22
+ from keras_hub.src.models.stable_diffusion_v3.mmdit_block import MMDiTBlock
23
+ from keras_hub.src.utils.keras_utils import standardize_data_format
24
+
25
+
26
+ class PatchEmbedding(layers.Layer):
27
+ def __init__(self, patch_size, hidden_dim, data_format=None, **kwargs):
28
+ super().__init__(**kwargs)
29
+ self.patch_size = int(patch_size)
30
+ self.hidden_dim = int(hidden_dim)
31
+ data_format = standardize_data_format(data_format)
32
+
33
+ self.patch_embedding = layers.Conv2D(
34
+ hidden_dim,
35
+ kernel_size=patch_size,
36
+ strides=patch_size,
37
+ data_format=data_format,
38
+ dtype=self.dtype_policy,
39
+ name="patch_embedding",
40
+ )
41
+
42
+ def build(self, input_shape):
43
+ self.patch_embedding.build(input_shape)
44
+
45
+ def call(self, inputs):
46
+ x = self.patch_embedding(inputs)
47
+ x_shape = ops.shape(x)
48
+ x = ops.reshape(x, (x_shape[0], x_shape[1] * x_shape[2], x_shape[3]))
49
+ return x
50
+
51
+ def get_config(self):
52
+ config = super().get_config()
53
+ config.update(
54
+ {
55
+ "patch_size": self.patch_size,
56
+ "hidden_dim": self.hidden_dim,
57
+ }
58
+ )
59
+ return config
60
+
61
+
62
+ class AdjustablePositionEmbedding(PositionEmbedding):
63
+ def __init__(
64
+ self,
65
+ height,
66
+ width,
67
+ initializer="glorot_uniform",
68
+ **kwargs,
69
+ ):
70
+ height = int(height)
71
+ width = int(width)
72
+ sequence_length = height * width
73
+ super().__init__(sequence_length, initializer, **kwargs)
74
+ self.height = height
75
+ self.width = width
76
+
77
+ def call(self, inputs, height=None, width=None):
78
+ height = height or self.height
79
+ width = width or self.width
80
+ shape = ops.shape(inputs)
81
+ feature_length = shape[-1]
82
+ top = ops.floor_divide(self.height - height, 2)
83
+ left = ops.floor_divide(self.width - width, 2)
84
+ position_embedding = ops.convert_to_tensor(self.position_embeddings)
85
+ position_embedding = ops.reshape(
86
+ position_embedding, (self.height, self.width, feature_length)
87
+ )
88
+ position_embedding = ops.slice(
89
+ position_embedding,
90
+ (top, left, 0),
91
+ (height, width, feature_length),
92
+ )
93
+ position_embedding = ops.reshape(
94
+ position_embedding, (height * width, feature_length)
95
+ )
96
+ position_embedding = ops.expand_dims(position_embedding, axis=0)
97
+ return position_embedding
98
+
99
+ def compute_output_shape(self, input_shape):
100
+ return input_shape
101
+
102
+
103
+ class TimestepEmbedding(layers.Layer):
104
+ def __init__(
105
+ self, embedding_dim, frequency_dim=256, max_period=10000, **kwargs
106
+ ):
107
+ super().__init__(**kwargs)
108
+ self.embedding_dim = int(embedding_dim)
109
+ self.frequency_dim = int(frequency_dim)
110
+ self.max_period = float(max_period)
111
+ self.half_frequency_dim = self.frequency_dim // 2
112
+
113
+ self.mlp = models.Sequential(
114
+ [
115
+ layers.Dense(
116
+ embedding_dim, activation="silu", dtype=self.dtype_policy
117
+ ),
118
+ layers.Dense(
119
+ embedding_dim, activation=None, dtype=self.dtype_policy
120
+ ),
121
+ ],
122
+ name="mlp",
123
+ )
124
+
125
+ def build(self, inputs_shape):
126
+ embedding_shape = list(inputs_shape)[1:]
127
+ embedding_shape.append(self.frequency_dim)
128
+ self.mlp.build(embedding_shape)
129
+
130
+ def _create_timestep_embedding(self, inputs):
131
+ compute_dtype = keras.backend.result_type(self.compute_dtype, "float32")
132
+ x = ops.cast(inputs, compute_dtype)
133
+ freqs = ops.exp(
134
+ ops.divide(
135
+ ops.multiply(
136
+ -math.log(self.max_period),
137
+ ops.arange(0, self.half_frequency_dim, dtype="float32"),
138
+ ),
139
+ self.half_frequency_dim,
140
+ )
141
+ )
142
+ freqs = ops.cast(freqs, compute_dtype)
143
+ x = ops.multiply(x, ops.expand_dims(freqs, axis=0))
144
+ embedding = ops.concatenate([ops.cos(x), ops.sin(x)], axis=-1)
145
+ if self.frequency_dim % 2 != 0:
146
+ embedding = ops.pad(embedding, [[0, 0], [0, 1]])
147
+ return ops.cast(embedding, self.compute_dtype)
148
+
149
+ def call(self, inputs, training=None):
150
+ timestep_embedding = self._create_timestep_embedding(inputs)
151
+ return self.mlp(timestep_embedding, training=training)
152
+
153
+ def get_config(self):
154
+ config = super().get_config()
155
+ config.update(
156
+ {
157
+ "embedding_dim": self.embedding_dim,
158
+ "max_period": self.max_period,
159
+ }
160
+ )
161
+ return config
162
+
163
+ def compute_output_shape(self, inputs_shape):
164
+ output_shape = list(inputs_shape)[1:]
165
+ output_shape.append(self.embedding_dim)
166
+ return output_shape
167
+
168
+
169
+ class OutputLayer(layers.Layer):
170
+ def __init__(self, hidden_dim, output_dim, **kwargs):
171
+ super().__init__(**kwargs)
172
+ self.hidden_dim = hidden_dim
173
+ self.output_dim = output_dim
174
+ num_modulation = 2
175
+
176
+ self.adaptive_norm_modulation = models.Sequential(
177
+ [
178
+ layers.Activation("silu", dtype=self.dtype_policy),
179
+ layers.Dense(
180
+ num_modulation * hidden_dim, dtype=self.dtype_policy
181
+ ),
182
+ ],
183
+ name="adaptive_norm_modulation",
184
+ )
185
+ self.norm = layers.LayerNormalization(
186
+ epsilon=1e-6,
187
+ center=False,
188
+ scale=False,
189
+ dtype=self.dtype_policy,
190
+ name="norm",
191
+ )
192
+ self.output_dense = layers.Dense(
193
+ output_dim, # patch_size ** 2 * input_channels
194
+ use_bias=True,
195
+ dtype=self.dtype_policy,
196
+ name="output_dense",
197
+ )
198
+
199
+ def build(self, inputs_shape, timestep_embedding_shape):
200
+ self.adaptive_norm_modulation.build(timestep_embedding_shape)
201
+ self.norm.build(inputs_shape)
202
+ self.output_dense.build(inputs_shape)
203
+
204
+ def _modulate(self, inputs, shift, scale):
205
+ shift = ops.expand_dims(shift, axis=1)
206
+ scale = ops.expand_dims(scale, axis=1)
207
+ return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
208
+
209
+ def call(self, inputs, timestep_embedding, training=None):
210
+ x = inputs
211
+ modulation = self.adaptive_norm_modulation(
212
+ timestep_embedding, training=training
213
+ )
214
+ modulation = ops.reshape(modulation, (-1, 2, self.hidden_dim))
215
+ shift, scale = ops.unstack(modulation, 2, axis=1)
216
+ x = self._modulate(self.norm(x), shift, scale)
217
+ x = self.output_dense(x, training=training)
218
+ return x
219
+
220
+ def get_config(self):
221
+ config = super().get_config()
222
+ config.update(
223
+ {
224
+ "hidden_dim": self.hidden_dim,
225
+ "output_dim": self.output_dim,
226
+ }
227
+ )
228
+ return config
229
+
230
+
231
+ class Unpatch(layers.Layer):
232
+ def __init__(self, patch_size, output_dim, **kwargs):
233
+ super().__init__(**kwargs)
234
+ self.patch_size = int(patch_size)
235
+ self.output_dim = int(output_dim)
236
+
237
+ def call(self, inputs, height, width):
238
+ patch_size = self.patch_size
239
+ output_dim = self.output_dim
240
+ x = ops.reshape(
241
+ inputs,
242
+ (-1, height, width, patch_size, patch_size, output_dim),
243
+ )
244
+ # (b, h, w, p1, p2, o) -> (b, h, p1, w, p2, o)
245
+ x = ops.transpose(x, (0, 1, 3, 2, 4, 5))
246
+ return ops.reshape(
247
+ x,
248
+ (-1, height * patch_size, width * patch_size, output_dim),
249
+ )
250
+
251
+ def get_config(self):
252
+ config = super().get_config()
253
+ config.update(
254
+ {
255
+ "patch_size": self.patch_size,
256
+ "output_dim": self.output_dim,
257
+ }
258
+ )
259
+ return config
260
+
261
+ def compute_output_shape(self, inputs_shape):
262
+ inputs_shape = list(inputs_shape)
263
+ return [inputs_shape[0], None, None, self.output_dim]
264
+
265
+
266
+ class MMDiT(keras.Model):
267
+ def __init__(
268
+ self,
269
+ patch_size,
270
+ num_heads,
271
+ hidden_dim,
272
+ depth,
273
+ position_size,
274
+ output_dim,
275
+ mlp_ratio=4.0,
276
+ latent_shape=(64, 64, 16),
277
+ context_shape=(1024, 4096),
278
+ pooled_projection_shape=(2048,),
279
+ data_format=None,
280
+ dtype=None,
281
+ **kwargs,
282
+ ):
283
+ if None in latent_shape:
284
+ raise ValueError(
285
+ "`latent_shape` must be fully specified. "
286
+ f"Received: latent_shape={latent_shape}"
287
+ )
288
+ image_height = latent_shape[0] // patch_size
289
+ image_width = latent_shape[1] // patch_size
290
+ output_dim_in_final = patch_size**2 * output_dim
291
+ data_format = standardize_data_format(data_format)
292
+ if data_format != "channels_last":
293
+ raise NotImplementedError(
294
+ "Currently only 'channels_last' is supported."
295
+ )
296
+
297
+ # === Layers ===
298
+ self.patch_embedding = PatchEmbedding(
299
+ patch_size,
300
+ hidden_dim,
301
+ data_format=data_format,
302
+ dtype=dtype,
303
+ name="patch_embedding",
304
+ )
305
+ self.position_embedding_add = layers.Add(
306
+ dtype=dtype, name="position_embedding_add"
307
+ )
308
+ self.position_embedding = AdjustablePositionEmbedding(
309
+ position_size, position_size, dtype=dtype, name="position_embedding"
310
+ )
311
+ self.context_embedding = layers.Dense(
312
+ hidden_dim,
313
+ dtype=dtype,
314
+ name="context_embedding",
315
+ )
316
+ self.vector_embedding = models.Sequential(
317
+ [
318
+ layers.Dense(hidden_dim, activation="silu", dtype=dtype),
319
+ layers.Dense(hidden_dim, activation=None, dtype=dtype),
320
+ ],
321
+ name="vector_embedding",
322
+ )
323
+ self.vector_embedding_add = layers.Add(
324
+ dtype=dtype, name="vector_embedding_add"
325
+ )
326
+ self.timestep_embedding = TimestepEmbedding(
327
+ hidden_dim, dtype=dtype, name="timestep_embedding"
328
+ )
329
+ self.joint_blocks = [
330
+ MMDiTBlock(
331
+ num_heads,
332
+ hidden_dim,
333
+ mlp_ratio,
334
+ use_context_projection=not (i == depth - 1),
335
+ dtype=dtype,
336
+ name=f"joint_block_{i}",
337
+ )
338
+ for i in range(depth)
339
+ ]
340
+ self.output_layer = OutputLayer(
341
+ hidden_dim, output_dim_in_final, dtype=dtype, name="output_layer"
342
+ )
343
+ self.unpatch = Unpatch(
344
+ patch_size, output_dim, dtype=dtype, name="unpatch"
345
+ )
346
+
347
+ # === Functional Model ===
348
+ latent_inputs = layers.Input(shape=latent_shape, name="latent")
349
+ context_inputs = layers.Input(shape=context_shape, name="context")
350
+ pooled_projection_inputs = layers.Input(
351
+ shape=pooled_projection_shape, name="pooled_projection"
352
+ )
353
+ timestep_inputs = layers.Input(shape=(1,), name="timestep")
354
+
355
+ # Embeddings.
356
+ x = self.patch_embedding(latent_inputs)
357
+ position_embedding = self.position_embedding(
358
+ x, height=image_height, width=image_width
359
+ )
360
+ x = self.position_embedding_add([x, position_embedding])
361
+ context = self.context_embedding(context_inputs)
362
+ pooled_projection = self.vector_embedding(pooled_projection_inputs)
363
+ timestep_embedding = self.timestep_embedding(timestep_inputs)
364
+ timestep_embedding = self.vector_embedding_add(
365
+ [timestep_embedding, pooled_projection]
366
+ )
367
+
368
+ # Blocks.
369
+ for block in self.joint_blocks:
370
+ if block.use_context_projection:
371
+ x, context = block(x, context, timestep_embedding)
372
+ else:
373
+ x = block(x, context, timestep_embedding)
374
+
375
+ # Output layer.
376
+ x = self.output_layer(x, timestep_embedding)
377
+ outputs = self.unpatch(x, height=image_height, width=image_width)
378
+
379
+ super().__init__(
380
+ inputs={
381
+ "latent": latent_inputs,
382
+ "context": context_inputs,
383
+ "pooled_projection": pooled_projection_inputs,
384
+ "timestep": timestep_inputs,
385
+ },
386
+ outputs=outputs,
387
+ **kwargs,
388
+ )
389
+
390
+ # === Config ===
391
+ self.patch_size = patch_size
392
+ self.num_heads = num_heads
393
+ self.hidden_dim = hidden_dim
394
+ self.depth = depth
395
+ self.position_size = position_size
396
+ self.output_dim = output_dim
397
+ self.mlp_ratio = mlp_ratio
398
+ self.latent_shape = latent_shape
399
+ self.context_shape = context_shape
400
+ self.pooled_projection_shape = pooled_projection_shape
401
+
402
+ if dtype is not None:
403
+ try:
404
+ self.dtype_policy = keras.dtype_policies.get(dtype)
405
+ # Before Keras 3.2, there is no `keras.dtype_policies.get`.
406
+ except AttributeError:
407
+ if isinstance(dtype, keras.DTypePolicy):
408
+ dtype = dtype.name
409
+ self.dtype_policy = keras.DTypePolicy(dtype)
410
+
411
+ def get_config(self):
412
+ config = super().get_config()
413
+ config.update(
414
+ {
415
+ "patch_size": self.patch_size,
416
+ "num_heads": self.num_heads,
417
+ "hidden_dim": self.hidden_dim,
418
+ "depth": self.depth,
419
+ "position_size": self.position_size,
420
+ "output_dim": self.output_dim,
421
+ "mlp_ratio": self.mlp_ratio,
422
+ "latent_shape": self.latent_shape,
423
+ "context_shape": self.context_shape,
424
+ "pooled_projection_shape": self.pooled_projection_shape,
425
+ }
426
+ )
427
+ return config