keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,147 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """BERT model preset configurations."""
15
+
16
+ backbone_presets = {
17
+ "bert_tiny_en_uncased": {
18
+ "metadata": {
19
+ "description": (
20
+ "2-layer BERT model where all input is lowercased. "
21
+ "Trained on English Wikipedia + BooksCorpus."
22
+ ),
23
+ "params": 4385920,
24
+ "official_name": "BERT",
25
+ "path": "bert",
26
+ "model_card": "https://github.com/google-research/bert/blob/master/README.md",
27
+ },
28
+ "kaggle_handle": "kaggle://keras/bert/keras/bert_tiny_en_uncased/2",
29
+ },
30
+ "bert_small_en_uncased": {
31
+ "metadata": {
32
+ "description": (
33
+ "4-layer BERT model where all input is lowercased. "
34
+ "Trained on English Wikipedia + BooksCorpus."
35
+ ),
36
+ "params": 28763648,
37
+ "official_name": "BERT",
38
+ "path": "bert",
39
+ "model_card": "https://github.com/google-research/bert/blob/master/README.md",
40
+ },
41
+ "kaggle_handle": "kaggle://keras/bert/keras/bert_small_en_uncased/2",
42
+ },
43
+ "bert_medium_en_uncased": {
44
+ "metadata": {
45
+ "description": (
46
+ "8-layer BERT model where all input is lowercased. "
47
+ "Trained on English Wikipedia + BooksCorpus."
48
+ ),
49
+ "params": 41373184,
50
+ "official_name": "BERT",
51
+ "path": "bert",
52
+ "model_card": "https://github.com/google-research/bert/blob/master/README.md",
53
+ },
54
+ "kaggle_handle": "kaggle://keras/bert/keras/bert_medium_en_uncased/2",
55
+ },
56
+ "bert_base_en_uncased": {
57
+ "metadata": {
58
+ "description": (
59
+ "12-layer BERT model where all input is lowercased. "
60
+ "Trained on English Wikipedia + BooksCorpus."
61
+ ),
62
+ "params": 109482240,
63
+ "official_name": "BERT",
64
+ "path": "bert",
65
+ "model_card": "https://github.com/google-research/bert/blob/master/README.md",
66
+ },
67
+ "kaggle_handle": "kaggle://keras/bert/keras/bert_base_en_uncased/2",
68
+ },
69
+ "bert_base_en": {
70
+ "metadata": {
71
+ "description": (
72
+ "12-layer BERT model where case is maintained. "
73
+ "Trained on English Wikipedia + BooksCorpus."
74
+ ),
75
+ "params": 108310272,
76
+ "official_name": "BERT",
77
+ "path": "bert",
78
+ "model_card": "https://github.com/google-research/bert/blob/master/README.md",
79
+ },
80
+ "kaggle_handle": "kaggle://keras/bert/keras/bert_base_en/2",
81
+ },
82
+ "bert_base_zh": {
83
+ "metadata": {
84
+ "description": (
85
+ "12-layer BERT model. Trained on Chinese Wikipedia."
86
+ ),
87
+ "params": 102267648,
88
+ "official_name": "BERT",
89
+ "path": "bert",
90
+ "model_card": "https://github.com/google-research/bert/blob/master/README.md",
91
+ },
92
+ "kaggle_handle": "kaggle://keras/bert/keras/bert_base_zh/2",
93
+ },
94
+ "bert_base_multi": {
95
+ "metadata": {
96
+ "description": (
97
+ "12-layer BERT model where case is maintained. Trained on trained on Wikipedias of 104 languages"
98
+ ),
99
+ "params": 177853440,
100
+ "official_name": "BERT",
101
+ "path": "bert",
102
+ "model_card": "https://github.com/google-research/bert/blob/master/README.md",
103
+ },
104
+ "kaggle_handle": "kaggle://keras/bert/keras/bert_base_multi/2",
105
+ },
106
+ "bert_large_en_uncased": {
107
+ "metadata": {
108
+ "description": (
109
+ "24-layer BERT model where all input is lowercased. "
110
+ "Trained on English Wikipedia + BooksCorpus."
111
+ ),
112
+ "params": 335141888,
113
+ "official_name": "BERT",
114
+ "path": "bert",
115
+ "model_card": "https://github.com/google-research/bert/blob/master/README.md",
116
+ },
117
+ "kaggle_handle": "kaggle://keras/bert/keras/bert_large_en_uncased/2",
118
+ },
119
+ "bert_large_en": {
120
+ "metadata": {
121
+ "description": (
122
+ "24-layer BERT model where case is maintained. "
123
+ "Trained on English Wikipedia + BooksCorpus."
124
+ ),
125
+ "params": 333579264,
126
+ "official_name": "BERT",
127
+ "path": "bert",
128
+ "model_card": "https://github.com/google-research/bert/blob/master/README.md",
129
+ },
130
+ "kaggle_handle": "kaggle://keras/bert/keras/bert_large_en/2",
131
+ },
132
+ }
133
+
134
+ classifier_presets = {
135
+ "bert_tiny_en_uncased_sst2": {
136
+ "metadata": {
137
+ "description": (
138
+ "The bert_tiny_en_uncased backbone model fine-tuned on the SST-2 sentiment analysis dataset."
139
+ ),
140
+ "params": 4385920,
141
+ "official_name": "BERT",
142
+ "path": "bert",
143
+ "model_card": "https://github.com/google-research/bert/blob/master/README.md",
144
+ },
145
+ "kaggle_handle": "kaggle://keras/bert/keras/bert_tiny_en_uncased_sst2/4",
146
+ }
147
+ }
@@ -0,0 +1,112 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.api_export import keras_hub_export
16
+ from keras_hub.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer
17
+
18
+
19
+ @keras_hub_export("keras_hub.models.BertTokenizer")
20
+ class BertTokenizer(WordPieceTokenizer):
21
+ """A BERT tokenizer using WordPiece subword segmentation.
22
+
23
+ This tokenizer class will tokenize raw strings into integer sequences and
24
+ is based on `keras_hub.tokenizers.WordPieceTokenizer`. Unlike the
25
+ underlying tokenizer, it will check for all special tokens needed by BERT
26
+ models and provides a `from_preset()` method to automatically download
27
+ a matching vocabulary for a BERT preset.
28
+
29
+ This tokenizer does not provide truncation or padding of inputs. It can be
30
+ combined with a `keras_hub.models.BertPreprocessor` layer for input packing.
31
+
32
+ If input is a batch of strings (rank > 0), the layer will output a
33
+ `tf.RaggedTensor` where the last dimension of the output is ragged.
34
+
35
+ If input is a scalar string (rank == 0), the layer will output a dense
36
+ `tf.Tensor` with static shape `[None]`.
37
+
38
+ Args:
39
+ vocabulary: A list of strings or a string filename path. If
40
+ passing a list, each element of the list should be a single word
41
+ piece token string. If passing a filename, the file should be a
42
+ plain text file containing a single word piece token per line.
43
+ lowercase: If `True`, the input text will be first lowered before
44
+ tokenization.
45
+ special_tokens_in_strings: bool. A bool to indicate if the tokenizer
46
+ should expect special tokens in input strings that should be
47
+ tokenized and mapped correctly to their ids. Defaults to False.
48
+
49
+ Examples:
50
+ ```python
51
+ # Unbatched input.
52
+ tokenizer = keras_hub.models.BertTokenizer.from_preset(
53
+ "bert_base_en_uncased",
54
+ )
55
+ tokenizer("The quick brown fox jumped.")
56
+
57
+ # Batched input.
58
+ tokenizer(["The quick brown fox jumped.", "The fox slept."])
59
+
60
+ # Detokenization.
61
+ tokenizer.detokenize(tokenizer("The quick brown fox jumped."))
62
+
63
+ # Custom vocabulary.
64
+ vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
65
+ vocab += ["The", "quick", "brown", "fox", "jumped", "."]
66
+ tokenizer = keras_hub.models.BertTokenizer(vocabulary=vocab)
67
+ tokenizer("The quick brown fox jumped.")
68
+ ```
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ vocabulary=None,
74
+ lowercase=False,
75
+ special_tokens_in_strings=False,
76
+ **kwargs,
77
+ ):
78
+ self.cls_token = "[CLS]"
79
+ self.sep_token = "[SEP]"
80
+ self.pad_token = "[PAD]"
81
+ self.mask_token = "[MASK]"
82
+ super().__init__(
83
+ vocabulary=vocabulary,
84
+ lowercase=lowercase,
85
+ special_tokens=[
86
+ self.cls_token,
87
+ self.sep_token,
88
+ self.pad_token,
89
+ self.mask_token,
90
+ ],
91
+ special_tokens_in_strings=special_tokens_in_strings,
92
+ **kwargs,
93
+ )
94
+
95
+ def set_vocabulary(self, vocabulary):
96
+ super().set_vocabulary(vocabulary)
97
+
98
+ if vocabulary is not None:
99
+ self.cls_token_id = self.token_to_id(self.cls_token)
100
+ self.sep_token_id = self.token_to_id(self.sep_token)
101
+ self.pad_token_id = self.token_to_id(self.pad_token)
102
+ self.mask_token_id = self.token_to_id(self.mask_token)
103
+ else:
104
+ self.cls_token_id = None
105
+ self.sep_token_id = None
106
+ self.pad_token_id = None
107
+ self.mask_token_id = None
108
+
109
+ def get_config(self):
110
+ config = super().get_config()
111
+ del config["special_tokens"] # Not configurable; set in __init__.
112
+ return config
@@ -0,0 +1,20 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone
16
+ from keras_hub.src.models.bloom.bloom_presets import backbone_presets
17
+ from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer
18
+ from keras_hub.src.utils.preset_utils import register_presets
19
+
20
+ register_presets(backbone_presets, (BloomBackbone, BloomTokenizer))
@@ -0,0 +1,186 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ import keras
17
+ from keras import ops
18
+
19
+ from keras_hub.src.layers.modeling.alibi_bias import AlibiBias
20
+ from keras_hub.src.utils.keras_utils import clone_initializer
21
+
22
+
23
+ class BloomAttention(keras.layers.Layer):
24
+ def __init__(
25
+ self,
26
+ num_heads,
27
+ dropout=0.0,
28
+ kernel_initializer="glorot_uniform",
29
+ bias_initializer="zeros",
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.num_heads = num_heads
34
+ self.dropout = dropout
35
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
36
+ self.bias_initializer = keras.initializers.get(bias_initializer)
37
+
38
+ def build(self, inputs_shape):
39
+ batch_size, seq_length, hidden_dim = inputs_shape
40
+
41
+ self.head_dim = hidden_dim // self.num_heads
42
+
43
+ # Layer-wise attention scaling
44
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
45
+
46
+ self._query_dense = keras.layers.EinsumDense(
47
+ equation="btm,mnh->btnh",
48
+ output_shape=(None, self.num_heads, self.head_dim),
49
+ bias_axes="nh",
50
+ kernel_initializer=clone_initializer(self.kernel_initializer),
51
+ bias_initializer=clone_initializer(self.bias_initializer),
52
+ dtype=self.dtype_policy,
53
+ name="query_dense",
54
+ )
55
+ self._query_dense.build(inputs_shape)
56
+
57
+ self._key_dense = keras.layers.EinsumDense(
58
+ equation="bsm,mnh->bsnh",
59
+ output_shape=(None, self.num_heads, self.head_dim),
60
+ bias_axes="nh",
61
+ kernel_initializer=clone_initializer(self.kernel_initializer),
62
+ bias_initializer=clone_initializer(self.bias_initializer),
63
+ dtype=self.dtype_policy,
64
+ name="key_dense",
65
+ )
66
+ self._key_dense.build(inputs_shape)
67
+
68
+ self._value_dense = keras.layers.EinsumDense(
69
+ equation="bsm,mnh->bsnh",
70
+ output_shape=(None, self.num_heads, self.head_dim),
71
+ bias_axes="nh",
72
+ kernel_initializer=clone_initializer(self.kernel_initializer),
73
+ bias_initializer=clone_initializer(self.bias_initializer),
74
+ dtype=self.dtype_policy,
75
+ name="value_dense",
76
+ )
77
+ self._value_dense.build(inputs_shape)
78
+
79
+ self._alibi_layer = AlibiBias(
80
+ dtype=self.dtype_policy,
81
+ )
82
+
83
+ self._output_dense = keras.layers.Dense(
84
+ hidden_dim,
85
+ kernel_initializer=clone_initializer(self.kernel_initializer),
86
+ bias_initializer=clone_initializer(self.bias_initializer),
87
+ dtype=self.dtype_policy,
88
+ name="output_dense",
89
+ )
90
+ self._output_dense.build(inputs_shape)
91
+
92
+ self._dropout_layer = keras.layers.Dropout(
93
+ rate=self.dropout,
94
+ dtype=self.dtype_policy,
95
+ name="dropout",
96
+ )
97
+ self._softmax = keras.layers.Softmax(
98
+ dtype="float32",
99
+ name="softmax",
100
+ )
101
+
102
+ self.built = True
103
+
104
+ def call(
105
+ self,
106
+ hidden_states,
107
+ attention_mask=None,
108
+ cache=None,
109
+ cache_update_index=None,
110
+ ):
111
+ batch_size, seq_length, hidden_dim = ops.shape(hidden_states)
112
+
113
+ query = self._query_dense(hidden_states)
114
+ key = self._key_dense(hidden_states)
115
+ value = self._value_dense(hidden_states)
116
+
117
+ if cache is not None:
118
+ key_cache = cache[:, 0, ...]
119
+ value_cache = cache[:, 1, ...]
120
+ if cache_update_index is None:
121
+ key = key_cache
122
+ value = value_cache
123
+ else:
124
+ start = [0, cache_update_index, 0, 0]
125
+ key = ops.slice_update(key_cache, start, key)
126
+ value = ops.slice_update(value_cache, start, value)
127
+ cache = ops.stack((key, value), axis=1)
128
+ else:
129
+ if cache_update_index is not None:
130
+ raise ValueError(
131
+ "`cache_update_index` should not be set if `cache` is "
132
+ f"`None`. Received: cache={cache}, "
133
+ f"cache_update_index={cache_update_index}"
134
+ )
135
+
136
+ # query (batch_size, num_heads, query_length, head_dim)
137
+ query = ops.transpose(query, [0, 2, 1, 3])
138
+ # value (batch_size, num_heads, kv_length, head_dim)
139
+ value = ops.transpose(value, [0, 2, 1, 3])
140
+ # key (batch_size, num_heads, head_dim, kv_length)
141
+ key = ops.transpose(key, [0, 2, 3, 1])
142
+
143
+ attention_scores = (
144
+ ops.matmul(query, key) * self.inv_norm_factor
145
+ ) # [batch_size, num_heads, query_length, kv_length]
146
+ attention_scores = self._alibi_layer(attention_scores)
147
+ attention_scores = self._softmax(
148
+ attention_scores, ops.expand_dims(attention_mask, 1)
149
+ )
150
+ attention_scores = self._dropout_layer(attention_scores)
151
+
152
+ attention_output = ops.matmul(
153
+ attention_scores, value
154
+ ) # [batch_size, num_heads, query_length, head_dim]
155
+
156
+ attention_output = ops.transpose(
157
+ attention_output, [0, 2, 1, 3]
158
+ ) # [batch_size, query_length, num_heads, head_dim]
159
+ attention_output = ops.reshape(
160
+ attention_output,
161
+ [batch_size, seq_length, self.num_heads * self.head_dim],
162
+ ) # [batch_size, query_length, hidden_dim]
163
+
164
+ attention_output = self._output_dense(attention_output)
165
+ attention_output = self._dropout_layer(attention_output)
166
+
167
+ if cache is not None:
168
+ return attention_output, cache
169
+
170
+ return attention_output
171
+
172
+ def get_config(self):
173
+ config = super().get_config()
174
+ config.update(
175
+ {
176
+ "num_heads": self.num_heads,
177
+ "dropout": self.dropout,
178
+ "kernel_initializer": keras.initializers.serialize(
179
+ self.kernel_initializer
180
+ ),
181
+ "bias_initializer": keras.initializers.serialize(
182
+ self.bias_initializer
183
+ ),
184
+ }
185
+ )
186
+ return config
@@ -0,0 +1,173 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.layers.modeling.reversible_embedding import (
19
+ ReversibleEmbedding,
20
+ )
21
+ from keras_hub.src.models.backbone import Backbone
22
+ from keras_hub.src.models.bloom.bloom_decoder import BloomDecoder
23
+
24
+
25
+ def _bloom_kernel_initializer(stddev=0.02):
26
+ return keras.initializers.RandomNormal(stddev=stddev)
27
+
28
+
29
+ @keras_hub_export("keras_hub.models.BloomBackbone")
30
+ class BloomBackbone(Backbone):
31
+ """A BLOOM decoder network.
32
+
33
+ This network implements a Transformer-based decoder network, BigScience
34
+ Language Open-science Open-access Multilingual (BLOOM), as descriped in
35
+ ["BLOOM: A 176B-Parameter Open-Access Multilingual Language Model"](https://arxiv.org/pdf/2211.05100.pdf).
36
+
37
+ The default constructor gives a fully customizable, randomly initialized
38
+ Bloom model with any number of layers, heads, and embedding dimensions. To
39
+ load preset architectures and weights, use the `from_preset()` constructor.
40
+
41
+ Disclaimer: Pre-trained models are provided on an "as is" basis, without
42
+ warranties or conditions of any kind. The underlying model is provided by a
43
+ third party and subject to a separate license, available [here](https://huggingface.co/spaces/bigscience/license).
44
+
45
+ Args:
46
+ vocabulary_size: int. The size of the token vocabulary.
47
+ num_layers: int. The number of transformer layers.
48
+ num_heads: int. The number of attention heads for each transformer.
49
+ The hidden size must be divisible by the number of attention heads.
50
+ hidden_dim: int. The dimensionality of the embeddings and hidden states.
51
+ intermediate_dim: int. The output dimension of the first Dense layer in
52
+ the MLP network of each transformer.
53
+ dropout: float. Dropout probability for the Transformer decoder.
54
+ layer_norm_epsilon: float. Epsilon for the layer normalization layers in
55
+ the transformer decoder.
56
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
57
+ for model computations and weights. Note that some computations,
58
+ such as softmax and layer normalization, will always be done at
59
+ float32 precision regardless of dtype.
60
+
61
+ Example:
62
+ ```python
63
+ input_data = {
64
+ "token_ids": np.ones(shape=(1, 12), dtype="int32"),
65
+ "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
66
+ }
67
+
68
+ # Pretrained BLOOM decoder.
69
+ model = keras_hub.models.BloomBackbone.from_preset("bloom_560m_multi")
70
+ model(input_data)
71
+
72
+ # Randomly initialized BLOOM decoder with a custom config.
73
+ model = keras_hub.models.BloomBackbone(
74
+ vocabulary_size=10,
75
+ num_layers=2,
76
+ num_heads=2,
77
+ hidden_dim=32,
78
+ intermediate_dim=32*4,
79
+ dropout=0.0,
80
+ layer_norm_epsilon=1e-5,
81
+ )
82
+ model(input_data)
83
+ ```
84
+
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ vocabulary_size,
90
+ num_layers,
91
+ num_heads,
92
+ hidden_dim,
93
+ intermediate_dim,
94
+ dropout=0.0,
95
+ layer_norm_epsilon=1e-5,
96
+ dtype=None,
97
+ **kwargs,
98
+ ):
99
+ # === Layers ===
100
+ self.token_embedding = ReversibleEmbedding(
101
+ input_dim=vocabulary_size,
102
+ output_dim=hidden_dim,
103
+ embeddings_initializer=_bloom_kernel_initializer(stddev=0.02),
104
+ dtype=dtype,
105
+ name="token_embedding",
106
+ )
107
+ self.embeddings_layer_norm = keras.layers.LayerNormalization(
108
+ epsilon=layer_norm_epsilon,
109
+ dtype=dtype,
110
+ name="embedding_layernorm",
111
+ )
112
+ self.transformer_layers = []
113
+ for i in range(num_layers):
114
+ layer = BloomDecoder(
115
+ num_heads=num_heads,
116
+ intermediate_dim=intermediate_dim,
117
+ dropout=dropout,
118
+ layer_norm_epsilon=layer_norm_epsilon,
119
+ dtype=dtype,
120
+ name=f"transformer_layer_{i}",
121
+ )
122
+ self.transformer_layers.append(layer)
123
+ self.layer_norm = keras.layers.LayerNormalization(
124
+ epsilon=layer_norm_epsilon,
125
+ dtype=dtype,
126
+ name="final_layernorm",
127
+ )
128
+
129
+ # === Functional Model ===
130
+ token_id_input = keras.Input(
131
+ shape=(None,), dtype="int32", name="token_ids"
132
+ )
133
+ padding_mask_input = keras.Input(
134
+ shape=(None,), dtype="int32", name="padding_mask"
135
+ )
136
+ x = self.token_embedding(token_id_input)
137
+ x = self.embeddings_layer_norm(x)
138
+ for transformer_layer in self.transformer_layers:
139
+ x = transformer_layer(x, decoder_padding_mask=padding_mask_input)
140
+ sequence_output = self.layer_norm(x)
141
+ super().__init__(
142
+ inputs={
143
+ "token_ids": token_id_input,
144
+ "padding_mask": padding_mask_input,
145
+ },
146
+ outputs=sequence_output,
147
+ dtype=dtype,
148
+ **kwargs,
149
+ )
150
+
151
+ # === Config ===
152
+ self.vocabulary_size = vocabulary_size
153
+ self.num_layers = num_layers
154
+ self.num_heads = num_heads
155
+ self.hidden_dim = hidden_dim
156
+ self.intermediate_dim = intermediate_dim
157
+ self.dropout = dropout
158
+ self.layer_norm_epsilon = layer_norm_epsilon
159
+
160
+ def get_config(self):
161
+ config = super().get_config()
162
+ config.update(
163
+ {
164
+ "vocabulary_size": self.vocabulary_size,
165
+ "num_layers": self.num_layers,
166
+ "num_heads": self.num_heads,
167
+ "hidden_dim": self.hidden_dim,
168
+ "intermediate_dim": self.intermediate_dim,
169
+ "dropout": self.dropout,
170
+ "layer_norm_epsilon": self.layer_norm_epsilon,
171
+ }
172
+ )
173
+ return config