ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl → 0.3.0.dev20240926__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (169) hide show
  1. ai_edge_torch/__init__.py +5 -4
  2. ai_edge_torch/_convert/conversion.py +112 -0
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +94 -48
  5. ai_edge_torch/_convert/fx_passes/__init__.py +22 -0
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +107 -44
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +23 -20
  8. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +5 -6
  9. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -1
  10. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +39 -9
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +17 -8
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +9 -8
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +31 -18
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +2 -2
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +34 -24
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  19. ai_edge_torch/_convert/signature.py +66 -0
  20. ai_edge_torch/_convert/test/test_convert.py +495 -0
  21. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  22. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  23. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -5
  24. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +10 -3
  25. ai_edge_torch/config.py +27 -0
  26. ai_edge_torch/conftest.py +20 -0
  27. ai_edge_torch/debug/culprit.py +72 -40
  28. ai_edge_torch/debug/test/test_culprit.py +7 -5
  29. ai_edge_torch/debug/test/test_search_model.py +8 -7
  30. ai_edge_torch/debug/utils.py +14 -3
  31. ai_edge_torch/fx_pass_base.py +101 -0
  32. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +68 -0
  33. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +68 -0
  34. ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +69 -55
  35. ai_edge_torch/generative/examples/gemma/gemma2.py +267 -0
  36. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  37. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +57 -0
  38. ai_edge_torch/generative/examples/gemma/verify_util.py +143 -0
  39. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +68 -0
  40. ai_edge_torch/generative/examples/openelm/openelm.py +206 -0
  41. ai_edge_torch/generative/examples/openelm/verify.py +64 -0
  42. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  43. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  44. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +68 -0
  45. ai_edge_torch/generative/examples/{phi2 → phi}/phi2.py +70 -51
  46. ai_edge_torch/generative/examples/phi/phi3.py +286 -0
  47. ai_edge_torch/generative/examples/phi/verify.py +65 -0
  48. ai_edge_torch/generative/examples/phi/verify_phi3.py +70 -0
  49. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  50. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +68 -0
  51. ai_edge_torch/generative/examples/smollm/smollm.py +101 -0
  52. ai_edge_torch/generative/examples/smollm/verify.py +62 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  54. ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -13
  55. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +27 -14
  56. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +74 -9
  57. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +179 -37
  58. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  59. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +83 -58
  60. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  61. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  62. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  63. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  64. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  65. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  66. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +28 -25
  67. ai_edge_torch/generative/examples/t5/t5.py +208 -159
  68. ai_edge_torch/generative/examples/t5/t5_attention.py +45 -30
  69. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  70. ai_edge_torch/generative/examples/test_models/toy_model.py +69 -41
  71. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +50 -64
  72. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  73. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +41 -39
  74. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +67 -54
  75. ai_edge_torch/generative/examples/tiny_llama/verify.py +64 -0
  76. ai_edge_torch/generative/fx_passes/__init__.py +4 -5
  77. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +10 -7
  78. ai_edge_torch/generative/layers/attention.py +141 -102
  79. ai_edge_torch/generative/layers/attention_utils.py +53 -12
  80. ai_edge_torch/generative/layers/builder.py +37 -7
  81. ai_edge_torch/generative/layers/feed_forward.py +39 -14
  82. ai_edge_torch/generative/layers/kv_cache.py +162 -50
  83. ai_edge_torch/generative/layers/model_config.py +84 -30
  84. ai_edge_torch/generative/layers/normalization.py +185 -7
  85. ai_edge_torch/generative/layers/rotary_position_embedding.py +6 -4
  86. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +48 -21
  87. ai_edge_torch/generative/layers/unet/blocks_2d.py +136 -77
  88. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  89. ai_edge_torch/generative/layers/unet/model_config.py +17 -15
  90. ai_edge_torch/generative/quantize/example.py +7 -8
  91. ai_edge_torch/generative/quantize/quant_recipe.py +10 -7
  92. ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -1
  93. ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
  94. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  95. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +9 -7
  96. ai_edge_torch/generative/test/test_model_conversion.py +124 -188
  97. ai_edge_torch/generative/test/test_model_conversion_large.py +251 -0
  98. ai_edge_torch/generative/test/test_quantize.py +76 -60
  99. ai_edge_torch/generative/test/utils.py +54 -0
  100. ai_edge_torch/generative/utilities/converter.py +82 -0
  101. ai_edge_torch/generative/utilities/loader.py +120 -57
  102. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +165 -57
  103. ai_edge_torch/generative/utilities/t5_loader.py +110 -81
  104. ai_edge_torch/generative/utilities/verifier.py +247 -0
  105. ai_edge_torch/hlfb/__init__.py +1 -1
  106. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -7
  107. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  108. ai_edge_torch/hlfb/mark_pattern/pattern.py +39 -30
  109. ai_edge_torch/hlfb/test/test_mark_pattern.py +46 -20
  110. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +24 -11
  111. ai_edge_torch/lowertools/__init__.py +18 -0
  112. ai_edge_torch/lowertools/_shim.py +80 -0
  113. ai_edge_torch/lowertools/common_utils.py +142 -0
  114. ai_edge_torch/lowertools/odml_torch_utils.py +255 -0
  115. ai_edge_torch/lowertools/test_utils.py +60 -0
  116. ai_edge_torch/lowertools/torch_xla_utils.py +284 -0
  117. ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +29 -14
  118. ai_edge_torch/model.py +53 -18
  119. ai_edge_torch/odml_torch/__init__.py +20 -0
  120. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  121. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  122. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  123. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  124. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  125. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  126. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  127. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  128. ai_edge_torch/odml_torch/export.py +357 -0
  129. ai_edge_torch/odml_torch/export_utils.py +168 -0
  130. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  131. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +150 -0
  132. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  133. ai_edge_torch/odml_torch/lowerings/__init__.py +25 -0
  134. ai_edge_torch/odml_torch/lowerings/_basic.py +258 -0
  135. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  136. ai_edge_torch/odml_torch/lowerings/_convolution.py +241 -0
  137. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +252 -0
  138. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  139. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  140. ai_edge_torch/odml_torch/lowerings/registry.py +96 -0
  141. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  142. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  143. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  144. ai_edge_torch/quantize/pt2e_quantizer.py +52 -24
  145. ai_edge_torch/quantize/pt2e_quantizer_utils.py +43 -23
  146. ai_edge_torch/quantize/quant_config.py +13 -9
  147. ai_edge_torch/testing/model_coverage/model_coverage.py +29 -16
  148. ai_edge_torch/version.py +16 -0
  149. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/METADATA +7 -3
  150. ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/RECORD +177 -0
  151. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/WHEEL +1 -1
  152. ai_edge_torch/convert/conversion.py +0 -117
  153. ai_edge_torch/convert/conversion_utils.py +0 -400
  154. ai_edge_torch/convert/fx_passes/__init__.py +0 -59
  155. ai_edge_torch/convert/fx_passes/_pass_base.py +0 -49
  156. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +0 -37
  157. ai_edge_torch/convert/test/test_convert.py +0 -311
  158. ai_edge_torch/convert/test/test_convert_composites.py +0 -192
  159. ai_edge_torch/convert/test/test_convert_multisig.py +0 -139
  160. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +0 -66
  161. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -64
  162. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -161
  163. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  164. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +0 -121
  165. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  166. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  167. /ai_edge_torch/generative/examples/{phi2 → openelm}/__init__.py +0 -0
  168. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/LICENSE +0 -0
  169. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/top_level.txt +0 -0
@@ -18,11 +18,10 @@ import glob
18
18
  import os
19
19
  from typing import Callable, Dict, List, Tuple
20
20
 
21
+ from ai_edge_torch.generative.layers import model_config
21
22
  from safetensors import safe_open
22
23
  import torch
23
24
 
24
- from ai_edge_torch.generative.layers import model_config
25
-
26
25
 
27
26
  def load_safetensors(full_path: str):
28
27
  """Loads safetensors into a single state dictionary.
@@ -73,7 +72,7 @@ def load_pytorch_statedict(full_path: str):
73
72
  patterns = []
74
73
  if os.path.isdir(full_path):
75
74
  patterns.append(os.path.join(full_path, "*.bin"))
76
- patterns.append(os.path.join(full_path, "*.pt"))
75
+ patterns.append(os.path.join(full_path, "*pt"))
77
76
  else:
78
77
  patterns.append(full_path)
79
78
  for pattern in patterns:
@@ -93,9 +92,7 @@ def load_pytorch_statedict(full_path: str):
93
92
 
94
93
 
95
94
  class ModelLoader:
96
- """A utility class for loading and converting model checkpoints to the
97
- Edge Generative API layer format.
98
- """
95
+ """Utlity for loading model checkpoints to the Edge Generative API layer."""
99
96
 
100
97
  @dataclass
101
98
  class TensorNames:
@@ -104,25 +101,30 @@ class ModelLoader:
104
101
  attn_value_proj: str = None
105
102
  attn_fused_qkv_proj: str = None
106
103
  attn_output_proj: str = None
104
+ attn_query_norm: str = None
105
+ attn_key_norm: str = None
107
106
 
108
107
  ff_up_proj: str = None
109
108
  ff_down_proj: str = None
110
109
  ff_gate_proj: str = None
111
110
 
112
111
  pre_attn_norm: str = None
112
+ post_attn_norm: str = None
113
113
  pre_ff_norm: str = None
114
+ post_ff_norm: str = None
114
115
  embedding: str = None
115
116
  embedding_position: str = None
116
117
  final_norm: str = None
117
118
  lm_head: str = None
118
119
 
119
120
  def __init__(self, file_name: str, names: TensorNames) -> None:
120
- """ModelLoader constructor. Can be used to load multiple models of the same
121
- type.
121
+ """ModelLoader constructor.
122
+
123
+ Can be used to load multiple models of the same type.
122
124
 
123
125
  Args:
124
- file_name (str): Path to the checkpoint. Can be a directory or an
125
- exact file.
126
+ file_name (str): Path to the checkpoint. Can be a directory or an exact
127
+ file.
126
128
  names (TensorNames): An instance of `TensorNames` to determine mappings.
127
129
  """
128
130
  self._file_name = file_name
@@ -141,13 +143,15 @@ class ModelLoader:
141
143
 
142
144
  Returns:
143
145
  missing_keys (List[str]): a list of str containing the missing keys.
144
- unexpected_keys (List[str]): a list of str containing the unexpected keys.
146
+ unexpected_keys (List[str]): a list of str containing the unexpected
147
+ keys.
145
148
 
146
149
  Raises:
147
150
  ValueError: If conversion results in unmapped tensors and strict mode is
148
151
  enabled.
149
152
  """
150
153
  state = self._loader(self._file_name)
154
+ state = state["model_state_dict"] if "model_state_dict" in state else state
151
155
  converted_state = dict()
152
156
  if self._names.embedding is not None:
153
157
  converted_state["tok_embedding.weight"] = state.pop(
@@ -158,14 +162,22 @@ class ModelLoader:
158
162
  f"{self._names.embedding_position}"
159
163
  )
160
164
  if self._names.lm_head is not None:
161
- converted_state["lm_head.weight"] = state.pop(f"{self._names.lm_head}.weight")
165
+ converted_state["lm_head.weight"] = state.pop(
166
+ f"{self._names.lm_head}.weight"
167
+ )
162
168
  if model.config.lm_head_use_bias:
163
- converted_state["lm_head.bias"] = state.pop(f"{self._names.lm_head}.bias")
169
+ converted_state["lm_head.bias"] = state.pop(
170
+ f"{self._names.lm_head}.bias"
171
+ )
164
172
  if self._names.final_norm is not None:
165
173
  final_norm_name = self._names.final_norm
166
- converted_state["final_norm.weight"] = state.pop(f"{final_norm_name}.weight")
174
+ converted_state["final_norm.weight"] = state.pop(
175
+ f"{final_norm_name}.weight"
176
+ )
167
177
  if f"{final_norm_name}.bias" in state:
168
- converted_state["final_norm.bias"] = state.pop(f"{final_norm_name}.bias")
178
+ converted_state["final_norm.bias"] = state.pop(
179
+ f"{final_norm_name}.bias"
180
+ )
169
181
 
170
182
  for i in range(model.config.num_layers):
171
183
  self._map_norm(i, model.config, state, converted_state)
@@ -191,17 +203,17 @@ class ModelLoader:
191
203
  if glob.glob(os.path.join(self._file_name, "*.safetensors")):
192
204
  return load_safetensors
193
205
  if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob(
194
- os.path.join(self._file_name, "*.pt")
206
+ os.path.join(self._file_name, "*pt")
195
207
  ):
196
208
  return load_pytorch_statedict
197
209
 
198
210
  if self._file_name.endswith(".safetensors"):
199
211
  return load_safetensors
200
212
 
201
- if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"):
213
+ if self._file_name.endswith(".bin") or self._file_name.endswith("pt"):
202
214
  return load_pytorch_statedict
203
215
 
204
- raise ValueError(f"File format not supported.")
216
+ raise ValueError("File format not supported.")
205
217
 
206
218
  def _map_feedforward(
207
219
  self,
@@ -211,31 +223,66 @@ class ModelLoader:
211
223
  converted_state: Dict[str, torch.Tensor],
212
224
  ):
213
225
  prefix = f"transformer_blocks.{idx}"
214
- if config.ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
226
+ ff_config = config.block_config(idx).ff_config
227
+ if ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
215
228
  ff_up_proj_name = self._names.ff_up_proj.format(idx)
216
229
  ff_down_proj_name = self._names.ff_down_proj.format(idx)
217
- converted_state[f"{prefix}.ff.w1.weight"] = state.pop(f"{ff_up_proj_name}.weight")
230
+ converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
231
+ f"{ff_up_proj_name}.weight"
232
+ )
218
233
  converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
219
234
  f"{ff_down_proj_name}.weight"
220
235
  )
221
- if config.ff_config.use_bias:
222
- converted_state[f"{prefix}.ff.w1.bias"] = state.pop(f"{ff_up_proj_name}.bias")
223
- converted_state[f"{prefix}.ff.w2.bias"] = state.pop(f"{ff_down_proj_name}.bias")
236
+ if ff_config.use_bias:
237
+ converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
238
+ f"{ff_up_proj_name}.bias"
239
+ )
240
+ converted_state[f"{prefix}.ff.w2.bias"] = state.pop(
241
+ f"{ff_down_proj_name}.bias"
242
+ )
224
243
  else:
225
244
  ff_up_proj_name = self._names.ff_up_proj.format(idx)
226
245
  ff_down_proj_name = self._names.ff_down_proj.format(idx)
227
246
  ff_gate_proj_name = self._names.ff_gate_proj.format(idx)
228
- converted_state[f"{prefix}.ff.w3.weight"] = state.pop(f"{ff_up_proj_name}.weight")
247
+ converted_state[f"{prefix}.ff.w3.weight"] = state.pop(
248
+ f"{ff_up_proj_name}.weight"
249
+ )
229
250
  converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
230
251
  f"{ff_down_proj_name}.weight"
231
252
  )
232
253
  converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
233
254
  f"{ff_gate_proj_name}.weight"
234
255
  )
235
- if config.ff_config.use_bias:
236
- converted_state[f"{prefix}.ff.w3.bias"] = state.pop(f"{ff_up_proj_name}.bias")
237
- converted_state[f"{prefix}.ff.w2.bias"] = state.pop(f"{ff_down_proj_name}.bias")
238
- converted_state[f"{prefix}.ff.w1.bias"] = state.pop(f"{ff_gate_proj_name}.bias")
256
+ if ff_config.use_bias:
257
+ converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
258
+ f"{ff_up_proj_name}.bias"
259
+ )
260
+ converted_state[f"{prefix}.ff.w2.bias"] = state.pop(
261
+ f"{ff_down_proj_name}.bias"
262
+ )
263
+ converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
264
+ f"{ff_gate_proj_name}.bias"
265
+ )
266
+
267
+ if self._names.pre_ff_norm is not None:
268
+ pre_ff_norm_name = self._names.pre_ff_norm.format(idx)
269
+ converted_state[f"{prefix}.ff.pre_ff_norm.weight"] = state.pop(
270
+ f"{pre_ff_norm_name}.weight"
271
+ )
272
+ if f"{pre_ff_norm_name}.bias" in state:
273
+ converted_state[f"{prefix}.ff.pre_ff_norm.bias"] = state.pop(
274
+ f"{pre_ff_norm_name}.bias"
275
+ )
276
+
277
+ if self._names.post_ff_norm is not None:
278
+ post_ff_norm_name = self._names.post_ff_norm.format(idx)
279
+ converted_state[f"{prefix}.ff.post_ff_norm.weight"] = state.pop(
280
+ f"{post_ff_norm_name}.weight"
281
+ )
282
+ if f"{post_ff_norm_name}.bias" in state:
283
+ converted_state[f"{prefix}.ff.post_ff_norm.bias"] = state.pop(
284
+ f"{post_ff_norm_name}.bias"
285
+ )
239
286
 
240
287
  def _map_attention(
241
288
  self,
@@ -245,6 +292,7 @@ class ModelLoader:
245
292
  converted_state: Dict[str, torch.Tensor],
246
293
  ):
247
294
  prefix = f"transformer_blocks.{idx}"
295
+ attn_config = config.block_config(idx).attn_config
248
296
  if self._names.attn_fused_qkv_proj:
249
297
  fused_qkv_name = self._names.attn_fused_qkv_proj.format(idx)
250
298
  converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = state.pop(
@@ -254,32 +302,47 @@ class ModelLoader:
254
302
  q_name = self._names.attn_query_proj.format(idx)
255
303
  k_name = self._names.attn_key_proj.format(idx)
256
304
  v_name = self._names.attn_value_proj.format(idx)
257
- converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv(
258
- config,
259
- state.pop(f"{q_name}.weight"),
260
- state.pop(f"{k_name}.weight"),
261
- state.pop(f"{v_name}.weight"),
305
+ converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = (
306
+ self._fuse_qkv(
307
+ attn_config,
308
+ state.pop(f"{q_name}.weight"),
309
+ state.pop(f"{k_name}.weight"),
310
+ state.pop(f"{v_name}.weight"),
311
+ )
262
312
  )
263
- if config.attn_config.qkv_use_bias:
313
+ if attn_config.qkv_use_bias:
264
314
  if self._names.attn_fused_qkv_proj:
265
315
  converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = state.pop(
266
316
  f"{fused_qkv_name}.bias"
267
317
  )
268
318
  else:
269
- converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
270
- config,
271
- state.pop(f"{q_name}.bias"),
272
- state.pop(f"{k_name}.bias"),
273
- state.pop(f"{v_name}.bias"),
319
+ converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = (
320
+ self._fuse_qkv(
321
+ attn_config,
322
+ state.pop(f"{q_name}.bias"),
323
+ state.pop(f"{k_name}.bias"),
324
+ state.pop(f"{v_name}.bias"),
325
+ )
274
326
  )
275
327
 
328
+ if self._names.attn_query_norm is not None:
329
+ attn_query_norm_name = self._names.attn_query_norm.format(idx)
330
+ converted_state[f"{prefix}.atten_func.query_norm.weight"] = state.pop(
331
+ f"{attn_query_norm_name}.weight"
332
+ )
333
+ if self._names.attn_key_norm is not None:
334
+ attn_key_norm_name = self._names.attn_key_norm.format(idx)
335
+ converted_state[f"{prefix}.atten_func.key_norm.weight"] = state.pop(
336
+ f"{attn_key_norm_name}.weight"
337
+ )
338
+
276
339
  o_name = self._names.attn_output_proj.format(idx)
277
- converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
278
- f"{o_name}.weight"
340
+ converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
341
+ state.pop(f"{o_name}.weight")
279
342
  )
280
- if config.attn_config.output_proj_use_bias:
281
- converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
282
- f"{o_name}.bias"
343
+ if attn_config.output_proj_use_bias:
344
+ converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
345
+ state.pop(f"{o_name}.bias")
283
346
  )
284
347
 
285
348
  def _map_norm(
@@ -300,28 +363,28 @@ class ModelLoader:
300
363
  f"{pre_attn_norm_name}.bias"
301
364
  )
302
365
 
303
- if self._names.pre_ff_norm is not None:
304
- pre_ff_norm_name = self._names.pre_ff_norm.format(idx)
305
- converted_state[f"{prefix}.pre_ff_norm.weight"] = state.pop(
306
- f"{pre_ff_norm_name}.weight"
366
+ if self._names.post_attn_norm is not None:
367
+ post_attn_norm_name = self._names.post_attn_norm.format(idx)
368
+ converted_state[f"{prefix}.post_atten_norm.weight"] = state.pop(
369
+ f"{post_attn_norm_name}.weight"
307
370
  )
308
- if f"{pre_ff_norm_name}.bias" in state:
309
- converted_state[f"{prefix}.pre_ff_norm.bias"] = state.pop(
310
- f"{pre_ff_norm_name}.bias"
371
+ if f"{post_attn_norm_name}.bias" in state:
372
+ converted_state[f"{prefix}.post_atten_norm.bias"] = state.pop(
373
+ f"{post_attn_norm_name}.bias"
311
374
  )
312
375
 
313
376
  def _fuse_qkv(
314
377
  self,
315
- config: model_config.ModelConfig,
378
+ attn_config: model_config.AttentionConfig,
316
379
  q: torch.Tensor,
317
380
  k: torch.Tensor,
318
381
  v: torch.Tensor,
319
382
  ) -> torch.Tensor:
320
- if config.attn_config.qkv_fused_interleaved:
321
- q_per_kv = config.attn_config.num_heads // config.attn_config.num_query_groups
322
- qs = torch.split(q, config.head_dim * q_per_kv)
323
- ks = torch.split(k, config.head_dim)
324
- vs = torch.split(v, config.head_dim)
383
+ if attn_config.qkv_fused_interleaved:
384
+ q_per_kv = attn_config.num_heads // attn_config.num_query_groups
385
+ qs = torch.split(q, attn_config.head_dim * q_per_kv)
386
+ ks = torch.split(k, attn_config.head_dim)
387
+ vs = torch.split(v, attn_config.head_dim)
325
388
  cycled = [t for group in zip(qs, ks, vs) for t in group]
326
389
  return torch.cat(cycled)
327
390
  else: