keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@ from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
14
14
  from keras_hub.src.layers.modeling.reversible_embedding import (
15
15
  ReversibleEmbedding,
16
16
  )
17
+ from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization
17
18
  from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
18
19
  from keras_hub.src.layers.modeling.sine_position_encoding import (
19
20
  SinePositionEncoding,
@@ -33,22 +34,39 @@ from keras_hub.src.layers.preprocessing.multi_segment_packer import (
33
34
  )
34
35
  from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion
35
36
  from keras_hub.src.layers.preprocessing.random_swap import RandomSwap
36
- from keras_hub.src.layers.preprocessing.resizing_image_converter import (
37
- ResizingImageConverter,
38
- )
39
37
  from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
38
+ from keras_hub.src.models.basnet.basnet_image_converter import (
39
+ BASNetImageConverter,
40
+ )
41
+ from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter
42
+ from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import (
43
+ DeepLabV3ImageConverter,
44
+ )
40
45
  from keras_hub.src.models.densenet.densenet_image_converter import (
41
46
  DenseNetImageConverter,
42
47
  )
48
+ from keras_hub.src.models.efficientnet.efficientnet_image_converter import (
49
+ EfficientNetImageConverter,
50
+ )
51
+ from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter
43
52
  from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
44
53
  PaliGemmaImageConverter,
45
54
  )
46
55
  from keras_hub.src.models.resnet.resnet_image_converter import (
47
56
  ResNetImageConverter,
48
57
  )
58
+ from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator
59
+ from keras_hub.src.models.retinanet.retinanet_image_converter import (
60
+ RetinaNetImageConverter,
61
+ )
49
62
  from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
50
63
  from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
51
64
  from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
65
+ from keras_hub.src.models.segformer.segformer_image_converter import (
66
+ SegFormerImageConverter,
67
+ )
68
+ from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter
69
+ from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter
52
70
  from keras_hub.src.models.whisper.whisper_audio_converter import (
53
71
  WhisperAudioConverter,
54
72
  )
@@ -29,6 +29,9 @@ from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import (
29
29
  BartSeq2SeqLMPreprocessor,
30
30
  )
31
31
  from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer
32
+ from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter
33
+ from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
34
+ from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor
32
35
  from keras_hub.src.models.bert.bert_backbone import BertBackbone
33
36
  from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM
34
37
  from keras_hub.src.models.bert.bert_masked_lm_preprocessor import (
@@ -53,8 +56,11 @@ from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import (
53
56
  from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer
54
57
  from keras_hub.src.models.causal_lm import CausalLM
55
58
  from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
59
+ from keras_hub.src.models.clip.clip_backbone import CLIPBackbone
56
60
  from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor
61
+ from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder
57
62
  from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer
63
+ from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder
58
64
  from keras_hub.src.models.csp_darknet.csp_darknet_backbone import (
59
65
  CSPDarkNetBackbone,
60
66
  )
@@ -85,6 +91,15 @@ from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor imp
85
91
  from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import (
86
92
  DebertaV3Tokenizer,
87
93
  )
94
+ from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
95
+ DeepLabV3Backbone,
96
+ )
97
+ from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import (
98
+ DeepLabV3ImageSegmenterPreprocessor,
99
+ )
100
+ from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import (
101
+ DeepLabV3ImageSegmenter,
102
+ )
88
103
  from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone
89
104
  from keras_hub.src.models.densenet.densenet_image_classifier import (
90
105
  DenseNetImageClassifier,
@@ -119,6 +134,12 @@ from keras_hub.src.models.distil_bert.distil_bert_tokenizer import (
119
134
  from keras_hub.src.models.efficientnet.efficientnet_backbone import (
120
135
  EfficientNetBackbone,
121
136
  )
137
+ from keras_hub.src.models.efficientnet.efficientnet_image_classifier import (
138
+ EfficientNetImageClassifier,
139
+ )
140
+ from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import (
141
+ EfficientNetImageClassifierPreprocessor,
142
+ )
122
143
  from keras_hub.src.models.electra.electra_backbone import ElectraBackbone
123
144
  from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer
124
145
  from keras_hub.src.models.f_net.f_net_backbone import FNetBackbone
@@ -144,6 +165,11 @@ from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import (
144
165
  )
145
166
  from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer
146
167
  from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
168
+ from keras_hub.src.models.flux.flux_model import FluxBackbone
169
+ from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage
170
+ from keras_hub.src.models.flux.flux_text_to_image_preprocessor import (
171
+ FluxTextToImagePreprocessor,
172
+ )
147
173
  from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
148
174
  from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM
149
175
  from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
@@ -167,22 +193,28 @@ from keras_hub.src.models.image_classifier import ImageClassifier
167
193
  from keras_hub.src.models.image_classifier_preprocessor import (
168
194
  ImageClassifierPreprocessor,
169
195
  )
196
+ from keras_hub.src.models.image_object_detector import ImageObjectDetector
197
+ from keras_hub.src.models.image_object_detector_preprocessor import (
198
+ ImageObjectDetectorPreprocessor,
199
+ )
170
200
  from keras_hub.src.models.image_segmenter import ImageSegmenter
171
201
  from keras_hub.src.models.image_segmenter_preprocessor import (
172
202
  ImageSegmenterPreprocessor,
173
203
  )
174
- from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone
175
- from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM
176
- from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import (
177
- Llama3CausalLMPreprocessor,
178
- )
179
- from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
204
+ from keras_hub.src.models.image_to_image import ImageToImage
205
+ from keras_hub.src.models.inpaint import Inpaint
180
206
  from keras_hub.src.models.llama.llama_backbone import LlamaBackbone
181
207
  from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM
182
208
  from keras_hub.src.models.llama.llama_causal_lm_preprocessor import (
183
209
  LlamaCausalLMPreprocessor,
184
210
  )
185
211
  from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer
212
+ from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone
213
+ from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM
214
+ from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import (
215
+ Llama3CausalLMPreprocessor,
216
+ )
217
+ from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
186
218
  from keras_hub.src.models.masked_lm import MaskedLM
187
219
  from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor
188
220
  from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone
@@ -191,11 +223,10 @@ from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import (
191
223
  MistralCausalLMPreprocessor,
192
224
  )
193
225
  from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer
194
- from keras_hub.src.models.mix_transformer.mix_transformer_backbone import (
195
- MiTBackbone,
196
- )
197
- from keras_hub.src.models.mix_transformer.mix_transformer_classifier import (
198
- MiTImageClassifier,
226
+ from keras_hub.src.models.mit.mit_backbone import MiTBackbone
227
+ from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier
228
+ from keras_hub.src.models.mit.mit_image_classifier_preprocessor import (
229
+ MiTImageClassifierPreprocessor,
199
230
  )
200
231
  from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone
201
232
  from keras_hub.src.models.mobilenet.mobilenet_image_classifier import (
@@ -233,6 +264,13 @@ from keras_hub.src.models.resnet.resnet_image_classifier import (
233
264
  from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import (
234
265
  ResNetImageClassifierPreprocessor,
235
266
  )
267
+ from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
268
+ from keras_hub.src.models.retinanet.retinanet_object_detector import (
269
+ RetinaNetObjectDetector,
270
+ )
271
+ from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import (
272
+ RetinaNetObjectDetectorPreprocessor,
273
+ )
236
274
  from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone
237
275
  from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM
238
276
  from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import (
@@ -254,13 +292,26 @@ from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
254
292
  from keras_hub.src.models.sam.sam_backbone import SAMBackbone
255
293
  from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter
256
294
  from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (
257
- SAMImageSegmenterPreprocessor as SamImageSegmenterPreprocessor,
295
+ SAMImageSegmenterPreprocessor,
296
+ )
297
+ from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
298
+ from keras_hub.src.models.segformer.segformer_image_segmenter import (
299
+ SegFormerImageSegmenter,
300
+ )
301
+ from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import (
302
+ SegFormerImageSegmenterPreprocessor,
258
303
  )
259
304
  from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
260
305
  from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
261
306
  from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
262
307
  StableDiffusion3Backbone,
263
308
  )
309
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import (
310
+ StableDiffusion3ImageToImage,
311
+ )
312
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import (
313
+ StableDiffusion3Inpaint,
314
+ )
264
315
  from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import (
265
316
  StableDiffusion3TextToImage,
266
317
  )
@@ -279,6 +330,14 @@ from keras_hub.src.models.text_classifier_preprocessor import (
279
330
  from keras_hub.src.models.text_to_image import TextToImage
280
331
  from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
281
332
  from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier
333
+ from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import (
334
+ VGGImageClassifierPreprocessor,
335
+ )
336
+ from keras_hub.src.models.vit.vit_backbone import ViTBackbone
337
+ from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier
338
+ from keras_hub.src.models.vit.vit_image_classifier_preprocessor import (
339
+ ViTImageClassifierPreprocessor,
340
+ )
282
341
  from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone
283
342
  from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone
284
343
  from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
@@ -21,8 +21,8 @@ from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer
21
21
  from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
22
22
  from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
23
23
  from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
24
- from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
25
24
  from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer
25
+ from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
26
26
  from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer
27
27
  from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer
28
28
  from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
@@ -0,0 +1,2 @@
1
+ # TODO: Once all bounding boxes are moved to keras repostory remove the
2
+ # bounding box folder.
@@ -20,29 +20,74 @@ class RequiresImagesException(Exception):
20
20
  ALL_AXES = 4
21
21
 
22
22
 
23
- def _encode_box_to_deltas(
23
+ def encode_box_to_deltas(
24
24
  anchors,
25
25
  boxes,
26
- anchor_format: str,
27
- box_format: str,
26
+ anchor_format,
27
+ box_format,
28
+ encoding_format="center_yxhw",
28
29
  variance=None,
29
30
  image_shape=None,
30
31
  ):
31
- """Converts bounding_boxes from `center_yxhw` to delta format."""
32
+ """Encodes bounding boxes relative to anchors as deltas.
33
+
34
+ This function calculates the deltas that represent the difference between
35
+ bounding boxes and provided anchors. Deltas encode the offsets and scaling
36
+ factors to apply to anchors to obtain the target boxes.
37
+
38
+ Boxes and anchors are first converted to the specified `encoding_format`
39
+ (defaulting to `center_yxhw`) for consistent delta representation.
40
+
41
+ Args:
42
+ anchors: `Tensors`. Anchor boxes with shape of `(N, 4)` where N is the
43
+ number of anchors.
44
+ boxes: `Tensors` Bounding boxes to encode. Boxes can be of shape
45
+ `(B, N, 4)` or `(N, 4)`.
46
+ anchor_format: str. The format of the input `anchors`
47
+ (e.g., "xyxy", "xywh", etc.).
48
+ box_format: str. The format of the input `boxes`
49
+ (e.g., "xyxy", "xywh", etc.).
50
+ encoding_format: str. The intermediate format to which boxes and anchors
51
+ are converted before delta calculation. Defaults to "center_yxhw".
52
+ variance: `List[float]`. A 4-element array/tensor representing variance
53
+ factors to scale the box deltas. If provided, the calculated deltas
54
+ are divided by the variance. Defaults to None.
55
+ image_shape: `Tuple[int]`. The shape of the image (height, width, 3).
56
+ When using relative bounding box format for `box_format` the
57
+ `image_shape` is used for normalization.
58
+ Returns:
59
+ Encoded box deltas. The return type matches the `encode_format`.
60
+
61
+ Raises:
62
+ ValueError: If `variance` is not None and its length is not 4.
63
+ ValueError: If `encoding_format` is not `"center_xywh"` or
64
+ `"center_yxhw"`.
65
+
66
+ """
32
67
  if variance is not None:
33
68
  variance = ops.convert_to_tensor(variance, "float32")
34
69
  var_len = variance.shape[-1]
35
70
 
36
71
  if var_len != 4:
37
72
  raise ValueError(f"`variance` must be length 4, got {variance}")
73
+
74
+ if encoding_format not in ["center_xywh", "center_yxhw"]:
75
+ raise ValueError(
76
+ "`encoding_format` should be one of 'center_xywh' or "
77
+ f"'center_yxhw', got {encoding_format}"
78
+ )
79
+
38
80
  encoded_anchors = convert_format(
39
81
  anchors,
40
82
  source=anchor_format,
41
- target="center_yxhw",
83
+ target=encoding_format,
42
84
  image_shape=image_shape,
43
85
  )
44
86
  boxes = convert_format(
45
- boxes, source=box_format, target="center_yxhw", image_shape=image_shape
87
+ boxes,
88
+ source=box_format,
89
+ target=encoding_format,
90
+ image_shape=image_shape,
46
91
  )
47
92
  anchor_dimensions = ops.maximum(
48
93
  encoded_anchors[..., 2:], keras.backend.epsilon()
@@ -61,15 +106,54 @@ def _encode_box_to_deltas(
61
106
  return boxes_delta
62
107
 
63
108
 
64
- def _decode_deltas_to_boxes(
109
+ def decode_deltas_to_boxes(
65
110
  anchors,
66
111
  boxes_delta,
67
- anchor_format: str,
68
- box_format: str,
112
+ anchor_format,
113
+ box_format,
114
+ encoded_format="center_yxhw",
69
115
  variance=None,
70
116
  image_shape=None,
71
117
  ):
72
- """Converts bounding_boxes from delta format to `center_yxhw`."""
118
+ """Converts bounding boxes from delta format to the specified `box_format`.
119
+
120
+ This function decodes bounding box deltas relative to anchors to obtain the
121
+ final bounding box coordinates. The boxes are encoded in a specific
122
+ `encoded_format` (center_yxhw by default) during the decoding process.
123
+ This allows flexibility in how the deltas are applied to the anchors.
124
+
125
+ Args:
126
+ anchors: Can be `Tensors` or `Dict[Tensors]` where keys are level
127
+ indices and values are corresponding anchor boxes.
128
+ The shape of the array/tensor should be `(N, 4)` where N is the
129
+ number of anchors.
130
+ boxes_delta Can be `Tensors` or `Dict[Tensors]` Bounding box deltas
131
+ must have the same type and structure as `anchors`. The
132
+ shape of the array/tensor can be `(N, 4)` or `(B, N, 4)` where N is
133
+ the number of boxes.
134
+ anchor_format: str. The format of the input `anchors`.
135
+ (e.g., `"xyxy"`, `"xywh"`, etc.)
136
+ box_format: str. The desired format for the output boxes.
137
+ (e.g., `"xyxy"`, `"xywh"`, etc.)
138
+ encoded_format: str. Raw output format from regression head. Defaults
139
+ to `"center_yxhw"`.
140
+ variance: `List[floats]`. A 4-element array/tensor representing
141
+ variance factors to scale the box deltas. If provided, the deltas
142
+ are multiplied by the variance before being applied to the anchors.
143
+ Defaults to None.
144
+ image_shape: The shape of the image (height, width). This is needed
145
+ if normalization to image size is required when converting between
146
+ formats. Defaults to None.
147
+
148
+ Returns:
149
+ Decoded box coordinates. The return type matches the `box_format`.
150
+
151
+ Raises:
152
+ ValueError: If `variance` is not None and its length is not 4.
153
+ ValueError: If `encoded_format` is not `"center_xywh"` or
154
+ `"center_yxhw"`.
155
+
156
+ """
73
157
  if variance is not None:
74
158
  variance = ops.convert_to_tensor(variance, "float32")
75
159
  var_len = variance.shape[-1]
@@ -77,11 +161,17 @@ def _decode_deltas_to_boxes(
77
161
  if var_len != 4:
78
162
  raise ValueError(f"`variance` must be length 4, got {variance}")
79
163
 
164
+ if encoded_format not in ["center_xywh", "center_yxhw"]:
165
+ raise ValueError(
166
+ f"`encoded_format` should be 'center_xywh' or 'center_yxhw', "
167
+ f"but got '{encoded_format}'."
168
+ )
169
+
80
170
  def decode_single_level(anchor, box_delta):
81
171
  encoded_anchor = convert_format(
82
172
  anchor,
83
173
  source=anchor_format,
84
- target="center_yxhw",
174
+ target=encoded_format,
85
175
  image_shape=image_shape,
86
176
  )
87
177
  if variance is not None:
@@ -97,7 +187,7 @@ def _decode_deltas_to_boxes(
97
187
  )
98
188
  box = convert_format(
99
189
  box,
100
- source="center_yxhw",
190
+ source=encoded_format,
101
191
  target=box_format,
102
192
  image_shape=image_shape,
103
193
  )
@@ -66,7 +66,7 @@ class FNetEncoder(keras.layers.Layer):
66
66
  layer_norm_epsilon=1e-5,
67
67
  kernel_initializer="glorot_uniform",
68
68
  bias_initializer="zeros",
69
- **kwargs
69
+ **kwargs,
70
70
  ):
71
71
  super().__init__(**kwargs)
72
72
  self.intermediate_dim = intermediate_dim
@@ -34,7 +34,8 @@ class MaskedLMHead(keras.layers.Layer):
34
34
  token_embedding: Optional. A `keras_hub.layers.ReversibleEmbedding`
35
35
  instance. If passed, the layer will be used to project from the
36
36
  `hidden_dim` of the model to the output `vocabulary_size`.
37
- intermediate_activation: The activation function of intermediate dense layer.
37
+ intermediate_activation: The activation function of intermediate dense
38
+ layer.
38
39
  activation: The activation function for the outputs of the layer.
39
40
  Usually either `None` (return logits), or `"softmax"`
40
41
  (return probabilities).
@@ -1,9 +1,7 @@
1
1
  import keras
2
2
  from keras import ops
3
- from packaging.version import parse
4
3
 
5
4
  from keras_hub.src.api_export import keras_hub_export
6
- from keras_hub.src.utils.keras_utils import assert_quantization_support
7
5
 
8
6
 
9
7
  @keras_hub_export("keras_hub.layers.ReversibleEmbedding")
@@ -145,10 +143,6 @@ class ReversibleEmbedding(keras.layers.Embedding):
145
143
  if not self.built:
146
144
  return
147
145
  super().save_own_variables(store)
148
- # Before Keras 3.2, the reverse weight is saved in the super() call.
149
- # After Keras 3.2, the reverse weight must be saved manually.
150
- if parse(keras.version()) < parse("3.2.0"):
151
- return
152
146
  target_variables = []
153
147
  if not self.tie_weights:
154
148
  # Store the reverse embedding weights as the last weights.
@@ -239,9 +233,7 @@ class ReversibleEmbedding(keras.layers.Embedding):
239
233
 
240
234
  def quantize(self, mode, type_check=True):
241
235
  import gc
242
- import inspect
243
236
 
244
- assert_quantization_support()
245
237
  if type_check and type(self) is not ReversibleEmbedding:
246
238
  raise NotImplementedError(
247
239
  f"Layer {self.__class__.__name__} does not have a `quantize()` "
@@ -250,14 +242,9 @@ class ReversibleEmbedding(keras.layers.Embedding):
250
242
  self._check_quantize_args(mode, self.compute_dtype)
251
243
 
252
244
  def abs_max_quantize(inputs, axis):
253
- sig = inspect.signature(keras.quantizers.abs_max_quantize)
254
- if "to_numpy" in sig.parameters:
255
- return keras.quantizers.abs_max_quantize(
256
- inputs, axis=axis, to_numpy=True
257
- )
258
- else:
259
- # `keras<=3.4.1` doesn't support `to_numpy`
260
- return keras.quantizers.abs_max_quantize(inputs, axis=axis)
245
+ return keras.quantizers.abs_max_quantize(
246
+ inputs, axis=axis, to_numpy=True
247
+ )
261
248
 
262
249
  self._tracker.unlock()
263
250
  if mode == "int8":
@@ -0,0 +1,36 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+
6
+
7
+ @keras_hub_export("keras_hub.layers.RMSNormalization")
8
+ class RMSNormalization(keras.layers.Layer):
9
+ """Root Mean Square (RMS) Normalization layer.
10
+
11
+ This layer normalizes the input tensor based on its RMS value and applies
12
+ a learned scaling factor.
13
+
14
+ Args:
15
+ input_dim: int. The dimensionality of the input tensor.
16
+ """
17
+
18
+ def __init__(self, input_dim):
19
+ super().__init__()
20
+ self.scale = self.add_weight(
21
+ name="scale", shape=(input_dim,), initializer="ones"
22
+ )
23
+
24
+ def call(self, x):
25
+ """Applies RMS normalization to the input tensor.
26
+
27
+ Args:
28
+ x: Input tensor of shape (batch_size, input_dim).
29
+
30
+ Returns:
31
+ The RMS-normalized tensor of the same shape (batch_size, input_dim),
32
+ scaled by the learned `scale` parameter.
33
+ """
34
+ x = ops.cast(x, float)
35
+ rrms = ops.rsqrt(ops.mean(ops.square(x), axis=-1, keepdims=True) + 1e-6)
36
+ return (x * rrms) * self.scale
@@ -11,7 +11,8 @@ class RotaryEmbedding(keras.layers.Layer):
11
11
  This layer encodes absolute positional information with a rotation
12
12
  matrix. It calculates the rotary encoding with a mix of sine and
13
13
  cosine functions with geometrically increasing wavelengths.
14
- Defined and formulated in [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4).
14
+ Defined and formulated in
15
+ [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4).
15
16
  The input must be a tensor with shape a sequence dimension and a feature
16
17
  dimension. Typically, this will either an input with shape
17
18
  `(batch_size, sequence_length, feature_length)` or
@@ -65,7 +66,7 @@ class RotaryEmbedding(keras.layers.Layer):
65
66
  scaling_factor=1.0,
66
67
  sequence_axis=1,
67
68
  feature_axis=-1,
68
- **kwargs
69
+ **kwargs,
69
70
  ):
70
71
  super().__init__(**kwargs)
71
72
  self.max_wavelength = max_wavelength
@@ -58,7 +58,7 @@ class TokenAndPositionEmbedding(keras.layers.Layer):
58
58
  tie_weights=True,
59
59
  embeddings_initializer="uniform",
60
60
  mask_zero=False,
61
- **kwargs
61
+ **kwargs,
62
62
  ):
63
63
  super().__init__(**kwargs)
64
64
  if vocabulary_size is None:
@@ -5,12 +5,13 @@ from keras_hub.src.api_export import keras_hub_export
5
5
  from keras_hub.src.layers.modeling.cached_multi_head_attention import (
6
6
  CachedMultiHeadAttention,
7
7
  )
8
- from keras_hub.src.utils.keras_utils import clone_initializer
9
-
10
- from keras_hub.src.layers.modeling.transformer_layer_utils import ( # isort:skip
8
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
11
9
  compute_causal_mask,
10
+ )
11
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
12
12
  merge_padding_and_attention_mask,
13
13
  )
14
+ from keras_hub.src.utils.keras_utils import clone_initializer
14
15
 
15
16
 
16
17
  @keras_hub_export("keras_hub.layers.TransformerDecoder")
@@ -265,13 +266,13 @@ class TransformerDecoder(keras.layers.Layer):
265
266
  `[batch_size, decoder_sequence_length]`.
266
267
  decoder_attention_mask: a boolean Tensor. Customized decoder
267
268
  sequence mask, must be of shape
268
- `[batch_size, decoder_sequence_length, decoder_sequence_length]`.
269
+ `[batch_size, decoder_sequence_length, decoder_sequence_length]`
269
270
  encoder_padding_mask: a boolean Tensor, the padding mask of encoder
270
271
  sequence, must be of shape
271
272
  `[batch_size, encoder_sequence_length]`.
272
273
  encoder_attention_mask: a boolean Tensor. Customized encoder
273
274
  sequence mask, must be of shape
274
- `[batch_size, encoder_sequence_length, encoder_sequence_length]`.
275
+ `[batch_size, encoder_sequence_length, encoder_sequence_length]`
275
276
  self_attention_cache: a dense float Tensor. The cache of key/values
276
277
  pairs in the self-attention layer. Has shape
277
278
  `[batch_size, 2, max_seq_len, num_heads, key_dims]`.
@@ -435,7 +436,8 @@ class TransformerDecoder(keras.layers.Layer):
435
436
  input_length = output_length = ops.shape(decoder_sequence)[1]
436
437
  # We need to handle a rectangular causal mask when doing cached
437
438
  # decoding. For generative inference, `decoder_sequence` will
438
- # generally be length 1, and `cache` will be the full generation length.
439
+ # generally be length 1, and `cache` will be the full generation
440
+ # length.
439
441
  if self_attention_cache is not None:
440
442
  input_length = ops.shape(self_attention_cache)[2]
441
443
 
@@ -170,7 +170,12 @@ class TransformerEncoder(keras.layers.Layer):
170
170
  self.built = True
171
171
 
172
172
  def call(
173
- self, inputs, padding_mask=None, attention_mask=None, training=None
173
+ self,
174
+ inputs,
175
+ padding_mask=None,
176
+ attention_mask=None,
177
+ training=None,
178
+ return_attention_scores=False,
174
179
  ):
175
180
  """Forward pass of the TransformerEncoder.
176
181
 
@@ -185,6 +190,9 @@ class TransformerEncoder(keras.layers.Layer):
185
190
  [batch_size, sequence_length, sequence_length].
186
191
  training: a boolean indicating whether the layer should behave in
187
192
  training mode or in inference mode.
193
+ return_attention_scores: a boolean indicating whether the output
194
+ should be `(attention_output, attention_scores)` if `True` or
195
+ `attention_output` if `False`. Defaults to `False`.
188
196
 
189
197
  Returns:
190
198
  A Tensor of the same shape as the `inputs`.
@@ -200,12 +208,23 @@ class TransformerEncoder(keras.layers.Layer):
200
208
  residual = x
201
209
  if self.normalize_first:
202
210
  x = self._self_attention_layer_norm(x)
203
- x = self._self_attention_layer(
204
- query=x,
205
- value=x,
206
- attention_mask=self_attention_mask,
207
- training=training,
208
- )
211
+
212
+ if return_attention_scores:
213
+ x, attention_scores = self._self_attention_layer(
214
+ query=x,
215
+ value=x,
216
+ attention_mask=self_attention_mask,
217
+ return_attention_scores=return_attention_scores,
218
+ training=training,
219
+ )
220
+ else:
221
+ x = self._self_attention_layer(
222
+ query=x,
223
+ value=x,
224
+ attention_mask=self_attention_mask,
225
+ training=training,
226
+ )
227
+
209
228
  x = self._self_attention_dropout(x, training=training)
210
229
  x = x + residual
211
230
  if not self.normalize_first:
@@ -222,6 +241,9 @@ class TransformerEncoder(keras.layers.Layer):
222
241
  if not self.normalize_first:
223
242
  x = self._feedforward_layer_norm(x)
224
243
 
244
+ if return_attention_scores:
245
+ return x, attention_scores
246
+
225
247
  return x
226
248
 
227
249
  def get_config(self):
@@ -2,11 +2,10 @@ from keras_hub.src.api_export import keras_hub_export
2
2
  from keras_hub.src.layers.preprocessing.preprocessing_layer import (
3
3
  PreprocessingLayer,
4
4
  )
5
- from keras_hub.src.utils.preset_utils import AUDIO_CONVERTER_CONFIG_FILE
6
5
  from keras_hub.src.utils.preset_utils import builtin_presets
7
6
  from keras_hub.src.utils.preset_utils import find_subclass
8
7
  from keras_hub.src.utils.preset_utils import get_preset_loader
9
- from keras_hub.src.utils.preset_utils import save_serialized_object
8
+ from keras_hub.src.utils.preset_utils import get_preset_saver
10
9
  from keras_hub.src.utils.python_utils import classproperty
11
10
 
12
11
 
@@ -101,8 +100,5 @@ class AudioConverter(PreprocessingLayer):
101
100
  Args:
102
101
  preset_dir: The path to the local model preset directory.
103
102
  """
104
- save_serialized_object(
105
- self,
106
- preset_dir,
107
- config_file=AUDIO_CONVERTER_CONFIG_FILE,
108
- )
103
+ saver = get_preset_saver(preset_dir)
104
+ saver.save_audio_converter(self)