keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,126 @@
1
+ """ViT model preset configurations."""
2
+
3
+ # Metadata for loading pretrained model weights.
4
+ backbone_presets = {
5
+ "vit_base_patch16_224_imagenet": {
6
+ "metadata": {
7
+ "description": (
8
+ "ViT-B16 model pre-trained on the ImageNet 1k dataset with "
9
+ "image resolution of 224x224 "
10
+ ),
11
+ "params": 85798656,
12
+ "path": "vit",
13
+ },
14
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet/2",
15
+ },
16
+ "vit_base_patch16_384_imagenet": {
17
+ "metadata": {
18
+ "description": (
19
+ "ViT-B16 model pre-trained on the ImageNet 1k dataset with "
20
+ "image resolution of 384x384 "
21
+ ),
22
+ "params": 86090496,
23
+ "path": "vit",
24
+ },
25
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_384_imagenet/2",
26
+ },
27
+ "vit_large_patch16_224_imagenet": {
28
+ "metadata": {
29
+ "description": (
30
+ "ViT-L16 model pre-trained on the ImageNet 1k dataset with "
31
+ "image resolution of 224x224 "
32
+ ),
33
+ "params": 303301632,
34
+ "path": "vit",
35
+ },
36
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet/2",
37
+ },
38
+ "vit_large_patch16_384_imagenet": {
39
+ "metadata": {
40
+ "description": (
41
+ "ViT-L16 model pre-trained on the ImageNet 1k dataset with "
42
+ "image resolution of 384x384 "
43
+ ),
44
+ "params": 303690752,
45
+ "path": "vit",
46
+ },
47
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_384_imagenet/2",
48
+ },
49
+ "vit_base_patch32_384_imagenet": {
50
+ "metadata": {
51
+ "description": (
52
+ "ViT-B32 model pre-trained on the ImageNet 1k dataset with "
53
+ "image resolution of 384x384 "
54
+ ),
55
+ "params": 87528192,
56
+ "path": "vit",
57
+ },
58
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_384_imagenet/1",
59
+ },
60
+ "vit_large_patch32_384_imagenet": {
61
+ "metadata": {
62
+ "description": (
63
+ "ViT-L32 model pre-trained on the ImageNet 1k dataset with "
64
+ "image resolution of 384x384 "
65
+ ),
66
+ "params": 305607680,
67
+ "path": "vit",
68
+ },
69
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_384_imagenet/1",
70
+ },
71
+ "vit_base_patch16_224_imagenet21k": {
72
+ "metadata": {
73
+ "description": (
74
+ "ViT-B16 backbone pre-trained on the ImageNet 21k dataset with "
75
+ "image resolution of 224x224 "
76
+ ),
77
+ "params": 85798656,
78
+ "path": "vit",
79
+ },
80
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet21k/1",
81
+ },
82
+ "vit_base_patch32_224_imagenet21k": {
83
+ "metadata": {
84
+ "description": (
85
+ "ViT-B32 backbone pre-trained on the ImageNet 21k dataset with "
86
+ "image resolution of 224x224 "
87
+ ),
88
+ "params": 87455232,
89
+ "path": "vit",
90
+ },
91
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_224_imagenet21k/1",
92
+ },
93
+ "vit_huge_patch14_224_imagenet21k": {
94
+ "metadata": {
95
+ "description": (
96
+ "ViT-H14 backbone pre-trained on the ImageNet 21k dataset with "
97
+ "image resolution of 224x224 "
98
+ ),
99
+ "params": 630764800,
100
+ "path": "vit",
101
+ },
102
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_huge_patch14_224_imagenet21k/1",
103
+ },
104
+ "vit_large_patch16_224_imagenet21k": {
105
+ "metadata": {
106
+ "description": (
107
+ "ViT-L16 backbone pre-trained on the ImageNet 21k dataset with "
108
+ "image resolution of 224x224 "
109
+ ),
110
+ "params": 303301632,
111
+ "path": "vit",
112
+ },
113
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet21k/1",
114
+ },
115
+ "vit_large_patch32_224_imagenet21k": {
116
+ "metadata": {
117
+ "description": (
118
+ "ViT-L32 backbone pre-trained on the ImageNet 21k dataset with "
119
+ "image resolution of 224x224 "
120
+ ),
121
+ "params": 305510400,
122
+ "path": "vit",
123
+ },
124
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_224_imagenet21k/1",
125
+ },
126
+ }
@@ -31,7 +31,7 @@ class ViTDetBackbone(Backbone):
31
31
  global_attention_layer_indices (list): Indexes for blocks using
32
32
  global attention.
33
33
  image_shape (tuple[int], optional): The size of the input image in
34
- `(H, W, C)` format. Defaults to `(1024, 1024, 3)`.
34
+ `(H, W, C)` format. Defaults to `(None, None, 3)`.
35
35
  patch_size (int, optional): the patch size to be supplied to the
36
36
  Patching layer to turn input images into a flattened sequence of
37
37
  patches. Defaults to `16`.
@@ -79,7 +79,7 @@ class ViTDetBackbone(Backbone):
79
79
  intermediate_dim,
80
80
  num_heads,
81
81
  global_attention_layer_indices,
82
- image_shape=(1024, 1024, 3),
82
+ image_shape=(None, None, 3),
83
83
  patch_size=16,
84
84
  num_output_channels=256,
85
85
  use_bias=True,
@@ -87,7 +87,7 @@ class ViTDetBackbone(Backbone):
87
87
  use_rel_pos=True,
88
88
  window_size=14,
89
89
  layer_norm_epsilon=1e-6,
90
- **kwargs
90
+ **kwargs,
91
91
  ):
92
92
  # === Functional model ===
93
93
  img_input = keras.layers.Input(shape=image_shape, name="images")
@@ -179,7 +179,9 @@ class ViTDetBackbone(Backbone):
179
179
  "use_abs_pos": self.use_abs_pos,
180
180
  "use_rel_pos": self.use_rel_pos,
181
181
  "window_size": self.window_size,
182
- "global_attention_layer_indices": self.global_attention_layer_indices,
182
+ "global_attention_layer_indices": (
183
+ self.global_attention_layer_indices
184
+ ),
183
185
  "layer_norm_epsilon": self.layer_norm_epsilon,
184
186
  }
185
187
  )
@@ -117,7 +117,7 @@ class AddRelativePositionalEmbedding(keras.layers.Layer):
117
117
  """Calculate decomposed Relative Positional Embeddings
118
118
 
119
119
  The code has been adapted based on
120
- https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa: E501
120
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
121
121
 
122
122
  Args:
123
123
  attention_map (tensor): Attention map.
@@ -193,7 +193,7 @@ class MultiHeadAttentionWithRelativePE(keras.layers.Layer):
193
193
  use_bias=True,
194
194
  use_rel_pos=False,
195
195
  input_size=None,
196
- **kwargs
196
+ **kwargs,
197
197
  ):
198
198
  super().__init__(**kwargs)
199
199
  self.num_heads = num_heads
@@ -378,7 +378,7 @@ class WindowedTransformerEncoder(keras.layers.Layer):
378
378
  input_size=None,
379
379
  activation="gelu",
380
380
  layer_norm_epsilon=1e-6,
381
- **kwargs
381
+ **kwargs,
382
382
  ):
383
383
  super().__init__(**kwargs)
384
384
  self.project_dim = project_dim
@@ -39,7 +39,7 @@ class WhisperAudioConverter(AudioConverter):
39
39
  audio_tensor = tf.ones((8000,), dtype="float32")
40
40
 
41
41
  # Compute the log-mel spectrogram.
42
- audio_converter = keras_hub.models.WhisperAudioConverter.from_preset(
42
+ audio_converter = keras_hub.layers.WhisperAudioConverter.from_preset(
43
43
  "whisper_base_en",
44
44
  )
45
45
  audio_converter(audio_tensor)
@@ -172,9 +172,7 @@ class WhisperAudioConverter(AudioConverter):
172
172
  )
173
173
 
174
174
  def tf_log10(x):
175
- """
176
- Computes log base 10 of input tensor using TensorFlow's natural log operator.
177
- """
175
+ """Computes log base 10 of input tensor using TensorFlow."""
178
176
  numerator = tf.math.log(x)
179
177
  denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
180
178
  return numerator / denominator
@@ -30,9 +30,10 @@ class WhisperBackbone(Backbone):
30
30
  It includes the embedding lookups and transformer layers, but not the head
31
31
  for predicting the next token.
32
32
 
33
- The default constructor gives a fully customizable, randomly initialized Whisper
34
- model with any number of layers, heads, and embedding dimensions. To load
35
- preset architectures and weights, use the `from_preset()` constructor.
33
+ The default constructor gives a fully customizable, randomly initialized
34
+ Whisper model with any number of layers, heads, and embedding dimensions.
35
+ To load preset architectures and weights, use the `from_preset()`
36
+ constructor.
36
37
 
37
38
  Disclaimer: Pre-trained models are provided on an "as is" basis, without
38
39
  warranties or conditions of any kind. The underlying model is provided by a
@@ -53,8 +54,8 @@ class WhisperBackbone(Backbone):
53
54
  max_encoder_sequence_length: int. The maximum sequence length that the
54
55
  audio encoder can consume. Since the second convolutional layer in
55
56
  the encoder reduces the sequence length by half (stride of 2), we
56
- use `max_encoder_sequence_length // 2` as the sequence length for the
57
- positional embedding layer.
57
+ use `max_encoder_sequence_length // 2` as the sequence length for
58
+ the positional embedding layer.
58
59
  max_decoder_sequence_length: int. The maximum sequence length that the
59
60
  text decoder can consume.
60
61
  dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
@@ -14,11 +14,9 @@ class WhisperDecoder(TransformerDecoder):
14
14
  """Whisper decoder.
15
15
 
16
16
  Inherits from `keras_hub.layers.TransformerDecoder`, and overrides the
17
- `build` method to use the
18
- `keras_hub.models.whisper.whisper_multi_head_attention.WhisperMultiHeadAttention`
19
- layer instead of `keras.layers.MultiHeadAttention` and
20
- `keras_hub.models.whisper.whisper_cached_multi_head_attention.WhisperCachedMultiHeadAttention`
21
- instead of `keras_hub.layers.cached_multi_head_attention.CachedMultiHeadAttention`.
17
+ `build` method to use the `WhisperMultiHeadAttention`
18
+ layer instead of `MultiHeadAttention` and `WhisperCachedMultiHeadAttention`
19
+ instead of `CachedMultiHeadAttention`.
22
20
  """
23
21
 
24
22
  def build(
@@ -7,11 +7,9 @@ backbone_presets = {
7
7
  "English speech data."
8
8
  ),
9
9
  "params": 37184256,
10
- "official_name": "Whisper",
11
10
  "path": "whisper",
12
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
13
11
  },
14
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_en/3",
12
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_en/4",
15
13
  },
16
14
  "whisper_base_en": {
17
15
  "metadata": {
@@ -20,11 +18,9 @@ backbone_presets = {
20
18
  "English speech data."
21
19
  ),
22
20
  "params": 124439808,
23
- "official_name": "Whisper",
24
21
  "path": "whisper",
25
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
26
22
  },
27
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_en/3",
23
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_en/4",
28
24
  },
29
25
  "whisper_small_en": {
30
26
  "metadata": {
@@ -33,11 +29,9 @@ backbone_presets = {
33
29
  "English speech data."
34
30
  ),
35
31
  "params": 241734144,
36
- "official_name": "Whisper",
37
32
  "path": "whisper",
38
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
39
33
  },
40
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_en/3",
34
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_en/4",
41
35
  },
42
36
  "whisper_medium_en": {
43
37
  "metadata": {
@@ -46,11 +40,9 @@ backbone_presets = {
46
40
  "English speech data."
47
41
  ),
48
42
  "params": 763856896,
49
- "official_name": "Whisper",
50
43
  "path": "whisper",
51
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
52
44
  },
53
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_en/3",
45
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_en/4",
54
46
  },
55
47
  "whisper_tiny_multi": {
56
48
  "metadata": {
@@ -59,11 +51,9 @@ backbone_presets = {
59
51
  "multilingual speech data."
60
52
  ),
61
53
  "params": 37760640,
62
- "official_name": "Whisper",
63
54
  "path": "whisper",
64
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
65
55
  },
66
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_multi/3",
56
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_multi/4",
67
57
  },
68
58
  "whisper_base_multi": {
69
59
  "metadata": {
@@ -72,11 +62,9 @@ backbone_presets = {
72
62
  "multilingual speech data."
73
63
  ),
74
64
  "params": 72593920,
75
- "official_name": "Whisper",
76
65
  "path": "whisper",
77
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
78
66
  },
79
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_multi/3",
67
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_multi/4",
80
68
  },
81
69
  "whisper_small_multi": {
82
70
  "metadata": {
@@ -85,11 +73,9 @@ backbone_presets = {
85
73
  "multilingual speech data."
86
74
  ),
87
75
  "params": 241734912,
88
- "official_name": "Whisper",
89
76
  "path": "whisper",
90
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
91
77
  },
92
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_multi/3",
78
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_multi/4",
93
79
  },
94
80
  "whisper_medium_multi": {
95
81
  "metadata": {
@@ -98,11 +84,9 @@ backbone_presets = {
98
84
  "multilingual speech data."
99
85
  ),
100
86
  "params": 763857920,
101
- "official_name": "Whisper",
102
87
  "path": "whisper",
103
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
104
88
  },
105
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_multi/3",
89
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_multi/4",
106
90
  },
107
91
  "whisper_large_multi": {
108
92
  "metadata": {
@@ -111,11 +95,9 @@ backbone_presets = {
111
95
  "multilingual speech data."
112
96
  ),
113
97
  "params": 1543304960,
114
- "official_name": "Whisper",
115
98
  "path": "whisper",
116
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
117
99
  },
118
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi/3",
100
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi/4",
119
101
  },
120
102
  "whisper_large_multi_v2": {
121
103
  "metadata": {
@@ -125,10 +107,8 @@ backbone_presets = {
125
107
  "of `whisper_large_multi`."
126
108
  ),
127
109
  "params": 1543304960,
128
- "official_name": "Whisper",
129
110
  "path": "whisper",
130
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
131
111
  },
132
- "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi_v2/3",
112
+ "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi_v2/4",
133
113
  },
134
114
  }
@@ -9,7 +9,7 @@ from keras_hub.src.models.roberta.roberta_backbone import (
9
9
  from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import (
10
10
  XLMRobertaBackbone,
11
11
  )
12
- from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import (
12
+ from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( # noqa: E501
13
13
  XLMRobertaMaskedLMPreprocessor,
14
14
  )
15
15
 
@@ -20,8 +20,8 @@ class XLMRobertaMaskedLMPreprocessor(MaskedLMPreprocessor):
20
20
 
21
21
  This preprocessing layer will prepare inputs for a masked language modeling
22
22
  task. It is primarily intended for use with the
23
- `keras_hub.models.XLMRobertaMaskedLM` task model. Preprocessing will occur in
24
- multiple steps.
23
+ `keras_hub.models.XLMRobertaMaskedLM` task model. Preprocessing will occur
24
+ in multiple steps.
25
25
 
26
26
  1. Tokenize any number of input segments using the `tokenizer`.
27
27
  2. Pack the inputs together with the appropriate `"<s>"`, `"</s>"` and
@@ -8,11 +8,9 @@ backbone_presets = {
8
8
  "Trained on CommonCrawl in 100 languages."
9
9
  ),
10
10
  "params": 277450752,
11
- "official_name": "XLM-RoBERTa",
12
11
  "path": "xlm_roberta",
13
- "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/xlmr/README.md",
14
12
  },
15
- "kaggle_handle": "kaggle://keras/xlm_roberta/keras/xlm_roberta_base_multi/2",
13
+ "kaggle_handle": "kaggle://keras/xlm_roberta/keras/xlm_roberta_base_multi/3",
16
14
  },
17
15
  "xlm_roberta_large_multi": {
18
16
  "metadata": {
@@ -21,10 +19,8 @@ backbone_presets = {
21
19
  "Trained on CommonCrawl in 100 languages."
22
20
  ),
23
21
  "params": 558837760,
24
- "official_name": "XLM-RoBERTa",
25
22
  "path": "xlm_roberta",
26
- "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/xlmr/README.md",
27
23
  },
28
- "kaggle_handle": "kaggle://keras/xlm_roberta/keras/xlm_roberta_large_multi/2",
24
+ "kaggle_handle": "kaggle://keras/xlm_roberta/keras/xlm_roberta_large_multi/3",
29
25
  },
30
26
  }
@@ -8,7 +8,7 @@ from keras_hub.src.models.text_classifier import TextClassifier
8
8
  from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import (
9
9
  XLMRobertaBackbone,
10
10
  )
11
- from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import (
11
+ from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( # noqa: E501
12
12
  XLMRobertaTextClassifierPreprocessor,
13
13
  )
14
14
 
@@ -40,9 +40,9 @@ class XLMRobertaTextClassifier(TextClassifier):
40
40
  Args:
41
41
  backbone: A `keras_hub.models.XLMRobertaBackbone` instance.
42
42
  num_classes: int. Number of classes to predict.
43
- preprocessor: A `keras_hub.models.XLMRobertaTextClassifierPreprocessor` or `None`. If
44
- `None`, this model will not apply preprocessing, and inputs should
45
- be preprocessed before calling the model.
43
+ preprocessor: A `keras_hub.models.XLMRobertaTextClassifierPreprocessor`
44
+ or `None`. If `None`, this model will not apply preprocessing, and
45
+ inputs should be preprocessed before calling the model.
46
46
  activation: Optional `str` or callable. The activation function to use
47
47
  on the model outputs. Set `activation="softmax"` to return output
48
48
  probabilities. Defaults to `None`.
@@ -177,7 +177,8 @@ class XLMRobertaTokenizer(SentencePieceTokenizer):
177
177
  # Shift the tokens IDs left by one.
178
178
  tokens = tf.subtract(tokens, 1)
179
179
 
180
- # Correct `unk_token_id`, `end_token_id`, `start_token_id`, respectively.
180
+ # Correct `unk_token_id`, `end_token_id`, `start_token_id`,
181
+ # respectively.
181
182
  # Note: The `pad_token_id` is taken as 0 (`unk_token_id`) since the
182
183
  # proto does not contain `pad_token_id`. This mapping of the pad token
183
184
  # is done automatically by the above subtraction.
@@ -64,27 +64,28 @@ def _rel_shift(x, klen=-1):
64
64
  class TwoStreamRelativeAttention(keras.layers.MultiHeadAttention):
65
65
  """Two-stream relative self-attention for XLNet.
66
66
 
67
- In XLNet, each token has two associated vectors at each self-attention layer,
68
- the content stream (h) and the query stream (g). The content stream is the
69
- self-attention stream as in Transformer XL and represents the context and
70
- content (the token itself). The query stream only has access to contextual
71
- information and the position, but not the content.
67
+ In XLNet, each token has two associated vectors at each self-attention
68
+ layer, the content stream (h) and the query stream (g). The content stream
69
+ is the self-attention stream as in Transformer XL and represents the context
70
+ and content (the token itself). The query stream only has access to
71
+ contextual information and the position, but not the content.
72
72
 
73
- This layer shares the same build signature as `keras.layers.MultiHeadAttention`
74
- but has different input/output projections.
73
+ This layer shares the same build signature as
74
+ `keras.layers.MultiHeadAttention` but has different input/output
75
+ projections.
75
76
 
76
77
  We use the notations `B`, `T`, `S`, `M`, `L`, `E`, `P`, `dim`, `num_heads`
77
- below, where
78
- `B` is the batch dimension, `T` is the target sequence length,
78
+ below, where `B` is the batch dimension, `T` is the target sequence length,
79
79
  `S` in the source sequence length, `M` is the length of the state or memory,
80
80
  `L` is the length of relative positional encoding, `E` is the last dimension
81
- of query input, `P` is the number of predictions, `dim` is the dimensionality
82
- of the encoder layers. and `num_heads` is the number of attention heads.
81
+ of query input, `P` is the number of predictions, `dim` is the
82
+ dimensionality of the encoder layers. and `num_heads` is the number of
83
+ attention heads.
83
84
 
84
85
  Args:
85
86
  content_stream: `Tensor` of shape `[B, T, dim]`.
86
- content_attention_bias: Bias `Tensor` for content based attention of shape
87
- `[num_heads, dim]`.
87
+ content_attention_bias: Bias `Tensor` for content based attention of
88
+ shape `[num_heads, dim]`.
88
89
  positional_attention_bias: Bias `Tensor` for position based attention of
89
90
  shape `[num_heads, dim]`.
90
91
  query_stream: `Tensor` of shape `[B, P, dim]`.
@@ -96,8 +97,8 @@ class TwoStreamRelativeAttention(keras.layers.MultiHeadAttention):
96
97
  segment_encoding: Optional `Tensor` representing the segmentation
97
98
  encoding as used in XLNet of shape `[2, num_heads, dim]`.
98
99
  segment_attention_bias: Optional trainable bias parameter added to the
99
- query had when calculating the segment-based attention score used
100
- in XLNet of shape `[num_heads, dim]`.
100
+ query had when calculating the segment-based attention score used in
101
+ XLNet of shape `[num_heads, dim]`.
101
102
  state: Optional `Tensor` of shape `[B, M, E]`.
102
103
  If passed, this is also attended over as in Transformer XL.
103
104
  content_attention_mask: a boolean mask of shape `[B, T, S]` that
@@ -336,11 +337,11 @@ class TwoStreamRelativeAttention(keras.layers.MultiHeadAttention):
336
337
  dimension of query input.
337
338
 
338
339
  Args:
339
- content_stream: The content representation, commonly referred to as h.
340
- This serves a similar role to the standard hidden states in
340
+ content_stream: The content representation, commonly referred to as
341
+ h. This serves a similar role to the standard hidden states in
341
342
  Transformer-XL.
342
- content_attention_bias: A trainable bias parameter added to the query
343
- head when calculating the content-based attention score.
343
+ content_attention_bias: A trainable bias parameter added to the
344
+ query head when calculating the content-based attention score.
344
345
  positional_attention_bias: A trainable bias parameter added to the
345
346
  query head when calculating the position-based attention score.
346
347
  query_stream: The query representation, commonly referred to as g.
@@ -49,8 +49,8 @@ class XLNetBackbone(Backbone):
49
49
  `[batch_size, sequence_length]`.
50
50
  segment_ids: Segment token indices to indicate first and second portions
51
51
  of the inputs of shape `[batch_size, sequence_length]`.
52
- padding_mask: Mask to avoid performing attention on padding token indices
53
- of shape `[batch_size, sequence_length]`.
52
+ padding_mask: Mask to avoid performing attention on padding token
53
+ indices of shape `[batch_size, sequence_length]`.
54
54
 
55
55
  Example:
56
56
  ```python
@@ -3,8 +3,7 @@ from keras import ops
3
3
 
4
4
 
5
5
  class ContentAndQueryEmbedding(keras.layers.Layer):
6
- """
7
- Content and Query Embedding.
6
+ """Content and Query Embedding.
8
7
 
9
8
  This class creates Content and Query Embeddings for XLNet model
10
9
  which is later used in XLNet Encoder.
@@ -20,9 +19,8 @@ class ContentAndQueryEmbedding(keras.layers.Layer):
20
19
  **kwargs: other keyword arguments.
21
20
 
22
21
  References:
23
- - [XLNet: Generalized Autoregressive Pretraining for Language Understanding]
24
- (https://arxiv.org/abs/1906.08237)
25
- """
22
+ - [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237)
23
+ """ # noqa: E501
26
24
 
27
25
  def __init__(
28
26
  self, vocabulary_size, hidden_dim, dropout, name=None, **kwargs
@@ -11,17 +11,16 @@ def xlnet_kernel_initializer(stddev=0.02):
11
11
 
12
12
 
13
13
  class XLNetEncoder(keras.layers.Layer):
14
- """
15
- XLNet Encoder.
14
+ """XLNet Encoder.
16
15
 
17
16
  This class follows the architecture of the transformer encoder layer in the
18
17
  paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users
19
18
  can instantiate multiple instances of this class to stack up an encoder.
20
19
 
21
20
  Contrary to the single hidden state used in the paper mentioned above, this
22
- Encoder uses two hidden states, Content State and Query State. Thus calculates
23
- Two Stream Relative Attention using both of the hidden states. To know more
24
- please check the reference.
21
+ Encoder uses two hidden states, Content State and Query State. Thus
22
+ calculates Two Stream Relative Attention using both of the hidden states.
23
+ To know more please check the reference.
25
24
 
26
25
  Args:
27
26
  num_heads: int, the number of heads in the
@@ -44,9 +43,8 @@ class XLNetEncoder(keras.layers.Layer):
44
43
  **kwargs: other keyword arguments.
45
44
 
46
45
  References:
47
- - [XLNet: Generalized Autoregressive Pretraining for Language Understanding]
48
- (https://arxiv.org/abs/1906.08237)
49
- """
46
+ - [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237)
47
+ """ # noqa: E501
50
48
 
51
49
  def __init__(
52
50
  self,
@@ -60,7 +58,7 @@ class XLNetEncoder(keras.layers.Layer):
60
58
  kernel_initializer_range=0.02,
61
59
  bias_initializer="zeros",
62
60
  name=None,
63
- **kwargs
61
+ **kwargs,
64
62
  ):
65
63
  super().__init__(name=name, **kwargs)
66
64
  self.num_heads = num_heads
@@ -150,9 +150,8 @@ class ContrastiveSampler(Sampler):
150
150
  # The final score of each candidate token is weighted sum of
151
151
  # probability and similarity against previous tokens.
152
152
  accumulated_scores = (
153
- (1 - self.alpha) * next_token_probabilities
154
- - self.alpha * max_similarity_scores
155
- )
153
+ 1 - self.alpha
154
+ ) * next_token_probabilities - self.alpha * max_similarity_scores
156
155
  # Unflatten variables to shape [batch_size, self.k, ...] for
157
156
  # gather purpose.
158
157
  unflat_score = unflatten_beams(accumulated_scores)
@@ -95,7 +95,8 @@ class Sampler:
95
95
  def cond(prompt, cache, index):
96
96
  if stop_token_ids is None:
97
97
  return True
98
- # Stop if all sequences have produced a *new* id from stop_token_ids.
98
+ # Stop if all sequences have produced a *new* id from
99
+ # stop_token_ids.
99
100
  end_tokens = any_equal(prompt, stop_token_ids, ~mask)
100
101
  prompt_done = ops.any(end_tokens, axis=-1)
101
102
  return ops.logical_not(ops.all(prompt_done))