keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,239 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+
20
+
21
+ @keras_hub_export("keras_hub.layers.MaskedLMHead")
22
+ class MaskedLMHead(keras.layers.Layer):
23
+ """Masked Language Model (MaskedLM) head.
24
+
25
+ This layer takes two inputs:
26
+
27
+ - `inputs`: which should be a tensor of encoded tokens with shape
28
+ `(batch_size, sequence_length, hidden_dim)`.
29
+ - `mask_positions`: which should be a tensor of integer positions to
30
+ predict with shape `(batch_size, masks_per_sequence)`.
31
+
32
+ The token encodings should usually be the last output of an encoder model,
33
+ and mask positions should be the integer positions you would like to
34
+ predict for the MaskedLM task.
35
+
36
+ The layer will first gather the token encodings at the mask positions. These
37
+ gathered tokens will be passed through a dense layer the same size as
38
+ encoding dimension, then transformed to predictions the same size as the
39
+ input vocabulary. This layer will produce a single output with shape
40
+ `(batch_size, masks_per_sequence, vocabulary_size)`, which can be used to
41
+ compute an MaskedLM loss function.
42
+
43
+ This layer is often be paired with `keras_hub.layers.MaskedLMMaskGenerator`,
44
+ which will help prepare inputs for the MaskedLM task.
45
+
46
+ Args:
47
+ vocabulary_size: The total size of the vocabulary for predictions.
48
+ token_embedding: Optional. A `keras_hub.layers.ReversibleEmbedding`
49
+ instance. If passed, the layer will be used to project from the
50
+ `hidden_dim` of the model to the output `vocabulary_size`.
51
+ intermediate_activation: The activation function of intermediate dense layer.
52
+ activation: The activation function for the outputs of the layer.
53
+ Usually either `None` (return logits), or `"softmax"`
54
+ (return probabilities).
55
+ layer_norm_epsilon: float. The epsilon value in layer
56
+ normalization components. Defaults to `1e-5`.
57
+ kernel_initializer: string or `keras.initializers` initializer.
58
+ The kernel initializer for the dense and multiheaded
59
+ attention layers. Defaults to `"glorot_uniform"`.
60
+ bias_initializer: string or `keras.initializers` initializer.
61
+ The bias initializer for the dense and multiheaded
62
+ attention layers. Defaults to `"zeros"`.
63
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
64
+ including `name`, `trainable`, `dtype` etc.
65
+
66
+ Example:
67
+
68
+ ```python
69
+ batch_size = 16
70
+ vocab_size = 100
71
+ hidden_dim = 32
72
+ seq_length = 50
73
+
74
+ # Generate random inputs.
75
+ token_ids = np.random.randint(vocab_size, size=(batch_size, seq_length))
76
+ # Choose random positions as the masked inputs.
77
+ mask_positions = np.random.randint(seq_length, size=(batch_size, 5))
78
+
79
+ # Embed tokens in a `hidden_dim` feature space.
80
+ token_embedding = keras_hub.layers.ReversibleEmbedding(
81
+ vocab_size,
82
+ hidden_dim,
83
+ )
84
+ hidden_states = token_embedding(token_ids)
85
+
86
+ preds = keras_hub.layers.MaskedLMHead(
87
+ vocabulary_size=vocab_size,
88
+ token_embedding=token_embedding,
89
+ activation="softmax",
90
+ )(hidden_states, mask_positions)
91
+ ```
92
+
93
+ References:
94
+ - [Press and Wolf, 2016](https://arxiv.org/abs/1608.05859)
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ vocabulary_size=None,
100
+ token_embedding=None,
101
+ intermediate_activation="relu",
102
+ activation=None,
103
+ layer_norm_epsilon=1e-05,
104
+ kernel_initializer="glorot_uniform",
105
+ bias_initializer="zeros",
106
+ **kwargs,
107
+ ):
108
+ super().__init__(**kwargs, autocast=False)
109
+
110
+ self.vocabulary_size = vocabulary_size
111
+ self.token_embedding = token_embedding
112
+ self.intermediate_activation = keras.activations.get(
113
+ intermediate_activation
114
+ )
115
+ self.activation = keras.activations.get(activation)
116
+ self.layer_norm_epsilon = layer_norm_epsilon
117
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
118
+ self.bias_initializer = keras.initializers.get(bias_initializer)
119
+
120
+ if vocabulary_size is None and token_embedding is None:
121
+ raise ValueError(
122
+ "One of `vocabulary_size` or `token_embedding` must be set. "
123
+ "Received: `vocabulary_size=None`, `token_embedding=None`"
124
+ )
125
+
126
+ if token_embedding:
127
+ if vocabulary_size and vocabulary_size != token_embedding.input_dim:
128
+ raise ValueError(
129
+ "`vocabulary_size` should match the input dimension of the "
130
+ "of `token_embedding`. Received: "
131
+ f"`vocabulary_size={vocabulary_size}`, "
132
+ f"`token_embedding.input_dim={token_embedding.input_dim}`"
133
+ )
134
+ self.vocabulary_size = token_embedding.input_dim
135
+
136
+ def build(self, inputs_shape, mask_positions_shape=None):
137
+ if self.token_embedding is not None:
138
+ feature_size = self.token_embedding.output_dim
139
+ else:
140
+ feature_size = inputs_shape[-1]
141
+
142
+ self._intermediate_dense = keras.layers.Dense(
143
+ feature_size,
144
+ activation=self.intermediate_activation,
145
+ kernel_initializer=self.kernel_initializer,
146
+ bias_initializer=self.bias_initializer,
147
+ dtype=self.dtype_policy,
148
+ name="intermediate_dense",
149
+ )
150
+ self._intermediate_layer_norm = keras.layers.LayerNormalization(
151
+ epsilon=self.layer_norm_epsilon,
152
+ dtype=self.dtype_policy,
153
+ name="intermediate_layer_norm",
154
+ )
155
+ # The gather length does not affect any of our built variables, so
156
+ # we can pass any value here.
157
+ gather_length = None
158
+ shape = (inputs_shape[0], gather_length, inputs_shape[-1])
159
+ self._intermediate_dense.build(shape)
160
+ shape = (inputs_shape[0], gather_length, feature_size)
161
+ self._intermediate_layer_norm.build(shape)
162
+ if self.token_embedding is None:
163
+ self._kernel = self.add_weight(
164
+ name="output_kernel",
165
+ shape=[feature_size, self.vocabulary_size],
166
+ initializer=self.kernel_initializer,
167
+ dtype=self.dtype,
168
+ )
169
+ self._bias = self.add_weight(
170
+ name="output_bias",
171
+ shape=[self.vocabulary_size],
172
+ initializer=self.bias_initializer,
173
+ dtype=self.dtype,
174
+ )
175
+ self.built = True
176
+
177
+ def call(self, inputs, mask_positions):
178
+ if keras.config.backend() == "tensorflow":
179
+ import tensorflow as tf
180
+
181
+ # On the tf backend, we need to work around an issue with dynamic
182
+ # shape broadcasting in take_along_axis.
183
+ x = tf.gather(inputs, mask_positions, batch_dims=1)
184
+ else:
185
+ # Gather the encoded tokens at the masked indices.
186
+ mask_positions = ops.expand_dims(mask_positions, axis=-1)
187
+ x = ops.take_along_axis(inputs, mask_positions, axis=1)
188
+
189
+ # Apply a trainable linear transformation and a layer norm.
190
+ x = self._intermediate_dense(x)
191
+ x = self._intermediate_layer_norm(x)
192
+
193
+ # Transform encodings to vocabulary_size predictions.
194
+ if self.token_embedding:
195
+ outputs = self.token_embedding(x, reverse=True)
196
+ else:
197
+ outputs = ops.matmul(x, self._kernel)
198
+ outputs = ops.cast(outputs, self.compute_dtype)
199
+ outputs = outputs + self._bias
200
+
201
+ # Apply a final activation.
202
+ if self.activation is not None:
203
+ outputs = self.activation(outputs)
204
+
205
+ return outputs
206
+
207
+ @classmethod
208
+ def from_config(cls, config):
209
+ embedding = config.get("token_embedding")
210
+ if embedding:
211
+ config["token_embedding"] = keras.layers.deserialize(embedding)
212
+ return super().from_config(config)
213
+
214
+ def get_config(self):
215
+ config = super().get_config()
216
+ embedding_config = None
217
+ if self.token_embedding:
218
+ embedding_config = keras.layers.serialize(self.token_embedding)
219
+ config.update(
220
+ {
221
+ "vocabulary_size": self.vocabulary_size,
222
+ "token_embedding": embedding_config,
223
+ "intermediate_activation": keras.activations.serialize(
224
+ self.intermediate_activation
225
+ ),
226
+ "activation": keras.activations.serialize(self.activation),
227
+ "layer_norm_epsilon": self.layer_norm_epsilon,
228
+ "kernel_initializer": keras.initializers.serialize(
229
+ self.kernel_initializer
230
+ ),
231
+ "bias_initializer": keras.initializers.serialize(
232
+ self.bias_initializer
233
+ ),
234
+ }
235
+ )
236
+ return config
237
+
238
+ def compute_output_shape(self, inputs_shape, mask_positions_shape):
239
+ return mask_positions_shape + (self.vocabulary_size,)
@@ -0,0 +1,123 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+
20
+
21
+ @keras_hub_export("keras_hub.layers.PositionEmbedding")
22
+ class PositionEmbedding(keras.layers.Layer):
23
+ """A layer which learns a position embedding for inputs sequences.
24
+
25
+ This class assumes that in the input tensor, the last dimension corresponds
26
+ to the features, and the dimension before the last corresponds to the
27
+ sequence.
28
+
29
+ This layer does not supporting masking, but can be combined with a
30
+ `keras.layers.Embedding` for padding mask support.
31
+
32
+ Args:
33
+ sequence_length: The maximum length of the dynamic sequence.
34
+ initializer: The initializer to use for the embedding weights. Defaults
35
+ to `"glorot_uniform"`.
36
+ seq_axis: The axis of the input tensor where we add the embeddings.
37
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
38
+ including `name`, `trainable`, `dtype` etc.
39
+
40
+ Call arguments:
41
+ inputs: The tensor inputs to compute an embedding for, with shape
42
+ `(batch_size, sequence_length, hidden_dim)`. Only the input shape
43
+ will be used, as the position embedding does not depend on the
44
+ input sequence content.
45
+ start_index: An integer or integer tensor. The starting position to
46
+ compute the position embedding from. This is useful during cached
47
+ decoding, where each position is predicted separately in a loop.
48
+
49
+ Example:
50
+
51
+ Called directly on input.
52
+ >>> layer = keras_hub.layers.PositionEmbedding(sequence_length=10)
53
+ >>> layer(np.zeros((8, 10, 16)))
54
+
55
+ Combine with a token embedding.
56
+ ```python
57
+ seq_length = 50
58
+ vocab_size = 5000
59
+ embed_dim = 128
60
+ inputs = keras.Input(shape=(seq_length,))
61
+ token_embeddings = keras.layers.Embedding(
62
+ input_dim=vocab_size, output_dim=embed_dim
63
+ )(inputs)
64
+ position_embeddings = keras_hub.layers.PositionEmbedding(
65
+ sequence_length=seq_length
66
+ )(token_embeddings)
67
+ outputs = token_embeddings + position_embeddings
68
+ ```
69
+
70
+ Reference:
71
+ - [Devlin et al., 2019](https://arxiv.org/abs/1810.04805)
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ sequence_length,
77
+ initializer="glorot_uniform",
78
+ **kwargs,
79
+ ):
80
+ super().__init__(**kwargs)
81
+ if sequence_length is None:
82
+ raise ValueError(
83
+ "`sequence_length` must be an Integer, received `None`."
84
+ )
85
+ self.sequence_length = int(sequence_length)
86
+ self.initializer = keras.initializers.get(initializer)
87
+
88
+ def get_config(self):
89
+ config = super().get_config()
90
+ config.update(
91
+ {
92
+ "sequence_length": self.sequence_length,
93
+ "initializer": keras.initializers.serialize(self.initializer),
94
+ }
95
+ )
96
+ return config
97
+
98
+ def build(self, inputs_shape):
99
+ feature_size = inputs_shape[-1]
100
+ self.position_embeddings = self.add_weight(
101
+ name="embeddings",
102
+ shape=[self.sequence_length, feature_size],
103
+ initializer=self.initializer,
104
+ trainable=True,
105
+ )
106
+ self.built = True
107
+
108
+ def call(self, inputs, start_index=0):
109
+ shape = ops.shape(inputs)
110
+ feature_length = shape[-1]
111
+ sequence_length = shape[-2]
112
+ # trim to match the length of the input sequence, which might be less
113
+ # than the sequence_length of the layer.
114
+ position_embeddings = ops.convert_to_tensor(self.position_embeddings)
115
+ position_embeddings = ops.slice(
116
+ position_embeddings,
117
+ (start_index, 0),
118
+ (sequence_length, feature_length),
119
+ )
120
+ return ops.broadcast_to(position_embeddings, shape)
121
+
122
+ def compute_output_shape(self, input_shape):
123
+ return input_shape
@@ -0,0 +1,311 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+ from packaging.version import parse
18
+
19
+ from keras_hub.src.api_export import keras_hub_export
20
+ from keras_hub.src.utils.keras_utils import assert_quantization_support
21
+
22
+
23
+ @keras_hub_export("keras_hub.layers.ReversibleEmbedding")
24
+ class ReversibleEmbedding(keras.layers.Embedding):
25
+ """An embedding layer which can project backwards to the input dim.
26
+
27
+ This layer is an extension of `keras.layers.Embedding` for language models.
28
+ This layer can be called "in reverse" with `reverse=True`, in which case the
29
+ layer will linearly project from `output_dim` back to `input_dim`.
30
+
31
+ By default, the reverse projection will use the transpose of the
32
+ `embeddings` weights to project to `input_dim` (weights are "tied"). If
33
+ `tie_weights=False`, the model will use a separate, trainable variable for
34
+ reverse projection.
35
+
36
+ This layer has no bias terms.
37
+
38
+ Args:
39
+ input_dim: Integer. Size of the vocabulary,
40
+ i.e. maximum integer index + 1.
41
+ output_dim: Integer. Dimension of the dense embedding.
42
+ tie_weights: Boolean, whether or not the matrix for embedding and
43
+ the matrix for the `reverse` projection should share the same
44
+ weights.
45
+ embeddings_initializer: Initializer for the `embeddings`
46
+ matrix (see `keras.initializers`).
47
+ embeddings_regularizer: Regularizer function applied to
48
+ the `embeddings` matrix (see `keras.regularizers`).
49
+ embeddings_constraint: Constraint function applied to
50
+ the `embeddings` matrix (see `keras.constraints`).
51
+ mask_zero: Boolean, whether or not the input value 0 is a special
52
+ "padding" value that should be masked out.
53
+ reverse_dtype: The dtype for the reverse projection computation.
54
+ Defaults to the `compute_dtype` of the layer.
55
+ logit_soft_cap: If `logit_soft_cap` is set and `reverse=True`, the
56
+ output logits will be scaled by
57
+ `tanh(logits / logit_soft_cap) * logit_soft_cap`. This narrows the
58
+ range of output logits and can improve training.
59
+ **kwargs: other keyword arguments passed to `keras.layers.Embedding`,
60
+ including `name`, `trainable`, `dtype` etc.
61
+
62
+ Call arguments:
63
+ inputs: The tensor inputs to the layer.
64
+ reverse: Boolean. If `True` the layer will perform a linear projection
65
+ from `output_dim` to `input_dim`, instead of a normal embedding
66
+ call. Default to `False`.
67
+
68
+ Example:
69
+ ```python
70
+ batch_size = 16
71
+ vocab_size = 100
72
+ hidden_dim = 32
73
+ seq_length = 50
74
+
75
+ # Generate random inputs.
76
+ token_ids = np.random.randint(vocab_size, size=(batch_size, seq_length))
77
+
78
+ embedding = keras_hub.layers.ReversibleEmbedding(vocab_size, hidden_dim)
79
+ # Embed tokens to shape `(batch_size, seq_length, hidden_dim)`.
80
+ hidden_states = embedding(token_ids)
81
+ # Project hidden states to shape `(batch_size, seq_length, vocab_size)`.
82
+ logits = embedding(hidden_states, reverse=True)
83
+ ```
84
+
85
+ References:
86
+ - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
87
+ - [Press and Wolf, 2016](https://arxiv.org/abs/1608.05859)
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ input_dim,
93
+ output_dim,
94
+ tie_weights=True,
95
+ embeddings_initializer="uniform",
96
+ embeddings_regularizer=None,
97
+ embeddings_constraint=None,
98
+ mask_zero=False,
99
+ reverse_dtype=None,
100
+ logit_soft_cap=None,
101
+ **kwargs,
102
+ ):
103
+ super().__init__(
104
+ input_dim,
105
+ output_dim,
106
+ embeddings_initializer=embeddings_initializer,
107
+ embeddings_regularizer=embeddings_regularizer,
108
+ embeddings_constraint=embeddings_constraint,
109
+ mask_zero=mask_zero,
110
+ **kwargs,
111
+ )
112
+ self.tie_weights = tie_weights
113
+ self.reverse_dtype = reverse_dtype
114
+ self.logit_soft_cap = logit_soft_cap
115
+
116
+ def build(self, inputs_shape=None):
117
+ super().build(inputs_shape)
118
+ if (
119
+ not self.tie_weights
120
+ and getattr(self, "quantization_mode", None) != "int8"
121
+ ):
122
+ self.reverse_embeddings = self.add_weight(
123
+ name="reverse_embeddings",
124
+ shape=(self.output_dim, self.input_dim),
125
+ initializer=self.embeddings_initializer,
126
+ dtype=self.dtype,
127
+ )
128
+
129
+ def call(self, inputs, reverse=False):
130
+ if reverse:
131
+ if self.tie_weights:
132
+ kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
133
+ else:
134
+ kernel = self.reverse_embeddings
135
+ if self.reverse_dtype is not None:
136
+ inputs = ops.cast(inputs, self.reverse_dtype)
137
+ kernel = ops.cast(kernel, self.reverse_dtype)
138
+ logits = ops.matmul(inputs, kernel)
139
+ # Optionally soft-cap logits.
140
+ if self.logit_soft_cap is not None:
141
+ soft_cap = self.logit_soft_cap
142
+ logits = ops.tanh(logits / soft_cap) * soft_cap
143
+ return logits
144
+
145
+ return super().call(inputs)
146
+
147
+ def get_config(self):
148
+ config = super().get_config()
149
+ config.update(
150
+ {
151
+ "tie_weights": self.tie_weights,
152
+ "reverse_dtype": self.reverse_dtype,
153
+ "logit_soft_cap": self.logit_soft_cap,
154
+ }
155
+ )
156
+ return config
157
+
158
+ def save_own_variables(self, store):
159
+ if not self.built:
160
+ return
161
+ super().save_own_variables(store)
162
+ # Before Keras 3.2, the reverse weight is saved in the super() call.
163
+ # After Keras 3.2, the reverse weight must be saved manually.
164
+ if parse(keras.version()) < parse("3.2.0"):
165
+ return
166
+ target_variables = []
167
+ if not self.tie_weights:
168
+ # Store the reverse embedding weights as the last weights.
169
+ target_variables.append(self.reverse_embeddings)
170
+ if getattr(self, "quantization_mode", None) == "int8":
171
+ target_variables.append(self.reverse_embeddings_scale)
172
+ for i, variable in enumerate(target_variables, start=len(store)):
173
+ store[str(i)] = variable
174
+
175
+ def load_own_variables(self, store):
176
+ if not self.built:
177
+ self.build()
178
+ super().load_own_variables(store)
179
+ if not self.tie_weights:
180
+ # Last weights in the stores are the reverse embedding weights.
181
+ target_variables = [self.reverse_embeddings]
182
+ if getattr(self, "quantization_mode", None) == "int8":
183
+ target_variables.append(self.reverse_embeddings_scale)
184
+ for i, variable in enumerate(
185
+ target_variables, start=len(store) - len(target_variables)
186
+ ):
187
+ variable.assign(store[str(i)])
188
+
189
+ def compute_output_spec(self, inputs, reverse=False):
190
+ output_shape = list(inputs.shape)
191
+ if reverse:
192
+ output_shape[-1] = self.input_dim
193
+ else:
194
+ output_shape += [self.output_dim]
195
+ return keras.KerasTensor(output_shape, dtype=self.compute_dtype)
196
+
197
+ # Quantization-related (int8) methods
198
+
199
+ def quantized_call(self, inputs, reverse=False):
200
+ # TODO (hongyu): This function could be removed once we add `*args` and
201
+ # `**kwargs` for `Embedding.quantized_call`
202
+ if self.quantization_mode == "int8":
203
+ return self._int8_call(inputs, reverse=reverse)
204
+ else:
205
+ self._quantization_mode_error(self.quantization_mode)
206
+
207
+ def _int8_build(
208
+ self,
209
+ embeddings_initializer="zeros",
210
+ embeddings_scale_initializer="ones",
211
+ reverse_embeddings_initializer="zeros",
212
+ reverse_embeddings_scale_initializer="ones",
213
+ ):
214
+ super()._int8_build(
215
+ embeddings_initializer, embeddings_scale_initializer
216
+ )
217
+ self.inputs_quantizer = keras.quantizers.AbsMaxQuantizer(axis=-1)
218
+ if not self.tie_weights:
219
+ self.reverse_embeddings = self.add_weight(
220
+ name="reverse_embeddings",
221
+ shape=(self.output_dim, self.input_dim),
222
+ initializer=reverse_embeddings_initializer,
223
+ dtype="int8",
224
+ trainable=False,
225
+ )
226
+ self.reverse_embeddings_scale = self.add_weight(
227
+ name="reverse_embeddings_scale",
228
+ shape=(self.input_dim,),
229
+ initializer=reverse_embeddings_scale_initializer,
230
+ trainable=False,
231
+ )
232
+
233
+ def _int8_call(self, inputs, reverse=False):
234
+ if reverse:
235
+ if self.tie_weights:
236
+ kernel = ops.transpose(self._embeddings)
237
+ scale = ops.transpose(self.embeddings_scale)
238
+ else:
239
+ kernel = self.reverse_embeddings
240
+ scale = self.reverse_embeddings_scale
241
+ inputs, inputs_scale = self.inputs_quantizer(inputs)
242
+ logits = ops.matmul(inputs, kernel)
243
+ # De-scale outputs
244
+ logits = ops.cast(logits, self.compute_dtype)
245
+ logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
246
+ # Optionally soft-cap logits.
247
+ if self.logit_soft_cap is not None:
248
+ soft_cap = self.logit_soft_cap
249
+ logits = ops.tanh(logits / soft_cap) * soft_cap
250
+ return logits
251
+
252
+ return super()._int8_call(inputs)
253
+
254
+ def quantize(self, mode, type_check=True):
255
+ import gc
256
+ import inspect
257
+
258
+ assert_quantization_support()
259
+ if type_check and type(self) is not ReversibleEmbedding:
260
+ raise NotImplementedError(
261
+ f"Layer {self.__class__.__name__} does not have a `quantize()` "
262
+ "method implemented."
263
+ )
264
+ self._check_quantize_args(mode, self.compute_dtype)
265
+
266
+ def abs_max_quantize(inputs, axis):
267
+ sig = inspect.signature(keras.quantizers.abs_max_quantize)
268
+ if "to_numpy" in sig.parameters:
269
+ return keras.quantizers.abs_max_quantize(
270
+ inputs, axis=axis, to_numpy=True
271
+ )
272
+ else:
273
+ # `keras<=3.4.1` doesn't support `to_numpy`
274
+ return keras.quantizers.abs_max_quantize(inputs, axis=axis)
275
+
276
+ self._tracker.unlock()
277
+ if mode == "int8":
278
+ embeddings, embeddings_scale = abs_max_quantize(
279
+ self._embeddings, axis=-1
280
+ )
281
+ embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
282
+ self._untrack_variable(self._embeddings)
283
+ del self._embeddings
284
+ if not self.tie_weights:
285
+ reverse_embeddings, reverse_embeddings_scale = abs_max_quantize(
286
+ self.reverse_embeddings, axis=0
287
+ )
288
+ reverse_embeddings_scale = ops.squeeze(
289
+ reverse_embeddings_scale, axis=0
290
+ )
291
+ self._untrack_variable(self.reverse_embeddings)
292
+ del self.reverse_embeddings
293
+ else:
294
+ reverse_embeddings = None
295
+ reverse_embeddings_scale = None
296
+ self._int8_build(
297
+ lambda shape, dtype: embeddings,
298
+ lambda shape, dtype: embeddings_scale,
299
+ lambda shape, dtype: reverse_embeddings,
300
+ lambda shape, dtype: reverse_embeddings_scale,
301
+ )
302
+ else:
303
+ raise self._quantization_mode_error(mode)
304
+ self._tracker.lock()
305
+
306
+ if self.dtype_policy.quantization_mode is None:
307
+ policy = keras.dtype_policies.get(
308
+ f"{mode}_from_{self.dtype_policy.name}"
309
+ )
310
+ self.dtype_policy = policy
311
+ gc.collect()