keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,138 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.layers.preprocessing.multi_segment_packer import (
18
+ MultiSegmentPacker,
19
+ )
20
+ from keras_hub.src.models.preprocessor import Preprocessor
21
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
22
+
23
+
24
+ @keras_hub_export("keras_hub.models.TextClassifierPreprocessor")
25
+ class TextClassifierPreprocessor(Preprocessor):
26
+ """Base class for text classification preprocessing layers.
27
+
28
+ `TextClassifierPreprocessor` tasks wrap a `keras_hub.tokenizer.Tokenizer` to
29
+ create a preprocessing layer for text classification tasks. It is intended
30
+ to be paired with a `keras_hub.models.TextClassifier` task.
31
+
32
+ All `TextClassifierPreprocessor` take inputs three ordered inputs, `x`, `y`,
33
+ and `sample_weight`. `x`, the first input, should always be included. It can
34
+ be a single string, a batch of strings, or a tuple of batches of string
35
+ segments that should be combined into a single sequence. See examples below.
36
+ `y` and `sample_weight` are optional inputs that will be passed through
37
+ unaltered. Usually, `y` will be the classification label, and
38
+ `sample_weight` will not be provided.
39
+
40
+ The layer will output either `x`, an `(x, y)` tuple if labels were provided,
41
+ or an `(x, y, sample_weight)` tuple if labels and sample weight were
42
+ provided. `x` will be a dictionary with tokenized input, the exact contents
43
+ of the dictionary will depend on the model being used.
44
+
45
+ All `TextClassifierPreprocessor` tasks include a `from_preset()` constructor
46
+ which can be used to load a pre-trained config and vocabularies. You can
47
+ call the `from_preset()` constructor directly on this base class, in which
48
+ case the correct class for you model will be automatically instantiated.
49
+
50
+ Examples.
51
+ ```python
52
+ preprocessor = keras_hub.models.TextClassifierPreprocessor.from_preset(
53
+ "bert_base_en_uncased",
54
+ sequence_length=256, # Optional.
55
+ )
56
+
57
+ # Tokenize and pad/truncate a single sentence.
58
+ x = "The quick brown fox jumped."
59
+ x = preprocessor(x)
60
+
61
+ # Tokenize and pad/truncate a labeled sentence.
62
+ x, y = "The quick brown fox jumped.", 1
63
+ x, y = preprocessor(x, y)
64
+
65
+ # Tokenize and pad/truncate a batch of labeled sentences.
66
+ x, y = ["The quick brown fox jumped.", "Call me Ishmael."], [1, 0]
67
+ x, y = preprocessor(x, y)
68
+
69
+ # Tokenize and combine a batch of labeled sentence pairs.
70
+ first = ["The quick brown fox jumped.", "Call me Ishmael."]
71
+ second = ["The fox tripped.", "Oh look, a whale."]
72
+ labels = [1, 0]
73
+ x, y = (first, second), labels
74
+ x, y = preprocessor(x, y)
75
+
76
+ # Use a `tf.data.Dataset`.
77
+ ds = tf.data.Dataset.from_tensor_slices(((first, second), labels))
78
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
79
+ ```
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ tokenizer,
85
+ sequence_length=512,
86
+ truncate="round_robin",
87
+ **kwargs,
88
+ ):
89
+ super().__init__(**kwargs)
90
+ self.tokenizer = tokenizer
91
+ self.packer = None
92
+ self.sequence_length = sequence_length
93
+ self.truncate = truncate
94
+
95
+ def build(self, input_shape):
96
+ super().build(input_shape)
97
+ # Defer masker creation to `build()` so that we can be sure tokenizer
98
+ # assets have loaded when restoring a saved model.
99
+ self.packer = MultiSegmentPacker(
100
+ start_value=self.tokenizer.start_token_id,
101
+ end_value=self.tokenizer.end_token_id,
102
+ pad_value=self.tokenizer.pad_token_id,
103
+ truncate=self.truncate,
104
+ sequence_length=self.sequence_length,
105
+ )
106
+
107
+ @preprocessing_function
108
+ def call(self, x, y=None, sample_weight=None):
109
+ x = x if isinstance(x, tuple) else (x,)
110
+ x = tuple(self.tokenizer(segment) for segment in x)
111
+ token_ids, segment_ids = self.packer(x)
112
+ x = {
113
+ "token_ids": token_ids,
114
+ "padding_mask": token_ids != self.tokenizer.pad_token_id,
115
+ "segment_ids": segment_ids,
116
+ }
117
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
118
+
119
+ def get_config(self):
120
+ config = super().get_config()
121
+ config.update(
122
+ {
123
+ "sequence_length": self.sequence_length,
124
+ "truncate": self.truncate,
125
+ }
126
+ )
127
+ return config
128
+
129
+ @property
130
+ def sequence_length(self):
131
+ """The padded length of model input sequences."""
132
+ return self._sequence_length
133
+
134
+ @sequence_length.setter
135
+ def sequence_length(self, value):
136
+ self._sequence_length = value
137
+ if self.packer is not None:
138
+ self.packer.sequence_length = value
@@ -14,7 +14,6 @@
14
14
 
15
15
  from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone
16
16
  from keras_hub.src.models.whisper.whisper_presets import backbone_presets
17
- from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
18
17
  from keras_hub.src.utils.preset_utils import register_presets
19
18
 
20
- register_presets(backbone_presets, (WhisperBackbone, WhisperTokenizer))
19
+ register_presets(backbone_presets, WhisperBackbone)
@@ -15,24 +15,19 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
20
+ from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone
21
+
18
22
  try:
19
23
  import tensorflow as tf
20
24
  except ImportError:
21
- raise ImportError(
22
- "To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
23
- "The TensorFlow package is required for data preprocessing with any backend."
24
- )
25
-
26
- from keras_hub.src.api_export import keras_hub_export
27
- from keras_hub.src.layers.preprocessing.preprocessing_layer import (
28
- PreprocessingLayer,
29
- )
25
+ tf = None
30
26
 
31
27
 
32
- @keras_hub_export("keras_hub.models.WhisperAudioFeatureExtractor")
33
- class WhisperAudioFeatureExtractor(PreprocessingLayer):
34
- """
35
- Whisper audio feature extractor layer.
28
+ @keras_hub_export("keras_hub.layers.WhisperAudioConverter")
29
+ class WhisperAudioConverter(AudioConverter):
30
+ """Whisper audio converter layer.
36
31
 
37
32
  This layer takes in a batch of audio tensors, and computes the log-mel
38
33
  spectrogram features for each audio tensor.
@@ -55,22 +50,25 @@ class WhisperAudioFeatureExtractor(PreprocessingLayer):
55
50
  `max_audio_length * sampling_rate`. Defaults to `30`.
56
51
 
57
52
  Examples:
58
-
59
53
  ```python
60
54
  audio_tensor = tf.ones((8000,), dtype="float32")
61
55
 
62
56
  # Compute the log-mel spectrogram.
63
- whisper_audio_feature_extractor = keras_hub.models.WhisperAudioFeatureExtractor()
64
- whisper_audio_feature_extractor(audio_tensor)
57
+ audio_converter = keras_hub.models.WhisperAudioConverter.from_preset(
58
+ "whisper_base_en",
59
+ )
60
+ audio_converter(audio_tensor)
65
61
 
66
62
  # Compute the log-mel spectrogram for a batch of audio tensors.
67
63
  audio_tensor_1 = tf.ones((8000,), dtype="float32")
68
- audio_tensor_2 = tf.ones((10000,), dtype="float32"
64
+ audio_tensor_2 = tf.ones((10000,), dtype="float32")
69
65
  audio_tensor = tf.ragged.stack([audio_tensor_1, audio_tensor_2], axis=0)
70
- whisper_audio_feature_extractor(audio_tensor)
66
+ audio_converter(audio_tensor)
71
67
  ```
72
68
  """
73
69
 
70
+ backbone_cls = WhisperBackbone
71
+
74
72
  def __init__(
75
73
  self,
76
74
  num_mels=80,
@@ -97,6 +95,10 @@ class WhisperAudioFeatureExtractor(PreprocessingLayer):
97
95
  # `(num_fft_bins // 2 + 1, num_mels).`
98
96
  self.mel_filters = self._get_mel_filters()
99
97
 
98
+ def audio_shape(self):
99
+ """Returns the preprocessed size of a single audio sample."""
100
+ return (self.max_audio_length, self.num_mels)
101
+
100
102
  def _get_mel_filters(self):
101
103
  """
102
104
  Adapted from Hugging Face
@@ -24,7 +24,6 @@ from keras_hub.src.layers.modeling.token_and_position_embedding import (
24
24
  from keras_hub.src.models.backbone import Backbone
25
25
  from keras_hub.src.models.whisper.whisper_decoder import WhisperDecoder
26
26
  from keras_hub.src.models.whisper.whisper_encoder import WhisperEncoder
27
- from keras_hub.src.utils.tensor_utils import assert_tf_backend
28
27
 
29
28
 
30
29
  def whisper_kernel_initializer(stddev=0.02):
@@ -117,8 +116,6 @@ class WhisperBackbone(Backbone):
117
116
  dtype=None,
118
117
  **kwargs,
119
118
  ):
120
- assert_tf_backend(self.__class__.__name__)
121
-
122
119
  # === Layers ===
123
120
  self.encoder_conv_layer_1 = keras.layers.Conv1D(
124
121
  filters=hidden_dim,
@@ -25,7 +25,7 @@ backbone_presets = {
25
25
  "path": "whisper",
26
26
  "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
27
27
  },
28
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_en/2",
28
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_en/3",
29
29
  },
30
30
  "whisper_base_en": {
31
31
  "metadata": {
@@ -38,7 +38,7 @@ backbone_presets = {
38
38
  "path": "whisper",
39
39
  "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
40
40
  },
41
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_en/2",
41
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_en/3",
42
42
  },
43
43
  "whisper_small_en": {
44
44
  "metadata": {
@@ -51,7 +51,7 @@ backbone_presets = {
51
51
  "path": "whisper",
52
52
  "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
53
53
  },
54
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_en/2",
54
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_en/3",
55
55
  },
56
56
  "whisper_medium_en": {
57
57
  "metadata": {
@@ -64,7 +64,7 @@ backbone_presets = {
64
64
  "path": "whisper",
65
65
  "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
66
66
  },
67
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_en/2",
67
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_en/3",
68
68
  },
69
69
  "whisper_tiny_multi": {
70
70
  "metadata": {
@@ -77,7 +77,7 @@ backbone_presets = {
77
77
  "path": "whisper",
78
78
  "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
79
79
  },
80
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_multi/2",
80
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_multi/3",
81
81
  },
82
82
  "whisper_base_multi": {
83
83
  "metadata": {
@@ -90,7 +90,7 @@ backbone_presets = {
90
90
  "path": "whisper",
91
91
  "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
92
92
  },
93
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_multi/2",
93
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_multi/3",
94
94
  },
95
95
  "whisper_small_multi": {
96
96
  "metadata": {
@@ -103,7 +103,7 @@ backbone_presets = {
103
103
  "path": "whisper",
104
104
  "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
105
105
  },
106
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_multi/2",
106
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_multi/3",
107
107
  },
108
108
  "whisper_medium_multi": {
109
109
  "metadata": {
@@ -116,7 +116,7 @@ backbone_presets = {
116
116
  "path": "whisper",
117
117
  "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
118
118
  },
119
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_multi/2",
119
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_multi/3",
120
120
  },
121
121
  "whisper_large_multi": {
122
122
  "metadata": {
@@ -129,7 +129,7 @@ backbone_presets = {
129
129
  "path": "whisper",
130
130
  "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
131
131
  },
132
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi/2",
132
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi/3",
133
133
  },
134
134
  "whisper_large_multi_v2": {
135
135
  "metadata": {
@@ -143,6 +143,6 @@ backbone_presets = {
143
143
  "path": "whisper",
144
144
  "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
145
145
  },
146
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi_v2/2",
146
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi_v2/3",
147
147
  },
148
148
  }
@@ -15,6 +15,7 @@
15
15
  import json
16
16
 
17
17
  from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone
18
19
  from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
19
20
 
20
21
 
@@ -25,7 +26,12 @@ def _load_dict(dict_or_path):
25
26
  return dict_or_path
26
27
 
27
28
 
28
- @keras_hub_export("keras_hub.models.WhisperTokenizer")
29
+ @keras_hub_export(
30
+ [
31
+ "keras_hub.tokenizers.WhisperTokenizer",
32
+ "keras_hub.models.WhisperTokenizer",
33
+ ]
34
+ )
29
35
  class WhisperTokenizer(BytePairTokenizer):
30
36
  """Whisper text tokenizer using Byte-Pair Encoding subword segmentation.
31
37
 
@@ -47,6 +53,8 @@ class WhisperTokenizer(BytePairTokenizer):
47
53
  tokenizer.
48
54
  """
49
55
 
56
+ backbone_cls = WhisperBackbone
57
+
50
58
  def __init__(
51
59
  self,
52
60
  vocabulary=None,
@@ -94,20 +102,22 @@ class WhisperTokenizer(BytePairTokenizer):
94
102
  self.translate_token_id = special_tokens[self.translate_token]
95
103
  self.transcribe_token_id = special_tokens[self.transcribe_token]
96
104
 
97
- self.special_tokens = special_tokens
105
+ self._special_token_dict = special_tokens
98
106
  self.language_tokens = language_tokens
99
-
100
- # TODO: Add language tokens to `unsplittable_tokens` once we figure
101
- # out the performance issue with a large list.
102
- unsplittable_tokens = list(special_tokens.keys())
103
-
104
107
  super().__init__(
105
108
  vocabulary=vocabulary,
106
109
  merges=merges,
107
- unsplittable_tokens=unsplittable_tokens,
108
110
  **kwargs,
109
111
  )
110
112
 
113
+ @property
114
+ def special_tokens(self):
115
+ return list(self._special_token_dict.keys())
116
+
117
+ @property
118
+ def special_token_ids(self):
119
+ return list(self._special_token_dict.values())
120
+
111
121
  def save_assets(self, dir_path):
112
122
  # TODO: whisper is currently mutating it's vocabulary before passing
113
123
  # it to the super class, so we need to restore the unmutated vocabulary
@@ -140,7 +150,7 @@ class WhisperTokenizer(BytePairTokenizer):
140
150
  self.translate_token,
141
151
  self.transcribe_token,
142
152
  ]:
143
- vocabulary[token] = self.special_tokens[token]
153
+ vocabulary[token] = self._special_token_dict[token]
144
154
  else:
145
155
  self._initial_vocabulary = None
146
156
 
@@ -148,15 +158,9 @@ class WhisperTokenizer(BytePairTokenizer):
148
158
 
149
159
  def get_config(self):
150
160
  config = super().get_config()
151
-
152
- # In the constructor, we pass the list of special tokens to the
153
- # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
154
- # delete it from the config here.
155
- del config["unsplittable_tokens"]
156
-
157
161
  config.update(
158
162
  {
159
- "special_tokens": self.special_tokens,
163
+ "special_tokens": self._special_token_dict,
160
164
  "language_tokens": self.language_tokens,
161
165
  }
162
166
  )
@@ -18,9 +18,6 @@ from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import (
18
18
  from keras_hub.src.models.xlm_roberta.xlm_roberta_presets import (
19
19
  backbone_presets,
20
20
  )
21
- from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import (
22
- XLMRobertaTokenizer,
23
- )
24
21
  from keras_hub.src.utils.preset_utils import register_presets
25
22
 
26
- register_presets(backbone_presets, (XLMRobertaBackbone, XLMRobertaTokenizer))
23
+ register_presets(backbone_presets, XLMRobertaBackbone)
@@ -13,19 +13,23 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import keras
16
- from absl import logging
17
16
 
18
17
  from keras_hub.src.api_export import keras_hub_export
19
- from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import (
20
- MaskedLMMaskGenerator,
18
+ from keras_hub.src.layers.preprocessing.multi_segment_packer import (
19
+ MultiSegmentPacker,
21
20
  )
22
- from keras_hub.src.models.xlm_roberta.xlm_roberta_preprocessor import (
23
- XLMRobertaPreprocessor,
21
+ from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor
22
+ from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import (
23
+ XLMRobertaBackbone,
24
24
  )
25
+ from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import (
26
+ XLMRobertaTokenizer,
27
+ )
28
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
25
29
 
26
30
 
27
31
  @keras_hub_export("keras_hub.models.XLMRobertaMaskedLMPreprocessor")
28
- class XLMRobertaMaskedLMPreprocessor(XLMRobertaPreprocessor):
32
+ class XLMRobertaMaskedLMPreprocessor(MaskedLMPreprocessor):
29
33
  """XLM-RoBERTa preprocessing for the masked language modeling task.
30
34
 
31
35
  This preprocessing layer will prepare inputs for a masked language modeling
@@ -120,76 +124,26 @@ class XLMRobertaMaskedLMPreprocessor(XLMRobertaPreprocessor):
120
124
  ```
121
125
  """
122
126
 
123
- def __init__(
124
- self,
125
- tokenizer,
126
- sequence_length=512,
127
- truncate="round_robin",
128
- mask_selection_rate=0.15,
129
- mask_selection_length=96,
130
- mask_token_rate=0.8,
131
- random_token_rate=0.1,
132
- **kwargs,
133
- ):
134
- super().__init__(
135
- tokenizer,
136
- sequence_length=sequence_length,
137
- truncate=truncate,
138
- **kwargs,
139
- )
140
- self.mask_selection_rate = mask_selection_rate
141
- self.mask_selection_length = mask_selection_length
142
- self.mask_token_rate = mask_token_rate
143
- self.random_token_rate = random_token_rate
144
- self.masker = None
127
+ backbone_cls = XLMRobertaBackbone
128
+ tokenizer_cls = XLMRobertaTokenizer
145
129
 
146
130
  def build(self, input_shape):
147
131
  super().build(input_shape)
148
- # Defer masker creation to `build()` so that we can be sure tokenizer
149
- # assets have loaded when restoring a saved model.
150
- self.masker = MaskedLMMaskGenerator(
151
- mask_selection_rate=self.mask_selection_rate,
152
- mask_selection_length=self.mask_selection_length,
153
- mask_token_rate=self.mask_token_rate,
154
- random_token_rate=self.random_token_rate,
155
- vocabulary_size=self.tokenizer.vocabulary_size(),
156
- mask_token_id=self.tokenizer.mask_token_id,
157
- unselectable_token_ids=[
158
- self.tokenizer.start_token_id,
159
- self.tokenizer.end_token_id,
160
- self.tokenizer.pad_token_id,
161
- ],
162
- )
163
-
164
- def get_config(self):
165
- config = super().get_config()
166
- config.update(
167
- {
168
- "mask_selection_rate": self.mask_selection_rate,
169
- "mask_selection_length": self.mask_selection_length,
170
- "mask_token_rate": self.mask_token_rate,
171
- "random_token_rate": self.random_token_rate,
172
- }
132
+ # Roberta is doubles up the sep token, so we override build.
133
+ self.packer = MultiSegmentPacker(
134
+ start_value=self.tokenizer.start_token_id,
135
+ end_value=self.tokenizer.end_token_id,
136
+ sep_value=[self.tokenizer.end_token_id] * 2,
137
+ pad_value=self.tokenizer.pad_token_id,
138
+ truncate=self.truncate,
139
+ sequence_length=self.sequence_length,
173
140
  )
174
- return config
141
+ self.built = True
175
142
 
143
+ @preprocessing_function
176
144
  def call(self, x, y=None, sample_weight=None):
177
- if y is not None or sample_weight is not None:
178
- logging.warning(
179
- f"{self.__class__.__name__} generates `y` and `sample_weight` "
180
- "based on your input data, but your data already contains `y` "
181
- "or `sample_weight`. Your `y` and `sample_weight` will be "
182
- "ignored."
183
- )
184
-
185
- x = super().call(x)
186
- token_ids, padding_mask = x["token_ids"], x["padding_mask"]
187
- masker_outputs = self.masker(token_ids)
188
- x = {
189
- "token_ids": masker_outputs["token_ids"],
190
- "padding_mask": padding_mask,
191
- "mask_positions": masker_outputs["mask_positions"],
192
- }
193
- y = masker_outputs["mask_ids"]
194
- sample_weight = masker_outputs["mask_weights"]
145
+ output = super().call(x, y=y, sample_weight=sample_weight)
146
+ x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output)
147
+ # Backbone has no segment ID input.
148
+ del x["segment_ids"]
195
149
  return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
@@ -16,20 +16,25 @@
16
16
  import keras
17
17
 
18
18
  from keras_hub.src.api_export import keras_hub_export
19
- from keras_hub.src.models.classifier import Classifier
20
19
  from keras_hub.src.models.roberta.roberta_backbone import (
21
20
  roberta_kernel_initializer,
22
21
  )
22
+ from keras_hub.src.models.text_classifier import TextClassifier
23
23
  from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import (
24
24
  XLMRobertaBackbone,
25
25
  )
26
- from keras_hub.src.models.xlm_roberta.xlm_roberta_preprocessor import (
27
- XLMRobertaPreprocessor,
26
+ from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import (
27
+ XLMRobertaTextClassifierPreprocessor,
28
28
  )
29
29
 
30
30
 
31
- @keras_hub_export("keras_hub.models.XLMRobertaClassifier")
32
- class XLMRobertaClassifier(Classifier):
31
+ @keras_hub_export(
32
+ [
33
+ "keras_hub.models.XLMRobertaTextClassifier",
34
+ "keras_hub.models.XLMRobertaClassifier",
35
+ ]
36
+ )
37
+ class XLMRobertaTextClassifier(TextClassifier):
33
38
  """An end-to-end XLM-RoBERTa model for classification tasks.
34
39
 
35
40
  This model attaches a classification head to a
@@ -50,7 +55,7 @@ class XLMRobertaClassifier(Classifier):
50
55
  Args:
51
56
  backbone: A `keras_hub.models.XLMRobertaBackbone` instance.
52
57
  num_classes: int. Number of classes to predict.
53
- preprocessor: A `keras_hub.models.XLMRobertaPreprocessor` or `None`. If
58
+ preprocessor: A `keras_hub.models.XLMRobertaTextClassifierPreprocessor` or `None`. If
54
59
  `None`, this model will not apply preprocessing, and inputs should
55
60
  be preprocessed before calling the model.
56
61
  activation: Optional `str` or callable. The activation function to use
@@ -68,7 +73,7 @@ class XLMRobertaClassifier(Classifier):
68
73
  labels = [0, 3]
69
74
 
70
75
  # Pretrained classifier.
71
- classifier = keras_hub.models.XLMRobertaClassifier.from_preset(
76
+ classifier = keras_hub.models.XLMRobertaTextClassifier.from_preset(
72
77
  "xlm_roberta_base_multi",
73
78
  num_classes=4,
74
79
  )
@@ -96,7 +101,7 @@ class XLMRobertaClassifier(Classifier):
96
101
  labels = [0, 3]
97
102
 
98
103
  # Pretrained classifier without preprocessing.
99
- classifier = keras_hub.models.XLMRobertaClassifier.from_preset(
104
+ classifier = keras_hub.models.XLMRobertaTextClassifier.from_preset(
100
105
  "xlm_roberta_base_multi",
101
106
  num_classes=4,
102
107
  preprocessor=None,
@@ -128,7 +133,7 @@ class XLMRobertaClassifier(Classifier):
128
133
  tokenizer = keras_hub.models.XLMRobertaTokenizer(
129
134
  proto=proto
130
135
  )
131
- preprocessor = keras_hub.models.XLMRobertaPreprocessor(
136
+ preprocessor = keras_hub.models.XLMRobertaTextClassifierPreprocessor(
132
137
  tokenizer,
133
138
  sequence_length=128,
134
139
  )
@@ -140,7 +145,7 @@ class XLMRobertaClassifier(Classifier):
140
145
  intermediate_dim=512,
141
146
  max_sequence_length=128,
142
147
  )
143
- classifier = keras_hub.models.XLMRobertaClassifier(
148
+ classifier = keras_hub.models.XLMRobertaTextClassifier(
144
149
  backbone=backbone,
145
150
  preprocessor=preprocessor,
146
151
  num_classes=4,
@@ -150,7 +155,7 @@ class XLMRobertaClassifier(Classifier):
150
155
  """
151
156
 
152
157
  backbone_cls = XLMRobertaBackbone
153
- preprocessor_cls = XLMRobertaPreprocessor
158
+ preprocessor_cls = XLMRobertaTextClassifierPreprocessor
154
159
 
155
160
  def __init__(
156
161
  self,