keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,6 @@ from keras_hub.src.layers.modeling.reversible_embedding import (
15
15
  )
16
16
  from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid
17
17
  from keras_hub.src.tokenizers.tokenizer import Tokenizer
18
- from keras_hub.src.utils.keras_utils import has_quantization_support
19
18
  from keras_hub.src.utils.tensor_utils import is_float_dtype
20
19
 
21
20
 
@@ -313,6 +312,14 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
313
312
 
314
313
  for policy in ["mixed_float16", "mixed_bfloat16", "bfloat16"]:
315
314
  policy = keras.mixed_precision.Policy(policy)
315
+ # Ensure the correct `dtype` is set for sublayers or submodels in
316
+ # `init_kwargs`.
317
+ original_init_kwargs = init_kwargs.copy()
318
+ for k, v in init_kwargs.items():
319
+ if isinstance(v, keras.Layer):
320
+ config = v.get_config()
321
+ config["dtype"] = policy
322
+ init_kwargs[k] = v.__class__.from_config(config)
316
323
  layer = cls(**{**init_kwargs, "dtype": policy})
317
324
  if isinstance(layer, keras.Model):
318
325
  output_data = layer(input_data)
@@ -343,8 +350,15 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
343
350
  continue
344
351
  self.assertEqual(policy.compute_dtype, sublayer.compute_dtype)
345
352
  self.assertEqual(policy.variable_dtype, sublayer.variable_dtype)
353
+ # Restore `init_kwargs`.
354
+ init_kwargs = original_init_kwargs
346
355
 
347
356
  def run_quantization_test(self, instance, cls, init_kwargs, input_data):
357
+ # TODO: revert the following if. This works around a torch
358
+ # quantization failure in `MultiHeadAttention` with Keras 3.7.
359
+ if keras.config.backend() == "torch":
360
+ return
361
+
348
362
  def _get_supported_layers(mode):
349
363
  supported_layers = [keras.layers.Dense, keras.layers.EinsumDense]
350
364
  if mode == "int8":
@@ -361,6 +375,14 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
361
375
  policy_map[layer.path] = keras.dtype_policies.get(
362
376
  f"{mode}_from_float32"
363
377
  )
378
+ # Ensure the correct `dtype` is set for sublayers or submodels in
379
+ # `init_kwargs`.
380
+ original_init_kwargs = init_kwargs.copy()
381
+ for k, v in init_kwargs.items():
382
+ if isinstance(v, keras.Layer):
383
+ config = v.get_config()
384
+ config["dtype"] = policy_map
385
+ init_kwargs[k] = v.__class__.from_config(config)
364
386
  # Instantiate the layer.
365
387
  model = cls(**{**init_kwargs, "dtype": policy_map})
366
388
  # Call layer eagerly.
@@ -382,12 +404,16 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
382
404
  # Check weights loading.
383
405
  weights = model.get_weights()
384
406
  revived_model.set_weights(weights)
407
+ # Restore `init_kwargs`.
408
+ init_kwargs = original_init_kwargs
385
409
 
386
410
  def run_model_saving_test(
387
411
  self,
388
412
  cls,
389
413
  init_kwargs,
390
414
  input_data,
415
+ atol=0.000001,
416
+ rtol=0.000001,
391
417
  ):
392
418
  """Save and load a model from disk and assert output is unchanged."""
393
419
  model = cls(**init_kwargs)
@@ -401,7 +427,7 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
401
427
 
402
428
  # Check that output matches.
403
429
  restored_output = restored_model(input_data)
404
- self.assertAllClose(model_output, restored_output)
430
+ self.assertAllClose(model_output, restored_output, atol=atol, rtol=rtol)
405
431
 
406
432
  def run_backbone_test(
407
433
  self,
@@ -431,8 +457,8 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
431
457
 
432
458
  # Check variable length sequences.
433
459
  if variable_length_data is None:
434
- # If no variable length data passed, assume the second axis of all
435
- # inputs is our sequence axis and create it ourselves.
460
+ # If no variable length data passed, assume the second axis of
461
+ # all inputs is our sequence axis and create it ourselves.
436
462
  variable_length_data = [
437
463
  tree.map_structure(
438
464
  lambda x: x[:, :seq_length, ...], input_data
@@ -453,14 +479,14 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
453
479
  # Check name maps to classname.
454
480
  name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", cls.__name__)
455
481
  name = re.sub("([a-z])([A-Z])", r"\1_\2", name).lower()
456
- self.assertRegexpMatches(backbone.name, name)
482
+ self.assertRegex(backbone.name, name)
457
483
 
458
484
  # Check mixed precision.
459
485
  if run_mixed_precision_check:
460
486
  self.run_precision_test(cls, init_kwargs, input_data)
461
487
 
462
488
  # Check quantization.
463
- if run_quantization_check and has_quantization_support():
489
+ if run_quantization_check:
464
490
  self.run_quantization_test(backbone, cls, init_kwargs, input_data)
465
491
 
466
492
  def run_vision_backbone_test(
@@ -567,6 +593,15 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
567
593
  ds = tf.data.Dataset.from_tensor_slices(train_data).batch(batch_size)
568
594
  x, y, sw = keras.utils.unpack_x_y_sample_weight(train_data)
569
595
 
596
+ # Test: the tree struct output by the
597
+ # preprocessor must match what model expects.
598
+ preprocessed_data = preprocessor(*train_data)[0]
599
+ tree.assert_same_structure(
600
+ preprocessed_data,
601
+ task._inputs_struct,
602
+ check_types=False,
603
+ )
604
+
570
605
  # Test predict.
571
606
  output = task.predict(x)
572
607
  if expected_output_shape is not None:
@@ -43,7 +43,11 @@ SPLIT_PATTERN_1 = (
43
43
  SPLIT_PATTERN_1 = SPLIT_PATTERN_1.replace(
44
44
  "{special_spaces}", SPECIAL_WHITESPACES
45
45
  )
46
- SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$"""
46
+
47
+ # The pattern " \t\r\f\v" is the same as \s "all spaces" but without the \n.
48
+ # Multiple \n\n\n in sequence must not be split for Llama3.
49
+ # SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$"""
50
+ SPLIT_PATTERN_2 = rf"""[ \t\r\f\v६{SPECIAL_WHITESPACES}]$"""
47
51
 
48
52
 
49
53
  def create_alts_for_unsplittable_tokens(unsplittable_tokens):
@@ -196,8 +200,8 @@ class BytePairTokenizer(tokenizer.Tokenizer):
196
200
  """Bype-pair encoding tokenizer layer.
197
201
 
198
202
  This BPE tokenizer provides the same functionality as the official GPT-2
199
- tokenizer. Given the same `vocabulary` which maps tokens to ids, and `merges`
200
- which describes BPE merge rules, it should provide the same output
203
+ tokenizer. Given the same `vocabulary` which maps tokens to ids, and
204
+ `merges` which describes BPE merge rules, it should provide the same output
201
205
  as OpenAI implementation (https://github.com/openai/gpt-2/blob/master/src/encoder.py).
202
206
  Different from OpenAI, this implementation is graph-compatible, so you can
203
207
  use it within a `tf.data` pipeline.
@@ -1,13 +1,5 @@
1
1
  import numpy as np
2
2
 
3
- try:
4
- import tensorflow as tf
5
- except ImportError:
6
- raise ImportError(
7
- "To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
8
- "The TensorFlow package is required for data preprocessing with any backend."
9
- )
10
-
11
3
  from keras_hub.src.api_export import keras_hub_export
12
4
  from keras_hub.src.tokenizers import tokenizer
13
5
  from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
@@ -15,8 +7,10 @@ from keras_hub.src.utils.tensor_utils import is_int_dtype
15
7
  from keras_hub.src.utils.tensor_utils import preprocessing_function
16
8
 
17
9
  try:
10
+ import tensorflow as tf
18
11
  import tensorflow_text as tf_text
19
12
  except ImportError:
13
+ tf = None
20
14
  tf_text = None
21
15
 
22
16
 
@@ -156,8 +150,7 @@ class ByteTokenizer(tokenizer.Tokenizer):
156
150
  ):
157
151
  if not is_int_dtype(dtype):
158
152
  raise ValueError(
159
- "Output dtype must be an integer type. "
160
- f"Received: dtype={dtype}"
153
+ f"Output dtype must be an integer type. Received: dtype={dtype}"
161
154
  )
162
155
 
163
156
  # Check normalization_form.
@@ -4,14 +4,6 @@ import os
4
4
 
5
5
  import keras
6
6
 
7
- try:
8
- import tensorflow as tf
9
- except ImportError:
10
- raise ImportError(
11
- "To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
12
- "The TensorFlow package is required for data preprocessing with any backend."
13
- )
14
-
15
7
  from keras_hub.src.api_export import keras_hub_export
16
8
  from keras_hub.src.tokenizers import tokenizer
17
9
  from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
@@ -21,11 +13,12 @@ from keras_hub.src.utils.tensor_utils import preprocessing_function
21
13
  from keras_hub.src.utils.tensor_utils import tensor_to_list
22
14
 
23
15
  try:
16
+ import tensorflow as tf
24
17
  import tensorflow_text as tf_text
25
18
  except ImportError:
19
+ tf = None
26
20
  tf_text = None
27
21
 
28
-
29
22
  VOCAB_FILENAME = "vocabulary.spm"
30
23
 
31
24
 
@@ -1,17 +1,13 @@
1
1
  import io
2
2
 
3
- try:
4
- import tensorflow as tf
5
- except ImportError:
6
- raise ImportError(
7
- "To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
8
- "The TensorFlow package is required for data preprocessing with any backend."
9
- )
3
+ from keras_hub.src.utils.tensor_utils import assert_tf_libs_installed
10
4
 
11
5
  try:
12
6
  import sentencepiece as spm
7
+ import tensorflow as tf
13
8
  except ImportError:
14
9
  spm = None
10
+ tf = None
15
11
 
16
12
  from keras_hub.src.api_export import keras_hub_export
17
13
 
@@ -52,7 +48,8 @@ def compute_sentence_piece_proto(
52
48
 
53
49
  Basic Usage (from Dataset).
54
50
  >>> inputs = tf.data.Dataset.from_tensor_slices(["Drifting Along"])
55
- >>> proto = keras_hub.tokenizers.compute_sentence_piece_proto(inputs, vocabulary_size=15)
51
+ >>> proto = keras_hub.tokenizers.compute_sentence_piece_proto(
52
+ ... inputs, vocabulary_size=15)
56
53
  >>> tokenizer = keras_hub.tokenizers.SentencePieceTokenizer(proto=proto)
57
54
  >>> outputs = inputs.map(tokenizer)
58
55
  >>> for output in outputs:
@@ -82,6 +79,7 @@ def compute_sentence_piece_proto(
82
79
  tf.Tensor([ 4 8 12 5 9 14 5 6 13 4 7 10 11 6 13],
83
80
  shape=(15,), dtype=int32)
84
81
  """
82
+ assert_tf_libs_installed("compute_sentence_piece_proto")
85
83
 
86
84
  if spm is None:
87
85
  raise ImportError(
@@ -92,7 +90,8 @@ def compute_sentence_piece_proto(
92
90
 
93
91
  if not isinstance(data, (list, tuple, tf.data.Dataset)):
94
92
  raise ValueError(
95
- "The `data` argument must be either `tf.data.Dataset` or `tuple` or `list`. "
93
+ "The `data` argument must be either `tf.data.Dataset` or "
94
+ "`tuple` or `list`. "
96
95
  f"Received: type(data)={type(data)}."
97
96
  )
98
97
 
@@ -105,8 +104,7 @@ def compute_sentence_piece_proto(
105
104
  model_writer = (
106
105
  open(proto_output_file, "wb") if proto_output_file else io.BytesIO()
107
106
  )
108
- is_dataset = isinstance(data, tf.data.Dataset)
109
- if is_dataset:
107
+ if tf is not None and isinstance(data, tf.data.Dataset):
110
108
  spm.SentencePieceTrainer.train(
111
109
  sentence_iterator=data.as_numpy_iterator(),
112
110
  model_writer=model_writer,
@@ -10,7 +10,7 @@ from keras_hub.src.utils.preset_utils import builtin_presets
10
10
  from keras_hub.src.utils.preset_utils import find_subclass
11
11
  from keras_hub.src.utils.preset_utils import get_file
12
12
  from keras_hub.src.utils.preset_utils import get_preset_loader
13
- from keras_hub.src.utils.preset_utils import save_serialized_object
13
+ from keras_hub.src.utils.preset_utils import get_preset_saver
14
14
  from keras_hub.src.utils.python_utils import classproperty
15
15
  from keras_hub.src.utils.tensor_utils import preprocessing_function
16
16
 
@@ -66,7 +66,7 @@ class Tokenizer(PreprocessingLayer):
66
66
  backbone_cls = None
67
67
 
68
68
  def __init__(self, *args, **kwargs):
69
- self.config_name = kwargs.pop("config_name", TOKENIZER_CONFIG_FILE)
69
+ self.config_file = kwargs.pop("config_file", TOKENIZER_CONFIG_FILE)
70
70
  super().__init__(*args, **kwargs)
71
71
  self.file_assets = None
72
72
 
@@ -178,7 +178,7 @@ class Tokenizer(PreprocessingLayer):
178
178
  config = super().get_config()
179
179
  config.update(
180
180
  {
181
- "config_name": self.config_name,
181
+ "config_file": self.config_file,
182
182
  }
183
183
  )
184
184
  return config
@@ -189,11 +189,8 @@ class Tokenizer(PreprocessingLayer):
189
189
  Args:
190
190
  preset_dir: The path to the local model preset directory.
191
191
  """
192
- save_serialized_object(self, preset_dir, config_file=self.config_name)
193
- subdir = self.config_name.split(".")[0]
194
- asset_dir = os.path.join(preset_dir, ASSET_DIR, subdir)
195
- os.makedirs(asset_dir, exist_ok=True)
196
- self.save_assets(asset_dir)
192
+ saver = get_preset_saver(preset_dir)
193
+ saver.save_tokenizer(self)
197
194
 
198
195
  @preprocessing_function
199
196
  def call(self, inputs, *args, training=None, **kwargs):
@@ -202,11 +199,11 @@ class Tokenizer(PreprocessingLayer):
202
199
  def load_preset_assets(self, preset):
203
200
  asset_path = None
204
201
  for asset in self.file_assets:
205
- subdir = self.config_name.split(".")[0]
202
+ subdir = self.config_file.split(".")[0]
206
203
  preset_path = os.path.join(ASSET_DIR, subdir, asset)
207
204
  asset_path = get_file(preset, preset_path)
208
- tokenizer_config_name = os.path.dirname(asset_path)
209
- self.load_assets(tokenizer_config_name)
205
+ tokenizer_config_file = os.path.dirname(asset_path)
206
+ self.load_assets(tokenizer_config_file)
210
207
 
211
208
  @classproperty
212
209
  def presets(cls):
@@ -217,7 +214,7 @@ class Tokenizer(PreprocessingLayer):
217
214
  def from_preset(
218
215
  cls,
219
216
  preset,
220
- config_name=TOKENIZER_CONFIG_FILE,
217
+ config_file=TOKENIZER_CONFIG_FILE,
221
218
  **kwargs,
222
219
  ):
223
220
  """Instantiate a `keras_hub.models.Tokenizer` from a model preset.
@@ -263,4 +260,4 @@ class Tokenizer(PreprocessingLayer):
263
260
  backbone_cls = loader.check_backbone_class()
264
261
  if cls.backbone_cls != backbone_cls:
265
262
  cls = find_subclass(preset, cls, backbone_cls)
266
- return loader.load_tokenizer(cls, config_name, **kwargs)
263
+ return loader.load_tokenizer(cls, config_file, **kwargs)
@@ -203,8 +203,7 @@ class UnicodeCodepointTokenizer(tokenizer.Tokenizer):
203
203
  ) -> None:
204
204
  if not is_int_dtype(dtype):
205
205
  raise ValueError(
206
- "Output dtype must be an integer type. "
207
- f"Received: dtype={dtype}"
206
+ f"Output dtype must be an integer type. Received: dtype={dtype}"
208
207
  )
209
208
 
210
209
  # Check normalization_form.
@@ -226,8 +225,9 @@ class UnicodeCodepointTokenizer(tokenizer.Tokenizer):
226
225
  if normalization_form:
227
226
  if input_encoding != "UTF-8":
228
227
  raise ValueError(
229
- """Normalization Forms are Only Supported for Input Encoding
230
- UTF-8"""
228
+ "Normalization Forms are Only Supported for Input "
229
+ "Encoding UTF-8"
230
+ ""
231
231
  )
232
232
 
233
233
  super().__init__(dtype=dtype, **kwargs)
@@ -259,8 +259,9 @@ class UnicodeCodepointTokenizer(tokenizer.Tokenizer):
259
259
  return config
260
260
 
261
261
  def vocabulary_size(self):
262
- """Get the size of the tokenizer vocabulary. None implies no vocabulary
263
- size was provided"""
262
+ """Get the size of the tokenizer vocabulary.
263
+
264
+ None implies no vocabulary size was provided"""
264
265
  return self._vocabulary_size
265
266
 
266
267
  def get_vocabulary(self):
@@ -334,6 +335,7 @@ class UnicodeCodepointTokenizer(tokenizer.Tokenizer):
334
335
  id = ord(token)
335
336
  if id >= self.vocabulary_size():
336
337
  raise ValueError(
337
- f"Token {token} is not supported by `UnicodeCodepointTokenizer`."
338
+ f"Token {token} is not supported by "
339
+ "`UnicodeCodepointTokenizer`."
338
340
  )
339
341
  return id
@@ -1,5 +1,6 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
2
  from keras_hub.src.tokenizers.word_piece_tokenizer import pretokenize
3
+ from keras_hub.src.utils.tensor_utils import assert_tf_libs_installed
3
4
 
4
5
  try:
5
6
  import tensorflow as tf
@@ -55,7 +56,8 @@ def compute_word_piece_vocabulary(
55
56
  suffix_indicator: str. The characters prepended to a
56
57
  WordPiece to indicate that it is a suffix to another subword.
57
58
  E.g. `"##ing"`. Defaults to `"##"`.
58
- reserved_tokens: list of strings. A list of tokens that must be included in the vocabulary.
59
+ reserved_tokens: list of strings. A list of tokens that must be included
60
+ in the vocabulary.
59
61
 
60
62
  Returns:
61
63
  Returns a list of vocabulary terms.
@@ -67,7 +69,10 @@ def compute_word_piece_vocabulary(
67
69
  >>> vocab = compute_word_piece_vocabulary(inputs, 13)
68
70
  >>> vocab
69
71
  ['[PAD]', '[CLS]', '[SEP]', '[UNK]', '[MASK]', 'a', 'b', 'm', 'p', 'r', 's', 't', '##at']
70
- >>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(vocabulary=vocab, oov_token="[UNK]")
72
+ >>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
73
+ ... vocabulary=vocab,
74
+ ... oov_token="[UNK]",
75
+ ... )
71
76
  >>> outputs = inputs.map(tokenizer.tokenize)
72
77
  >>> for x in outputs:
73
78
  ... print(x)
@@ -112,7 +117,9 @@ def compute_word_piece_vocabulary(
112
117
  tokenizer = keras_hub.tokenizers.WordPieceTokenizer(vocabulary=vocab)
113
118
  inputs.map(tokenizer.tokenize)
114
119
  ```
115
- """
120
+ """ # noqa: E501
121
+ assert_tf_libs_installed("compute_word_piece_vocabulary")
122
+
116
123
  # Read data files.
117
124
  if not isinstance(data, (list, tf.data.Dataset)):
118
125
  raise ValueError(
@@ -2,7 +2,6 @@ import sys
2
2
 
3
3
  import keras
4
4
  from absl import logging
5
- from packaging.version import parse
6
5
 
7
6
  try:
8
7
  import tensorflow as tf
@@ -36,23 +35,13 @@ def print_msg(message, line_break=True):
36
35
  logging.info(message)
37
36
 
38
37
 
38
+ # Register twice for backwards compat.
39
39
  @keras.saving.register_keras_serializable(package="keras_hub")
40
+ @keras.saving.register_keras_serializable(package="keras_nlp")
40
41
  def gelu_approximate(x):
41
42
  return keras.activations.gelu(x, approximate=True)
42
43
 
43
44
 
44
- def has_quantization_support():
45
- return False if parse(keras.version()) < parse("3.4.0") else True
46
-
47
-
48
- def assert_quantization_support():
49
- if not has_quantization_support():
50
- raise ValueError(
51
- "Quantization API requires Keras >= 3.4.0 to function "
52
- f"correctly. Received: '{keras.version()}'"
53
- )
54
-
55
-
56
45
  def standardize_data_format(data_format):
57
46
  if data_format is None:
58
47
  return keras.config.image_data_format()
@@ -232,7 +232,7 @@ class PipelineModel(keras.Model):
232
232
  ):
233
233
  data = self.preprocess_samples(x, y, sample_weight)
234
234
  x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
235
- x = ops.convert_to_tensor(x)
235
+ x = tree.map_structure(ops.convert_to_tensor, x)
236
236
  if y is not None:
237
237
  y = ops.convert_to_tensor(y)
238
238
  if sample_weight is not None:
@@ -253,7 +253,7 @@ class PipelineModel(keras.Model):
253
253
  ):
254
254
  data = self.preprocess_samples(x, y, sample_weight)
255
255
  x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
256
- x = ops.convert_to_tensor(x)
256
+ x = tree.map_structure(ops.convert_to_tensor, x)
257
257
  if y is not None:
258
258
  y = ops.convert_to_tensor(y)
259
259
  if sample_weight is not None:
@@ -272,7 +272,7 @@ class PipelineModel(keras.Model):
272
272
  ):
273
273
  data = self.preprocess_samples(x)
274
274
  x, _, _ = keras.utils.unpack_x_y_sample_weight(data)
275
- x = ops.convert_to_tensor(x)
275
+ x = tree.map_structure(ops.convert_to_tensor, x)
276
276
  return super().predict_on_batch(
277
277
  x=x,
278
278
  **kwargs,