keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,317 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ from keras import layers
17
+ from keras import models
18
+ from keras import ops
19
+
20
+ from keras_hub.src.utils.keras_utils import gelu_approximate
21
+
22
+
23
+ class DismantledBlock(layers.Layer):
24
+ def __init__(
25
+ self,
26
+ num_heads,
27
+ hidden_dim,
28
+ mlp_ratio=4.0,
29
+ use_projection=True,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.num_heads = num_heads
34
+ self.hidden_dim = hidden_dim
35
+ self.mlp_ratio = mlp_ratio
36
+ self.use_projection = use_projection
37
+
38
+ head_dim = hidden_dim // num_heads
39
+ self.head_dim = head_dim
40
+ mlp_hidden_dim = int(hidden_dim * mlp_ratio)
41
+ self.mlp_hidden_dim = mlp_hidden_dim
42
+ num_modulations = 6 if use_projection else 2
43
+ self.num_modulations = num_modulations
44
+
45
+ self.adaptive_norm_modulation = models.Sequential(
46
+ [
47
+ layers.Activation("silu", dtype=self.dtype_policy),
48
+ layers.Dense(
49
+ num_modulations * hidden_dim, dtype=self.dtype_policy
50
+ ),
51
+ ],
52
+ name="adaptive_norm_modulation",
53
+ )
54
+ self.norm1 = layers.LayerNormalization(
55
+ epsilon=1e-6,
56
+ center=False,
57
+ scale=False,
58
+ dtype=self.dtype_policy,
59
+ name="norm1",
60
+ )
61
+ self.attention_qkv = layers.Dense(
62
+ hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
63
+ )
64
+ if use_projection:
65
+ self.attention_proj = layers.Dense(
66
+ hidden_dim, dtype=self.dtype_policy, name="attention_proj"
67
+ )
68
+ self.norm2 = layers.LayerNormalization(
69
+ epsilon=1e-6,
70
+ center=False,
71
+ scale=False,
72
+ dtype=self.dtype_policy,
73
+ name="norm2",
74
+ )
75
+ self.mlp = models.Sequential(
76
+ [
77
+ layers.Dense(
78
+ mlp_hidden_dim,
79
+ activation=gelu_approximate,
80
+ dtype=self.dtype_policy,
81
+ ),
82
+ layers.Dense(
83
+ hidden_dim,
84
+ dtype=self.dtype_policy,
85
+ ),
86
+ ],
87
+ name="mlp",
88
+ )
89
+
90
+ def build(self, inputs_shape, timestep_embedding):
91
+ self.adaptive_norm_modulation.build(timestep_embedding)
92
+ self.attention_qkv.build(inputs_shape)
93
+ self.norm1.build(inputs_shape)
94
+ if self.use_projection:
95
+ self.attention_proj.build(inputs_shape)
96
+ self.norm2.build(inputs_shape)
97
+ self.mlp.build(inputs_shape)
98
+
99
+ def _modulate(self, inputs, shift, scale):
100
+ shift = ops.expand_dims(shift, axis=1)
101
+ scale = ops.expand_dims(scale, axis=1)
102
+ return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
103
+
104
+ def _compute_pre_attention(self, inputs, timestep_embedding, training=None):
105
+ batch_size = ops.shape(inputs)[0]
106
+ if self.use_projection:
107
+ modulation = self.adaptive_norm_modulation(
108
+ timestep_embedding, training=training
109
+ )
110
+ modulation = ops.reshape(
111
+ modulation, (batch_size, 6, self.hidden_dim)
112
+ )
113
+ (
114
+ shift_msa,
115
+ scale_msa,
116
+ gate_msa,
117
+ shift_mlp,
118
+ scale_mlp,
119
+ gate_mlp,
120
+ ) = ops.unstack(modulation, 6, axis=1)
121
+ qkv = self.attention_qkv(
122
+ self._modulate(self.norm1(inputs), shift_msa, scale_msa),
123
+ training=training,
124
+ )
125
+ qkv = ops.reshape(
126
+ qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
127
+ )
128
+ q, k, v = ops.unstack(qkv, 3, axis=2)
129
+ return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
130
+ else:
131
+ modulation = self.adaptive_norm_modulation(
132
+ timestep_embedding, training=training
133
+ )
134
+ modulation = ops.reshape(
135
+ modulation, (batch_size, 2, self.hidden_dim)
136
+ )
137
+ shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1)
138
+ qkv = self.attention_qkv(
139
+ self._modulate(self.norm1(inputs), shift_msa, scale_msa),
140
+ training=training,
141
+ )
142
+ qkv = ops.reshape(
143
+ qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
144
+ )
145
+ q, k, v = ops.unstack(qkv, 3, axis=2)
146
+ return (q, k, v)
147
+
148
+ def _compute_post_attention(
149
+ self, inputs, inputs_intermediates, training=None
150
+ ):
151
+ x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates
152
+ attn = self.attention_proj(inputs, training=training)
153
+ x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn))
154
+ x = ops.add(
155
+ x,
156
+ ops.multiply(
157
+ ops.expand_dims(gate_mlp, axis=1),
158
+ self.mlp(
159
+ self._modulate(self.norm2(x), shift_mlp, scale_mlp),
160
+ training=training,
161
+ ),
162
+ ),
163
+ )
164
+ return x
165
+
166
+ def call(
167
+ self,
168
+ inputs,
169
+ timestep_embedding=None,
170
+ inputs_intermediates=None,
171
+ pre_attention=True,
172
+ training=None,
173
+ ):
174
+ if pre_attention:
175
+ return self._compute_pre_attention(
176
+ inputs, timestep_embedding, training=training
177
+ )
178
+ else:
179
+ return self._compute_post_attention(
180
+ inputs, inputs_intermediates, training=training
181
+ )
182
+
183
+ def get_config(self):
184
+ config = super().get_config()
185
+ config.update(
186
+ {
187
+ "num_heads": self.num_heads,
188
+ "hidden_dim": self.hidden_dim,
189
+ "mlp_ratio": self.mlp_ratio,
190
+ "use_projection": self.use_projection,
191
+ }
192
+ )
193
+ return config
194
+
195
+
196
+ class MMDiTBlock(layers.Layer):
197
+ def __init__(
198
+ self,
199
+ num_heads,
200
+ hidden_dim,
201
+ mlp_ratio=4.0,
202
+ use_context_projection=True,
203
+ **kwargs,
204
+ ):
205
+ super().__init__(**kwargs)
206
+ self.num_heads = num_heads
207
+ self.hidden_dim = hidden_dim
208
+ self.mlp_ratio = mlp_ratio
209
+ self.use_context_projection = use_context_projection
210
+
211
+ head_dim = hidden_dim // num_heads
212
+ self.head_dim = head_dim
213
+ self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim)
214
+ self._dot_product_equation = "aecd,abcd->acbe"
215
+ self._combine_equation = "acbe,aecd->abcd"
216
+
217
+ self.x_block = DismantledBlock(
218
+ num_heads=num_heads,
219
+ hidden_dim=hidden_dim,
220
+ mlp_ratio=mlp_ratio,
221
+ use_projection=True,
222
+ dtype=self.dtype_policy,
223
+ name="x_block",
224
+ )
225
+ self.context_block = DismantledBlock(
226
+ num_heads=num_heads,
227
+ hidden_dim=hidden_dim,
228
+ mlp_ratio=mlp_ratio,
229
+ use_projection=use_context_projection,
230
+ dtype=self.dtype_policy,
231
+ name="context_block",
232
+ )
233
+
234
+ def build(self, inputs_shape, context_shape, timestep_embedding_shape):
235
+ self.x_block.build(inputs_shape, timestep_embedding_shape)
236
+ self.context_block.build(context_shape, timestep_embedding_shape)
237
+
238
+ def _compute_attention(self, query, key, value):
239
+ query = ops.multiply(
240
+ query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
241
+ )
242
+ attention_scores = ops.einsum(self._dot_product_equation, key, query)
243
+ attention_scores = ops.nn.softmax(attention_scores, axis=-1)
244
+ attention_output = ops.einsum(
245
+ self._combine_equation, attention_scores, value
246
+ )
247
+ batch_size = ops.shape(attention_output)[0]
248
+ attention_output = ops.reshape(
249
+ attention_output, (batch_size, -1, self.num_heads * self.head_dim)
250
+ )
251
+ return attention_output
252
+
253
+ def call(self, inputs, context, timestep_embedding, training=None):
254
+ # Compute pre-attention.
255
+ x = inputs
256
+ if self.use_context_projection:
257
+ context_qkv, context_intermediates = self.context_block(
258
+ context,
259
+ timestep_embedding=timestep_embedding,
260
+ training=training,
261
+ )
262
+ else:
263
+ context_qkv = self.context_block(
264
+ context,
265
+ timestep_embedding=timestep_embedding,
266
+ training=training,
267
+ )
268
+ context_len = ops.shape(context_qkv[0])[1]
269
+ x_qkv, x_intermediates = self.x_block(
270
+ x, timestep_embedding=timestep_embedding, training=training
271
+ )
272
+ q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1)
273
+ k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1)
274
+ v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1)
275
+
276
+ # Compute attention.
277
+ attention = self._compute_attention(q, k, v)
278
+ context_attention = attention[:, :context_len]
279
+ x_attention = attention[:, context_len:]
280
+
281
+ # Compute post-attention.
282
+ x = self.x_block(
283
+ x_attention,
284
+ inputs_intermediates=x_intermediates,
285
+ pre_attention=False,
286
+ training=training,
287
+ )
288
+ if self.use_context_projection:
289
+ context = self.context_block(
290
+ context_attention,
291
+ inputs_intermediates=context_intermediates,
292
+ pre_attention=False,
293
+ training=training,
294
+ )
295
+ return x, context
296
+ else:
297
+ return x
298
+
299
+ def get_config(self):
300
+ config = super().get_config()
301
+ config.update(
302
+ {
303
+ "num_heads": self.num_heads,
304
+ "hidden_dim": self.hidden_dim,
305
+ "mlp_ratio": self.mlp_ratio,
306
+ "use_context_projection": self.use_context_projection,
307
+ }
308
+ )
309
+ return config
310
+
311
+ def compute_output_shape(
312
+ self, inputs_shape, context_shape, timestep_embedding_shape
313
+ ):
314
+ if self.use_context_projection:
315
+ return inputs_shape, context_shape
316
+ else:
317
+ return inputs_shape
@@ -0,0 +1,74 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+
16
+ from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
17
+ from keras_hub.src.models.preprocessor import Preprocessor
18
+ from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer
19
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
20
+
21
+
22
+ class T5XXLPreprocessor(Preprocessor):
23
+ tokenizer_cls = T5Tokenizer
24
+
25
+ def __init__(
26
+ self,
27
+ tokenizer,
28
+ sequence_length=256,
29
+ add_start_token=False,
30
+ add_end_token=True,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(**kwargs)
34
+ self.tokenizer = tokenizer
35
+ self.sequence_length = sequence_length
36
+ self.add_start_token = add_start_token
37
+ self.add_end_token = add_end_token
38
+
39
+ def build(self, input_shape):
40
+ # Defer packer creation to `build()` so that we can be sure tokenizer
41
+ # assets have loaded when restoring a saved model.
42
+ self.packer = StartEndPacker(
43
+ start_value=self.tokenizer.start_token_id,
44
+ end_value=self.tokenizer.end_token_id,
45
+ pad_value=self.tokenizer.pad_token_id,
46
+ sequence_length=self.sequence_length,
47
+ return_padding_mask=True,
48
+ )
49
+ self.built = True
50
+
51
+ @preprocessing_function
52
+ def call(self, x, y=None, sample_weight=None, sequence_length=None):
53
+ token_ids, padding_mask = self.packer(
54
+ self.tokenizer(x),
55
+ sequence_length=sequence_length or self.sequence_length,
56
+ add_start_value=self.add_start_token,
57
+ add_end_value=self.add_end_token,
58
+ )
59
+ x = {
60
+ "token_ids": token_ids,
61
+ "padding_mask": padding_mask,
62
+ }
63
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
64
+
65
+ def get_config(self):
66
+ config = super().get_config()
67
+ config.update(
68
+ {
69
+ "sequence_length": self.sequence_length,
70
+ "add_start_token": self.add_start_token,
71
+ "add_end_token": self.add_end_token,
72
+ }
73
+ )
74
+ return config
@@ -0,0 +1,155 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+
16
+ from keras_hub.src.layers.modeling.reversible_embedding import (
17
+ ReversibleEmbedding,
18
+ )
19
+ from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm
20
+ from keras_hub.src.models.t5.t5_transformer_layer import T5TransformerLayer
21
+
22
+
23
+ class T5XXLTextEncoder(keras.Model):
24
+ def __init__(
25
+ self,
26
+ vocabulary_size,
27
+ num_layers,
28
+ num_heads,
29
+ hidden_dim,
30
+ intermediate_dim,
31
+ key_value_dim=None,
32
+ dropout=0.1,
33
+ activation="relu",
34
+ use_gated_activation=True,
35
+ layer_norm_epsilon=1e-06,
36
+ tie_embedding_weights=True,
37
+ dtype=None,
38
+ **kwargs,
39
+ ):
40
+ # === Layers ===
41
+ self.token_embedding = ReversibleEmbedding(
42
+ input_dim=vocabulary_size,
43
+ output_dim=hidden_dim,
44
+ tie_weights=tie_embedding_weights,
45
+ embeddings_initializer=keras.initializers.TruncatedNormal(1.0),
46
+ dtype=dtype,
47
+ name="token_embedding",
48
+ )
49
+ self.encoder_embedding_dropout = keras.layers.Dropout(
50
+ dropout,
51
+ dtype=dtype,
52
+ name="encoder_embedding_dropout",
53
+ )
54
+ self.encoder_transformer_layers = []
55
+ for i in range(num_layers):
56
+ layer = T5TransformerLayer(
57
+ is_decoder=False,
58
+ hidden_dim=hidden_dim,
59
+ intermediate_dim=intermediate_dim,
60
+ key_value_dim=key_value_dim or hidden_dim // num_heads,
61
+ dropout=dropout,
62
+ activation=activation,
63
+ layer_norm_epsilon=layer_norm_epsilon,
64
+ num_heads=num_heads,
65
+ use_gated_activation=use_gated_activation,
66
+ use_relative_attention_bias=bool(i == 0),
67
+ dtype=dtype,
68
+ name=f"transformer_encoder_layer_{i}",
69
+ )
70
+ self.encoder_transformer_layers.append(layer)
71
+ self.encoder_layer_norm = T5LayerNorm(
72
+ epsilon=layer_norm_epsilon,
73
+ dtype=dtype,
74
+ name="encoder_output_layer_norm",
75
+ )
76
+ self.encoder_dropout = keras.layers.Dropout(
77
+ dropout,
78
+ dtype=dtype,
79
+ name="encoder_output_dropout",
80
+ )
81
+
82
+ # === Functional Model ===
83
+ encoder_token_id_input = keras.Input(
84
+ shape=(None,), dtype="int32", name="encoder_token_ids"
85
+ )
86
+ encoder_padding_mask_input = keras.Input(
87
+ shape=(None,), dtype="int32", name="encoder_padding_mask"
88
+ )
89
+ # Encoder.
90
+ x = self.token_embedding(encoder_token_id_input)
91
+ x = self.encoder_embedding_dropout(x)
92
+ encoder_attention_mask = encoder_padding_mask_input[:, None, :]
93
+ position_bias = None
94
+ for transformer_layer in self.encoder_transformer_layers:
95
+ output = transformer_layer(
96
+ x,
97
+ attention_mask=encoder_attention_mask,
98
+ position_bias=position_bias,
99
+ use_causal_mask=False,
100
+ )
101
+ if isinstance(output, tuple):
102
+ x, position_bias = output
103
+ x = self.encoder_layer_norm(x)
104
+ x = self.encoder_dropout(x)
105
+ encoder_output = x
106
+
107
+ super().__init__(
108
+ {
109
+ "encoder_token_ids": encoder_token_id_input,
110
+ "encoder_padding_mask": encoder_padding_mask_input,
111
+ },
112
+ outputs=encoder_output,
113
+ **kwargs,
114
+ )
115
+
116
+ # === Config ===
117
+ self.vocabulary_size = vocabulary_size
118
+ self.hidden_dim = hidden_dim
119
+ self.intermediate_dim = intermediate_dim
120
+ self.num_layers = num_layers
121
+ self.num_heads = num_heads
122
+ self.activation = keras.activations.get(activation)
123
+ self.key_value_dim = key_value_dim
124
+ self.dropout = dropout
125
+ self.use_gated_activation = use_gated_activation
126
+ self.layer_norm_epsilon = layer_norm_epsilon
127
+ self.tie_embedding_weights = tie_embedding_weights
128
+
129
+ if dtype is not None:
130
+ try:
131
+ self.dtype_policy = keras.dtype_policies.get(dtype)
132
+ # Before Keras 3.2, there is no `keras.dtype_policies.get`.
133
+ except AttributeError:
134
+ if isinstance(dtype, keras.DTypePolicy):
135
+ dtype = dtype.name
136
+ self.dtype_policy = keras.DTypePolicy(dtype)
137
+
138
+ def get_config(self):
139
+ config = super().get_config()
140
+ config.update(
141
+ {
142
+ "vocabulary_size": self.vocabulary_size,
143
+ "hidden_dim": self.hidden_dim,
144
+ "intermediate_dim": self.intermediate_dim,
145
+ "num_layers": self.num_layers,
146
+ "num_heads": self.num_heads,
147
+ "activation": keras.activations.serialize(self.activation),
148
+ "key_value_dim": self.key_value_dim,
149
+ "dropout": self.dropout,
150
+ "use_gated_activation": self.use_gated_activation,
151
+ "layer_norm_epsilon": self.layer_norm_epsilon,
152
+ "tie_embedding_weights": self.tie_embedding_weights,
153
+ }
154
+ )
155
+ return config
@@ -0,0 +1,126 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ from keras import layers
17
+ from keras import ops
18
+
19
+ from keras_hub.src.utils.keras_utils import standardize_data_format
20
+
21
+
22
+ class VAEAttention(layers.Layer):
23
+ def __init__(self, filters, groups=32, data_format=None, **kwargs):
24
+ super().__init__(**kwargs)
25
+ self.filters = filters
26
+ self.data_format = standardize_data_format(data_format)
27
+ gn_axis = -1 if self.data_format == "channels_last" else 1
28
+
29
+ self.group_norm = layers.GroupNormalization(
30
+ groups=groups,
31
+ axis=gn_axis,
32
+ epsilon=1e-6,
33
+ dtype=self.dtype_policy,
34
+ name="group_norm",
35
+ )
36
+ self.query_conv2d = layers.Conv2D(
37
+ filters,
38
+ 1,
39
+ 1,
40
+ data_format=self.data_format,
41
+ dtype=self.dtype_policy,
42
+ name="query_conv2d",
43
+ )
44
+ self.key_conv2d = layers.Conv2D(
45
+ filters,
46
+ 1,
47
+ 1,
48
+ data_format=self.data_format,
49
+ dtype=self.dtype_policy,
50
+ name="key_conv2d",
51
+ )
52
+ self.value_conv2d = layers.Conv2D(
53
+ filters,
54
+ 1,
55
+ 1,
56
+ data_format=self.data_format,
57
+ dtype=self.dtype_policy,
58
+ name="value_conv2d",
59
+ )
60
+ self.softmax = layers.Softmax(dtype="float32")
61
+ self.output_conv2d = layers.Conv2D(
62
+ filters,
63
+ 1,
64
+ 1,
65
+ data_format=self.data_format,
66
+ dtype=self.dtype_policy,
67
+ name="output_conv2d",
68
+ )
69
+
70
+ self.groups = groups
71
+ self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
72
+
73
+ def build(self, input_shape):
74
+ self.group_norm.build(input_shape)
75
+ self.query_conv2d.build(input_shape)
76
+ self.key_conv2d.build(input_shape)
77
+ self.value_conv2d.build(input_shape)
78
+ self.output_conv2d.build(input_shape)
79
+
80
+ def call(self, inputs, training=None):
81
+ x = self.group_norm(inputs)
82
+ query = self.query_conv2d(x)
83
+ key = self.key_conv2d(x)
84
+ value = self.value_conv2d(x)
85
+
86
+ if self.data_format == "channels_first":
87
+ query = ops.transpose(query, (0, 2, 3, 1))
88
+ key = ops.transpose(key, (0, 2, 3, 1))
89
+ value = ops.transpose(value, (0, 2, 3, 1))
90
+ shape = ops.shape(inputs)
91
+ b = shape[0]
92
+ query = ops.reshape(query, (b, -1, self.filters))
93
+ key = ops.reshape(key, (b, -1, self.filters))
94
+ value = ops.reshape(value, (b, -1, self.filters))
95
+
96
+ # Compute attention.
97
+ query = ops.multiply(
98
+ query, ops.cast(self._inverse_sqrt_filters, query.dtype)
99
+ )
100
+ # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
101
+ attention_scores = ops.einsum("abc,adc->abd", query, key)
102
+ attention_scores = ops.cast(
103
+ self.softmax(attention_scores), self.compute_dtype
104
+ )
105
+ # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
106
+ attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
107
+ x = ops.reshape(attention_output, shape)
108
+
109
+ x = self.output_conv2d(x)
110
+ if self.data_format == "channels_first":
111
+ x = ops.transpose(x, (0, 3, 1, 2))
112
+ x = ops.add(x, inputs)
113
+ return x
114
+
115
+ def get_config(self):
116
+ config = super().get_config()
117
+ config.update(
118
+ {
119
+ "filters": self.filters,
120
+ "groups": self.groups,
121
+ }
122
+ )
123
+ return config
124
+
125
+ def compute_output_shape(self, input_shape):
126
+ return input_shape