keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.15.0.dev20240911134614__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. keras_hub/api/__init__.py +1 -0
  2. keras_hub/api/bounding_box/__init__.py +36 -0
  3. keras_hub/api/layers/__init__.py +14 -0
  4. keras_hub/api/models/__init__.py +75 -31
  5. keras_hub/api/tokenizers/__init__.py +30 -0
  6. keras_hub/src/bounding_box/__init__.py +13 -0
  7. keras_hub/src/bounding_box/converters.py +529 -0
  8. keras_hub/src/bounding_box/formats.py +162 -0
  9. keras_hub/src/bounding_box/iou.py +263 -0
  10. keras_hub/src/bounding_box/to_dense.py +95 -0
  11. keras_hub/src/bounding_box/to_ragged.py +99 -0
  12. keras_hub/src/bounding_box/utils.py +194 -0
  13. keras_hub/src/bounding_box/validate_format.py +99 -0
  14. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  15. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  16. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  17. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  18. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  19. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  20. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  21. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  22. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  23. keras_hub/src/models/albert/__init__.py +1 -2
  24. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  25. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +29 -10
  26. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  27. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  28. keras_hub/src/models/backbone.py +12 -34
  29. keras_hub/src/models/bart/__init__.py +1 -2
  30. keras_hub/src/models/bart/bart_preprocessor.py +6 -18
  31. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  32. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  33. keras_hub/src/models/bert/__init__.py +1 -5
  34. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  35. keras_hub/src/models/bert/bert_presets.py +1 -4
  36. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +12 -10
  37. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  38. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  39. keras_hub/src/models/bloom/__init__.py +1 -2
  40. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  41. keras_hub/src/models/bloom/bloom_preprocessor.py +5 -12
  42. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  43. keras_hub/src/models/causal_lm.py +10 -29
  44. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  45. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  46. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  47. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  48. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +11 -11
  49. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  50. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  51. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  52. keras_hub/src/models/distil_bert/__init__.py +1 -4
  53. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  54. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +12 -12
  55. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  56. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  57. keras_hub/src/models/efficientnet/__init__.py +13 -0
  58. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  59. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  60. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  61. keras_hub/src/models/electra/__init__.py +1 -2
  62. keras_hub/src/models/electra/electra_preprocessor.py +6 -5
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +10 -8
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_preprocessor.py +5 -12
  72. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  73. keras_hub/src/models/gemma/__init__.py +1 -2
  74. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  75. keras_hub/src/models/gemma/gemma_preprocessor.py +5 -12
  76. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  77. keras_hub/src/models/gpt2/__init__.py +1 -2
  78. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  79. keras_hub/src/models/gpt2/gpt2_preprocessor.py +5 -12
  80. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  82. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +5 -12
  83. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  84. keras_hub/src/models/image_classifier.py +0 -5
  85. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  86. keras_hub/src/models/llama/__init__.py +1 -2
  87. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  88. keras_hub/src/models/llama/llama_preprocessor.py +5 -12
  89. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  90. keras_hub/src/models/llama3/__init__.py +1 -2
  91. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  92. keras_hub/src/models/llama3/llama3_preprocessor.py +2 -0
  93. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  94. keras_hub/src/models/masked_lm.py +0 -2
  95. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  96. keras_hub/src/models/mistral/__init__.py +1 -2
  97. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  98. keras_hub/src/models/mistral/mistral_preprocessor.py +5 -12
  99. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  100. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  101. keras_hub/src/models/mobilenet/__init__.py +13 -0
  102. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  103. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  104. keras_hub/src/models/opt/__init__.py +1 -2
  105. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  106. keras_hub/src/models/opt/opt_preprocessor.py +5 -12
  107. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  108. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  109. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  110. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  111. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  112. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +10 -2
  113. keras_hub/src/models/phi3/__init__.py +1 -2
  114. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  115. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  116. keras_hub/src/models/phi3/phi3_preprocessor.py +5 -12
  117. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  118. keras_hub/src/models/preprocessor.py +76 -83
  119. keras_hub/src/models/resnet/__init__.py +6 -0
  120. keras_hub/src/models/resnet/resnet_backbone.py +387 -26
  121. keras_hub/src/models/resnet/resnet_image_classifier.py +7 -3
  122. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  123. keras_hub/src/models/resnet/resnet_image_converter.py +23 -0
  124. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  125. keras_hub/src/models/roberta/__init__.py +1 -2
  126. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  127. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +11 -11
  128. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  129. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  130. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  131. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  133. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  134. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  135. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  136. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  137. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  138. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  139. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  140. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  141. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  142. keras_hub/src/models/t5/__init__.py +1 -2
  143. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  144. keras_hub/src/models/task.py +71 -116
  145. keras_hub/src/models/{classifier.py → text_classifier.py} +8 -13
  146. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  147. keras_hub/src/models/whisper/__init__.py +1 -2
  148. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  149. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  150. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  151. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  152. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  154. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +11 -11
  155. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  156. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  157. keras_hub/src/tests/test_case.py +25 -0
  158. keras_hub/src/tokenizers/byte_pair_tokenizer.py +29 -17
  159. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +19 -7
  161. keras_hub/src/tokenizers/tokenizer.py +67 -32
  162. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  163. keras_hub/src/tokenizers/word_piece_tokenizer.py +33 -47
  164. keras_hub/src/utils/keras_utils.py +0 -50
  165. keras_hub/src/utils/preset_utils.py +238 -67
  166. keras_hub/src/utils/tensor_utils.py +187 -69
  167. keras_hub/src/utils/timm/convert_resnet.py +20 -16
  168. keras_hub/src/utils/timm/preset_loader.py +67 -0
  169. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  170. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  171. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  172. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  173. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  174. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  175. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  176. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  177. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  178. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  179. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  180. keras_hub/src/version_utils.py +1 -1
  181. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.15.0.dev20240911134614.dist-info}/METADATA +1 -2
  182. keras_hub_nightly-0.15.0.dev20240911134614.dist-info/RECORD +338 -0
  183. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.15.0.dev20240911134614.dist-info}/WHEEL +1 -1
  184. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  185. keras_hub/src/utils/timm/convert.py +0 -37
  186. keras_hub/src/utils/transformers/convert.py +0 -101
  187. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  188. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.15.0.dev20240911134614.dist-info}/top_level.txt +0 -0
@@ -20,18 +20,11 @@ from keras_hub.src.api_export import keras_hub_export
20
20
  from keras_hub.src.utils.keras_utils import assert_quantization_support
21
21
  from keras_hub.src.utils.preset_utils import CONFIG_FILE
22
22
  from keras_hub.src.utils.preset_utils import MODEL_WEIGHTS_FILE
23
- from keras_hub.src.utils.preset_utils import check_config_class
24
- from keras_hub.src.utils.preset_utils import check_format
25
- from keras_hub.src.utils.preset_utils import get_file
26
- from keras_hub.src.utils.preset_utils import jax_memory_cleanup
27
- from keras_hub.src.utils.preset_utils import list_presets
28
- from keras_hub.src.utils.preset_utils import list_subclasses
29
- from keras_hub.src.utils.preset_utils import load_serialized_object
23
+ from keras_hub.src.utils.preset_utils import builtin_presets
24
+ from keras_hub.src.utils.preset_utils import get_preset_loader
30
25
  from keras_hub.src.utils.preset_utils import save_metadata
31
26
  from keras_hub.src.utils.preset_utils import save_serialized_object
32
27
  from keras_hub.src.utils.python_utils import classproperty
33
- from keras_hub.src.utils.timm.convert import load_timm_backbone
34
- from keras_hub.src.utils.transformers.convert import load_transformers_backbone
35
28
 
36
29
 
37
30
  @keras_hub_export("keras_hub.models.Backbone")
@@ -147,11 +140,8 @@ class Backbone(keras.Model):
147
140
 
148
141
  @classproperty
149
142
  def presets(cls):
150
- """List built-in presets for a `Task` subclass."""
151
- presets = list_presets(cls)
152
- for subclass in list_subclasses(cls):
153
- presets.update(subclass.presets)
154
- return presets
143
+ """List built-in presets for a `Backbone` subclass."""
144
+ return builtin_presets(cls)
155
145
 
156
146
  @classmethod
157
147
  def from_preset(
@@ -166,7 +156,7 @@ class Backbone(keras.Model):
166
156
  to save and load a pre-trained model. The `preset` can be passed as a
167
157
  one of:
168
158
 
169
- 1. a built in preset identifier like `'bert_base_en'`
159
+ 1. a built-in preset identifier like `'bert_base_en'`
170
160
  2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'`
171
161
  3. a Hugging Face handle like `'hf://user/bert_base_en'`
172
162
  4. a path to a local preset directory like `'./bert_base_en'`
@@ -181,7 +171,7 @@ class Backbone(keras.Model):
181
171
  all built-in presets available on the class.
182
172
 
183
173
  Args:
184
- preset: string. A built in preset identifier, a Kaggle Models
174
+ preset: string. A built-in preset identifier, a Kaggle Models
185
175
  handle, a Hugging Face handle, or a path to a local directory.
186
176
  load_weights: bool. If `True`, the weights will be loaded into the
187
177
  model architecture. If `False`, the weights will be randomly
@@ -201,27 +191,15 @@ class Backbone(keras.Model):
201
191
  )
202
192
  ```
203
193
  """
204
- format = check_format(preset)
205
-
206
- if format == "transformers":
207
- return load_transformers_backbone(cls, preset, load_weights)
208
- elif format == "timm":
209
- return load_timm_backbone(cls, preset, load_weights, **kwargs)
210
-
211
- preset_cls = check_config_class(preset)
212
- if not issubclass(preset_cls, cls):
194
+ loader = get_preset_loader(preset)
195
+ backbone_cls = loader.check_backbone_class()
196
+ if not issubclass(backbone_cls, cls):
213
197
  raise ValueError(
214
- f"Preset has type `{preset_cls.__name__}` which is not a "
198
+ f"Saved preset has type `{backbone_cls.__name__}` which is not "
215
199
  f"a subclass of calling class `{cls.__name__}`. Call "
216
- f"`from_preset` directly on `{preset_cls.__name__}` instead."
200
+ f"`from_preset` directly on `{backbone_cls.__name__}` instead."
217
201
  )
218
-
219
- backbone = load_serialized_object(preset, CONFIG_FILE, **kwargs)
220
- if load_weights:
221
- jax_memory_cleanup(backbone)
222
- backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
223
-
224
- return backbone
202
+ return loader.load_backbone(backbone_cls, load_weights, **kwargs)
225
203
 
226
204
  def save_to_preset(self, preset_dir):
227
205
  """Save backbone to a preset directory.
@@ -14,7 +14,6 @@
14
14
 
15
15
  from keras_hub.src.models.bart.bart_backbone import BartBackbone
16
16
  from keras_hub.src.models.bart.bart_presets import backbone_presets
17
- from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer
18
17
  from keras_hub.src.utils.preset_utils import register_presets
19
18
 
20
- register_presets(backbone_presets, (BartBackbone, BartTokenizer))
19
+ register_presets(backbone_presets, BartBackbone)
@@ -17,11 +17,10 @@ import keras
17
17
 
18
18
  from keras_hub.src.api_export import keras_hub_export
19
19
  from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
20
+ from keras_hub.src.models.bart.bart_backbone import BartBackbone
20
21
  from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer
21
22
  from keras_hub.src.models.preprocessor import Preprocessor
22
- from keras_hub.src.utils.keras_utils import (
23
- convert_inputs_to_list_of_tensor_segments,
24
- )
23
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
25
24
 
26
25
 
27
26
  @keras_hub_export("keras_hub.models.BartPreprocessor")
@@ -129,6 +128,7 @@ class BartPreprocessor(Preprocessor):
129
128
  ```
130
129
  """
131
130
 
131
+ backbone_cls = BartBackbone
132
132
  tokenizer_cls = BartTokenizer
133
133
 
134
134
  def __init__(
@@ -174,6 +174,7 @@ class BartPreprocessor(Preprocessor):
174
174
  )
175
175
  self.built = True
176
176
 
177
+ @preprocessing_function
177
178
  def call(
178
179
  self,
179
180
  x,
@@ -200,26 +201,13 @@ class BartPreprocessor(Preprocessor):
200
201
  if decoder_sequence_length is None:
201
202
  decoder_sequence_length = self.decoder_sequence_length
202
203
 
203
- encoder_text = x["encoder_text"]
204
- decoder_text = x["decoder_text"]
205
-
206
- encoder_text = convert_inputs_to_list_of_tensor_segments(encoder_text)
207
- decoder_text = convert_inputs_to_list_of_tensor_segments(decoder_text)
208
-
209
- if len(encoder_text) > 1 or len(decoder_text) > 1:
210
- raise ValueError(
211
- '`BARTPreprocessor` requires both `"encoder_text"` and '
212
- f'`"decoder_text"` to contain only one segment, but received '
213
- f"{len(encoder_text)} and {len(decoder_text)}, respectively."
214
- )
215
-
216
- encoder_inputs = self.tokenizer(encoder_text[0])
204
+ encoder_inputs = self.tokenizer(x["encoder_text"])
217
205
  encoder_token_ids, encoder_padding_mask = self.encoder_packer(
218
206
  encoder_inputs,
219
207
  sequence_length=encoder_sequence_length,
220
208
  )
221
209
 
222
- decoder_inputs = self.tokenizer(decoder_text[0])
210
+ decoder_inputs = self.tokenizer(x["decoder_text"])
223
211
  decoder_token_ids, decoder_padding_mask = self.decoder_packer(
224
212
  decoder_inputs,
225
213
  sequence_length=decoder_sequence_length,
@@ -13,24 +13,15 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- import keras
17
- from absl import logging
18
-
19
16
  from keras_hub.src.api_export import keras_hub_export
20
- from keras_hub.src.models.bart.bart_preprocessor import BartPreprocessor
21
- from keras_hub.src.utils.keras_utils import (
22
- convert_inputs_to_list_of_tensor_segments,
23
- )
24
- from keras_hub.src.utils.tensor_utils import strip_to_ragged
25
-
26
- try:
27
- import tensorflow as tf
28
- except ImportError:
29
- tf = None
17
+ from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
18
+ from keras_hub.src.models.bart.bart_backbone import BartBackbone
19
+ from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer
20
+ from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
30
21
 
31
22
 
32
23
  @keras_hub_export("keras_hub.models.BartSeq2SeqLMPreprocessor")
33
- class BartSeq2SeqLMPreprocessor(BartPreprocessor):
24
+ class BartSeq2SeqLMPreprocessor(Seq2SeqLMPreprocessor):
34
25
  """BART Seq2Seq LM preprocessor.
35
26
 
36
27
  This layer is used as preprocessor for seq2seq tasks using the BART model.
@@ -125,138 +116,20 @@ class BartSeq2SeqLMPreprocessor(BartPreprocessor):
125
116
  ```
126
117
  """
127
118
 
128
- def call(
129
- self,
130
- x,
131
- y=None,
132
- sample_weight=None,
133
- *,
134
- encoder_sequence_length=None,
135
- decoder_sequence_length=None,
136
- # `sequence_length` is an alias for `decoder_sequence_length`
137
- sequence_length=None,
138
- ):
139
- if y is not None or sample_weight is not None:
140
- logging.warning(
141
- "`BartSeq2SeqLMPreprocessor` infers `y` and `sample_weight` "
142
- "from the provided input data, i.e., `x`. However, non-`None`"
143
- "values have been passed for `y` or `sample_weight` or both. "
144
- "These values will be ignored."
145
- )
146
-
147
- if encoder_sequence_length is None:
148
- encoder_sequence_length = self.encoder_sequence_length
149
- decoder_sequence_length = decoder_sequence_length or sequence_length
150
- if decoder_sequence_length is None:
151
- decoder_sequence_length = self.decoder_sequence_length
152
-
153
- x = super().call(
154
- x,
155
- encoder_sequence_length=encoder_sequence_length,
156
- decoder_sequence_length=decoder_sequence_length + 1,
157
- )
158
- decoder_token_ids = x.pop("decoder_token_ids")
159
- decoder_padding_mask = x.pop("decoder_padding_mask")
160
-
161
- # The last token does not have a next token. Hence, we truncate it.
162
- x = {
163
- **x,
164
- "decoder_token_ids": decoder_token_ids[..., :-1],
165
- "decoder_padding_mask": decoder_padding_mask[..., :-1],
166
- }
167
- # Target `y` will be the decoder input sequence shifted one step to the
168
- # left (i.e., the next token).
169
- y = decoder_token_ids[..., 1:]
170
- sample_weight = decoder_padding_mask[..., 1:]
171
- return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
172
-
173
- def generate_preprocess(
174
- self,
175
- x,
176
- *,
177
- encoder_sequence_length=None,
178
- # `sequence_length` is an alias for `decoder_sequence_length`
179
- decoder_sequence_length=None,
180
- sequence_length=None,
181
- ):
182
- """Convert encoder and decoder input strings to integer token inputs for generation.
183
-
184
- Similar to calling the layer for training, this method takes in a dict
185
- containing `"encoder_text"` and `"decoder_text"`, with strings or tensor
186
- strings for values, tokenizes and packs the input, and computes a
187
- padding mask masking all inputs not filled in with a padded value.
188
-
189
- Unlike calling the layer for training, this method does not compute
190
- labels and will never append a tokenizer.end_token_id to the end of
191
- the decoder sequence (as generation is expected to continue at the end
192
- of the inputted decoder prompt).
193
- """
194
- if not self.built:
195
- self.build(None)
196
-
197
- if isinstance(x, dict):
198
- encoder_text = x["encoder_text"]
199
- decoder_text = x["decoder_text"]
200
- else:
201
- encoder_text = x
202
- # Initialize empty prompt for the decoder.
203
- decoder_text = tf.fill((tf.shape(encoder_text)[0],), "")
204
-
205
- if encoder_sequence_length is None:
206
- encoder_sequence_length = self.encoder_sequence_length
207
- decoder_sequence_length = decoder_sequence_length or sequence_length
208
- if decoder_sequence_length is None:
209
- decoder_sequence_length = self.decoder_sequence_length
210
-
211
- # Tokenize and pack the encoder inputs.
212
- # TODO: Remove `[0]` once we have shifted to `MultiSegmentPacker`.
213
- encoder_text = convert_inputs_to_list_of_tensor_segments(encoder_text)[
214
- 0
215
- ]
216
- encoder_token_ids = self.tokenizer(encoder_text)
217
- encoder_token_ids, encoder_padding_mask = self.encoder_packer(
218
- encoder_token_ids,
219
- sequence_length=encoder_sequence_length,
220
- )
221
-
222
- # Tokenize and pack the decoder inputs.
223
- decoder_text = convert_inputs_to_list_of_tensor_segments(decoder_text)[
224
- 0
225
- ]
226
- decoder_token_ids = self.tokenizer(decoder_text)
227
- decoder_token_ids, decoder_padding_mask = self.decoder_packer(
228
- decoder_token_ids,
229
- sequence_length=decoder_sequence_length,
230
- add_end_value=False,
231
- )
232
-
233
- return {
234
- "encoder_token_ids": encoder_token_ids,
235
- "encoder_padding_mask": encoder_padding_mask,
236
- "decoder_token_ids": decoder_token_ids,
237
- "decoder_padding_mask": decoder_padding_mask,
238
- }
239
-
240
- def generate_postprocess(
241
- self,
242
- x,
243
- ):
244
- """Convert integer token output to strings for generation.
245
-
246
- This method reverses `generate_preprocess()`, by first removing all
247
- padding and start/end tokens, and then converting the integer sequence
248
- back to a string.
249
- """
250
- if not self.built:
251
- self.build(None)
252
-
253
- token_ids, padding_mask = (
254
- x["decoder_token_ids"],
255
- x["decoder_padding_mask"],
256
- )
257
- ids_to_strip = (
258
- self.tokenizer.start_token_id,
259
- self.tokenizer.end_token_id,
119
+ backbone_cls = BartBackbone
120
+ tokenizer_cls = BartTokenizer
121
+
122
+ def build(self, input_shape):
123
+ super().build(input_shape)
124
+ # The decoder is packed a bit differently; the format is as follows:
125
+ # `[end_token_id, start_token_id, tokens..., end_token_id, padding...]`.
126
+ self.decoder_packer = StartEndPacker(
127
+ start_value=[
128
+ self.tokenizer.end_token_id,
129
+ self.tokenizer.start_token_id,
130
+ ],
131
+ end_value=self.tokenizer.end_token_id,
132
+ pad_value=self.tokenizer.pad_token_id,
133
+ sequence_length=self.decoder_sequence_length,
134
+ return_padding_mask=True,
260
135
  )
261
- token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
262
- return self.tokenizer.detokenize(token_ids)
@@ -14,10 +14,16 @@
14
14
 
15
15
 
16
16
  from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.bart.bart_backbone import BartBackbone
17
18
  from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
18
19
 
19
20
 
20
- @keras_hub_export("keras_hub.models.BartTokenizer")
21
+ @keras_hub_export(
22
+ [
23
+ "keras_hub.tokenizers.BartTokenizer",
24
+ "keras_hub.models.BartTokenizer",
25
+ ]
26
+ )
21
27
  class BartTokenizer(BytePairTokenizer):
22
28
  """A BART tokenizer using Byte-Pair Encoding subword segmentation.
23
29
 
@@ -73,52 +79,19 @@ class BartTokenizer(BytePairTokenizer):
73
79
  ```
74
80
  """
75
81
 
82
+ backbone_cls = BartBackbone
83
+
76
84
  def __init__(
77
85
  self,
78
86
  vocabulary=None,
79
87
  merges=None,
80
88
  **kwargs,
81
89
  ):
82
- self.start_token = "<s>"
83
- self.pad_token = "<pad>"
84
- self.end_token = "</s>"
85
-
90
+ self._add_special_token("<s>", "start_token")
91
+ self._add_special_token("</s>", "end_token")
92
+ self._add_special_token("<pad>", "pad_token")
86
93
  super().__init__(
87
94
  vocabulary=vocabulary,
88
95
  merges=merges,
89
- unsplittable_tokens=[
90
- self.start_token,
91
- self.pad_token,
92
- self.end_token,
93
- ],
94
96
  **kwargs,
95
97
  )
96
-
97
- def set_vocabulary_and_merges(self, vocabulary, merges):
98
- super().set_vocabulary_and_merges(vocabulary, merges)
99
-
100
- if vocabulary is not None:
101
- # Check for necessary special tokens.
102
- for token in [self.start_token, self.pad_token, self.end_token]:
103
- if token not in self.vocabulary:
104
- raise ValueError(
105
- f"Cannot find token `'{token}'` in the provided "
106
- f"`vocabulary`. Please provide `'{token}'` in your "
107
- "`vocabulary` or use a pretrained `vocabulary` name."
108
- )
109
-
110
- self.start_token_id = self.token_to_id(self.start_token)
111
- self.pad_token_id = self.token_to_id(self.pad_token)
112
- self.end_token_id = self.token_to_id(self.end_token)
113
- else:
114
- self.start_token_id = None
115
- self.pad_token_id = None
116
- self.end_token_id = None
117
-
118
- def get_config(self):
119
- config = super().get_config()
120
- # In the constructor, we pass the list of special tokens to the
121
- # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
122
- # delete it from the config here.
123
- del config["unsplittable_tokens"]
124
- return config
@@ -13,11 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from keras_hub.src.models.bert.bert_backbone import BertBackbone
16
- from keras_hub.src.models.bert.bert_classifier import BertClassifier
17
16
  from keras_hub.src.models.bert.bert_presets import backbone_presets
18
- from keras_hub.src.models.bert.bert_presets import classifier_presets
19
- from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
20
17
  from keras_hub.src.utils.preset_utils import register_presets
21
18
 
22
- register_presets(backbone_presets, (BertBackbone, BertTokenizer))
23
- register_presets(classifier_presets, (BertClassifier, BertTokenizer))
19
+ register_presets(backbone_presets, BertBackbone)
@@ -12,18 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import keras
16
- from absl import logging
17
-
18
15
  from keras_hub.src.api_export import keras_hub_export
19
- from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import (
20
- MaskedLMMaskGenerator,
21
- )
22
- from keras_hub.src.models.bert.bert_preprocessor import BertPreprocessor
16
+ from keras_hub.src.models.bert.bert_backbone import BertBackbone
17
+ from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
18
+ from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor
23
19
 
24
20
 
25
21
  @keras_hub_export("keras_hub.models.BertMaskedLMPreprocessor")
26
- class BertMaskedLMPreprocessor(BertPreprocessor):
22
+ class BertMaskedLMPreprocessor(MaskedLMPreprocessor):
27
23
  """BERT preprocessing for the masked language modeling task.
28
24
 
29
25
  This preprocessing layer will prepare inputs for a masked language modeling
@@ -117,82 +113,5 @@ class BertMaskedLMPreprocessor(BertPreprocessor):
117
113
  ```
118
114
  """
119
115
 
120
- def __init__(
121
- self,
122
- tokenizer,
123
- sequence_length=512,
124
- truncate="round_robin",
125
- mask_selection_rate=0.15,
126
- mask_selection_length=96,
127
- mask_token_rate=0.8,
128
- random_token_rate=0.1,
129
- **kwargs,
130
- ):
131
- super().__init__(
132
- tokenizer,
133
- sequence_length=sequence_length,
134
- truncate=truncate,
135
- **kwargs,
136
- )
137
- self.mask_selection_rate = mask_selection_rate
138
- self.mask_selection_length = mask_selection_length
139
- self.mask_token_rate = mask_token_rate
140
- self.random_token_rate = random_token_rate
141
- self.masker = None
142
-
143
- def build(self, input_shape):
144
- super().build(input_shape)
145
- # Defer masker creation to `build()` so that we can be sure tokenizer
146
- # assets have loaded when restoring a saved model.
147
- self.masker = MaskedLMMaskGenerator(
148
- mask_selection_rate=self.mask_selection_rate,
149
- mask_selection_length=self.mask_selection_length,
150
- mask_token_rate=self.mask_token_rate,
151
- random_token_rate=self.random_token_rate,
152
- vocabulary_size=self.tokenizer.vocabulary_size(),
153
- mask_token_id=self.tokenizer.mask_token_id,
154
- unselectable_token_ids=[
155
- self.tokenizer.cls_token_id,
156
- self.tokenizer.sep_token_id,
157
- self.tokenizer.pad_token_id,
158
- ],
159
- )
160
-
161
- def call(self, x, y=None, sample_weight=None):
162
- if y is not None or sample_weight is not None:
163
- logging.warning(
164
- f"{self.__class__.__name__} generates `y` and `sample_weight` "
165
- "based on your input data, but your data already contains `y` "
166
- "or `sample_weight`. Your `y` and `sample_weight` will be "
167
- "ignored."
168
- )
169
-
170
- x = super().call(x)
171
-
172
- token_ids, padding_mask, segment_ids = (
173
- x["token_ids"],
174
- x["padding_mask"],
175
- x["segment_ids"],
176
- )
177
- masker_outputs = self.masker(token_ids)
178
- x = {
179
- "token_ids": masker_outputs["token_ids"],
180
- "padding_mask": padding_mask,
181
- "segment_ids": segment_ids,
182
- "mask_positions": masker_outputs["mask_positions"],
183
- }
184
- y = masker_outputs["mask_ids"]
185
- sample_weight = masker_outputs["mask_weights"]
186
- return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
187
-
188
- def get_config(self):
189
- config = super().get_config()
190
- config.update(
191
- {
192
- "mask_selection_rate": self.mask_selection_rate,
193
- "mask_selection_length": self.mask_selection_length,
194
- "mask_token_rate": self.mask_token_rate,
195
- "random_token_rate": self.random_token_rate,
196
- }
197
- )
198
- return config
116
+ backbone_cls = BertBackbone
117
+ tokenizer_cls = BertTokenizer
@@ -129,9 +129,6 @@ backbone_presets = {
129
129
  },
130
130
  "kaggle_handle": "kaggle://keras/bert/keras/bert_large_en/2",
131
131
  },
132
- }
133
-
134
- classifier_presets = {
135
132
  "bert_tiny_en_uncased_sst2": {
136
133
  "metadata": {
137
134
  "description": (
@@ -143,5 +140,5 @@ classifier_presets = {
143
140
  "model_card": "https://github.com/google-research/bert/blob/master/README.md",
144
141
  },
145
142
  "kaggle_handle": "kaggle://keras/bert/keras/bert_tiny_en_uncased_sst2/4",
146
- }
143
+ },
147
144
  }
@@ -17,12 +17,14 @@ import keras
17
17
  from keras_hub.src.api_export import keras_hub_export
18
18
  from keras_hub.src.models.bert.bert_backbone import BertBackbone
19
19
  from keras_hub.src.models.bert.bert_backbone import bert_kernel_initializer
20
- from keras_hub.src.models.bert.bert_preprocessor import BertPreprocessor
21
- from keras_hub.src.models.classifier import Classifier
20
+ from keras_hub.src.models.bert.bert_text_classifier_preprocessor import (
21
+ BertTextClassifierPreprocessor,
22
+ )
23
+ from keras_hub.src.models.text_classifier import TextClassifier
22
24
 
23
25
 
24
- @keras_hub_export("keras_hub.models.BertClassifier")
25
- class BertClassifier(Classifier):
26
+ @keras_hub_export("keras_hub.models.BertTextClassifier")
27
+ class BertTextClassifier(TextClassifier):
26
28
  """An end-to-end BERT model for classification tasks.
27
29
 
28
30
  This model attaches a classification head to a
@@ -41,7 +43,7 @@ class BertClassifier(Classifier):
41
43
  Args:
42
44
  backbone: A `keras_hub.models.BertBackbone` instance.
43
45
  num_classes: int. Number of classes to predict.
44
- preprocessor: A `keras_hub.models.BertPreprocessor` or `None`. If
46
+ preprocessor: A `keras_hub.models.BertTextClassifierPreprocessor` or `None`. If
45
47
  `None`, this model will not apply preprocessing, and inputs should
46
48
  be preprocessed before calling the model.
47
49
  activation: Optional `str` or callable. The
@@ -59,7 +61,7 @@ class BertClassifier(Classifier):
59
61
  labels = [0, 3]
60
62
 
61
63
  # Pretrained classifier.
62
- classifier = keras_hub.models.BertClassifier.from_preset(
64
+ classifier = keras_hub.models.BertTextClassifier.from_preset(
63
65
  "bert_base_en_uncased",
64
66
  num_classes=4,
65
67
  )
@@ -88,7 +90,7 @@ class BertClassifier(Classifier):
88
90
  labels = [0, 3]
89
91
 
90
92
  # Pretrained classifier without preprocessing.
91
- classifier = keras_hub.models.BertClassifier.from_preset(
93
+ classifier = keras_hub.models.BertTextClassifier.from_preset(
92
94
  "bert_base_en_uncased",
93
95
  num_classes=4,
94
96
  preprocessor=None,
@@ -106,7 +108,7 @@ class BertClassifier(Classifier):
106
108
  tokenizer = keras_hub.models.BertTokenizer(
107
109
  vocabulary=vocab,
108
110
  )
109
- preprocessor = keras_hub.models.BertPreprocessor(
111
+ preprocessor = keras_hub.models.BertTextClassifierPreprocessor(
110
112
  tokenizer=tokenizer,
111
113
  sequence_length=128,
112
114
  )
@@ -118,7 +120,7 @@ class BertClassifier(Classifier):
118
120
  intermediate_dim=512,
119
121
  max_sequence_length=128,
120
122
  )
121
- classifier = keras_hub.models.BertClassifier(
123
+ classifier = keras_hub.models.BertTextClassifier(
122
124
  backbone=backbone,
123
125
  preprocessor=preprocessor,
124
126
  num_classes=4,
@@ -128,7 +130,7 @@ class BertClassifier(Classifier):
128
130
  """
129
131
 
130
132
  backbone_cls = BertBackbone
131
- preprocessor_cls = BertPreprocessor
133
+ preprocessor_cls = BertTextClassifierPreprocessor
132
134
 
133
135
  def __init__(
134
136
  self,