keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,373 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+
16
+ from keras_hub.src.models.bart.bart_backbone import BartBackbone
17
+ from keras_hub.src.utils.preset_utils import get_file
18
+
19
+ backbone_cls = BartBackbone
20
+
21
+
22
+ def convert_backbone_config(transformers_config):
23
+ return {
24
+ "vocabulary_size": transformers_config["vocab_size"],
25
+ "num_layers": transformers_config["num_hidden_layers"],
26
+ "num_heads": transformers_config["encoder_attention_heads"],
27
+ "hidden_dim": transformers_config["d_model"],
28
+ "intermediate_dim": transformers_config["encoder_ffn_dim"],
29
+ "dropout": transformers_config["dropout"],
30
+ "max_sequence_length": transformers_config["max_position_embeddings"],
31
+ }
32
+
33
+
34
+ def convert_weights(backbone, loader, transformers_config):
35
+ # Embeddings
36
+ loader.port_weight(
37
+ keras_variable=backbone.token_embedding.embeddings,
38
+ hf_weight_key="shared.weight",
39
+ )
40
+ loader.port_weight(
41
+ keras_variable=backbone.encoder_position_embedding.position_embeddings,
42
+ hf_weight_key="encoder.embed_positions.weight",
43
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
44
+ hf_tensor[2:, :], keras_shape
45
+ ),
46
+ )
47
+ loader.port_weight(
48
+ keras_variable=backbone.decoder_position_embedding.position_embeddings,
49
+ hf_weight_key="decoder.embed_positions.weight",
50
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
51
+ hf_tensor[2:, :], keras_shape
52
+ ),
53
+ )
54
+
55
+ # Encoder blocks
56
+ for index in range(backbone.num_layers):
57
+ encoder_layer = backbone.encoder_transformer_layers[index]
58
+ encoder_self_attention = encoder_layer._self_attention_layer
59
+ hf_encoder_prefix = f"encoder.layers.{index}"
60
+
61
+ # Norm layers
62
+ loader.port_weight(
63
+ keras_variable=encoder_layer._self_attention_layer_norm.gamma,
64
+ hf_weight_key=f"{hf_encoder_prefix}.self_attn_layer_norm.weight",
65
+ )
66
+ loader.port_weight(
67
+ keras_variable=encoder_layer._self_attention_layer_norm.beta,
68
+ hf_weight_key=f"{hf_encoder_prefix}.self_attn_layer_norm.bias",
69
+ )
70
+ loader.port_weight(
71
+ keras_variable=encoder_layer._feedforward_layer_norm.gamma,
72
+ hf_weight_key=f"{hf_encoder_prefix}.final_layer_norm.weight",
73
+ )
74
+ loader.port_weight(
75
+ keras_variable=encoder_layer._feedforward_layer_norm.beta,
76
+ hf_weight_key=f"{hf_encoder_prefix}.final_layer_norm.bias",
77
+ )
78
+
79
+ # Self Attention layers
80
+ # Query
81
+ loader.port_weight(
82
+ keras_variable=encoder_self_attention.query_dense.kernel,
83
+ hf_weight_key=f"{hf_encoder_prefix}.self_attn.q_proj.weight",
84
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
85
+ np.transpose(hf_tensor), keras_shape
86
+ ),
87
+ )
88
+ loader.port_weight(
89
+ keras_variable=encoder_self_attention.query_dense.bias,
90
+ hf_weight_key=f"{hf_encoder_prefix}.self_attn.q_proj.bias",
91
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
92
+ np.transpose(hf_tensor), keras_shape
93
+ ),
94
+ )
95
+
96
+ # Key
97
+ loader.port_weight(
98
+ keras_variable=encoder_self_attention.key_dense.kernel,
99
+ hf_weight_key=f"{hf_encoder_prefix}.self_attn.k_proj.weight",
100
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
101
+ np.transpose(hf_tensor), keras_shape
102
+ ),
103
+ )
104
+ loader.port_weight(
105
+ keras_variable=encoder_self_attention.key_dense.bias,
106
+ hf_weight_key=f"{hf_encoder_prefix}.self_attn.k_proj.bias",
107
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
108
+ np.transpose(hf_tensor), keras_shape
109
+ ),
110
+ )
111
+
112
+ # Value
113
+ loader.port_weight(
114
+ keras_variable=encoder_self_attention.value_dense.kernel,
115
+ hf_weight_key=f"{hf_encoder_prefix}.self_attn.v_proj.weight",
116
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
117
+ np.transpose(hf_tensor), keras_shape
118
+ ),
119
+ )
120
+ loader.port_weight(
121
+ keras_variable=encoder_self_attention.value_dense.bias,
122
+ hf_weight_key=f"{hf_encoder_prefix}.self_attn.v_proj.bias",
123
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
124
+ np.transpose(hf_tensor), keras_shape
125
+ ),
126
+ )
127
+
128
+ # Output
129
+ loader.port_weight(
130
+ keras_variable=encoder_self_attention.output_dense.kernel,
131
+ hf_weight_key=f"{hf_encoder_prefix}.self_attn.out_proj.weight",
132
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
133
+ np.transpose(hf_tensor), keras_shape
134
+ ),
135
+ )
136
+ loader.port_weight(
137
+ keras_variable=encoder_self_attention.output_dense.bias,
138
+ hf_weight_key=f"{hf_encoder_prefix}.self_attn.out_proj.bias",
139
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
140
+ np.transpose(hf_tensor), keras_shape
141
+ ),
142
+ )
143
+
144
+ # MLP layers
145
+ loader.port_weight(
146
+ keras_variable=encoder_layer._feedforward_intermediate_dense.kernel,
147
+ hf_weight_key=f"{hf_encoder_prefix}.fc1.weight",
148
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
149
+ )
150
+ loader.port_weight(
151
+ keras_variable=encoder_layer._feedforward_intermediate_dense.bias,
152
+ hf_weight_key=f"{hf_encoder_prefix}.fc1.bias",
153
+ )
154
+ loader.port_weight(
155
+ keras_variable=encoder_layer._feedforward_output_dense.kernel,
156
+ hf_weight_key=f"{hf_encoder_prefix}.fc2.weight",
157
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
158
+ )
159
+ loader.port_weight(
160
+ keras_variable=encoder_layer._feedforward_output_dense.bias,
161
+ hf_weight_key=f"{hf_encoder_prefix}.fc2.bias",
162
+ )
163
+
164
+ # Decoder blocks
165
+ for index in range(backbone.num_layers):
166
+ decoder_layer = backbone.decoder_transformer_layers[index]
167
+ decoder_self_attention = decoder_layer._self_attention_layer
168
+ decoder_cross_attention = decoder_layer._cross_attention_layer
169
+ hf_decoder_prefix = f"decoder.layers.{index}"
170
+
171
+ # Norm layers
172
+ loader.port_weight(
173
+ keras_variable=decoder_layer._self_attention_layer_norm.gamma,
174
+ hf_weight_key=f"{hf_decoder_prefix}.self_attn_layer_norm.weight",
175
+ )
176
+ loader.port_weight(
177
+ keras_variable=decoder_layer._self_attention_layer_norm.beta,
178
+ hf_weight_key=f"{hf_decoder_prefix}.self_attn_layer_norm.bias",
179
+ )
180
+ loader.port_weight(
181
+ keras_variable=decoder_layer._feedforward_layer_norm.gamma,
182
+ hf_weight_key=f"{hf_decoder_prefix}.final_layer_norm.weight",
183
+ )
184
+ loader.port_weight(
185
+ keras_variable=decoder_layer._feedforward_layer_norm.beta,
186
+ hf_weight_key=f"{hf_decoder_prefix}.final_layer_norm.bias",
187
+ )
188
+ loader.port_weight(
189
+ keras_variable=decoder_layer._cross_attention_layer_norm.gamma,
190
+ hf_weight_key=f"{hf_decoder_prefix}.encoder_attn_layer_norm.weight",
191
+ )
192
+ loader.port_weight(
193
+ keras_variable=decoder_layer._cross_attention_layer_norm.beta,
194
+ hf_weight_key=f"{hf_decoder_prefix}.encoder_attn_layer_norm.bias",
195
+ )
196
+
197
+ # Self Attention layers
198
+ # Query
199
+ loader.port_weight(
200
+ keras_variable=decoder_self_attention.query_dense.kernel,
201
+ hf_weight_key=f"{hf_decoder_prefix}.self_attn.q_proj.weight",
202
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
203
+ np.transpose(hf_tensor), keras_shape
204
+ ),
205
+ )
206
+ loader.port_weight(
207
+ keras_variable=decoder_self_attention.query_dense.bias,
208
+ hf_weight_key=f"{hf_decoder_prefix}.self_attn.q_proj.bias",
209
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
210
+ np.transpose(hf_tensor), keras_shape
211
+ ),
212
+ )
213
+
214
+ # Key
215
+ loader.port_weight(
216
+ keras_variable=decoder_self_attention.key_dense.kernel,
217
+ hf_weight_key=f"{hf_decoder_prefix}.self_attn.k_proj.weight",
218
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
219
+ np.transpose(hf_tensor), keras_shape
220
+ ),
221
+ )
222
+ loader.port_weight(
223
+ keras_variable=decoder_self_attention.key_dense.bias,
224
+ hf_weight_key=f"{hf_decoder_prefix}.self_attn.k_proj.bias",
225
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
226
+ np.transpose(hf_tensor), keras_shape
227
+ ),
228
+ )
229
+
230
+ # Value
231
+ loader.port_weight(
232
+ keras_variable=decoder_self_attention.value_dense.kernel,
233
+ hf_weight_key=f"{hf_decoder_prefix}.self_attn.v_proj.weight",
234
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
235
+ np.transpose(hf_tensor), keras_shape
236
+ ),
237
+ )
238
+ loader.port_weight(
239
+ keras_variable=decoder_self_attention.value_dense.bias,
240
+ hf_weight_key=f"{hf_decoder_prefix}.self_attn.v_proj.bias",
241
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
242
+ np.transpose(hf_tensor), keras_shape
243
+ ),
244
+ )
245
+
246
+ # Output
247
+ loader.port_weight(
248
+ keras_variable=decoder_self_attention.output_dense.kernel,
249
+ hf_weight_key=f"{hf_decoder_prefix}.self_attn.out_proj.weight",
250
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
251
+ np.transpose(hf_tensor), keras_shape
252
+ ),
253
+ )
254
+ loader.port_weight(
255
+ keras_variable=decoder_self_attention.output_dense.bias,
256
+ hf_weight_key=f"{hf_decoder_prefix}.self_attn.out_proj.bias",
257
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
258
+ np.transpose(hf_tensor), keras_shape
259
+ ),
260
+ )
261
+
262
+ # MLP layers
263
+ loader.port_weight(
264
+ keras_variable=decoder_layer._feedforward_intermediate_dense.kernel,
265
+ hf_weight_key=f"{hf_decoder_prefix}.fc1.weight",
266
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
267
+ )
268
+ loader.port_weight(
269
+ keras_variable=decoder_layer._feedforward_intermediate_dense.bias,
270
+ hf_weight_key=f"{hf_decoder_prefix}.fc1.bias",
271
+ )
272
+ loader.port_weight(
273
+ keras_variable=decoder_layer._feedforward_output_dense.kernel,
274
+ hf_weight_key=f"{hf_decoder_prefix}.fc2.weight",
275
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
276
+ )
277
+ loader.port_weight(
278
+ keras_variable=decoder_layer._feedforward_output_dense.bias,
279
+ hf_weight_key=f"{hf_decoder_prefix}.fc2.bias",
280
+ )
281
+
282
+ # Cross Attention Layers
283
+ # Query
284
+ loader.port_weight(
285
+ keras_variable=decoder_cross_attention.query_dense.kernel,
286
+ hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.q_proj.weight",
287
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
288
+ np.transpose(hf_tensor), keras_shape
289
+ ),
290
+ )
291
+ loader.port_weight(
292
+ keras_variable=decoder_cross_attention.query_dense.bias,
293
+ hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.q_proj.bias",
294
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
295
+ np.transpose(hf_tensor), keras_shape
296
+ ),
297
+ )
298
+
299
+ # Key
300
+ loader.port_weight(
301
+ keras_variable=decoder_cross_attention.key_dense.kernel,
302
+ hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.k_proj.weight",
303
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
304
+ np.transpose(hf_tensor), keras_shape
305
+ ),
306
+ )
307
+ loader.port_weight(
308
+ keras_variable=decoder_cross_attention.key_dense.bias,
309
+ hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.k_proj.bias",
310
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
311
+ np.transpose(hf_tensor), keras_shape
312
+ ),
313
+ )
314
+
315
+ # Value
316
+ loader.port_weight(
317
+ keras_variable=decoder_cross_attention.value_dense.kernel,
318
+ hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.v_proj.weight",
319
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
320
+ np.transpose(hf_tensor), keras_shape
321
+ ),
322
+ )
323
+ loader.port_weight(
324
+ keras_variable=decoder_cross_attention.value_dense.bias,
325
+ hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.v_proj.bias",
326
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
327
+ np.transpose(hf_tensor), keras_shape
328
+ ),
329
+ )
330
+
331
+ # Output
332
+ loader.port_weight(
333
+ keras_variable=decoder_cross_attention.output_dense.kernel,
334
+ hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.out_proj.weight",
335
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
336
+ np.transpose(hf_tensor), keras_shape
337
+ ),
338
+ )
339
+ loader.port_weight(
340
+ keras_variable=decoder_cross_attention.output_dense.bias,
341
+ hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.out_proj.bias",
342
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
343
+ np.transpose(hf_tensor), keras_shape
344
+ ),
345
+ )
346
+
347
+ # Normalization
348
+ loader.port_weight(
349
+ keras_variable=backbone.encoder_embeddings_layer_norm.gamma,
350
+ hf_weight_key="encoder.layernorm_embedding.weight",
351
+ )
352
+ loader.port_weight(
353
+ keras_variable=backbone.encoder_embeddings_layer_norm.beta,
354
+ hf_weight_key="encoder.layernorm_embedding.bias",
355
+ )
356
+ loader.port_weight(
357
+ keras_variable=backbone.decoder_embeddings_layer_norm.gamma,
358
+ hf_weight_key="decoder.layernorm_embedding.weight",
359
+ )
360
+ loader.port_weight(
361
+ keras_variable=backbone.decoder_embeddings_layer_norm.beta,
362
+ hf_weight_key="decoder.layernorm_embedding.bias",
363
+ )
364
+
365
+
366
+ def convert_tokenizer(cls, preset, **kwargs):
367
+ vocab_file = get_file(preset, "vocab.json")
368
+ merges_file = get_file(preset, "merges.txt")
369
+ return cls(
370
+ vocabulary=vocab_file,
371
+ merges=merges_file,
372
+ **kwargs,
373
+ )
@@ -13,12 +13,12 @@
13
13
  # limitations under the License.
14
14
  import numpy as np
15
15
 
16
- from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
16
+ from keras_hub.src.models.bert.bert_backbone import BertBackbone
17
17
  from keras_hub.src.utils.preset_utils import HF_TOKENIZER_CONFIG_FILE
18
18
  from keras_hub.src.utils.preset_utils import get_file
19
- from keras_hub.src.utils.preset_utils import jax_memory_cleanup
20
- from keras_hub.src.utils.preset_utils import load_config
21
- from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
19
+ from keras_hub.src.utils.preset_utils import load_json
20
+
21
+ backbone_cls = BertBackbone
22
22
 
23
23
 
24
24
  def convert_backbone_config(transformers_config):
@@ -154,20 +154,10 @@ def convert_weights(backbone, loader, transformers_config):
154
154
  )
155
155
 
156
156
 
157
- def load_bert_backbone(cls, preset, load_weights):
158
- transformers_config = load_config(preset, HF_CONFIG_FILE)
159
- keras_config = convert_backbone_config(transformers_config)
160
- backbone = cls(**keras_config)
161
- if load_weights:
162
- jax_memory_cleanup(backbone)
163
- with SafetensorLoader(preset) as loader:
164
- convert_weights(backbone, loader, transformers_config)
165
- return backbone
166
-
167
-
168
- def load_bert_tokenizer(cls, preset):
169
- transformers_config = load_config(preset, HF_TOKENIZER_CONFIG_FILE)
157
+ def convert_tokenizer(cls, preset, **kwargs):
158
+ transformers_config = load_json(preset, HF_TOKENIZER_CONFIG_FILE)
170
159
  return cls(
171
160
  get_file(preset, "vocab.txt"),
172
161
  lowercase=transformers_config["do_lower_case"],
162
+ **kwargs,
173
163
  )
@@ -13,12 +13,14 @@
13
13
  # limitations under the License.
14
14
  import numpy as np
15
15
 
16
- from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
16
+ from keras_hub.src.models.distil_bert.distil_bert_backbone import (
17
+ DistilBertBackbone,
18
+ )
17
19
  from keras_hub.src.utils.preset_utils import HF_TOKENIZER_CONFIG_FILE
18
20
  from keras_hub.src.utils.preset_utils import get_file
19
- from keras_hub.src.utils.preset_utils import jax_memory_cleanup
20
- from keras_hub.src.utils.preset_utils import load_config
21
- from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
21
+ from keras_hub.src.utils.preset_utils import load_json
22
+
23
+ backbone_cls = DistilBertBackbone
22
24
 
23
25
 
24
26
  def convert_backbone_config(transformers_config):
@@ -33,7 +35,7 @@ def convert_backbone_config(transformers_config):
33
35
  }
34
36
 
35
37
 
36
- def convert_weights(backbone, loader):
38
+ def convert_weights(backbone, loader, transformers_config):
37
39
  # Embeddings
38
40
  loader.port_weight(
39
41
  keras_variable=backbone.get_layer(
@@ -162,23 +164,11 @@ def convert_weights(backbone, loader):
162
164
  hf_weight_key="distilbert.embeddings.LayerNorm.bias",
163
165
  )
164
166
 
165
- return backbone
166
-
167
-
168
- def load_distilbert_backbone(cls, preset, load_weights):
169
- transformers_config = load_config(preset, HF_CONFIG_FILE)
170
- keras_config = convert_backbone_config(transformers_config)
171
- backbone = cls(**keras_config)
172
- if load_weights:
173
- jax_memory_cleanup(backbone)
174
- with SafetensorLoader(preset) as loader:
175
- convert_weights(backbone, loader)
176
- return backbone
177
-
178
167
 
179
- def load_distilbert_tokenizer(cls, preset):
180
- transformers_config = load_config(preset, HF_TOKENIZER_CONFIG_FILE)
168
+ def convert_tokenizer(cls, preset, **kwargs):
169
+ transformers_config = load_json(preset, HF_TOKENIZER_CONFIG_FILE)
181
170
  return cls(
182
171
  get_file(preset, "vocab.txt"),
183
172
  lowercase=transformers_config["do_lower_case"],
173
+ **kwargs,
184
174
  )
@@ -13,11 +13,10 @@
13
13
  # limitations under the License.
14
14
  import numpy as np
15
15
 
16
- from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
16
+ from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
17
17
  from keras_hub.src.utils.preset_utils import get_file
18
- from keras_hub.src.utils.preset_utils import jax_memory_cleanup
19
- from keras_hub.src.utils.preset_utils import load_config
20
- from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
18
+
19
+ backbone_cls = GemmaBackbone
21
20
 
22
21
 
23
22
  def convert_backbone_config(transformers_config):
@@ -169,19 +168,6 @@ def convert_weights(backbone, loader, transformers_config):
169
168
  hf_weight_key="model.norm.weight",
170
169
  )
171
170
 
172
- return backbone
173
-
174
-
175
- def load_gemma_backbone(cls, preset, load_weights):
176
- transformers_config = load_config(preset, HF_CONFIG_FILE)
177
- keras_config = convert_backbone_config(transformers_config)
178
- backbone = cls(**keras_config)
179
- if load_weights:
180
- jax_memory_cleanup(backbone)
181
- with SafetensorLoader(preset) as loader:
182
- convert_weights(backbone, loader, transformers_config)
183
- return backbone
184
-
185
171
 
186
- def load_gemma_tokenizer(cls, preset):
187
- return cls(get_file(preset, "tokenizer.model"))
172
+ def convert_tokenizer(cls, preset, **kwargs):
173
+ return cls(get_file(preset, "tokenizer.model"), **kwargs)
@@ -13,11 +13,10 @@
13
13
  # limitations under the License.
14
14
  import numpy as np
15
15
 
16
- from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
16
+ from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone
17
17
  from keras_hub.src.utils.preset_utils import get_file
18
- from keras_hub.src.utils.preset_utils import jax_memory_cleanup
19
- from keras_hub.src.utils.preset_utils import load_config
20
- from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
18
+
19
+ backbone_cls = GPT2Backbone
21
20
 
22
21
 
23
22
  def convert_backbone_config(transformers_config):
@@ -163,24 +162,12 @@ def convert_weights(backbone, loader, transformers_config):
163
162
  hf_weight_key="ln_f.bias",
164
163
  )
165
164
 
166
- return backbone
167
-
168
-
169
- def load_gpt2_backbone(cls, preset, load_weights):
170
- transformers_config = load_config(preset, HF_CONFIG_FILE)
171
- keras_config = convert_backbone_config(transformers_config)
172
- backbone = cls(**keras_config)
173
- if load_weights:
174
- jax_memory_cleanup(backbone)
175
- with SafetensorLoader(preset) as loader:
176
- convert_weights(backbone, loader, transformers_config)
177
- return backbone
178
-
179
165
 
180
- def load_gpt2_tokenizer(cls, preset):
166
+ def convert_tokenizer(cls, preset, **kwargs):
181
167
  vocab_file = get_file(preset, "vocab.json")
182
168
  merges_file = get_file(preset, "merges.txt")
183
169
  return cls(
184
170
  vocabulary=vocab_file,
185
171
  merges=merges_file,
172
+ **kwargs,
186
173
  )
@@ -13,10 +13,10 @@
13
13
  # limitations under the License.
14
14
  import numpy as np
15
15
 
16
- from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
17
- from keras_hub.src.utils.preset_utils import jax_memory_cleanup
18
- from keras_hub.src.utils.preset_utils import load_config
19
- from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
16
+ from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone
17
+ from keras_hub.src.utils.preset_utils import load_json
18
+
19
+ backbone_cls = Llama3Backbone
20
20
 
21
21
 
22
22
  def convert_backbone_config(transformers_config):
@@ -111,19 +111,8 @@ def convert_weights(backbone, loader, transformers_config):
111
111
  return backbone
112
112
 
113
113
 
114
- def load_llama3_backbone(cls, preset, load_weights):
115
- transformers_config = load_config(preset, HF_CONFIG_FILE)
116
- keras_config = convert_backbone_config(transformers_config)
117
- backbone = cls(**keras_config)
118
- if load_weights:
119
- jax_memory_cleanup(backbone)
120
- with SafetensorLoader(preset) as loader:
121
- convert_weights(backbone, loader, transformers_config)
122
- return backbone
123
-
124
-
125
- def load_llama3_tokenizer(cls, preset):
126
- tokenizer_config = load_config(preset, "tokenizer.json")
114
+ def convert_tokenizer(cls, preset, **kwargs):
115
+ tokenizer_config = load_json(preset, "tokenizer.json")
127
116
  vocab = tokenizer_config["model"]["vocab"]
128
117
  merges = tokenizer_config["model"]["merges"]
129
118
 
@@ -133,4 +122,4 @@ def load_llama3_tokenizer(cls, preset):
133
122
  vocab[bot["content"]] = bot["id"]
134
123
  vocab[eot["content"]] = eot["id"]
135
124
 
136
- return cls(vocabulary=vocab, merges=merges)
125
+ return cls(vocabulary=vocab, merges=merges, **kwargs)