keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -20,18 +20,11 @@ from keras_hub.src.api_export import keras_hub_export
20
20
  from keras_hub.src.utils.keras_utils import assert_quantization_support
21
21
  from keras_hub.src.utils.preset_utils import CONFIG_FILE
22
22
  from keras_hub.src.utils.preset_utils import MODEL_WEIGHTS_FILE
23
- from keras_hub.src.utils.preset_utils import check_config_class
24
- from keras_hub.src.utils.preset_utils import check_format
25
- from keras_hub.src.utils.preset_utils import get_file
26
- from keras_hub.src.utils.preset_utils import jax_memory_cleanup
27
- from keras_hub.src.utils.preset_utils import list_presets
28
- from keras_hub.src.utils.preset_utils import list_subclasses
29
- from keras_hub.src.utils.preset_utils import load_serialized_object
23
+ from keras_hub.src.utils.preset_utils import builtin_presets
24
+ from keras_hub.src.utils.preset_utils import get_preset_loader
30
25
  from keras_hub.src.utils.preset_utils import save_metadata
31
26
  from keras_hub.src.utils.preset_utils import save_serialized_object
32
27
  from keras_hub.src.utils.python_utils import classproperty
33
- from keras_hub.src.utils.timm.convert import load_timm_backbone
34
- from keras_hub.src.utils.transformers.convert import load_transformers_backbone
35
28
 
36
29
 
37
30
  @keras_hub_export("keras_hub.models.Backbone")
@@ -147,11 +140,8 @@ class Backbone(keras.Model):
147
140
 
148
141
  @classproperty
149
142
  def presets(cls):
150
- """List built-in presets for a `Task` subclass."""
151
- presets = list_presets(cls)
152
- for subclass in list_subclasses(cls):
153
- presets.update(subclass.presets)
154
- return presets
143
+ """List built-in presets for a `Backbone` subclass."""
144
+ return builtin_presets(cls)
155
145
 
156
146
  @classmethod
157
147
  def from_preset(
@@ -166,7 +156,7 @@ class Backbone(keras.Model):
166
156
  to save and load a pre-trained model. The `preset` can be passed as a
167
157
  one of:
168
158
 
169
- 1. a built in preset identifier like `'bert_base_en'`
159
+ 1. a built-in preset identifier like `'bert_base_en'`
170
160
  2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'`
171
161
  3. a Hugging Face handle like `'hf://user/bert_base_en'`
172
162
  4. a path to a local preset directory like `'./bert_base_en'`
@@ -181,7 +171,7 @@ class Backbone(keras.Model):
181
171
  all built-in presets available on the class.
182
172
 
183
173
  Args:
184
- preset: string. A built in preset identifier, a Kaggle Models
174
+ preset: string. A built-in preset identifier, a Kaggle Models
185
175
  handle, a Hugging Face handle, or a path to a local directory.
186
176
  load_weights: bool. If `True`, the weights will be loaded into the
187
177
  model architecture. If `False`, the weights will be randomly
@@ -201,27 +191,15 @@ class Backbone(keras.Model):
201
191
  )
202
192
  ```
203
193
  """
204
- format = check_format(preset)
205
-
206
- if format == "transformers":
207
- return load_transformers_backbone(cls, preset, load_weights)
208
- elif format == "timm":
209
- return load_timm_backbone(cls, preset, load_weights, **kwargs)
210
-
211
- preset_cls = check_config_class(preset)
212
- if not issubclass(preset_cls, cls):
194
+ loader = get_preset_loader(preset)
195
+ backbone_cls = loader.check_backbone_class()
196
+ if not issubclass(backbone_cls, cls):
213
197
  raise ValueError(
214
- f"Preset has type `{preset_cls.__name__}` which is not a "
198
+ f"Saved preset has type `{backbone_cls.__name__}` which is not "
215
199
  f"a subclass of calling class `{cls.__name__}`. Call "
216
- f"`from_preset` directly on `{preset_cls.__name__}` instead."
200
+ f"`from_preset` directly on `{backbone_cls.__name__}` instead."
217
201
  )
218
-
219
- backbone = load_serialized_object(preset, CONFIG_FILE, **kwargs)
220
- if load_weights:
221
- jax_memory_cleanup(backbone)
222
- backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
223
-
224
- return backbone
202
+ return loader.load_backbone(backbone_cls, load_weights, **kwargs)
225
203
 
226
204
  def save_to_preset(self, preset_dir):
227
205
  """Save backbone to a preset directory.
@@ -14,7 +14,6 @@
14
14
 
15
15
  from keras_hub.src.models.bart.bart_backbone import BartBackbone
16
16
  from keras_hub.src.models.bart.bart_presets import backbone_presets
17
- from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer
18
17
  from keras_hub.src.utils.preset_utils import register_presets
19
18
 
20
- register_presets(backbone_presets, (BartBackbone, BartTokenizer))
19
+ register_presets(backbone_presets, BartBackbone)
@@ -13,24 +13,15 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- import keras
17
- from absl import logging
18
-
19
16
  from keras_hub.src.api_export import keras_hub_export
20
- from keras_hub.src.models.bart.bart_preprocessor import BartPreprocessor
21
- from keras_hub.src.utils.keras_utils import (
22
- convert_inputs_to_list_of_tensor_segments,
23
- )
24
- from keras_hub.src.utils.tensor_utils import strip_to_ragged
25
-
26
- try:
27
- import tensorflow as tf
28
- except ImportError:
29
- tf = None
17
+ from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
18
+ from keras_hub.src.models.bart.bart_backbone import BartBackbone
19
+ from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer
20
+ from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
30
21
 
31
22
 
32
23
  @keras_hub_export("keras_hub.models.BartSeq2SeqLMPreprocessor")
33
- class BartSeq2SeqLMPreprocessor(BartPreprocessor):
24
+ class BartSeq2SeqLMPreprocessor(Seq2SeqLMPreprocessor):
34
25
  """BART Seq2Seq LM preprocessor.
35
26
 
36
27
  This layer is used as preprocessor for seq2seq tasks using the BART model.
@@ -125,138 +116,20 @@ class BartSeq2SeqLMPreprocessor(BartPreprocessor):
125
116
  ```
126
117
  """
127
118
 
128
- def call(
129
- self,
130
- x,
131
- y=None,
132
- sample_weight=None,
133
- *,
134
- encoder_sequence_length=None,
135
- decoder_sequence_length=None,
136
- # `sequence_length` is an alias for `decoder_sequence_length`
137
- sequence_length=None,
138
- ):
139
- if y is not None or sample_weight is not None:
140
- logging.warning(
141
- "`BartSeq2SeqLMPreprocessor` infers `y` and `sample_weight` "
142
- "from the provided input data, i.e., `x`. However, non-`None`"
143
- "values have been passed for `y` or `sample_weight` or both. "
144
- "These values will be ignored."
145
- )
146
-
147
- if encoder_sequence_length is None:
148
- encoder_sequence_length = self.encoder_sequence_length
149
- decoder_sequence_length = decoder_sequence_length or sequence_length
150
- if decoder_sequence_length is None:
151
- decoder_sequence_length = self.decoder_sequence_length
152
-
153
- x = super().call(
154
- x,
155
- encoder_sequence_length=encoder_sequence_length,
156
- decoder_sequence_length=decoder_sequence_length + 1,
157
- )
158
- decoder_token_ids = x.pop("decoder_token_ids")
159
- decoder_padding_mask = x.pop("decoder_padding_mask")
160
-
161
- # The last token does not have a next token. Hence, we truncate it.
162
- x = {
163
- **x,
164
- "decoder_token_ids": decoder_token_ids[..., :-1],
165
- "decoder_padding_mask": decoder_padding_mask[..., :-1],
166
- }
167
- # Target `y` will be the decoder input sequence shifted one step to the
168
- # left (i.e., the next token).
169
- y = decoder_token_ids[..., 1:]
170
- sample_weight = decoder_padding_mask[..., 1:]
171
- return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
172
-
173
- def generate_preprocess(
174
- self,
175
- x,
176
- *,
177
- encoder_sequence_length=None,
178
- # `sequence_length` is an alias for `decoder_sequence_length`
179
- decoder_sequence_length=None,
180
- sequence_length=None,
181
- ):
182
- """Convert encoder and decoder input strings to integer token inputs for generation.
183
-
184
- Similar to calling the layer for training, this method takes in a dict
185
- containing `"encoder_text"` and `"decoder_text"`, with strings or tensor
186
- strings for values, tokenizes and packs the input, and computes a
187
- padding mask masking all inputs not filled in with a padded value.
188
-
189
- Unlike calling the layer for training, this method does not compute
190
- labels and will never append a tokenizer.end_token_id to the end of
191
- the decoder sequence (as generation is expected to continue at the end
192
- of the inputted decoder prompt).
193
- """
194
- if not self.built:
195
- self.build(None)
196
-
197
- if isinstance(x, dict):
198
- encoder_text = x["encoder_text"]
199
- decoder_text = x["decoder_text"]
200
- else:
201
- encoder_text = x
202
- # Initialize empty prompt for the decoder.
203
- decoder_text = tf.fill((tf.shape(encoder_text)[0],), "")
204
-
205
- if encoder_sequence_length is None:
206
- encoder_sequence_length = self.encoder_sequence_length
207
- decoder_sequence_length = decoder_sequence_length or sequence_length
208
- if decoder_sequence_length is None:
209
- decoder_sequence_length = self.decoder_sequence_length
210
-
211
- # Tokenize and pack the encoder inputs.
212
- # TODO: Remove `[0]` once we have shifted to `MultiSegmentPacker`.
213
- encoder_text = convert_inputs_to_list_of_tensor_segments(encoder_text)[
214
- 0
215
- ]
216
- encoder_token_ids = self.tokenizer(encoder_text)
217
- encoder_token_ids, encoder_padding_mask = self.encoder_packer(
218
- encoder_token_ids,
219
- sequence_length=encoder_sequence_length,
220
- )
221
-
222
- # Tokenize and pack the decoder inputs.
223
- decoder_text = convert_inputs_to_list_of_tensor_segments(decoder_text)[
224
- 0
225
- ]
226
- decoder_token_ids = self.tokenizer(decoder_text)
227
- decoder_token_ids, decoder_padding_mask = self.decoder_packer(
228
- decoder_token_ids,
229
- sequence_length=decoder_sequence_length,
230
- add_end_value=False,
231
- )
232
-
233
- return {
234
- "encoder_token_ids": encoder_token_ids,
235
- "encoder_padding_mask": encoder_padding_mask,
236
- "decoder_token_ids": decoder_token_ids,
237
- "decoder_padding_mask": decoder_padding_mask,
238
- }
239
-
240
- def generate_postprocess(
241
- self,
242
- x,
243
- ):
244
- """Convert integer token output to strings for generation.
245
-
246
- This method reverses `generate_preprocess()`, by first removing all
247
- padding and start/end tokens, and then converting the integer sequence
248
- back to a string.
249
- """
250
- if not self.built:
251
- self.build(None)
252
-
253
- token_ids, padding_mask = (
254
- x["decoder_token_ids"],
255
- x["decoder_padding_mask"],
256
- )
257
- ids_to_strip = (
258
- self.tokenizer.start_token_id,
259
- self.tokenizer.end_token_id,
119
+ backbone_cls = BartBackbone
120
+ tokenizer_cls = BartTokenizer
121
+
122
+ def build(self, input_shape):
123
+ super().build(input_shape)
124
+ # The decoder is packed a bit differently; the format is as follows:
125
+ # `[end_token_id, start_token_id, tokens..., end_token_id, padding...]`.
126
+ self.decoder_packer = StartEndPacker(
127
+ start_value=[
128
+ self.tokenizer.end_token_id,
129
+ self.tokenizer.start_token_id,
130
+ ],
131
+ end_value=self.tokenizer.end_token_id,
132
+ pad_value=self.tokenizer.pad_token_id,
133
+ sequence_length=self.decoder_sequence_length,
134
+ return_padding_mask=True,
260
135
  )
261
- token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
262
- return self.tokenizer.detokenize(token_ids)
@@ -14,10 +14,16 @@
14
14
 
15
15
 
16
16
  from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.bart.bart_backbone import BartBackbone
17
18
  from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
18
19
 
19
20
 
20
- @keras_hub_export("keras_hub.models.BartTokenizer")
21
+ @keras_hub_export(
22
+ [
23
+ "keras_hub.tokenizers.BartTokenizer",
24
+ "keras_hub.models.BartTokenizer",
25
+ ]
26
+ )
21
27
  class BartTokenizer(BytePairTokenizer):
22
28
  """A BART tokenizer using Byte-Pair Encoding subword segmentation.
23
29
 
@@ -73,52 +79,19 @@ class BartTokenizer(BytePairTokenizer):
73
79
  ```
74
80
  """
75
81
 
82
+ backbone_cls = BartBackbone
83
+
76
84
  def __init__(
77
85
  self,
78
86
  vocabulary=None,
79
87
  merges=None,
80
88
  **kwargs,
81
89
  ):
82
- self.start_token = "<s>"
83
- self.pad_token = "<pad>"
84
- self.end_token = "</s>"
85
-
90
+ self._add_special_token("<s>", "start_token")
91
+ self._add_special_token("</s>", "end_token")
92
+ self._add_special_token("<pad>", "pad_token")
86
93
  super().__init__(
87
94
  vocabulary=vocabulary,
88
95
  merges=merges,
89
- unsplittable_tokens=[
90
- self.start_token,
91
- self.pad_token,
92
- self.end_token,
93
- ],
94
96
  **kwargs,
95
97
  )
96
-
97
- def set_vocabulary_and_merges(self, vocabulary, merges):
98
- super().set_vocabulary_and_merges(vocabulary, merges)
99
-
100
- if vocabulary is not None:
101
- # Check for necessary special tokens.
102
- for token in [self.start_token, self.pad_token, self.end_token]:
103
- if token not in self.vocabulary:
104
- raise ValueError(
105
- f"Cannot find token `'{token}'` in the provided "
106
- f"`vocabulary`. Please provide `'{token}'` in your "
107
- "`vocabulary` or use a pretrained `vocabulary` name."
108
- )
109
-
110
- self.start_token_id = self.token_to_id(self.start_token)
111
- self.pad_token_id = self.token_to_id(self.pad_token)
112
- self.end_token_id = self.token_to_id(self.end_token)
113
- else:
114
- self.start_token_id = None
115
- self.pad_token_id = None
116
- self.end_token_id = None
117
-
118
- def get_config(self):
119
- config = super().get_config()
120
- # In the constructor, we pass the list of special tokens to the
121
- # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
122
- # delete it from the config here.
123
- del config["unsplittable_tokens"]
124
- return config
@@ -13,11 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from keras_hub.src.models.bert.bert_backbone import BertBackbone
16
- from keras_hub.src.models.bert.bert_classifier import BertClassifier
17
16
  from keras_hub.src.models.bert.bert_presets import backbone_presets
18
- from keras_hub.src.models.bert.bert_presets import classifier_presets
19
- from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
20
17
  from keras_hub.src.utils.preset_utils import register_presets
21
18
 
22
- register_presets(backbone_presets, (BertBackbone, BertTokenizer))
23
- register_presets(classifier_presets, (BertClassifier, BertTokenizer))
19
+ register_presets(backbone_presets, BertBackbone)
@@ -12,18 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import keras
16
- from absl import logging
17
-
18
15
  from keras_hub.src.api_export import keras_hub_export
19
- from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import (
20
- MaskedLMMaskGenerator,
21
- )
22
- from keras_hub.src.models.bert.bert_preprocessor import BertPreprocessor
16
+ from keras_hub.src.models.bert.bert_backbone import BertBackbone
17
+ from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
18
+ from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor
23
19
 
24
20
 
25
21
  @keras_hub_export("keras_hub.models.BertMaskedLMPreprocessor")
26
- class BertMaskedLMPreprocessor(BertPreprocessor):
22
+ class BertMaskedLMPreprocessor(MaskedLMPreprocessor):
27
23
  """BERT preprocessing for the masked language modeling task.
28
24
 
29
25
  This preprocessing layer will prepare inputs for a masked language modeling
@@ -117,82 +113,5 @@ class BertMaskedLMPreprocessor(BertPreprocessor):
117
113
  ```
118
114
  """
119
115
 
120
- def __init__(
121
- self,
122
- tokenizer,
123
- sequence_length=512,
124
- truncate="round_robin",
125
- mask_selection_rate=0.15,
126
- mask_selection_length=96,
127
- mask_token_rate=0.8,
128
- random_token_rate=0.1,
129
- **kwargs,
130
- ):
131
- super().__init__(
132
- tokenizer,
133
- sequence_length=sequence_length,
134
- truncate=truncate,
135
- **kwargs,
136
- )
137
- self.mask_selection_rate = mask_selection_rate
138
- self.mask_selection_length = mask_selection_length
139
- self.mask_token_rate = mask_token_rate
140
- self.random_token_rate = random_token_rate
141
- self.masker = None
142
-
143
- def build(self, input_shape):
144
- super().build(input_shape)
145
- # Defer masker creation to `build()` so that we can be sure tokenizer
146
- # assets have loaded when restoring a saved model.
147
- self.masker = MaskedLMMaskGenerator(
148
- mask_selection_rate=self.mask_selection_rate,
149
- mask_selection_length=self.mask_selection_length,
150
- mask_token_rate=self.mask_token_rate,
151
- random_token_rate=self.random_token_rate,
152
- vocabulary_size=self.tokenizer.vocabulary_size(),
153
- mask_token_id=self.tokenizer.mask_token_id,
154
- unselectable_token_ids=[
155
- self.tokenizer.cls_token_id,
156
- self.tokenizer.sep_token_id,
157
- self.tokenizer.pad_token_id,
158
- ],
159
- )
160
-
161
- def call(self, x, y=None, sample_weight=None):
162
- if y is not None or sample_weight is not None:
163
- logging.warning(
164
- f"{self.__class__.__name__} generates `y` and `sample_weight` "
165
- "based on your input data, but your data already contains `y` "
166
- "or `sample_weight`. Your `y` and `sample_weight` will be "
167
- "ignored."
168
- )
169
-
170
- x = super().call(x)
171
-
172
- token_ids, padding_mask, segment_ids = (
173
- x["token_ids"],
174
- x["padding_mask"],
175
- x["segment_ids"],
176
- )
177
- masker_outputs = self.masker(token_ids)
178
- x = {
179
- "token_ids": masker_outputs["token_ids"],
180
- "padding_mask": padding_mask,
181
- "segment_ids": segment_ids,
182
- "mask_positions": masker_outputs["mask_positions"],
183
- }
184
- y = masker_outputs["mask_ids"]
185
- sample_weight = masker_outputs["mask_weights"]
186
- return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
187
-
188
- def get_config(self):
189
- config = super().get_config()
190
- config.update(
191
- {
192
- "mask_selection_rate": self.mask_selection_rate,
193
- "mask_selection_length": self.mask_selection_length,
194
- "mask_token_rate": self.mask_token_rate,
195
- "random_token_rate": self.random_token_rate,
196
- }
197
- )
198
- return config
116
+ backbone_cls = BertBackbone
117
+ tokenizer_cls = BertTokenizer
@@ -129,9 +129,6 @@ backbone_presets = {
129
129
  },
130
130
  "kaggle_handle": "kaggle://keras/bert/keras/bert_large_en/2",
131
131
  },
132
- }
133
-
134
- classifier_presets = {
135
132
  "bert_tiny_en_uncased_sst2": {
136
133
  "metadata": {
137
134
  "description": (
@@ -143,5 +140,5 @@ classifier_presets = {
143
140
  "model_card": "https://github.com/google-research/bert/blob/master/README.md",
144
141
  },
145
142
  "kaggle_handle": "kaggle://keras/bert/keras/bert_tiny_en_uncased_sst2/4",
146
- }
143
+ },
147
144
  }
@@ -17,12 +17,19 @@ import keras
17
17
  from keras_hub.src.api_export import keras_hub_export
18
18
  from keras_hub.src.models.bert.bert_backbone import BertBackbone
19
19
  from keras_hub.src.models.bert.bert_backbone import bert_kernel_initializer
20
- from keras_hub.src.models.bert.bert_preprocessor import BertPreprocessor
21
- from keras_hub.src.models.classifier import Classifier
22
-
23
-
24
- @keras_hub_export("keras_hub.models.BertClassifier")
25
- class BertClassifier(Classifier):
20
+ from keras_hub.src.models.bert.bert_text_classifier_preprocessor import (
21
+ BertTextClassifierPreprocessor,
22
+ )
23
+ from keras_hub.src.models.text_classifier import TextClassifier
24
+
25
+
26
+ @keras_hub_export(
27
+ [
28
+ "keras_hub.models.BertTextClassifier",
29
+ "keras_hub.models.BertClassifier",
30
+ ]
31
+ )
32
+ class BertTextClassifier(TextClassifier):
26
33
  """An end-to-end BERT model for classification tasks.
27
34
 
28
35
  This model attaches a classification head to a
@@ -41,7 +48,7 @@ class BertClassifier(Classifier):
41
48
  Args:
42
49
  backbone: A `keras_hub.models.BertBackbone` instance.
43
50
  num_classes: int. Number of classes to predict.
44
- preprocessor: A `keras_hub.models.BertPreprocessor` or `None`. If
51
+ preprocessor: A `keras_hub.models.BertTextClassifierPreprocessor` or `None`. If
45
52
  `None`, this model will not apply preprocessing, and inputs should
46
53
  be preprocessed before calling the model.
47
54
  activation: Optional `str` or callable. The
@@ -59,7 +66,7 @@ class BertClassifier(Classifier):
59
66
  labels = [0, 3]
60
67
 
61
68
  # Pretrained classifier.
62
- classifier = keras_hub.models.BertClassifier.from_preset(
69
+ classifier = keras_hub.models.BertTextClassifier.from_preset(
63
70
  "bert_base_en_uncased",
64
71
  num_classes=4,
65
72
  )
@@ -88,7 +95,7 @@ class BertClassifier(Classifier):
88
95
  labels = [0, 3]
89
96
 
90
97
  # Pretrained classifier without preprocessing.
91
- classifier = keras_hub.models.BertClassifier.from_preset(
98
+ classifier = keras_hub.models.BertTextClassifier.from_preset(
92
99
  "bert_base_en_uncased",
93
100
  num_classes=4,
94
101
  preprocessor=None,
@@ -106,7 +113,7 @@ class BertClassifier(Classifier):
106
113
  tokenizer = keras_hub.models.BertTokenizer(
107
114
  vocabulary=vocab,
108
115
  )
109
- preprocessor = keras_hub.models.BertPreprocessor(
116
+ preprocessor = keras_hub.models.BertTextClassifierPreprocessor(
110
117
  tokenizer=tokenizer,
111
118
  sequence_length=128,
112
119
  )
@@ -118,7 +125,7 @@ class BertClassifier(Classifier):
118
125
  intermediate_dim=512,
119
126
  max_sequence_length=128,
120
127
  )
121
- classifier = keras_hub.models.BertClassifier(
128
+ classifier = keras_hub.models.BertTextClassifier(
122
129
  backbone=backbone,
123
130
  preprocessor=preprocessor,
124
131
  num_classes=4,
@@ -128,7 +135,7 @@ class BertClassifier(Classifier):
128
135
  """
129
136
 
130
137
  backbone_cls = BertBackbone
131
- preprocessor_cls = BertPreprocessor
138
+ preprocessor_cls = BertTextClassifierPreprocessor
132
139
 
133
140
  def __init__(
134
141
  self,