keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,311 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import keras
18
+
19
+ from keras_hub.src.api_export import keras_hub_export
20
+ from keras_hub.src.utils.keras_utils import assert_quantization_support
21
+ from keras_hub.src.utils.preset_utils import CONFIG_FILE
22
+ from keras_hub.src.utils.preset_utils import MODEL_WEIGHTS_FILE
23
+ from keras_hub.src.utils.preset_utils import check_config_class
24
+ from keras_hub.src.utils.preset_utils import check_format
25
+ from keras_hub.src.utils.preset_utils import get_file
26
+ from keras_hub.src.utils.preset_utils import jax_memory_cleanup
27
+ from keras_hub.src.utils.preset_utils import list_presets
28
+ from keras_hub.src.utils.preset_utils import list_subclasses
29
+ from keras_hub.src.utils.preset_utils import load_serialized_object
30
+ from keras_hub.src.utils.preset_utils import save_metadata
31
+ from keras_hub.src.utils.preset_utils import save_serialized_object
32
+ from keras_hub.src.utils.python_utils import classproperty
33
+ from keras_hub.src.utils.timm.convert import load_timm_backbone
34
+ from keras_hub.src.utils.transformers.convert import load_transformers_backbone
35
+
36
+
37
+ @keras_hub_export("keras_hub.models.Backbone")
38
+ class Backbone(keras.Model):
39
+ """Base class for all `Backbone` models.
40
+
41
+ A `Backbone` is the basic architecture for a given NLP model. Unlike a
42
+ `keras_hub.models.Task`, a `Backbone` is not tailored to any specific loss
43
+ function and training setup. A `Backbone` generally outputs the last hidden
44
+ states of an architecture before any output predictions.
45
+
46
+ A `Backbone` can be used in one of two ways:
47
+
48
+ 1. Through a `Task` class, which will wrap and extend a `Backbone` so it
49
+ can be used with high level Keras functions like `fit()`, `predict()` or
50
+ `evaluate()`. `Task` classes are built with a particular training
51
+ objective in mind (e.g. classification or language modeling).
52
+ 2. Directly, by extending underlying functional model with additional
53
+ outputs and training setup. This is the most flexible approach, and can
54
+ allow for any outputs, loss, or custom training loop.
55
+
56
+ All backbones include a `from_preset()` constructor which can be used to
57
+ load a pre-trained config and weights.
58
+
59
+ Example:
60
+ ```python
61
+ # Load a BERT backbone with pre-trained weights.
62
+ backbone = keras_hub.models.Backbone.from_preset(
63
+ "bert_base_en",
64
+ )
65
+ # Load a GPT2 backbone with pre-trained weights at bfloat16 precision.
66
+ backbone = keras_hub.models.Backbone.from_preset(
67
+ "gpt2_base_en",
68
+ dtype="bfloat16",
69
+ trainable=False,
70
+ )
71
+ ```
72
+ """
73
+
74
+ def __init__(self, *args, dtype=None, **kwargs):
75
+ super().__init__(*args, **kwargs)
76
+ self._functional_layer_ids = set(
77
+ id(layer) for layer in self._flatten_layers()
78
+ )
79
+ self._initialized = True
80
+ if dtype is not None:
81
+ try:
82
+ self.dtype_policy = keras.dtype_policies.get(dtype)
83
+ # Before Keras 3.2, there is no `keras.dtype_policies.get`.
84
+ except AttributeError:
85
+ if isinstance(dtype, keras.DTypePolicy):
86
+ dtype = dtype.name
87
+ self.dtype_policy = keras.DTypePolicy(dtype)
88
+
89
+ def __setattr__(self, name, value):
90
+ # Work around setattr issues for Keras 2 and Keras 3 torch backend.
91
+ # Since all our state is covered by functional model we can route
92
+ # around custom setattr calls.
93
+ is_property = isinstance(getattr(type(self), name, None), property)
94
+ is_unitialized = not hasattr(self, "_initialized")
95
+ simple_setattr = keras.config.backend() == "torch"
96
+ if simple_setattr and (is_property or is_unitialized):
97
+ return object.__setattr__(self, name, value)
98
+ return super().__setattr__(name, value)
99
+
100
+ @property
101
+ def token_embedding(self):
102
+ """A `keras.layers.Embedding` instance for embedding token ids.
103
+
104
+ This layer embeds integer token ids to the hidden dim of the model.
105
+ """
106
+ return getattr(self, "_token_embedding", None)
107
+
108
+ @token_embedding.setter
109
+ def token_embedding(self, value):
110
+ self._token_embedding = value
111
+
112
+ def quantize(self, mode, **kwargs):
113
+ assert_quantization_support()
114
+ return super().quantize(mode, **kwargs)
115
+
116
+ def get_config(self):
117
+ # Don't chain to super here. `get_config()` for functional models is
118
+ # a nested layer config and cannot be passed to Backbone constructors.
119
+ config = {
120
+ "name": self.name,
121
+ "trainable": self.trainable,
122
+ }
123
+
124
+ # Add quantization support by utilizing `DTypePolicyMap`
125
+ try:
126
+ if isinstance(
127
+ self.dtype_policy, keras.dtype_policies.DTypePolicyMap
128
+ ):
129
+ config.update({"dtype": self.dtype_policy})
130
+ else:
131
+ policy_map = keras.dtype_policies.DTypePolicyMap()
132
+ for layer in self._flatten_layers():
133
+ if layer.quantization_mode is not None:
134
+ policy_map[layer.path] = layer.dtype_policy
135
+ if len(policy_map) > 0:
136
+ config.update({"dtype": policy_map})
137
+ # Before Keras 3.2, there is no `keras.dtype_policies.get`.
138
+ except AttributeError:
139
+ pass
140
+ return config
141
+
142
+ @classmethod
143
+ def from_config(cls, config):
144
+ # The default `from_config()` for functional models will return a
145
+ # vanilla `keras.Model`. We override it to get a subclass instance back.
146
+ return cls(**config)
147
+
148
+ @classproperty
149
+ def presets(cls):
150
+ """List built-in presets for a `Task` subclass."""
151
+ presets = list_presets(cls)
152
+ for subclass in list_subclasses(cls):
153
+ presets.update(subclass.presets)
154
+ return presets
155
+
156
+ @classmethod
157
+ def from_preset(
158
+ cls,
159
+ preset,
160
+ load_weights=True,
161
+ **kwargs,
162
+ ):
163
+ """Instantiate a `keras_hub.models.Backbone` from a model preset.
164
+
165
+ A preset is a directory of configs, weights and other file assets used
166
+ to save and load a pre-trained model. The `preset` can be passed as a
167
+ one of:
168
+
169
+ 1. a built in preset identifier like `'bert_base_en'`
170
+ 2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'`
171
+ 3. a Hugging Face handle like `'hf://user/bert_base_en'`
172
+ 4. a path to a local preset directory like `'./bert_base_en'`
173
+
174
+ This constructor can be called in one of two ways. Either from the base
175
+ class like `keras_hub.models.Backbone.from_preset()`, or from
176
+ a model class like `keras_hub.models.GemmaBackbone.from_preset()`.
177
+ If calling from the base class, the subclass of the returning object
178
+ will be inferred from the config in the preset directory.
179
+
180
+ For any `Backbone` subclass, you can run `cls.presets.keys()` to list
181
+ all built-in presets available on the class.
182
+
183
+ Args:
184
+ preset: string. A built in preset identifier, a Kaggle Models
185
+ handle, a Hugging Face handle, or a path to a local directory.
186
+ load_weights: bool. If `True`, the weights will be loaded into the
187
+ model architecture. If `False`, the weights will be randomly
188
+ initialized.
189
+
190
+ Examples:
191
+ ```python
192
+ # Load a Gemma backbone with pre-trained weights.
193
+ model = keras_hub.models.Backbone.from_preset(
194
+ "gemma_2b_en",
195
+ )
196
+
197
+ # Load a Bert backbone with a pre-trained config and random weights.
198
+ model = keras_hub.models.Backbone.from_preset(
199
+ "bert_base_en",
200
+ load_weights=False,
201
+ )
202
+ ```
203
+ """
204
+ format = check_format(preset)
205
+
206
+ if format == "transformers":
207
+ return load_transformers_backbone(cls, preset, load_weights)
208
+ elif format == "timm":
209
+ return load_timm_backbone(cls, preset, load_weights, **kwargs)
210
+
211
+ preset_cls = check_config_class(preset)
212
+ if not issubclass(preset_cls, cls):
213
+ raise ValueError(
214
+ f"Preset has type `{preset_cls.__name__}` which is not a "
215
+ f"a subclass of calling class `{cls.__name__}`. Call "
216
+ f"`from_preset` directly on `{preset_cls.__name__}` instead."
217
+ )
218
+
219
+ backbone = load_serialized_object(preset, CONFIG_FILE, **kwargs)
220
+ if load_weights:
221
+ jax_memory_cleanup(backbone)
222
+ backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
223
+
224
+ return backbone
225
+
226
+ def save_to_preset(self, preset_dir):
227
+ """Save backbone to a preset directory.
228
+
229
+ Args:
230
+ preset_dir: The path to the local model preset directory.
231
+ """
232
+ save_serialized_object(self, preset_dir, config_file=CONFIG_FILE)
233
+ self.save_weights(os.path.join(preset_dir, MODEL_WEIGHTS_FILE))
234
+ save_metadata(self, preset_dir)
235
+
236
+ def enable_lora(self, rank):
237
+ """Enable Lora on the backbone.
238
+
239
+ Calling this method will freeze all weights on the backbone,
240
+ while enabling Lora on the query & value `EinsumDense` layers
241
+ of the attention layers.
242
+ """
243
+ target_names = ["query_dense", "value_dense", "query", "value"]
244
+ self.trainable = True
245
+ self._lora_enabled_layers = []
246
+ self._lora_rank = rank
247
+ for layer in self._flatten_layers(include_self=False):
248
+ layer.trainable = False
249
+ all_layers = self._flatten_layers(include_self=False)
250
+ all_layers = [lyr for lyr in all_layers if lyr.weights]
251
+ for i, layer in enumerate(all_layers):
252
+ for name in target_names:
253
+ if layer.name == name:
254
+ if hasattr(layer, "enable_lora"):
255
+ layer.trainable = True
256
+ layer.enable_lora(rank)
257
+ self._lora_enabled_layers.append(i)
258
+
259
+ def save_lora_weights(self, filepath):
260
+ if not getattr(self, "_lora_enabled_layers", []):
261
+ raise ValueError(
262
+ "There are no lora-enabled layers in this model. "
263
+ "Make sure to call `.enable_lora(rank)` first."
264
+ )
265
+ if not str(filepath).endswith(".lora.h5"):
266
+ raise ValueError(
267
+ "The filename must end in `.lora.h5`. "
268
+ f"Received: filepath={filepath}"
269
+ )
270
+
271
+ store = keras.src.saving.saving_lib.H5IOStore(filepath, mode="w")
272
+ lora_store = store.make("lora")
273
+ lora_store["rank"] = self._lora_rank
274
+ # We cannot identify layers by name since names are non-unique,
275
+ # so we identify them by index in the topologically sorted list
276
+ # of layers that have weights.
277
+ all_layers = self._flatten_layers(include_self=False)
278
+ all_layers = [lyr for lyr in all_layers if lyr.weights]
279
+ for layer_index in self._lora_enabled_layers:
280
+ # We only lora the einsumdense layers,
281
+ # so the factored weights are always named `kernel`
282
+ layer = all_layers[layer_index]
283
+ inner_store = store.make(f"lora/{layer_index}")
284
+ inner_store["lora_kernel_a"] = layer.lora_kernel_a
285
+ inner_store["lora_kernel_b"] = layer.lora_kernel_b
286
+ store.close()
287
+
288
+ def load_lora_weights(self, filepath):
289
+ store = keras.src.saving.saving_lib.H5IOStore(filepath, mode="r")
290
+ lora_store = store.get("lora")
291
+ rank = int(lora_store["rank"][()])
292
+
293
+ if not getattr(self, "_lora_enabled_layers", []):
294
+ self.enable_lora(rank)
295
+ else:
296
+ if self._lora_rank != rank:
297
+ raise ValueError(
298
+ f"The Lora rank expected by file '{filepath}' "
299
+ f"is rank={rank}, but the model was called with "
300
+ f"`.enable_lora(rank={self._lora_rank})`. "
301
+ "Both ranks must match."
302
+ )
303
+ all_layers = self._flatten_layers(include_self=False)
304
+ all_layers = [lyr for lyr in all_layers if lyr.weights]
305
+ for layer_index in self._lora_enabled_layers:
306
+ layer = all_layers[layer_index]
307
+ lora_kernel_a = store.get(f"lora/{layer_index}")["lora_kernel_a"]
308
+ lora_kernel_b = store.get(f"lora/{layer_index}")["lora_kernel_b"]
309
+ layer.lora_kernel_a.assign(lora_kernel_a)
310
+ layer.lora_kernel_b.assign(lora_kernel_b)
311
+ store.close()
@@ -0,0 +1,20 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.models.bart.bart_backbone import BartBackbone
16
+ from keras_hub.src.models.bart.bart_presets import backbone_presets
17
+ from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer
18
+ from keras_hub.src.utils.preset_utils import register_presets
19
+
20
+ register_presets(backbone_presets, (BartBackbone, BartTokenizer))
@@ -0,0 +1,261 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
19
+ from keras_hub.src.layers.modeling.reversible_embedding import (
20
+ ReversibleEmbedding,
21
+ )
22
+ from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder
23
+ from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
24
+ from keras_hub.src.models.backbone import Backbone
25
+
26
+
27
+ def bart_kernel_initializer(stddev=0.02):
28
+ return keras.initializers.TruncatedNormal(stddev=stddev)
29
+
30
+
31
+ @keras_hub_export("keras_hub.models.BartBackbone")
32
+ class BartBackbone(Backbone):
33
+ """BART encoder-decoder network.
34
+
35
+ This class implements a Transformer-based encoder-decoder model as
36
+ described in
37
+ ["BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension"](https://arxiv.org/abs/1910.13461).
38
+
39
+ The default constructor gives a fully customizable, randomly initialized BART
40
+ model with any number of layers, heads, and embedding dimensions. To load
41
+ preset architectures and weights, use the `from_preset` constructor.
42
+
43
+ Disclaimer: Pre-trained models are provided on an "as is" basis, without
44
+ warranties or conditions of any kind. The underlying model is provided by a
45
+ third party and subject to a separate license, available
46
+ [here](https://github.com/facebookresearch/fairseq/).
47
+
48
+ Args:
49
+ vocabulary_size: int. The size of the token vocabulary.
50
+ num_layers: int. The number of transformer encoder layers and
51
+ transformer decoder layers.
52
+ num_heads: int. The number of attention heads for each transformer.
53
+ The hidden size must be divisible by the number of attention heads.
54
+ hidden_dim: int. The size of the transformer encoding and pooler layers.
55
+ intermediate_dim: int. The output dimension of the first Dense layer in
56
+ a two-layer feedforward network for each transformer.
57
+ dropout: float. Dropout probability for the Transformer encoder.
58
+ max_sequence_length: int. The maximum sequence length that this encoder
59
+ can consume. If None, `max_sequence_length` uses the value from
60
+ sequence length. This determines the variable shape for positional
61
+ embeddings.
62
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
63
+ for model computations and weights. Note that some computations,
64
+ such as softmax and layer normalization, will always be done at
65
+ float32 precision regardless of dtype.
66
+
67
+ Examples:
68
+ ```python
69
+ input_data = {
70
+ "encoder_token_ids": np.ones(shape=(1, 12), dtype="int32"),
71
+ "encoder_padding_mask": np.array(
72
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]
73
+ ),
74
+ "decoder_token_ids": np.ones(shape=(1, 12), dtype="int32"),
75
+ "decoder_padding_mask": np.array(
76
+ [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]
77
+ ),
78
+ }
79
+
80
+ # Pretrained BART encoder.
81
+ model = keras_hub.models.BartBackbone.from_preset("bart_base_en")
82
+ model(input_data)
83
+
84
+ # Randomly initialized BART encoder-decoder model with a custom config
85
+ model = keras_hub.models.BartBackbone(
86
+ vocabulary_size=50265,
87
+ num_layers=6,
88
+ num_heads=12,
89
+ hidden_dim=768,
90
+ intermediate_dim=3072,
91
+ max_sequence_length=12,
92
+ )
93
+ output = model(input_data)
94
+ ```
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ vocabulary_size,
100
+ num_layers,
101
+ num_heads,
102
+ hidden_dim,
103
+ intermediate_dim,
104
+ dropout=0.1,
105
+ max_sequence_length=1024,
106
+ dtype=None,
107
+ **kwargs,
108
+ ):
109
+ # === Layers ===
110
+ self.token_embedding = ReversibleEmbedding(
111
+ input_dim=vocabulary_size,
112
+ output_dim=hidden_dim,
113
+ embeddings_initializer=bart_kernel_initializer(),
114
+ dtype=dtype,
115
+ name="token_embedding",
116
+ )
117
+ self.encoder_position_embedding = PositionEmbedding(
118
+ initializer=bart_kernel_initializer(),
119
+ sequence_length=max_sequence_length,
120
+ dtype=dtype,
121
+ name="encoder_position_embedding",
122
+ )
123
+ self.encoder_embeddings_add = keras.layers.Add(
124
+ dtype=dtype,
125
+ name="encoder_embeddings_add",
126
+ )
127
+ self.encoder_embeddings_layer_norm = keras.layers.LayerNormalization(
128
+ axis=-1,
129
+ epsilon=1e-5,
130
+ dtype=dtype,
131
+ name="encoder_embeddings_layer_norm",
132
+ )
133
+ self.encoder_embeddings_dropout = keras.layers.Dropout(
134
+ dropout,
135
+ dtype=dtype,
136
+ name="encoder_embeddings_dropout",
137
+ )
138
+ self.encoder_transformer_layers = []
139
+ for i in range(num_layers):
140
+ layer = TransformerEncoder(
141
+ num_heads=num_heads,
142
+ intermediate_dim=intermediate_dim,
143
+ activation=keras.activations.gelu,
144
+ dropout=dropout,
145
+ layer_norm_epsilon=1e-5,
146
+ kernel_initializer=bart_kernel_initializer(),
147
+ dtype=dtype,
148
+ name=f"transformer_encoder_layer_{i}",
149
+ )
150
+ self.encoder_transformer_layers.append(layer)
151
+ self.decoder_position_embedding = PositionEmbedding(
152
+ initializer=bart_kernel_initializer(),
153
+ sequence_length=max_sequence_length,
154
+ dtype=dtype,
155
+ name="decoder_position_embedding",
156
+ )
157
+ self.decoder_embeddings_add = keras.layers.Add(
158
+ dtype=dtype,
159
+ name="decoder_embeddings_add",
160
+ )
161
+ self.decoder_embeddings_layer_norm = keras.layers.LayerNormalization(
162
+ axis=-1,
163
+ epsilon=1e-5,
164
+ dtype=dtype,
165
+ name="decoder_embeddings_layer_norm",
166
+ )
167
+ self.decoder_embeddings_dropout = keras.layers.Dropout(
168
+ dropout,
169
+ dtype=dtype,
170
+ name="decoder_embeddings_dropout",
171
+ )
172
+ self.decoder_transformer_layers = []
173
+ for i in range(num_layers):
174
+ layer = TransformerDecoder(
175
+ intermediate_dim=intermediate_dim,
176
+ num_heads=num_heads,
177
+ dropout=dropout,
178
+ activation=keras.activations.gelu,
179
+ layer_norm_epsilon=1e-5,
180
+ kernel_initializer=bart_kernel_initializer(),
181
+ dtype=dtype,
182
+ name=f"transformer_decoder_layer_{i}",
183
+ )
184
+ self.decoder_transformer_layers.append(layer)
185
+
186
+ # === Functional Model ===
187
+ encoder_token_id_input = keras.Input(
188
+ shape=(None,), dtype="int32", name="encoder_token_ids"
189
+ )
190
+ encoder_padding_mask_input = keras.Input(
191
+ shape=(None,), dtype="int32", name="encoder_padding_mask"
192
+ )
193
+ decoder_token_id_input = keras.Input(
194
+ shape=(None,), dtype="int32", name="decoder_token_ids"
195
+ )
196
+ decoder_padding_mask_input = keras.Input(
197
+ shape=(None,), dtype="int32", name="decoder_padding_mask"
198
+ )
199
+ # Encoder.
200
+ tokens = self.token_embedding(encoder_token_id_input)
201
+ positions = self.encoder_position_embedding(tokens)
202
+ x = self.encoder_embeddings_add((tokens, positions))
203
+ x = self.encoder_embeddings_layer_norm(x)
204
+ x = self.encoder_embeddings_dropout(x)
205
+ for transformer_layer in self.encoder_transformer_layers:
206
+ x = transformer_layer(x, padding_mask=encoder_padding_mask_input)
207
+ encoder_output = x
208
+ # Decoder.
209
+ tokens = self.token_embedding(decoder_token_id_input)
210
+ positions = self.decoder_position_embedding(tokens)
211
+ x = self.decoder_embeddings_add((tokens, positions))
212
+ x = self.decoder_embeddings_layer_norm(x)
213
+ x = self.decoder_embeddings_dropout(x)
214
+ for transformer_layer in self.decoder_transformer_layers:
215
+ x = transformer_layer(
216
+ decoder_sequence=x,
217
+ encoder_sequence=encoder_output,
218
+ decoder_padding_mask=decoder_padding_mask_input,
219
+ encoder_padding_mask=encoder_padding_mask_input,
220
+ )
221
+ decoder_output = x
222
+ # Instantiate using Functional API Model constructor
223
+ super().__init__(
224
+ inputs={
225
+ "encoder_token_ids": encoder_token_id_input,
226
+ "encoder_padding_mask": encoder_padding_mask_input,
227
+ "decoder_token_ids": decoder_token_id_input,
228
+ "decoder_padding_mask": decoder_padding_mask_input,
229
+ },
230
+ outputs={
231
+ "encoder_sequence_output": encoder_output,
232
+ "decoder_sequence_output": decoder_output,
233
+ },
234
+ dtype=dtype,
235
+ **kwargs,
236
+ )
237
+
238
+ # === Config ===
239
+ self.vocabulary_size = vocabulary_size
240
+ self.num_layers = num_layers
241
+ self.num_heads = num_heads
242
+ self.hidden_dim = hidden_dim
243
+ self.intermediate_dim = intermediate_dim
244
+ self.dropout = dropout
245
+ self.max_sequence_length = max_sequence_length
246
+
247
+ def get_config(self):
248
+ config = super().get_config()
249
+ config.update(
250
+ {
251
+ "vocabulary_size": self.vocabulary_size,
252
+ "num_layers": self.num_layers,
253
+ "num_heads": self.num_heads,
254
+ "hidden_dim": self.hidden_dim,
255
+ "intermediate_dim": self.intermediate_dim,
256
+ "dropout": self.dropout,
257
+ "max_sequence_length": self.max_sequence_length,
258
+ }
259
+ )
260
+
261
+ return config