keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -12,21 +12,21 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import keras
16
-
17
15
  from keras_hub.src.api_export import keras_hub_export
18
- from keras_hub.src.layers.preprocessing.multi_segment_packer import (
19
- MultiSegmentPacker,
20
- )
16
+ from keras_hub.src.models.bert.bert_backbone import BertBackbone
21
17
  from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
22
- from keras_hub.src.models.preprocessor import Preprocessor
23
- from keras_hub.src.utils.keras_utils import (
24
- convert_inputs_to_list_of_tensor_segments,
18
+ from keras_hub.src.models.text_classifier_preprocessor import (
19
+ TextClassifierPreprocessor,
25
20
  )
26
21
 
27
22
 
28
- @keras_hub_export("keras_hub.models.BertPreprocessor")
29
- class BertPreprocessor(Preprocessor):
23
+ @keras_hub_export(
24
+ [
25
+ "keras_hub.models.BertTextClassifierPreprocessor",
26
+ "keras_hub.models.BertPreprocessor",
27
+ ]
28
+ )
29
+ class BertTextClassifierPreprocessor(TextClassifierPreprocessor):
30
30
  """A BERT preprocessing layer which tokenizes and packs inputs.
31
31
 
32
32
  This preprocessing layer will do three things:
@@ -67,7 +67,7 @@ class BertPreprocessor(Preprocessor):
67
67
 
68
68
  Directly calling the layer on data.
69
69
  ```python
70
- preprocessor = keras_hub.models.BertPreprocessor.from_preset(
70
+ preprocessor = keras_hub.models.TextClassifierPreprocessor.from_preset(
71
71
  "bert_base_en_uncased"
72
72
  )
73
73
 
@@ -87,13 +87,13 @@ class BertPreprocessor(Preprocessor):
87
87
  vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
88
88
  vocab += ["The", "quick", "brown", "fox", "jumped", "."]
89
89
  tokenizer = keras_hub.models.BertTokenizer(vocabulary=vocab)
90
- preprocessor = keras_hub.models.BertPreprocessor(tokenizer)
90
+ preprocessor = keras_hub.models.BertTextClassifierPreprocessor(tokenizer)
91
91
  preprocessor("The quick brown fox jumped.")
92
92
  ```
93
93
 
94
94
  Mapping with `tf.data.Dataset`.
95
95
  ```python
96
- preprocessor = keras_hub.models.BertPreprocessor.from_preset(
96
+ preprocessor = keras_hub.models.TextClassifierPreprocessor.from_preset(
97
97
  "bert_base_en_uncased"
98
98
  )
99
99
 
@@ -124,61 +124,5 @@ class BertPreprocessor(Preprocessor):
124
124
  ```
125
125
  """
126
126
 
127
+ backbone_cls = BertBackbone
127
128
  tokenizer_cls = BertTokenizer
128
-
129
- def __init__(
130
- self,
131
- tokenizer,
132
- sequence_length=512,
133
- truncate="round_robin",
134
- **kwargs,
135
- ):
136
- super().__init__(**kwargs)
137
- self.tokenizer = tokenizer
138
- self.packer = None
139
- self.sequence_length = sequence_length
140
- self.truncate = truncate
141
-
142
- def build(self, input_shape):
143
- # Defer packer creation to `build()` so that we can be sure tokenizer
144
- # assets have loaded when restoring a saved model.
145
- self.packer = MultiSegmentPacker(
146
- start_value=self.tokenizer.cls_token_id,
147
- end_value=self.tokenizer.sep_token_id,
148
- pad_value=self.tokenizer.pad_token_id,
149
- truncate=self.truncate,
150
- sequence_length=self.sequence_length,
151
- )
152
- self.built = True
153
-
154
- def call(self, x, y=None, sample_weight=None):
155
- x = convert_inputs_to_list_of_tensor_segments(x)
156
- x = [self.tokenizer(segment) for segment in x]
157
- token_ids, segment_ids = self.packer(x)
158
- x = {
159
- "token_ids": token_ids,
160
- "segment_ids": segment_ids,
161
- "padding_mask": token_ids != self.tokenizer.pad_token_id,
162
- }
163
- return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
164
-
165
- def get_config(self):
166
- config = super().get_config()
167
- config.update(
168
- {
169
- "sequence_length": self.sequence_length,
170
- "truncate": self.truncate,
171
- }
172
- )
173
- return config
174
-
175
- @property
176
- def sequence_length(self):
177
- """The padded length of model input sequences."""
178
- return self._sequence_length
179
-
180
- @sequence_length.setter
181
- def sequence_length(self, value):
182
- self._sequence_length = value
183
- if self.packer is not None:
184
- self.packer.sequence_length = value
@@ -13,10 +13,16 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from keras_hub.src.api_export import keras_hub_export
16
+ from keras_hub.src.models.bert.bert_backbone import BertBackbone
16
17
  from keras_hub.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer
17
18
 
18
19
 
19
- @keras_hub_export("keras_hub.models.BertTokenizer")
20
+ @keras_hub_export(
21
+ [
22
+ "keras_hub.tokenizers.BertTokenizer",
23
+ "keras_hub.models.BertTokenizer",
24
+ ]
25
+ )
20
26
  class BertTokenizer(WordPieceTokenizer):
21
27
  """A BERT tokenizer using WordPiece subword segmentation.
22
28
 
@@ -26,9 +32,6 @@ class BertTokenizer(WordPieceTokenizer):
26
32
  models and provides a `from_preset()` method to automatically download
27
33
  a matching vocabulary for a BERT preset.
28
34
 
29
- This tokenizer does not provide truncation or padding of inputs. It can be
30
- combined with a `keras_hub.models.BertPreprocessor` layer for input packing.
31
-
32
35
  If input is a batch of strings (rank > 0), the layer will output a
33
36
  `tf.RaggedTensor` where the last dimension of the output is ragged.
34
37
 
@@ -68,45 +71,24 @@ class BertTokenizer(WordPieceTokenizer):
68
71
  ```
69
72
  """
70
73
 
74
+ backbone_cls = BertBackbone
75
+
71
76
  def __init__(
72
77
  self,
73
78
  vocabulary=None,
74
79
  lowercase=False,
75
- special_tokens_in_strings=False,
76
80
  **kwargs,
77
81
  ):
78
- self.cls_token = "[CLS]"
79
- self.sep_token = "[SEP]"
80
- self.pad_token = "[PAD]"
81
- self.mask_token = "[MASK]"
82
+ self._add_special_token("[CLS]", "cls_token")
83
+ self._add_special_token("[SEP]", "sep_token")
84
+ self._add_special_token("[PAD]", "pad_token")
85
+ self._add_special_token("[MASK]", "mask_token")
86
+ # Also add `tokenizer.start_token` and `tokenizer.end_token` for
87
+ # compatibility with other tokenizers.
88
+ self._add_special_token("[CLS]", "start_token")
89
+ self._add_special_token("[SEP]", "end_token")
82
90
  super().__init__(
83
91
  vocabulary=vocabulary,
84
92
  lowercase=lowercase,
85
- special_tokens=[
86
- self.cls_token,
87
- self.sep_token,
88
- self.pad_token,
89
- self.mask_token,
90
- ],
91
- special_tokens_in_strings=special_tokens_in_strings,
92
93
  **kwargs,
93
94
  )
94
-
95
- def set_vocabulary(self, vocabulary):
96
- super().set_vocabulary(vocabulary)
97
-
98
- if vocabulary is not None:
99
- self.cls_token_id = self.token_to_id(self.cls_token)
100
- self.sep_token_id = self.token_to_id(self.sep_token)
101
- self.pad_token_id = self.token_to_id(self.pad_token)
102
- self.mask_token_id = self.token_to_id(self.mask_token)
103
- else:
104
- self.cls_token_id = None
105
- self.sep_token_id = None
106
- self.pad_token_id = None
107
- self.mask_token_id = None
108
-
109
- def get_config(self):
110
- config = super().get_config()
111
- del config["special_tokens"] # Not configurable; set in __init__.
112
- return config
@@ -14,7 +14,6 @@
14
14
 
15
15
  from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone
16
16
  from keras_hub.src.models.bloom.bloom_presets import backbone_presets
17
- from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer
18
17
  from keras_hub.src.utils.preset_utils import register_presets
19
18
 
20
- register_presets(backbone_presets, (BloomBackbone, BloomTokenizer))
19
+ register_presets(backbone_presets, BloomBackbone)
@@ -12,19 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import keras
16
- from absl import logging
17
15
 
18
16
  from keras_hub.src.api_export import keras_hub_export
19
- from keras_hub.src.models.bloom.bloom_preprocessor import BloomPreprocessor
20
- from keras_hub.src.utils.keras_utils import (
21
- convert_inputs_to_list_of_tensor_segments,
22
- )
23
- from keras_hub.src.utils.tensor_utils import strip_to_ragged
17
+ from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone
18
+ from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer
19
+ from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
24
20
 
25
21
 
26
22
  @keras_hub_export("keras_hub.models.BloomCausalLMPreprocessor")
27
- class BloomCausalLMPreprocessor(BloomPreprocessor):
23
+ class BloomCausalLMPreprocessor(CausalLMPreprocessor):
28
24
  """BLOOM Causal LM preprocessor.
29
25
 
30
26
  This preprocessing layer is meant for use with
@@ -91,86 +87,5 @@ class BloomCausalLMPreprocessor(BloomPreprocessor):
91
87
  ```
92
88
  """
93
89
 
94
- def call(
95
- self,
96
- x,
97
- y=None,
98
- sample_weight=None,
99
- sequence_length=None,
100
- ):
101
- if y is not None or sample_weight is not None:
102
- logging.warning(
103
- "`BloomCausalLMPreprocessor` generates `y` and `sample_weight` "
104
- "based on your input data, but your data already contains `y` "
105
- "or `sample_weight`. Your `y` and `sample_weight` will be "
106
- "ignored."
107
- )
108
- sequence_length = sequence_length or self.sequence_length
109
-
110
- x = convert_inputs_to_list_of_tensor_segments(x)[0]
111
- x = self.tokenizer(x)
112
- # Pad with one extra token to account for the truncation below.
113
- token_ids, padding_mask = self.packer(
114
- x,
115
- sequence_length=sequence_length + 1,
116
- add_start_value=self.add_start_token,
117
- add_end_value=self.add_end_token,
118
- )
119
- # The last token does not have a next token, so we truncate it out.
120
- x = {
121
- "token_ids": token_ids[..., :-1],
122
- "padding_mask": padding_mask[..., :-1],
123
- }
124
- # Target `y` will be the next token.
125
- y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
126
- return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
127
-
128
- def generate_preprocess(
129
- self,
130
- x,
131
- sequence_length=None,
132
- ):
133
- """Convert strings to integer token input for generation.
134
-
135
- Similar to calling the layer for training, this method takes in strings
136
- or tensor strings, tokenizes and packs the input, and computes a padding
137
- mask masking all inputs not filled in with a padded value.
138
-
139
- Unlike calling the layer for training, this method does not compute
140
- labels and will never append a `tokenizer.end_token_id` to the end of
141
- the sequence (as generation is expected to continue at the end of the
142
- inputted prompt).
143
- """
144
- if not self.built:
145
- self.build(None)
146
-
147
- x = convert_inputs_to_list_of_tensor_segments(x)[0]
148
- x = self.tokenizer(x)
149
- token_ids, padding_mask = self.packer(
150
- x, sequence_length=sequence_length, add_end_value=False
151
- )
152
- return {
153
- "token_ids": token_ids,
154
- "padding_mask": padding_mask,
155
- }
156
-
157
- def generate_postprocess(
158
- self,
159
- x,
160
- ):
161
- """Convert integer token output to strings for generation.
162
-
163
- This method reverses `generate_preprocess()`, by first removing all
164
- padding and start/end tokens, and then converting the integer sequence
165
- back to a string.
166
- """
167
- if not self.built:
168
- self.build(None)
169
-
170
- token_ids, padding_mask = x["token_ids"], x["padding_mask"]
171
- ids_to_strip = (
172
- self.tokenizer.start_token_id,
173
- self.tokenizer.end_token_id,
174
- )
175
- token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
176
- return self.tokenizer.detokenize(token_ids)
90
+ backbone_cls = BloomBackbone
91
+ tokenizer_cls = BloomTokenizer
@@ -14,10 +14,16 @@
14
14
 
15
15
 
16
16
  from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone
17
18
  from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
18
19
 
19
20
 
20
- @keras_hub_export("keras_hub.models.BloomTokenizer")
21
+ @keras_hub_export(
22
+ [
23
+ "keras_hub.tokenizers.BloomTokenizer",
24
+ "keras_hub.models.BloomTokenizer",
25
+ ]
26
+ )
21
27
  class BloomTokenizer(BytePairTokenizer):
22
28
  """A BLOOM tokenizer using Byte-Pair Encoding subword segmentation.
23
29
 
@@ -27,8 +33,6 @@ class BloomTokenizer(BytePairTokenizer):
27
33
  models and provides a `from_preset()` method to automatically download
28
34
  a matching vocabulary for a BLOOM preset.
29
35
 
30
- This tokenizer does not provide truncation or padding of inputs.
31
-
32
36
  If input is a batch of strings (rank > 0), the layer will output a
33
37
  `tf.RaggedTensor` where the last dimension of the output is ragged.
34
38
 
@@ -65,52 +69,19 @@ class BloomTokenizer(BytePairTokenizer):
65
69
  ```
66
70
  """
67
71
 
72
+ backbone_cls = BloomBackbone
73
+
68
74
  def __init__(
69
75
  self,
70
76
  vocabulary=None,
71
77
  merges=None,
72
78
  **kwargs,
73
79
  ):
74
- self.start_token = "<s>"
75
- self.end_token = "</s>"
76
- self.pad_token = "<pad>"
77
-
80
+ self._add_special_token("<s>", "start_token")
81
+ self._add_special_token("</s>", "end_token")
82
+ self._add_special_token("<pad>", "pad_token")
78
83
  super().__init__(
79
84
  vocabulary=vocabulary,
80
85
  merges=merges,
81
- unsplittable_tokens=[
82
- self.start_token,
83
- self.end_token,
84
- self.pad_token,
85
- ],
86
86
  **kwargs,
87
87
  )
88
-
89
- def set_vocabulary_and_merges(self, vocabulary, merges):
90
- super().set_vocabulary_and_merges(vocabulary, merges)
91
-
92
- if vocabulary is not None:
93
- # Check for necessary special tokens.
94
- for token in [self.start_token, self.end_token, self.pad_token]:
95
- if token not in self.get_vocabulary():
96
- raise ValueError(
97
- f"Cannot find token `'{token}'` in the provided "
98
- f"`vocabulary`. Please provide `'{token}'` in "
99
- "your `vocabulary` or use a pretrained `vocabulary` name."
100
- )
101
-
102
- self.start_token_id = self.token_to_id(self.start_token)
103
- self.end_token_id = self.token_to_id(self.end_token)
104
- self.pad_token_id = self.token_to_id(self.pad_token)
105
- else:
106
- self.start_token_id = None
107
- self.end_token_id = None
108
- self.pad_token_id = None
109
-
110
- def get_config(self):
111
- config = super().get_config()
112
- # In the constructor, we pass the list of special tokens to the
113
- # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
114
- # delete it from the config here.
115
- del config["unsplittable_tokens"]
116
- return config
@@ -22,7 +22,6 @@ from keras import tree
22
22
  from keras_hub.src.api_export import keras_hub_export
23
23
  from keras_hub.src.models.task import Task
24
24
  from keras_hub.src.samplers.serialization import get as get_sampler
25
- from keras_hub.src.utils.tensor_utils import tensor_to_list
26
25
 
27
26
  try:
28
27
  import tensorflow as tf
@@ -73,8 +72,6 @@ class CausalLM(Task):
73
72
 
74
73
  def __init__(self, *args, **kwargs):
75
74
  super().__init__(*args, **kwargs)
76
- # Default compilation.
77
- self.compile()
78
75
 
79
76
  def compile(
80
77
  self,
@@ -234,21 +231,18 @@ class CausalLM(Task):
234
231
  necessary, and returns a iterable "dataset like" object (either an
235
232
  actual `tf.data.Dataset` or a list with a single batch element).
236
233
  """
237
- input_is_scalar = False
234
+ if tf and isinstance(inputs, tf.data.Dataset):
235
+ return inputs.as_numpy_iterator(), False
238
236
 
239
- if isinstance(inputs, tf.data.Dataset):
240
- return inputs, input_is_scalar
237
+ if self.preprocessor is None:
238
+ return [inputs], False
241
239
 
242
240
  def normalize(x):
243
- x_is_scalar = False
244
- if isinstance(x, str) or isinstance(x, list):
245
- x = tf.convert_to_tensor(x)
246
-
247
- if isinstance(x, tf.Tensor) and x.shape.rank == 0:
248
- x_is_scalar = True
249
- x = x[tf.newaxis]
250
-
251
- return x, x_is_scalar
241
+ if isinstance(x, str):
242
+ return [x], True
243
+ if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0:
244
+ return x[tf.newaxis], True
245
+ return x, False
252
246
 
253
247
  if isinstance(inputs, dict):
254
248
  for key in inputs:
@@ -256,8 +250,6 @@ class CausalLM(Task):
256
250
  else:
257
251
  inputs, input_is_scalar = normalize(inputs)
258
252
 
259
- # We avoid converting to a dataset purely for speed, for a single batch
260
- # of input, creating a dataset would add significant overhead.
261
253
  return [inputs], input_is_scalar
262
254
 
263
255
  def _normalize_generate_outputs(
@@ -280,10 +272,6 @@ class CausalLM(Task):
280
272
  for e in batch:
281
273
  outputs.append(e)
282
274
  return outputs[0] if input_is_scalar else outputs
283
- if isinstance(x[0], tf.Tensor) and x[0].dtype == tf.string:
284
- outputs = tf.concat(x, axis=0)
285
- outputs = tf.squeeze(outputs, 0) if input_is_scalar else outputs
286
- return tensor_to_list(outputs)
287
275
  outputs = ops.concatenate(x, axis=0)
288
276
  outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs
289
277
  return ops.convert_to_numpy(outputs)
@@ -368,15 +356,8 @@ class CausalLM(Task):
368
356
  inputs, input_is_scalar = self._normalize_generate_inputs(inputs)
369
357
 
370
358
  if self.preprocessor is not None:
371
- if isinstance(inputs, tf.data.Dataset):
372
- inputs = inputs.map(preprocess, tf.data.AUTOTUNE)
373
- inputs = inputs.prefetch(tf.data.AUTOTUNE)
374
- else:
375
- # Fast path for non-dataset, single-batch input.
376
- inputs = [preprocess(x) for x in inputs]
377
-
359
+ inputs = [preprocess(x) for x in inputs]
378
360
  outputs = [generate(x) for x in inputs]
379
-
380
361
  if self.preprocessor is not None:
381
362
  outputs = [postprocess(x) for x in outputs]
382
363
 
@@ -0,0 +1,195 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
18
+ from keras_hub.src.models.preprocessor import Preprocessor
19
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
20
+ from keras_hub.src.utils.tensor_utils import strip_to_ragged
21
+
22
+
23
+ @keras_hub_export("keras_hub.models.CausalLMPreprocessor")
24
+ class CausalLMPreprocessor(Preprocessor):
25
+ """Base class for causal language modeling preprocessing layers.
26
+
27
+ `CausalLMPreprocessor` tasks wrap a `keras_hub.tokenizer.Tokenizer` to
28
+ create a preprocessing layer for causal language modeling tasks. It is
29
+ intended to be paired with a `keras.models.CausalLM` task.
30
+
31
+ All `CausalLMPreprocessor` take inputs a single input. This can be a single
32
+ string or a batch of strings. See examples below. These inputs
33
+ will be tokenized and padded/truncated to a fixed sequence length.
34
+
35
+ This layer will always output a `(x, y, sample_weight)` tuple, where `x`
36
+ is a dictionary with the tokenized inputs, `y` contains the tokens from `x`
37
+ offset by 1, and `sample_weight` marks where `y` contains padded
38
+ values. The exact contents of `x` will vary depending on the model being
39
+ used.
40
+
41
+ a `CausalLMPreprocessor` contains two extra methods, `generate_preprocess`
42
+ and `generate_postprocess` for use with generation. See examples below.
43
+
44
+ All `CausalLMPreprocessor` tasks include a `from_preset()` constructor
45
+ which can be used to load a pre-trained config and vocabularies. You can
46
+ call the `from_preset()` constructor directly on this base class, in which
47
+ case the correct class for you model will be automatically instantiated.
48
+
49
+ Examples.
50
+ ```python
51
+ preprocessor = keras_hub.models.CausalLMPreprocessor.from_preset(
52
+ "bert_base_en_uncased",
53
+ sequence_length=256, # Optional.
54
+ )
55
+
56
+ # Tokenize, mask and pack a single sentence.
57
+ x = "The quick brown fox jumped."
58
+ x, y, sample_weight = preprocessor(x)
59
+
60
+ # Tokenize and pad/truncate a batch of labeled sentences.
61
+ x = ["The quick brown fox jumped.", "Call me Ishmael."]
62
+ x, y, sample_weight = preprocessor(x)
63
+
64
+ # With a `tf.data.Dataset`.
65
+ ds = tf.data.Dataset.from_tensor_slices(x)
66
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
67
+
68
+ # Generate preprocess and postprocess.
69
+ x = preprocessor.generate_preprocess(x) # Tokenized numeric inputs.
70
+ x = preprocessor.generate_postprocess(x) # Detokenized string outputs.
71
+ ```
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ tokenizer,
77
+ sequence_length=1024,
78
+ add_start_token=True,
79
+ add_end_token=True,
80
+ **kwargs,
81
+ ):
82
+ super().__init__(**kwargs)
83
+ self.tokenizer = tokenizer
84
+ self.packer = None
85
+ self.sequence_length = sequence_length
86
+ self.add_start_token = add_start_token
87
+ self.add_end_token = add_end_token
88
+
89
+ def build(self, input_shape):
90
+ # Defer packer creation to `build()` so that we can be sure tokenizer
91
+ # assets have loaded when restoring a saved model.
92
+ self.packer = StartEndPacker(
93
+ start_value=self.tokenizer.start_token_id,
94
+ end_value=self.tokenizer.end_token_id,
95
+ pad_value=self.tokenizer.pad_token_id,
96
+ sequence_length=self.sequence_length,
97
+ return_padding_mask=True,
98
+ )
99
+ self.built = True
100
+
101
+ @preprocessing_function
102
+ def call(
103
+ self,
104
+ x,
105
+ y=None,
106
+ sample_weight=None,
107
+ sequence_length=None,
108
+ ):
109
+ sequence_length = sequence_length or self.sequence_length
110
+ x = self.tokenizer(x)
111
+ # Pad with one extra token to account for the truncation below.
112
+ token_ids, padding_mask = self.packer(
113
+ x,
114
+ sequence_length=sequence_length + 1,
115
+ add_start_value=self.add_start_token,
116
+ add_end_value=self.add_end_token,
117
+ )
118
+ # The last token does not have a next token, so we truncate it out.
119
+ x = {
120
+ "token_ids": token_ids[..., :-1],
121
+ "padding_mask": padding_mask[..., :-1],
122
+ }
123
+ # Target `y` will be the next token.
124
+ y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
125
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
126
+
127
+ @preprocessing_function
128
+ def generate_preprocess(
129
+ self,
130
+ x,
131
+ sequence_length=None,
132
+ ):
133
+ """Convert strings to integer token input for generation.
134
+
135
+ Similar to calling the layer for training, this method takes in strings
136
+ or tensor strings, tokenizes and packs the input, and computes a padding
137
+ mask masking all inputs not filled in with a padded value.
138
+
139
+ Unlike calling the layer for training, this method does not compute
140
+ labels and will never append a `tokenizer.end_token_id` to the end of
141
+ the sequence (as generation is expected to continue at the end of the
142
+ inputted prompt).
143
+ """
144
+ if not self.built:
145
+ self.build(None)
146
+
147
+ x = self.tokenizer(x)
148
+ token_ids, padding_mask = self.packer(
149
+ x, sequence_length=sequence_length, add_end_value=False
150
+ )
151
+ return {
152
+ "token_ids": token_ids,
153
+ "padding_mask": padding_mask,
154
+ }
155
+
156
+ @preprocessing_function
157
+ def generate_postprocess(
158
+ self,
159
+ x,
160
+ ):
161
+ """Convert integer token output to strings for generation.
162
+
163
+ This method reverses `generate_preprocess()`, by first removing all
164
+ padding and start/end tokens, and then converting the integer sequence
165
+ back to a string.
166
+ """
167
+ if not self.built:
168
+ self.build(None)
169
+
170
+ token_ids, padding_mask = x["token_ids"], x["padding_mask"]
171
+ ids_to_strip = self.tokenizer.special_token_ids
172
+ token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
173
+ return self.tokenizer.detokenize(token_ids)
174
+
175
+ def get_config(self):
176
+ config = super().get_config()
177
+ config.update(
178
+ {
179
+ "sequence_length": self.sequence_length,
180
+ "add_start_token": self.add_start_token,
181
+ "add_end_token": self.add_end_token,
182
+ }
183
+ )
184
+ return config
185
+
186
+ @property
187
+ def sequence_length(self):
188
+ """The padded length of model input sequences."""
189
+ return self._sequence_length
190
+
191
+ @sequence_length.setter
192
+ def sequence_length(self, value):
193
+ self._sequence_length = value
194
+ if self.packer is not None:
195
+ self.packer.sequence_length = value