keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev20240915160609__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. keras_hub/api/__init__.py +1 -0
  2. keras_hub/api/bounding_box/__init__.py +36 -0
  3. keras_hub/api/layers/__init__.py +14 -0
  4. keras_hub/api/models/__init__.py +97 -48
  5. keras_hub/api/tokenizers/__init__.py +30 -0
  6. keras_hub/src/bounding_box/__init__.py +13 -0
  7. keras_hub/src/bounding_box/converters.py +529 -0
  8. keras_hub/src/bounding_box/formats.py +162 -0
  9. keras_hub/src/bounding_box/iou.py +263 -0
  10. keras_hub/src/bounding_box/to_dense.py +95 -0
  11. keras_hub/src/bounding_box/to_ragged.py +99 -0
  12. keras_hub/src/bounding_box/utils.py +194 -0
  13. keras_hub/src/bounding_box/validate_format.py +99 -0
  14. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  15. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  16. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  17. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  18. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  19. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  20. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  21. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  22. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  23. keras_hub/src/models/albert/__init__.py +1 -2
  24. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  25. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  26. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  27. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  28. keras_hub/src/models/backbone.py +12 -34
  29. keras_hub/src/models/bart/__init__.py +1 -2
  30. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  31. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  32. keras_hub/src/models/bert/__init__.py +1 -5
  33. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  34. keras_hub/src/models/bert/bert_presets.py +1 -4
  35. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  36. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  37. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  38. keras_hub/src/models/bloom/__init__.py +1 -2
  39. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  40. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  41. keras_hub/src/models/causal_lm.py +10 -29
  42. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  43. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  44. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  45. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  46. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  47. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  48. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  49. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  50. keras_hub/src/models/distil_bert/__init__.py +1 -4
  51. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  52. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  53. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  54. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  55. keras_hub/src/models/efficientnet/__init__.py +13 -0
  56. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  57. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  58. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  59. keras_hub/src/models/electra/__init__.py +1 -2
  60. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  61. keras_hub/src/models/f_net/__init__.py +1 -2
  62. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  63. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  64. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  65. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  66. keras_hub/src/models/falcon/__init__.py +1 -2
  67. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  68. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  69. keras_hub/src/models/gemma/__init__.py +1 -2
  70. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  71. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  72. keras_hub/src/models/gpt2/__init__.py +1 -2
  73. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  74. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  75. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  76. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  77. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  78. keras_hub/src/models/image_classifier.py +0 -5
  79. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  80. keras_hub/src/models/llama/__init__.py +1 -2
  81. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  82. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  83. keras_hub/src/models/llama3/__init__.py +1 -2
  84. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  85. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  86. keras_hub/src/models/masked_lm.py +0 -2
  87. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  88. keras_hub/src/models/mistral/__init__.py +1 -2
  89. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  90. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  91. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  92. keras_hub/src/models/mobilenet/__init__.py +13 -0
  93. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  94. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  95. keras_hub/src/models/opt/__init__.py +1 -2
  96. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  97. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  98. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  99. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  100. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  101. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  102. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  103. keras_hub/src/models/phi3/__init__.py +1 -2
  104. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  105. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  106. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  107. keras_hub/src/models/preprocessor.py +72 -83
  108. keras_hub/src/models/resnet/__init__.py +6 -0
  109. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  110. keras_hub/src/models/resnet/resnet_image_classifier.py +24 -3
  111. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  112. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  113. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  114. keras_hub/src/models/roberta/__init__.py +1 -2
  115. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  116. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  117. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  118. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  119. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  120. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  121. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  122. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  123. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  124. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  125. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  126. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  127. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  128. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  129. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  130. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  131. keras_hub/src/models/t5/__init__.py +1 -2
  132. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  133. keras_hub/src/models/task.py +71 -116
  134. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  135. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  136. keras_hub/src/models/whisper/__init__.py +1 -2
  137. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  138. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  139. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  140. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  141. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  142. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  143. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  144. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  145. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  146. keras_hub/src/tests/test_case.py +38 -0
  147. keras_hub/src/tokenizers/byte_pair_tokenizer.py +29 -17
  148. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  149. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +19 -7
  150. keras_hub/src/tokenizers/tokenizer.py +67 -32
  151. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  152. keras_hub/src/tokenizers/word_piece_tokenizer.py +33 -47
  153. keras_hub/src/utils/keras_utils.py +0 -50
  154. keras_hub/src/utils/preset_utils.py +220 -67
  155. keras_hub/src/utils/tensor_utils.py +187 -69
  156. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  157. keras_hub/src/utils/timm/preset_loader.py +66 -0
  158. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  159. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  160. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  161. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  162. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  163. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  164. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  165. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  166. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  167. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  168. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  169. keras_hub/src/version_utils.py +1 -1
  170. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/METADATA +1 -2
  171. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/RECORD +173 -143
  172. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/WHEEL +1 -1
  173. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  174. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  175. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  176. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  177. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  178. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  179. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  180. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  181. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  182. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  183. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  184. keras_hub/src/utils/timm/convert.py +0 -37
  185. keras_hub/src/utils/transformers/convert.py +0 -101
  186. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,269 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
18
+ from keras_hub.src.models.preprocessor import Preprocessor
19
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
20
+ from keras_hub.src.utils.tensor_utils import strip_to_ragged
21
+
22
+ try:
23
+ import tensorflow as tf
24
+ except ImportError:
25
+ tf = None
26
+
27
+
28
+ @keras_hub_export("keras_hub.models.Seq2SeqLMPreprocessor")
29
+ class Seq2SeqLMPreprocessor(Preprocessor):
30
+ """Base class for seq2seq language modeling preprocessing layers.
31
+
32
+ `Seq2SeqLMPreprocessor` tasks wrap a `keras_hub.tokenizer.Tokenizer` to
33
+ create a preprocessing layer for seq2seq language modeling tasks. It is
34
+ intended to be paired with a `keras.models.Seq2SeqLM` task.
35
+
36
+ All `Seq2SeqLMPreprocessor` layers take inputs a dictionary input with keys
37
+ `"encoder_text"` and `"decoder_text"`.
38
+
39
+ This layer will always output a `(x, y, sample_weight)` tuple, where `x`
40
+ is a dictionary with the tokenized inputs, `y` contains the tokens from `x`
41
+ offset by 1, and `sample_weight` marks where `y` contains padded
42
+ values. The exact contents of `x` will vary depending on the model being
43
+ used.
44
+
45
+ a `Seq2SeqLMPreprocessor` contains two extra methods, `generate_preprocess`
46
+ and `generate_postprocess` for use with generation. See examples below.
47
+
48
+ All `Seq2SeqLMPreprocessor` tasks include a `from_preset()` constructor
49
+ which can be used to load a pre-trained config and vocabularies. You can
50
+ call the `from_preset()` constructor directly on this base class, in which
51
+ case the correct class for you model will be automatically instantiated.
52
+
53
+ Examples.
54
+ ```python
55
+ preprocessor = keras_hub.models.Seq2SeqLMPreprocessor.from_preset(
56
+ "bart_base_en",
57
+ encoder_sequence_length=256,
58
+ decoder_sequence_length=256,
59
+ )
60
+
61
+ # Tokenize, mask and pack a single sentence.
62
+ x = {
63
+ "encoder_text": "The fox was sleeping.",
64
+ "decoder_text": "The fox was awake.",
65
+ }
66
+ x, y, sample_weight = preprocessor(x)
67
+
68
+ # Tokenize and pad/truncate a batch of labeled sentences.
69
+ x = {
70
+ "encoder_text": ["The fox was sleeping."],
71
+ "decoder_text": ["The fox was awake."],
72
+ x, y, sample_weight = preprocessor(x)
73
+
74
+ # With a `tf.data.Dataset`.
75
+ ds = tf.data.Dataset.from_tensor_slices(x)
76
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
77
+
78
+ # Generate preprocess and postprocess.
79
+ x = preprocessor.generate_preprocess(x) # Tokenized numeric inputs.
80
+ x = preprocessor.generate_postprocess(x) # Detokenized string outputs.
81
+ ```
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ tokenizer,
87
+ encoder_sequence_length=1024,
88
+ decoder_sequence_length=1024,
89
+ **kwargs,
90
+ ):
91
+ super().__init__(**kwargs)
92
+ self.tokenizer = tokenizer
93
+ self.encoder_packer = None
94
+ self.decoder_packer = None
95
+ self.encoder_sequence_length = encoder_sequence_length
96
+ self.decoder_sequence_length = decoder_sequence_length
97
+
98
+ def build(self, input_shape):
99
+ # Defer packer creation to `build()` so that we can be sure tokenizer
100
+ # assets have loaded when restoring a saved model.
101
+ self.encoder_packer = StartEndPacker(
102
+ start_value=self.tokenizer.start_token_id,
103
+ end_value=self.tokenizer.end_token_id,
104
+ pad_value=self.tokenizer.pad_token_id,
105
+ sequence_length=self.encoder_sequence_length,
106
+ return_padding_mask=True,
107
+ )
108
+ self.decoder_packer = StartEndPacker(
109
+ start_value=self.tokenizer.start_token_id,
110
+ end_value=self.tokenizer.end_token_id,
111
+ pad_value=self.tokenizer.pad_token_id,
112
+ sequence_length=self.decoder_sequence_length,
113
+ return_padding_mask=True,
114
+ )
115
+ self.built = True
116
+
117
+ @preprocessing_function
118
+ def call(
119
+ self,
120
+ x,
121
+ y=None,
122
+ sample_weight=None,
123
+ *,
124
+ encoder_sequence_length=None,
125
+ decoder_sequence_length=None,
126
+ # `sequence_length` is an alias for `decoder_sequence_length`
127
+ sequence_length=None,
128
+ ):
129
+ if encoder_sequence_length is None:
130
+ encoder_sequence_length = self.encoder_sequence_length
131
+ decoder_sequence_length = decoder_sequence_length or sequence_length
132
+ if decoder_sequence_length is None:
133
+ decoder_sequence_length = self.decoder_sequence_length
134
+
135
+ encoder_inputs = self.tokenizer(x["encoder_text"])
136
+ encoder_token_ids, encoder_padding_mask = self.encoder_packer(
137
+ encoder_inputs,
138
+ sequence_length=encoder_sequence_length,
139
+ )
140
+ decoder_inputs = self.tokenizer(x["decoder_text"])
141
+ decoder_token_ids, decoder_padding_mask = self.decoder_packer(
142
+ decoder_inputs,
143
+ sequence_length=decoder_sequence_length + 1,
144
+ )
145
+ x = {
146
+ "encoder_token_ids": encoder_token_ids,
147
+ "encoder_padding_mask": encoder_padding_mask,
148
+ "decoder_token_ids": decoder_token_ids[..., :-1],
149
+ "decoder_padding_mask": decoder_padding_mask[..., :-1],
150
+ }
151
+ # Target `y` will be the decoder input sequence shifted one step to the
152
+ # left (i.e., the next token).
153
+ y = decoder_token_ids[..., 1:]
154
+ sample_weight = decoder_padding_mask[..., 1:]
155
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
156
+
157
+ @preprocessing_function
158
+ def generate_preprocess(
159
+ self,
160
+ x,
161
+ *,
162
+ encoder_sequence_length=None,
163
+ decoder_sequence_length=None,
164
+ # `sequence_length` is an alias for `decoder_sequence_length`
165
+ sequence_length=None,
166
+ ):
167
+ """Convert encoder and decoder input strings to integer token inputs for generation.
168
+
169
+ Similar to calling the layer for training, this method takes in a dict
170
+ containing `"encoder_text"` and `"decoder_text"`, with strings or tensor
171
+ strings for values, tokenizes and packs the input, and computes a
172
+ padding mask masking all inputs not filled in with a padded value.
173
+
174
+ Unlike calling the layer for training, this method does not compute
175
+ labels and will never append a tokenizer.end_token_id to the end of
176
+ the decoder sequence (as generation is expected to continue at the end
177
+ of the inputted decoder prompt).
178
+ """
179
+ if not self.built:
180
+ self.build(None)
181
+
182
+ if isinstance(x, dict):
183
+ encoder_text = x["encoder_text"]
184
+ decoder_text = x["decoder_text"]
185
+ else:
186
+ encoder_text = x
187
+ # Initialize empty prompt for the decoder.
188
+ decoder_text = tf.fill((tf.shape(encoder_text)[0],), "")
189
+
190
+ if encoder_sequence_length is None:
191
+ encoder_sequence_length = self.encoder_sequence_length
192
+ decoder_sequence_length = decoder_sequence_length or sequence_length
193
+ if decoder_sequence_length is None:
194
+ decoder_sequence_length = self.decoder_sequence_length
195
+
196
+ # Tokenize and pack the encoder inputs.
197
+ encoder_token_ids = self.tokenizer(encoder_text)
198
+ encoder_token_ids, encoder_padding_mask = self.encoder_packer(
199
+ encoder_token_ids,
200
+ sequence_length=encoder_sequence_length,
201
+ )
202
+
203
+ # Tokenize and pack the decoder inputs.
204
+ decoder_token_ids = self.tokenizer(decoder_text)
205
+ decoder_token_ids, decoder_padding_mask = self.decoder_packer(
206
+ decoder_token_ids,
207
+ sequence_length=decoder_sequence_length,
208
+ add_end_value=False,
209
+ )
210
+
211
+ return {
212
+ "encoder_token_ids": encoder_token_ids,
213
+ "encoder_padding_mask": encoder_padding_mask,
214
+ "decoder_token_ids": decoder_token_ids,
215
+ "decoder_padding_mask": decoder_padding_mask,
216
+ }
217
+
218
+ @preprocessing_function
219
+ def generate_postprocess(
220
+ self,
221
+ x,
222
+ ):
223
+ """Convert integer token output to strings for generation.
224
+
225
+ This method reverses `generate_preprocess()`, by first removing all
226
+ padding and start/end tokens, and then converting the integer sequence
227
+ back to a string.
228
+ """
229
+ if not self.built:
230
+ self.build(None)
231
+
232
+ token_ids, padding_mask = (
233
+ x["decoder_token_ids"],
234
+ x["decoder_padding_mask"],
235
+ )
236
+ ids_to_strip = self.tokenizer.special_token_ids
237
+ token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
238
+ return self.tokenizer.detokenize(token_ids)
239
+
240
+ @property
241
+ def encoder_sequence_length(self):
242
+ """The padded length of encoder input sequences."""
243
+ return self._encoder_sequence_length
244
+
245
+ @encoder_sequence_length.setter
246
+ def encoder_sequence_length(self, value):
247
+ self._encoder_sequence_length = value
248
+ if self.encoder_packer is not None:
249
+ self.encoder_packer.sequence_length = value
250
+
251
+ @property
252
+ def decoder_sequence_length(self):
253
+ """The padded length of decoder input sequences."""
254
+ return self._decoder_sequence_length
255
+
256
+ @decoder_sequence_length.setter
257
+ def decoder_sequence_length(self, value):
258
+ self._decoder_sequence_length = value
259
+ if self.decoder_packer is not None:
260
+ self.decoder_packer.sequence_length = value
261
+
262
+ @property
263
+ def sequence_length(self):
264
+ """Alias for `decoder_sequence_length`."""
265
+ return self.decoder_sequence_length
266
+
267
+ @sequence_length.setter
268
+ def sequence_length(self, value):
269
+ self.decoder_sequence_length = value
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,103 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from keras import layers
15
+ from keras import ops
16
+
17
+
18
+ def quick_gelu(x):
19
+ return x * ops.sigmoid(1.702 * x)
20
+
21
+
22
+ class CLIPEncoderBlock(layers.Layer):
23
+ def __init__(
24
+ self,
25
+ hidden_dim,
26
+ num_heads,
27
+ intermediate_dim,
28
+ intermediate_activation="quick_gelu",
29
+ **kwargs,
30
+ ):
31
+ super().__init__(**kwargs)
32
+ if hidden_dim % num_heads != 0:
33
+ raise ValueError(
34
+ "`hidden_dim` must be divisible by `num_heads`. "
35
+ f"Received: hidden_dim={hidden_dim}, num_heads={num_heads}"
36
+ )
37
+ self.hidden_dim = hidden_dim
38
+ self.num_heads = num_heads
39
+ self.intermediate_dim = intermediate_dim
40
+ self.intermediate_activation = intermediate_activation
41
+
42
+ if intermediate_activation == "quick_gelu":
43
+ intermediate_activation = quick_gelu
44
+
45
+ self.layer_norm_1 = layers.LayerNormalization(
46
+ epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_1"
47
+ )
48
+ self.attention = layers.MultiHeadAttention(
49
+ num_heads,
50
+ hidden_dim // num_heads,
51
+ dtype=self.dtype_policy,
52
+ name="attention",
53
+ )
54
+ self.layer_norm_2 = layers.LayerNormalization(
55
+ epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_2"
56
+ )
57
+ self.dense_1 = layers.Dense(
58
+ self.intermediate_dim, dtype=self.dtype_policy, name="dense_1"
59
+ )
60
+ self.activation = layers.Activation(
61
+ intermediate_activation, dtype=self.dtype_policy, name="activation"
62
+ )
63
+ self.dense_2 = layers.Dense(
64
+ self.hidden_dim, dtype=self.dtype_policy, name="dense_2"
65
+ )
66
+
67
+ def build(self, input_shape):
68
+ self.layer_norm_1.build(input_shape)
69
+ self.attention.build(input_shape, input_shape, input_shape)
70
+ self.layer_norm_2.build(input_shape)
71
+ self.dense_1.build(input_shape)
72
+ input_shape = self.dense_1.compute_output_shape(input_shape)
73
+ self.dense_2.build(input_shape)
74
+
75
+ def compute_output_shape(self, inputs_shape):
76
+ outputs_shape = list(inputs_shape)
77
+ outputs_shape[-1] = self.hidden_dim
78
+ return outputs_shape
79
+
80
+ def call(self, x, training=None):
81
+ residual = x
82
+ x = self.layer_norm_1(x)
83
+ x = self.attention(x, x, x, training=training, use_causal_mask=True)
84
+ x = ops.add(residual, x)
85
+
86
+ residual = x
87
+ x = self.dense_1(self.layer_norm_2(residual))
88
+ x = self.activation(x)
89
+ x = self.dense_2(x)
90
+ x = ops.add(residual, x)
91
+ return x
92
+
93
+ def get_config(self):
94
+ config = super().get_config()
95
+ config.update(
96
+ {
97
+ "hidden_dim": self.hidden_dim,
98
+ "num_heads": self.num_heads,
99
+ "intermediate_dim": self.intermediate_dim,
100
+ "intermediate_activation": self.intermediate_activation,
101
+ }
102
+ )
103
+ return config
@@ -0,0 +1,93 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+
16
+ from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
17
+ from keras_hub.src.models.preprocessor import Preprocessor
18
+ from keras_hub.src.models.stable_diffusion_v3.clip_tokenizer import (
19
+ CLIPTokenizer,
20
+ )
21
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
22
+
23
+ try:
24
+ import tensorflow as tf
25
+ except ImportError:
26
+ tf = None
27
+
28
+
29
+ class CLIPPreprocessor(Preprocessor):
30
+ tokenizer_cls = CLIPTokenizer
31
+
32
+ def __init__(
33
+ self,
34
+ tokenizer,
35
+ sequence_length=77,
36
+ add_start_token=True,
37
+ add_end_token=False,
38
+ to_lower=True,
39
+ pad_with_end_token=True,
40
+ **kwargs,
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.tokenizer = tokenizer
44
+ self.sequence_length = sequence_length
45
+ self.add_start_token = add_start_token
46
+ self.add_end_token = add_end_token
47
+ self.to_lower = to_lower
48
+ self.pad_with_end_token = pad_with_end_token
49
+
50
+ def build(self, input_shape):
51
+ # Defer packer creation to `build()` so that we can be sure tokenizer
52
+ # assets have loaded when restoring a saved model.
53
+ pad_value = self.tokenizer.pad_token_id
54
+ if self.pad_with_end_token:
55
+ pad_value = self.tokenizer.end_token_id
56
+
57
+ self.packer = StartEndPacker(
58
+ start_value=self.tokenizer.start_token_id,
59
+ end_value=self.tokenizer.end_token_id,
60
+ pad_value=pad_value,
61
+ sequence_length=self.sequence_length,
62
+ return_padding_mask=True,
63
+ )
64
+ self.built = True
65
+
66
+ @preprocessing_function
67
+ def call(self, x, y=None, sample_weight=None, sequence_length=None):
68
+ if self.to_lower:
69
+ x = tf.strings.lower(x)
70
+ token_ids, padding_mask = self.packer(
71
+ self.tokenizer(x),
72
+ sequence_length=sequence_length or self.sequence_length,
73
+ add_start_value=self.add_start_token,
74
+ add_end_value=self.add_end_token,
75
+ )
76
+ x = {
77
+ "token_ids": token_ids,
78
+ "padding_mask": padding_mask,
79
+ }
80
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
81
+
82
+ def get_config(self):
83
+ config = super().get_config()
84
+ config.update(
85
+ {
86
+ "sequence_length": self.sequence_length,
87
+ "add_start_token": self.add_start_token,
88
+ "add_end_token": self.add_end_token,
89
+ "to_lower": self.to_lower,
90
+ "pad_with_end_token": self.pad_with_end_token,
91
+ }
92
+ )
93
+ return config
@@ -0,0 +1,149 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+ from keras import layers
16
+ from keras import ops
17
+
18
+ from keras_hub.src.layers.modeling.token_and_position_embedding import (
19
+ TokenAndPositionEmbedding,
20
+ )
21
+ from keras_hub.src.models.stable_diffusion_v3.clip_encoder_block import (
22
+ CLIPEncoderBlock,
23
+ )
24
+
25
+
26
+ class CLIPTextEncoder(keras.Model):
27
+ def __init__(
28
+ self,
29
+ embedding_dim,
30
+ hidden_dim,
31
+ num_layers,
32
+ num_heads,
33
+ intermediate_dim,
34
+ intermediate_activation="quick_gelu",
35
+ intermediate_output_index=None,
36
+ vocabulary_size=49408,
37
+ sequence_length=77,
38
+ dtype=None,
39
+ **kwargs,
40
+ ):
41
+ if (
42
+ intermediate_output_index is not None
43
+ and intermediate_output_index < 0
44
+ ):
45
+ intermediate_output_index += num_layers
46
+
47
+ # === Layers ===
48
+ self.embedding = TokenAndPositionEmbedding(
49
+ vocabulary_size=vocabulary_size,
50
+ sequence_length=sequence_length,
51
+ embedding_dim=embedding_dim,
52
+ dtype=dtype,
53
+ name="embedding",
54
+ )
55
+ self.encoder_layers = [
56
+ CLIPEncoderBlock(
57
+ hidden_dim,
58
+ num_heads,
59
+ intermediate_dim,
60
+ intermediate_activation,
61
+ dtype=dtype,
62
+ )
63
+ for _ in range(num_layers)
64
+ ]
65
+ self.layer_norm = layers.LayerNormalization(
66
+ epsilon=0.00001, dtype=dtype, name="layer_norm"
67
+ )
68
+ self.text_projection = layers.Dense(
69
+ hidden_dim,
70
+ use_bias=False,
71
+ dtype=dtype,
72
+ name="text_projection",
73
+ )
74
+
75
+ # === Functional Model ===
76
+ encoder_token_ids = layers.Input(
77
+ shape=(sequence_length,), dtype="int32", name="encoder_token_ids"
78
+ )
79
+ x = self.embedding(encoder_token_ids)
80
+ encoder_intermediate_output = None
81
+ # Encoder.
82
+ for i, block in enumerate(self.encoder_layers):
83
+ x = block(x)
84
+ if i == intermediate_output_index:
85
+ encoder_intermediate_output = x
86
+ x = self.layer_norm(x)
87
+ encoder_output = x
88
+ if encoder_intermediate_output is not None:
89
+ encoder_intermediate_output = self.layer_norm(
90
+ encoder_intermediate_output
91
+ )
92
+ # Projection.
93
+ indices = ops.expand_dims(
94
+ ops.cast(ops.argmax(encoder_token_ids, axis=-1), "int32"), axis=-1
95
+ )
96
+ pooled_output = ops.take_along_axis(x, indices[:, :, None], axis=1)
97
+ pooled_output = ops.squeeze(pooled_output, axis=1)
98
+ projection_output = self.text_projection(pooled_output)
99
+
100
+ outputs = {
101
+ "encoder_sequence_output": encoder_output,
102
+ "encoder_pooled_output": pooled_output,
103
+ "encoder_projection_output": projection_output,
104
+ }
105
+ if intermediate_output_index is not None:
106
+ outputs["encoder_intermediate_output"] = encoder_intermediate_output
107
+
108
+ super().__init__(
109
+ inputs={"encoder_token_ids": encoder_token_ids},
110
+ outputs=outputs,
111
+ **kwargs,
112
+ )
113
+
114
+ # === Config ===
115
+ self.embedding_dim = embedding_dim
116
+ self.hidden_dim = hidden_dim
117
+ self.num_layers = num_layers
118
+ self.num_heads = num_heads
119
+ self.intermediate_dim = intermediate_dim
120
+ self.intermediate_activation = intermediate_activation
121
+ self.intermediate_output_index = intermediate_output_index
122
+ self.vocabulary_size = vocabulary_size
123
+ self.sequence_length = sequence_length
124
+
125
+ if dtype is not None:
126
+ try:
127
+ self.dtype_policy = keras.dtype_policies.get(dtype)
128
+ # Before Keras 3.2, there is no `keras.dtype_policies.get`.
129
+ except AttributeError:
130
+ if isinstance(dtype, keras.DTypePolicy):
131
+ dtype = dtype.name
132
+ self.dtype_policy = keras.DTypePolicy(dtype)
133
+
134
+ def get_config(self):
135
+ config = super().get_config()
136
+ config.update(
137
+ {
138
+ "embedding_dim": self.embedding_dim,
139
+ "hidden_dim": self.hidden_dim,
140
+ "num_layers": self.num_layers,
141
+ "num_heads": self.num_heads,
142
+ "intermediate_dim": self.intermediate_dim,
143
+ "intermediate_activation": self.intermediate_activation,
144
+ "intermediate_output_index": self.intermediate_output_index,
145
+ "vocabulary_size": self.vocabulary_size,
146
+ "sequence_length": self.sequence_length,
147
+ }
148
+ )
149
+ return config