keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,447 @@
1
+ import math
2
+
3
+ import numpy as np
4
+
5
+ from keras_hub.src.models.efficientnet.efficientnet_backbone import (
6
+ EfficientNetBackbone,
7
+ )
8
+
9
+ backbone_cls = EfficientNetBackbone
10
+
11
+
12
+ VARIANT_MAP = {
13
+ "b0": {
14
+ "stackwise_width_coefficients": [1.0] * 7,
15
+ "stackwise_depth_coefficients": [1.0] * 7,
16
+ "stackwise_squeeze_and_excite_ratios": [0.25] * 7,
17
+ },
18
+ "b1": {
19
+ "stackwise_width_coefficients": [1.0] * 7,
20
+ "stackwise_depth_coefficients": [1.1] * 7,
21
+ "stackwise_squeeze_and_excite_ratios": [0.25] * 7,
22
+ },
23
+ "b2": {
24
+ "stackwise_width_coefficients": [1.1] * 7,
25
+ "stackwise_depth_coefficients": [1.2] * 7,
26
+ "stackwise_squeeze_and_excite_ratios": [0.25] * 7,
27
+ },
28
+ "b3": {
29
+ "stackwise_width_coefficients": [1.2] * 7,
30
+ "stackwise_depth_coefficients": [1.4] * 7,
31
+ "stackwise_squeeze_and_excite_ratios": [0.25] * 7,
32
+ },
33
+ "b4": {
34
+ "stackwise_width_coefficients": [1.4] * 7,
35
+ "stackwise_depth_coefficients": [1.8] * 7,
36
+ "stackwise_squeeze_and_excite_ratios": [0.25] * 7,
37
+ },
38
+ "b5": {
39
+ "stackwise_width_coefficients": [1.6] * 7,
40
+ "stackwise_depth_coefficients": [2.2] * 7,
41
+ "stackwise_squeeze_and_excite_ratios": [0.25] * 7,
42
+ },
43
+ "lite0": {
44
+ "stackwise_width_coefficients": [1.0] * 7,
45
+ "stackwise_depth_coefficients": [1.0] * 7,
46
+ "stackwise_squeeze_and_excite_ratios": [0] * 7,
47
+ "activation": "relu6",
48
+ },
49
+ "el": {
50
+ "stackwise_width_coefficients": [1.2] * 6,
51
+ "stackwise_depth_coefficients": [1.4] * 6,
52
+ "stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5],
53
+ "stackwise_num_repeats": [1, 2, 4, 5, 4, 2],
54
+ "stackwise_input_filters": [32, 24, 32, 48, 96, 144],
55
+ "stackwise_output_filters": [24, 32, 48, 96, 144, 192],
56
+ "stackwise_expansion_ratios": [4, 8, 8, 8, 8, 8],
57
+ "stackwise_strides": [1, 2, 2, 2, 1, 2],
58
+ "stackwise_squeeze_and_excite_ratios": [0] * 6,
59
+ "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
60
+ "stackwise_force_input_filters": [24, 0, 0, 0, 0, 0],
61
+ "stackwise_nores_option": [True] + [False] * 5,
62
+ "activation": "relu",
63
+ },
64
+ "em": {
65
+ "stackwise_width_coefficients": [1.0] * 6,
66
+ "stackwise_depth_coefficients": [1.1] * 6,
67
+ "stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5],
68
+ "stackwise_num_repeats": [1, 2, 4, 5, 4, 2],
69
+ "stackwise_input_filters": [32, 24, 32, 48, 96, 144],
70
+ "stackwise_output_filters": [24, 32, 48, 96, 144, 192],
71
+ "stackwise_expansion_ratios": [4, 8, 8, 8, 8, 8],
72
+ "stackwise_strides": [1, 2, 2, 2, 1, 2],
73
+ "stackwise_squeeze_and_excite_ratios": [0] * 6,
74
+ "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
75
+ "stackwise_force_input_filters": [24, 0, 0, 0, 0, 0],
76
+ "stackwise_nores_option": [True] + [False] * 5,
77
+ "activation": "relu",
78
+ },
79
+ "es": {
80
+ "stackwise_width_coefficients": [1.0] * 6,
81
+ "stackwise_depth_coefficients": [1.0] * 6,
82
+ "stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5],
83
+ "stackwise_num_repeats": [1, 2, 4, 5, 4, 2],
84
+ "stackwise_input_filters": [32, 24, 32, 48, 96, 144],
85
+ "stackwise_output_filters": [24, 32, 48, 96, 144, 192],
86
+ "stackwise_expansion_ratios": [4, 8, 8, 8, 8, 8],
87
+ "stackwise_strides": [1, 2, 2, 2, 1, 2],
88
+ "stackwise_squeeze_and_excite_ratios": [0] * 6,
89
+ "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
90
+ "stackwise_force_input_filters": [24, 0, 0, 0, 0, 0],
91
+ "stackwise_nores_option": [True] + [False] * 5,
92
+ "activation": "relu",
93
+ },
94
+ "rw_m": {
95
+ "stackwise_width_coefficients": [1.2] * 6,
96
+ "stackwise_depth_coefficients": [1.2] * 4 + [1.6] * 2,
97
+ "stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3],
98
+ "stackwise_num_repeats": [2, 4, 4, 6, 9, 15],
99
+ "stackwise_input_filters": [24, 24, 48, 64, 128, 160],
100
+ "stackwise_output_filters": [24, 48, 64, 128, 160, 272],
101
+ "stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6],
102
+ "stackwise_strides": [1, 2, 2, 2, 1, 2],
103
+ "stackwise_squeeze_and_excite_ratios": [0, 0, 0, 0.25, 0.25, 0.25],
104
+ "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
105
+ "stackwise_force_input_filters": [0, 0, 0, 0, 0, 0],
106
+ "stackwise_nores_option": [False] * 6,
107
+ "activation": "silu",
108
+ "num_features": 1792,
109
+ },
110
+ "rw_s": {
111
+ "stackwise_width_coefficients": [1.0] * 6,
112
+ "stackwise_depth_coefficients": [1.0] * 6,
113
+ "stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3],
114
+ "stackwise_num_repeats": [2, 4, 4, 6, 9, 15],
115
+ "stackwise_input_filters": [24, 24, 48, 64, 128, 160],
116
+ "stackwise_output_filters": [24, 48, 64, 128, 160, 272],
117
+ "stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6],
118
+ "stackwise_strides": [1, 2, 2, 2, 1, 2],
119
+ "stackwise_squeeze_and_excite_ratios": [0, 0, 0, 0.25, 0.25, 0.25],
120
+ "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
121
+ "stackwise_force_input_filters": [0, 0, 0, 0, 0, 0],
122
+ "stackwise_nores_option": [False] * 6,
123
+ "activation": "silu",
124
+ "num_features": 1792,
125
+ },
126
+ "rw_t": {
127
+ "stackwise_width_coefficients": [0.8] * 6,
128
+ "stackwise_depth_coefficients": [0.9] * 6,
129
+ "stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3],
130
+ "stackwise_num_repeats": [2, 4, 4, 6, 9, 15],
131
+ "stackwise_input_filters": [24, 24, 48, 64, 128, 160],
132
+ "stackwise_output_filters": [24, 48, 64, 128, 160, 256],
133
+ "stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6],
134
+ "stackwise_strides": [1, 2, 2, 2, 1, 2],
135
+ "stackwise_squeeze_and_excite_ratios": [0, 0, 0, 0.25, 0.25, 0.25],
136
+ "stackwise_block_types": ["cba"] + ["fused"] * 2 + ["unfused"] * 3,
137
+ "stackwise_force_input_filters": [0, 0, 0, 0, 0, 0],
138
+ "stackwise_nores_option": [False] * 6,
139
+ "activation": "silu",
140
+ },
141
+ }
142
+
143
+
144
+ def convert_backbone_config(timm_config):
145
+ timm_architecture = timm_config["architecture"]
146
+
147
+ base_kwargs = {
148
+ "stackwise_kernel_sizes": [3, 3, 5, 3, 5, 5, 3],
149
+ "stackwise_num_repeats": [1, 2, 2, 3, 3, 4, 1],
150
+ "stackwise_input_filters": [32, 16, 24, 40, 80, 112, 192],
151
+ "stackwise_output_filters": [16, 24, 40, 80, 112, 192, 320],
152
+ "stackwise_expansion_ratios": [1, 6, 6, 6, 6, 6, 6],
153
+ "stackwise_strides": [1, 2, 2, 2, 1, 2, 1],
154
+ "stackwise_block_types": ["v1"] * 7,
155
+ "min_depth": None,
156
+ "include_stem_padding": True,
157
+ "use_depth_divisor_as_min_depth": True,
158
+ "cap_round_filter_decrease": True,
159
+ "stem_conv_padding": "valid",
160
+ "batch_norm_momentum": 0.9,
161
+ "batch_norm_epsilon": 1e-5,
162
+ "dropout": 0,
163
+ "projection_activation": None,
164
+ }
165
+
166
+ variant = "_".join(timm_architecture.split("_")[1:])
167
+
168
+ if variant not in VARIANT_MAP:
169
+ raise ValueError(
170
+ f"Currently, the architecture {timm_architecture} is not supported."
171
+ )
172
+
173
+ base_kwargs.update(VARIANT_MAP[variant])
174
+
175
+ return base_kwargs
176
+
177
+
178
+ def convert_weights(backbone, loader, timm_config):
179
+ timm_architecture = timm_config["architecture"]
180
+ variant = "_".join(timm_architecture.split("_")[1:])
181
+
182
+ def port_conv2d(keras_layer, hf_weight_prefix, port_bias=True):
183
+ loader.port_weight(
184
+ keras_layer.kernel,
185
+ hf_weight_key=f"{hf_weight_prefix}.weight",
186
+ hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
187
+ )
188
+
189
+ if port_bias:
190
+ loader.port_weight(
191
+ keras_layer.bias,
192
+ hf_weight_key=f"{hf_weight_prefix}.bias",
193
+ )
194
+
195
+ def port_depthwise_conv2d(
196
+ keras_layer,
197
+ hf_weight_prefix,
198
+ port_bias=True,
199
+ depth_multiplier=1,
200
+ ):
201
+ def convert_pt_conv2d_kernel(pt_kernel):
202
+ out_channels, in_channels_per_group, height, width = pt_kernel.shape
203
+ # PT Convs are depthwise convs if and only if
204
+ # `in_channels_per_group == 1`
205
+ assert in_channels_per_group == 1
206
+ pt_kernel = np.transpose(pt_kernel, (2, 3, 0, 1))
207
+ in_channels = out_channels // depth_multiplier
208
+ return np.reshape(
209
+ pt_kernel, (height, width, in_channels, depth_multiplier)
210
+ )
211
+
212
+ loader.port_weight(
213
+ keras_layer.kernel,
214
+ hf_weight_key=f"{hf_weight_prefix}.weight",
215
+ hook_fn=lambda x, _: convert_pt_conv2d_kernel(x),
216
+ )
217
+
218
+ if port_bias:
219
+ loader.port_weight(
220
+ keras_layer.bias,
221
+ hf_weight_key=f"{hf_weight_prefix}.bias",
222
+ )
223
+
224
+ def port_batch_normalization(keras_layer, hf_weight_prefix):
225
+ loader.port_weight(
226
+ keras_layer.gamma,
227
+ hf_weight_key=f"{hf_weight_prefix}.weight",
228
+ )
229
+ loader.port_weight(
230
+ keras_layer.beta,
231
+ hf_weight_key=f"{hf_weight_prefix}.bias",
232
+ )
233
+ loader.port_weight(
234
+ keras_layer.moving_mean,
235
+ hf_weight_key=f"{hf_weight_prefix}.running_mean",
236
+ )
237
+ loader.port_weight(
238
+ keras_layer.moving_variance,
239
+ hf_weight_key=f"{hf_weight_prefix}.running_var",
240
+ )
241
+ # do we need num batches tracked?
242
+
243
+ # Stem
244
+ port_conv2d(backbone.get_layer("stem_conv"), "conv_stem", port_bias=False)
245
+ port_batch_normalization(backbone.get_layer("stem_bn"), "bn1")
246
+
247
+ # Stages
248
+ num_stacks = len(backbone.stackwise_kernel_sizes)
249
+
250
+ for stack_index in range(num_stacks):
251
+ block_type = backbone.stackwise_block_types[stack_index]
252
+ expansion_ratio = backbone.stackwise_expansion_ratios[stack_index]
253
+ repeats = backbone.stackwise_num_repeats[stack_index]
254
+ stack_depth_coefficient = backbone.stackwise_depth_coefficients[
255
+ stack_index
256
+ ]
257
+
258
+ repeats = int(math.ceil(stack_depth_coefficient * repeats))
259
+
260
+ se_ratio = VARIANT_MAP[variant]["stackwise_squeeze_and_excite_ratios"][
261
+ stack_index
262
+ ]
263
+
264
+ for block_idx in range(repeats):
265
+ conv_pw_count = 0
266
+ bn_count = 1
267
+
268
+ # 97 is the start of the lowercase alphabet.
269
+ letter_identifier = chr(block_idx + 97)
270
+
271
+ keras_block_prefix = f"block{stack_index + 1}{letter_identifier}_"
272
+ hf_block_prefix = f"blocks.{stack_index}.{block_idx}."
273
+
274
+ if block_type == "v1":
275
+ conv_pw_name_map = ["conv_pw", "conv_pwl"]
276
+ # Initial Expansion Conv
277
+ if expansion_ratio != 1:
278
+ port_conv2d(
279
+ backbone.get_layer(keras_block_prefix + "expand_conv"),
280
+ hf_block_prefix + conv_pw_name_map[conv_pw_count],
281
+ port_bias=False,
282
+ )
283
+ conv_pw_count += 1
284
+ port_batch_normalization(
285
+ backbone.get_layer(keras_block_prefix + "expand_bn"),
286
+ hf_block_prefix + f"bn{bn_count}",
287
+ )
288
+ bn_count += 1
289
+
290
+ # Depthwise Conv
291
+ port_depthwise_conv2d(
292
+ backbone.get_layer(keras_block_prefix + "dwconv"),
293
+ hf_block_prefix + "conv_dw",
294
+ port_bias=False,
295
+ )
296
+ port_batch_normalization(
297
+ backbone.get_layer(keras_block_prefix + "dwconv_bn"),
298
+ hf_block_prefix + f"bn{bn_count}",
299
+ )
300
+ bn_count += 1
301
+
302
+ if 0 < se_ratio <= 1:
303
+ # Squeeze and Excite
304
+ port_conv2d(
305
+ backbone.get_layer(keras_block_prefix + "se_reduce"),
306
+ hf_block_prefix + "se.conv_reduce",
307
+ )
308
+ port_conv2d(
309
+ backbone.get_layer(keras_block_prefix + "se_expand"),
310
+ hf_block_prefix + "se.conv_expand",
311
+ )
312
+
313
+ # Output/Projection
314
+ port_conv2d(
315
+ backbone.get_layer(keras_block_prefix + "project"),
316
+ hf_block_prefix + conv_pw_name_map[conv_pw_count],
317
+ port_bias=False,
318
+ )
319
+ conv_pw_count += 1
320
+ port_batch_normalization(
321
+ backbone.get_layer(keras_block_prefix + "project_bn"),
322
+ hf_block_prefix + f"bn{bn_count}",
323
+ )
324
+ bn_count += 1
325
+ elif block_type == "fused":
326
+ fused_block_layer = backbone.get_layer(keras_block_prefix)
327
+
328
+ # Initial Expansion Conv
329
+ port_conv2d(
330
+ fused_block_layer.conv1,
331
+ hf_block_prefix + "conv_exp",
332
+ port_bias=False,
333
+ )
334
+ conv_pw_count += 1
335
+ port_batch_normalization(
336
+ fused_block_layer.bn1,
337
+ hf_block_prefix + f"bn{bn_count}",
338
+ )
339
+ bn_count += 1
340
+
341
+ if 0 < se_ratio <= 1:
342
+ # Squeeze and Excite
343
+ port_conv2d(
344
+ fused_block_layer.se_conv1,
345
+ hf_block_prefix + "se.conv_reduce",
346
+ )
347
+ port_conv2d(
348
+ fused_block_layer.se_conv2,
349
+ hf_block_prefix + "se.conv_expand",
350
+ )
351
+
352
+ # Output/Projection
353
+ port_conv2d(
354
+ fused_block_layer.output_conv,
355
+ hf_block_prefix + "conv_pwl",
356
+ port_bias=False,
357
+ )
358
+ conv_pw_count += 1
359
+ port_batch_normalization(
360
+ fused_block_layer.bn2,
361
+ hf_block_prefix + f"bn{bn_count}",
362
+ )
363
+ bn_count += 1
364
+
365
+ elif block_type == "unfused":
366
+ unfused_block_layer = backbone.get_layer(keras_block_prefix)
367
+ # Initial Expansion Conv
368
+ if expansion_ratio != 1:
369
+ port_conv2d(
370
+ unfused_block_layer.conv1,
371
+ hf_block_prefix + "conv_pw",
372
+ port_bias=False,
373
+ )
374
+ conv_pw_count += 1
375
+ port_batch_normalization(
376
+ unfused_block_layer.bn1,
377
+ hf_block_prefix + f"bn{bn_count}",
378
+ )
379
+ bn_count += 1
380
+
381
+ # Depthwise Conv
382
+ port_depthwise_conv2d(
383
+ unfused_block_layer.depthwise,
384
+ hf_block_prefix + "conv_dw",
385
+ port_bias=False,
386
+ )
387
+ port_batch_normalization(
388
+ unfused_block_layer.bn2,
389
+ hf_block_prefix + f"bn{bn_count}",
390
+ )
391
+ bn_count += 1
392
+
393
+ if 0 < se_ratio <= 1:
394
+ # Squeeze and Excite
395
+ port_conv2d(
396
+ unfused_block_layer.se_conv1,
397
+ hf_block_prefix + "se.conv_reduce",
398
+ )
399
+ port_conv2d(
400
+ unfused_block_layer.se_conv2,
401
+ hf_block_prefix + "se.conv_expand",
402
+ )
403
+
404
+ # Output/Projection
405
+ port_conv2d(
406
+ unfused_block_layer.output_conv,
407
+ hf_block_prefix + "conv_pwl",
408
+ port_bias=False,
409
+ )
410
+ conv_pw_count += 1
411
+ port_batch_normalization(
412
+ unfused_block_layer.bn3,
413
+ hf_block_prefix + f"bn{bn_count}",
414
+ )
415
+ bn_count += 1
416
+ elif block_type == "cba":
417
+ cba_block_layer = backbone.get_layer(keras_block_prefix)
418
+ # Initial Expansion Conv
419
+ port_conv2d(
420
+ cba_block_layer.conv1,
421
+ hf_block_prefix + "conv",
422
+ port_bias=False,
423
+ )
424
+ conv_pw_count += 1
425
+ port_batch_normalization(
426
+ cba_block_layer.bn1,
427
+ hf_block_prefix + f"bn{bn_count}",
428
+ )
429
+ bn_count += 1
430
+
431
+ # Head/Top
432
+ port_conv2d(backbone.get_layer("top_conv"), "conv_head", port_bias=False)
433
+ port_batch_normalization(backbone.get_layer("top_bn"), "bn2")
434
+
435
+
436
+ def convert_head(task, loader, timm_config):
437
+ classifier_prefix = timm_config["pretrained_cfg"]["classifier"]
438
+ prefix = f"{classifier_prefix}."
439
+ loader.port_weight(
440
+ task.output_dense.kernel,
441
+ hf_weight_key=prefix + "weight",
442
+ hook_fn=lambda x, _: np.transpose(np.squeeze(x)),
443
+ )
444
+ loader.port_weight(
445
+ task.output_dense.bias,
446
+ hf_weight_key=prefix + "bias",
447
+ )
@@ -89,7 +89,7 @@ def convert_weights(backbone, loader, timm_config):
89
89
  for block_idx in range(backbone.stackwise_num_blocks[stack_index]):
90
90
  if version == "v1":
91
91
  keras_name = f"stack{stack_index}_block{block_idx}"
92
- hf_name = f"layer{stack_index+1}.{block_idx}"
92
+ hf_name = f"layer{stack_index + 1}.{block_idx}"
93
93
  else:
94
94
  keras_name = f"stack{stack_index}_block{block_idx}"
95
95
  hf_name = f"stages.{stack_index}.blocks.{block_idx}"
@@ -0,0 +1,85 @@
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+
5
+ from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
6
+ from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier
7
+
8
+ backbone_cls = VGGBackbone
9
+
10
+
11
+ REPEATS_BY_SIZE = {
12
+ "vgg11": [1, 1, 2, 2, 2],
13
+ "vgg13": [2, 2, 2, 2, 2],
14
+ "vgg16": [2, 2, 3, 3, 3],
15
+ "vgg19": [2, 2, 4, 4, 4],
16
+ }
17
+
18
+
19
+ def convert_backbone_config(timm_config):
20
+ architecture = timm_config["architecture"]
21
+ stackwise_num_repeats = REPEATS_BY_SIZE[architecture]
22
+ return dict(
23
+ stackwise_num_repeats=stackwise_num_repeats,
24
+ stackwise_num_filters=[64, 128, 256, 512, 512],
25
+ )
26
+
27
+
28
+ def convert_conv2d(
29
+ model,
30
+ loader,
31
+ keras_layer_name: str,
32
+ hf_layer_name: str,
33
+ ):
34
+ loader.port_weight(
35
+ model.get_layer(keras_layer_name).kernel,
36
+ hf_weight_key=f"{hf_layer_name}.weight",
37
+ hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
38
+ )
39
+ loader.port_weight(
40
+ model.get_layer(keras_layer_name).bias,
41
+ hf_weight_key=f"{hf_layer_name}.bias",
42
+ )
43
+
44
+
45
+ def convert_weights(
46
+ backbone: VGGBackbone,
47
+ loader,
48
+ timm_config: dict[Any],
49
+ ):
50
+ architecture = timm_config["architecture"]
51
+ stackwise_num_repeats = REPEATS_BY_SIZE[architecture]
52
+
53
+ hf_index_to_keras_layer_name = {}
54
+ layer_index = 0
55
+ for block_index, repeats_in_block in enumerate(stackwise_num_repeats):
56
+ for repeat_index in range(repeats_in_block):
57
+ hf_index = layer_index
58
+ layer_index += 2 # Conv + activation layers.
59
+ layer_name = f"block{block_index + 1}_conv{repeat_index + 1}"
60
+ hf_index_to_keras_layer_name[hf_index] = layer_name
61
+ layer_index += 1 # Pooling layer after blocks.
62
+
63
+ for hf_index, keras_layer_name in hf_index_to_keras_layer_name.items():
64
+ convert_conv2d(
65
+ backbone, loader, keras_layer_name, f"features.{hf_index}"
66
+ )
67
+
68
+
69
+ def convert_head(
70
+ task: VGGImageClassifier,
71
+ loader,
72
+ timm_config: dict[Any],
73
+ ):
74
+ convert_conv2d(task.head, loader, "fc1", "pre_logits.fc1")
75
+ convert_conv2d(task.head, loader, "fc2", "pre_logits.fc2")
76
+
77
+ loader.port_weight(
78
+ task.head.get_layer("predictions").kernel,
79
+ hf_weight_key="head.fc.weight",
80
+ hook_fn=lambda x, _: np.transpose(np.squeeze(x)),
81
+ )
82
+ loader.port_weight(
83
+ task.head.get_layer("predictions").bias,
84
+ hf_weight_key="head.fc.bias",
85
+ )
@@ -4,7 +4,9 @@ from keras_hub.src.models.image_classifier import ImageClassifier
4
4
  from keras_hub.src.utils.preset_utils import PresetLoader
5
5
  from keras_hub.src.utils.preset_utils import jax_memory_cleanup
6
6
  from keras_hub.src.utils.timm import convert_densenet
7
+ from keras_hub.src.utils.timm import convert_efficientnet
7
8
  from keras_hub.src.utils.timm import convert_resnet
9
+ from keras_hub.src.utils.timm import convert_vgg
8
10
  from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
9
11
 
10
12
 
@@ -14,8 +16,12 @@ class TimmPresetLoader(PresetLoader):
14
16
  architecture = self.config["architecture"]
15
17
  if "resnet" in architecture:
16
18
  self.converter = convert_resnet
17
- if "densenet" in architecture:
19
+ elif "densenet" in architecture:
18
20
  self.converter = convert_densenet
21
+ elif "vgg" in architecture:
22
+ self.converter = convert_vgg
23
+ elif "efficientnet" in architecture:
24
+ self.converter = convert_efficientnet
19
25
  else:
20
26
  raise ValueError(
21
27
  "KerasHub has no converter for timm models "
@@ -52,20 +58,19 @@ class TimmPresetLoader(PresetLoader):
52
58
  pretrained_cfg = self.config.get("pretrained_cfg", None)
53
59
  if not pretrained_cfg or "input_size" not in pretrained_cfg:
54
60
  return None
55
- # This assumes the same basic setup for all timm preprocessing, and that
56
- # all our image conversion will be via a `ResizingImageConverter. We may
61
+ # This assumes the same basic setup for all timm preprocessing, We may
57
62
  # need to extend this as we cover more model types.
58
63
  input_size = pretrained_cfg["input_size"]
59
64
  mean = pretrained_cfg["mean"]
60
- variance = [s**2 for s in pretrained_cfg["std"]]
65
+ std = pretrained_cfg["std"]
66
+ scale = [1.0 / 255.0 / s for s in std]
67
+ offset = [-m / s for m, s in zip(mean, std)]
61
68
  interpolation = pretrained_cfg["interpolation"]
62
69
  if interpolation not in ("bilinear", "nearest", "bicubic"):
63
70
  interpolation = "bilinear" # Unsupported interpolation type.
64
71
  return cls(
65
- width=input_size[1],
66
- height=input_size[2],
67
- scale=1 / 255.0,
68
- mean=mean,
69
- variance=variance,
72
+ image_size=input_size[1:],
73
+ scale=scale,
74
+ offset=offset,
70
75
  interpolation=interpolation,
71
76
  )
@@ -107,10 +107,26 @@ def convert_tokenizer(cls, preset, **kwargs):
107
107
  vocab = tokenizer_config["model"]["vocab"]
108
108
  merges = tokenizer_config["model"]["merges"]
109
109
 
110
- bot = tokenizer_config["added_tokens"][0] # begin of text
111
- eot = tokenizer_config["added_tokens"][1] # end of text
112
-
113
- vocab[bot["content"]] = bot["id"]
114
- vocab[eot["content"]] = eot["id"]
110
+ # Load all special tokens with the exception of "reserved" ones.
111
+ special_tokens = set()
112
+ for token in tokenizer_config["added_tokens"]:
113
+ if not token["content"].startswith("<|reserved_special_token_"):
114
+ vocab[token["content"]] = token["id"]
115
+ special_tokens.add(token["content"])
116
+
117
+ # Load text start and stop tokens from the config.
118
+ # Llama3 uses the <|end_of_text|> end token for regular models
119
+ # but uses <|eot_id|> for instruction-tuned variants.
120
+ tokenizer_config2 = load_json(preset, "tokenizer_config.json")
121
+ bos_token = tokenizer_config2["bos_token"]
122
+ eos_token = tokenizer_config2["eos_token"]
123
+
124
+ kwargs.update(
125
+ {
126
+ "bos_token": bos_token,
127
+ "eos_token": eos_token,
128
+ "misc_special_tokens": special_tokens,
129
+ }
130
+ )
115
131
 
116
132
  return cls(vocabulary=vocab, merges=merges, **kwargs)