ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (121) hide show
  1. ai_edge_torch/__init__.py +31 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +400 -0
  5. ai_edge_torch/convert/converter.py +202 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
  9. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +311 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +192 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
  27. ai_edge_torch/convert/to_channel_last_io.py +85 -0
  28. ai_edge_torch/debug/__init__.py +17 -0
  29. ai_edge_torch/debug/culprit.py +464 -0
  30. ai_edge_torch/debug/test/__init__.py +14 -0
  31. ai_edge_torch/debug/test/test_culprit.py +133 -0
  32. ai_edge_torch/debug/test/test_search_model.py +50 -0
  33. ai_edge_torch/debug/utils.py +48 -0
  34. ai_edge_torch/experimental/__init__.py +14 -0
  35. ai_edge_torch/generative/__init__.py +14 -0
  36. ai_edge_torch/generative/examples/__init__.py +14 -0
  37. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  39. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  40. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  42. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
  45. ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
  46. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
  47. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
  48. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
  49. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
  50. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
  51. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  52. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
  54. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
  55. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
  56. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
  57. ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
  58. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  59. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  60. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  61. ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
  62. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  63. ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
  64. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
  65. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  66. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  67. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  68. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  69. ai_edge_torch/generative/fx_passes/__init__.py +31 -0
  70. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
  71. ai_edge_torch/generative/layers/__init__.py +14 -0
  72. ai_edge_torch/generative/layers/attention.py +354 -0
  73. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  74. ai_edge_torch/generative/layers/builder.py +131 -0
  75. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  76. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  77. ai_edge_torch/generative/layers/model_config.py +158 -0
  78. ai_edge_torch/generative/layers/normalization.py +62 -0
  79. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  80. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
  81. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  82. ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
  83. ai_edge_torch/generative/layers/unet/builder.py +47 -0
  84. ai_edge_torch/generative/layers/unet/model_config.py +269 -0
  85. ai_edge_torch/generative/quantize/__init__.py +14 -0
  86. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  87. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
  88. ai_edge_torch/generative/quantize/example.py +45 -0
  89. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  90. ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
  91. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  92. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  93. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  94. ai_edge_torch/generative/test/__init__.py +14 -0
  95. ai_edge_torch/generative/test/loader_test.py +80 -0
  96. ai_edge_torch/generative/test/test_model_conversion.py +235 -0
  97. ai_edge_torch/generative/test/test_quantize.py +162 -0
  98. ai_edge_torch/generative/utilities/__init__.py +15 -0
  99. ai_edge_torch/generative/utilities/loader.py +328 -0
  100. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
  101. ai_edge_torch/generative/utilities/t5_loader.py +483 -0
  102. ai_edge_torch/hlfb/__init__.py +16 -0
  103. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  104. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  105. ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
  106. ai_edge_torch/hlfb/test/__init__.py +14 -0
  107. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  108. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  109. ai_edge_torch/model.py +142 -0
  110. ai_edge_torch/quantize/__init__.py +16 -0
  111. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  112. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  113. ai_edge_torch/quantize/quant_config.py +81 -0
  114. ai_edge_torch/testing/__init__.py +14 -0
  115. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  116. ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
  117. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
  118. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
  119. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
  120. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
  121. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
@@ -0,0 +1,162 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ from parameterized import parameterized
19
+ import torch
20
+
21
+ import ai_edge_torch
22
+ from ai_edge_torch.generative.examples.test_models import toy_model # NOQA
23
+ from ai_edge_torch.generative.quantize import quant_recipe
24
+ from ai_edge_torch.generative.quantize import quant_recipe_utils
25
+ from ai_edge_torch.generative.quantize import quant_recipes
26
+ from ai_edge_torch.generative.quantize.quant_attrs import Algorithm
27
+ from ai_edge_torch.generative.quantize.quant_attrs import Dtype
28
+ from ai_edge_torch.generative.quantize.quant_attrs import Granularity
29
+ from ai_edge_torch.generative.quantize.quant_attrs import Mode
30
+ from ai_edge_torch.quantize import quant_config
31
+ from ai_edge_torch.testing import model_coverage
32
+
33
+
34
+ class TestVerifyRecipes(unittest.TestCase):
35
+ """Unit tests that check for model quantization recipes."""
36
+
37
+ @parameterized.expand(
38
+ [
39
+ (Dtype.FP32, Dtype.FP32),
40
+ (Dtype.INT8, Dtype.INT8),
41
+ (Dtype.INT8, Dtype.FP16),
42
+ (Dtype.FP16, Dtype.INT8),
43
+ (Dtype.FP16, Dtype.FP16),
44
+ ]
45
+ )
46
+ def test_verify_invalid_recipes(
47
+ self,
48
+ activation,
49
+ weight,
50
+ ):
51
+ for m in Mode:
52
+ for a in Algorithm:
53
+ for g in Granularity:
54
+ with self.assertRaises(ValueError):
55
+ quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
56
+
57
+ @parameterized.expand(
58
+ [
59
+ (
60
+ Dtype.FP32,
61
+ Dtype.INT8,
62
+ Mode.DYNAMIC_RANGE,
63
+ Algorithm.MIN_MAX,
64
+ Granularity.CHANNELWISE,
65
+ ),
66
+ (
67
+ Dtype.FP32,
68
+ Dtype.INT8,
69
+ Mode.WEIGHT_ONLY,
70
+ Algorithm.MIN_MAX,
71
+ Granularity.CHANNELWISE,
72
+ ),
73
+ (
74
+ Dtype.FP32,
75
+ Dtype.FP16,
76
+ Mode.WEIGHT_ONLY,
77
+ Algorithm.FLOAT_CAST,
78
+ Granularity.NONE,
79
+ ),
80
+ ]
81
+ )
82
+ def test_verify_valid_recipes(
83
+ self,
84
+ activation,
85
+ weight,
86
+ mode,
87
+ algo,
88
+ granularity,
89
+ ):
90
+ quant_recipe.LayerQuantRecipe(activation, weight, mode, algo, granularity).verify()
91
+
92
+
93
+ class TestQuantizeConvert(unittest.TestCase):
94
+ """Test conversion with quantization."""
95
+
96
+ def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
97
+ return quant_config.QuantConfig(
98
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
99
+ attention=quant_recipe_utils.create_layer_quant_int8_dynamic(),
100
+ )
101
+ )
102
+
103
+ def _feedforward_int8_dynamic_recipe() -> quant_config.QuantConfig:
104
+ return quant_config.QuantConfig(
105
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
106
+ feedforward=quant_recipe_utils.create_layer_quant_int8_dynamic(),
107
+ )
108
+ )
109
+
110
+ @parameterized.expand(
111
+ [
112
+ (quant_recipes.full_fp16_recipe(), 0.65),
113
+ (quant_recipes.full_int8_dynamic_recipe(), 0.47),
114
+ (_attention_int8_dynamic_recipe(), 0.89),
115
+ (_feedforward_int8_dynamic_recipe(), 0.72),
116
+ ]
117
+ )
118
+ def test_quantize_convert_toy_sizes(self, quant_config, expected_compression):
119
+ config = toy_model.get_model_config()
120
+ pytorch_model = toy_model.ToySingleLayerModel(config)
121
+ idx = torch.unsqueeze(torch.arange(0, 100), 0)
122
+ input_pos = torch.arange(0, 100)
123
+
124
+ quantized_model = ai_edge_torch.convert(
125
+ pytorch_model, (idx, input_pos), quant_config=quant_config
126
+ )
127
+ float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
128
+ self.assertAlmostEqual(
129
+ len(quantized_model._tflite_model) / len(float_model._tflite_model),
130
+ expected_compression,
131
+ delta=0.01,
132
+ )
133
+
134
+ def test_quantize_convert_compare_toy(self):
135
+ self.skipTest("b/338288901")
136
+ config = toy_model_with_kv_cache.get_model_config()
137
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
138
+ idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
139
+ [10], dtype=torch.int64
140
+ )
141
+
142
+ quant_config = quant_recipes.full_fp16_recipe()
143
+ quantized_model = ai_edge_torch.convert(
144
+ pytorch_model, (idx, input_pos), quant_config=quant_config
145
+ )
146
+ float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
147
+
148
+ self.assertLess(len(quantized_model._tflite_model), len(float_model._tflite_model))
149
+ self.assertTrue(
150
+ model_coverage.compare_tflite_torch(
151
+ quantized_model,
152
+ pytorch_model,
153
+ (idx, input_pos),
154
+ num_valid_inputs=1,
155
+ atol=1e-3,
156
+ rtol=1e-3,
157
+ )
158
+ )
159
+
160
+
161
+ if __name__ == "__main__":
162
+ unittest.main()
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # This module contains common utility functions.
@@ -0,0 +1,328 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Common utility functions for data loading etc.
16
+ from dataclasses import dataclass
17
+ import glob
18
+ import os
19
+ from typing import Callable, Dict, List, Tuple
20
+
21
+ from safetensors import safe_open
22
+ import torch
23
+
24
+ from ai_edge_torch.generative.layers import model_config
25
+
26
+
27
+ def load_safetensors(full_path: str):
28
+ """Loads safetensors into a single state dictionary.
29
+
30
+ Args:
31
+ full_path (string): the directory that contains the safetensor files.
32
+
33
+ Returns:
34
+ A state dictionary contating loaded tensors.
35
+
36
+ Raises:
37
+ ValueError: If no tensors are loaded from the provided directory or file.
38
+ """
39
+ pattern = (
40
+ os.path.join(full_path, "*.safetensors")
41
+ if os.path.isdir(full_path)
42
+ else full_path
43
+ )
44
+ files = []
45
+ for file in glob.glob(pattern):
46
+ files.append(file)
47
+
48
+ tensors = {}
49
+ for file in files:
50
+ with safe_open(file, framework="pt") as fp:
51
+ for k in fp.keys():
52
+ assert k not in tensors
53
+ tensors[k] = fp.get_tensor(k)
54
+
55
+ if not tensors:
56
+ raise ValueError("Failed to load SafeTensors.")
57
+ return tensors
58
+
59
+
60
+ def load_pytorch_statedict(full_path: str):
61
+ """Loads state dictionary binaries into a single state dictionary.
62
+
63
+ Args:
64
+ full_path (string): the directory that contains the bin files.
65
+
66
+ Returns:
67
+ A state dictionary contating loaded tensors.
68
+
69
+ Raises:
70
+ ValueError: If no tensors are loaded from the provided directory or file.
71
+ """
72
+ files = []
73
+ patterns = []
74
+ if os.path.isdir(full_path):
75
+ patterns.append(os.path.join(full_path, "*.bin"))
76
+ patterns.append(os.path.join(full_path, "*.pt"))
77
+ else:
78
+ patterns.append(full_path)
79
+ for pattern in patterns:
80
+ for file in glob.glob(pattern):
81
+ files.append(file)
82
+
83
+ tensors = {}
84
+ for file in files:
85
+ this_file_tensors = torch.load(file)
86
+ for k in this_file_tensors:
87
+ assert k not in tensors
88
+ tensors.update(this_file_tensors)
89
+
90
+ if not tensors:
91
+ raise ValueError("Failed to load torch bin files.")
92
+ return tensors
93
+
94
+
95
+ class ModelLoader:
96
+ """A utility class for loading and converting model checkpoints to the
97
+ Edge Generative API layer format.
98
+ """
99
+
100
+ @dataclass
101
+ class TensorNames:
102
+ attn_query_proj: str = None
103
+ attn_key_proj: str = None
104
+ attn_value_proj: str = None
105
+ attn_fused_qkv_proj: str = None
106
+ attn_output_proj: str = None
107
+
108
+ ff_up_proj: str = None
109
+ ff_down_proj: str = None
110
+ ff_gate_proj: str = None
111
+
112
+ pre_attn_norm: str = None
113
+ pre_ff_norm: str = None
114
+ embedding: str = None
115
+ embedding_position: str = None
116
+ final_norm: str = None
117
+ lm_head: str = None
118
+
119
+ def __init__(self, file_name: str, names: TensorNames) -> None:
120
+ """ModelLoader constructor. Can be used to load multiple models of the same
121
+ type.
122
+
123
+ Args:
124
+ file_name (str): Path to the checkpoint. Can be a directory or an
125
+ exact file.
126
+ names (TensorNames): An instance of `TensorNames` to determine mappings.
127
+ """
128
+ self._file_name = file_name
129
+ self._names = names
130
+ self._loader = self._get_loader()
131
+
132
+ def load(
133
+ self, model: torch.nn.Module, strict: bool = True
134
+ ) -> Tuple[List[str], List[str]]:
135
+ """Load the model from the checkpoint.
136
+
137
+ Args:
138
+ model (torch.nn.Module): The pytorch model that needs to be loaded.
139
+ strict (bool, optional): Whether the converted keys are strictly
140
+ matched. Defaults to True.
141
+
142
+ Returns:
143
+ missing_keys (List[str]): a list of str containing the missing keys.
144
+ unexpected_keys (List[str]): a list of str containing the unexpected keys.
145
+
146
+ Raises:
147
+ ValueError: If conversion results in unmapped tensors and strict mode is
148
+ enabled.
149
+ """
150
+ state = self._loader(self._file_name)
151
+ converted_state = dict()
152
+ if self._names.embedding is not None:
153
+ converted_state["tok_embedding.weight"] = state.pop(
154
+ f"{self._names.embedding}.weight"
155
+ )
156
+ if self._names.embedding_position is not None:
157
+ converted_state["tok_embedding_position"] = state.pop(
158
+ f"{self._names.embedding_position}"
159
+ )
160
+ if self._names.lm_head is not None:
161
+ converted_state["lm_head.weight"] = state.pop(f"{self._names.lm_head}.weight")
162
+ if model.config.lm_head_use_bias:
163
+ converted_state["lm_head.bias"] = state.pop(f"{self._names.lm_head}.bias")
164
+ if self._names.final_norm is not None:
165
+ final_norm_name = self._names.final_norm
166
+ converted_state["final_norm.weight"] = state.pop(f"{final_norm_name}.weight")
167
+ if f"{final_norm_name}.bias" in state:
168
+ converted_state["final_norm.bias"] = state.pop(f"{final_norm_name}.bias")
169
+
170
+ for i in range(model.config.num_layers):
171
+ self._map_norm(i, model.config, state, converted_state)
172
+ self._map_feedforward(i, model.config, state, converted_state)
173
+ self._map_attention(i, model.config, state, converted_state)
174
+
175
+ if strict and state:
176
+ raise ValueError(
177
+ f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
178
+ )
179
+ return model.load_state_dict(converted_state, strict=strict)
180
+
181
+ def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]:
182
+ """A best effort method for finding appropriate state loader.
183
+
184
+ Raises:
185
+ ValueError: If it fails to find an appropriate loader.
186
+
187
+ Returns:
188
+ Callable[[str], Dict[str, torch.Tensor]]: State loader to be used.
189
+ """
190
+ if os.path.isdir(self._file_name):
191
+ if glob.glob(os.path.join(self._file_name, "*.safetensors")):
192
+ return load_safetensors
193
+ if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob(
194
+ os.path.join(self._file_name, "*.pt")
195
+ ):
196
+ return load_pytorch_statedict
197
+
198
+ if self._file_name.endswith(".safetensors"):
199
+ return load_safetensors
200
+
201
+ if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"):
202
+ return load_pytorch_statedict
203
+
204
+ raise ValueError(f"File format not supported.")
205
+
206
+ def _map_feedforward(
207
+ self,
208
+ idx: int,
209
+ config: model_config.ModelConfig,
210
+ state: Dict[str, torch.Tensor],
211
+ converted_state: Dict[str, torch.Tensor],
212
+ ):
213
+ prefix = f"transformer_blocks.{idx}"
214
+ if config.ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
215
+ ff_up_proj_name = self._names.ff_up_proj.format(idx)
216
+ ff_down_proj_name = self._names.ff_down_proj.format(idx)
217
+ converted_state[f"{prefix}.ff.w1.weight"] = state.pop(f"{ff_up_proj_name}.weight")
218
+ converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
219
+ f"{ff_down_proj_name}.weight"
220
+ )
221
+ if config.ff_config.use_bias:
222
+ converted_state[f"{prefix}.ff.w1.bias"] = state.pop(f"{ff_up_proj_name}.bias")
223
+ converted_state[f"{prefix}.ff.w2.bias"] = state.pop(f"{ff_down_proj_name}.bias")
224
+ else:
225
+ ff_up_proj_name = self._names.ff_up_proj.format(idx)
226
+ ff_down_proj_name = self._names.ff_down_proj.format(idx)
227
+ ff_gate_proj_name = self._names.ff_gate_proj.format(idx)
228
+ converted_state[f"{prefix}.ff.w3.weight"] = state.pop(f"{ff_up_proj_name}.weight")
229
+ converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
230
+ f"{ff_down_proj_name}.weight"
231
+ )
232
+ converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
233
+ f"{ff_gate_proj_name}.weight"
234
+ )
235
+ if config.ff_config.use_bias:
236
+ converted_state[f"{prefix}.ff.w3.bias"] = state.pop(f"{ff_up_proj_name}.bias")
237
+ converted_state[f"{prefix}.ff.w2.bias"] = state.pop(f"{ff_down_proj_name}.bias")
238
+ converted_state[f"{prefix}.ff.w1.bias"] = state.pop(f"{ff_gate_proj_name}.bias")
239
+
240
+ def _map_attention(
241
+ self,
242
+ idx: int,
243
+ config: model_config.ModelConfig,
244
+ state: Dict[str, torch.Tensor],
245
+ converted_state: Dict[str, torch.Tensor],
246
+ ):
247
+ prefix = f"transformer_blocks.{idx}"
248
+ if self._names.attn_fused_qkv_proj:
249
+ fused_qkv_name = self._names.attn_fused_qkv_proj.format(idx)
250
+ converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = state.pop(
251
+ f"{fused_qkv_name}.weight"
252
+ )
253
+ else:
254
+ q_name = self._names.attn_query_proj.format(idx)
255
+ k_name = self._names.attn_key_proj.format(idx)
256
+ v_name = self._names.attn_value_proj.format(idx)
257
+ converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv(
258
+ config,
259
+ state.pop(f"{q_name}.weight"),
260
+ state.pop(f"{k_name}.weight"),
261
+ state.pop(f"{v_name}.weight"),
262
+ )
263
+ if config.attn_config.qkv_use_bias:
264
+ if self._names.attn_fused_qkv_proj:
265
+ converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = state.pop(
266
+ f"{fused_qkv_name}.bias"
267
+ )
268
+ else:
269
+ converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
270
+ config,
271
+ state.pop(f"{q_name}.bias"),
272
+ state.pop(f"{k_name}.bias"),
273
+ state.pop(f"{v_name}.bias"),
274
+ )
275
+
276
+ o_name = self._names.attn_output_proj.format(idx)
277
+ converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
278
+ f"{o_name}.weight"
279
+ )
280
+ if config.attn_config.output_proj_use_bias:
281
+ converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
282
+ f"{o_name}.bias"
283
+ )
284
+
285
+ def _map_norm(
286
+ self,
287
+ idx: int,
288
+ config: model_config.ModelConfig,
289
+ state: Dict[str, torch.Tensor],
290
+ converted_state: Dict[str, torch.Tensor],
291
+ ):
292
+ prefix = f"transformer_blocks.{idx}"
293
+ if self._names.pre_attn_norm is not None:
294
+ pre_attn_norm_name = self._names.pre_attn_norm.format(idx)
295
+ converted_state[f"{prefix}.pre_atten_norm.weight"] = state.pop(
296
+ f"{pre_attn_norm_name}.weight"
297
+ )
298
+ if f"{pre_attn_norm_name}.bias" in state:
299
+ converted_state[f"{prefix}.pre_atten_norm.bias"] = state.pop(
300
+ f"{pre_attn_norm_name}.bias"
301
+ )
302
+
303
+ if self._names.pre_ff_norm is not None:
304
+ pre_ff_norm_name = self._names.pre_ff_norm.format(idx)
305
+ converted_state[f"{prefix}.pre_ff_norm.weight"] = state.pop(
306
+ f"{pre_ff_norm_name}.weight"
307
+ )
308
+ if f"{pre_ff_norm_name}.bias" in state:
309
+ converted_state[f"{prefix}.pre_ff_norm.bias"] = state.pop(
310
+ f"{pre_ff_norm_name}.bias"
311
+ )
312
+
313
+ def _fuse_qkv(
314
+ self,
315
+ config: model_config.ModelConfig,
316
+ q: torch.Tensor,
317
+ k: torch.Tensor,
318
+ v: torch.Tensor,
319
+ ) -> torch.Tensor:
320
+ if config.attn_config.qkv_fused_interleaved:
321
+ q_per_kv = config.attn_config.num_heads // config.attn_config.num_query_groups
322
+ qs = torch.split(q, config.head_dim * q_per_kv)
323
+ ks = torch.split(k, config.head_dim)
324
+ vs = torch.split(v, config.head_dim)
325
+ cycled = [t for group in zip(qs, ks, vs) for t in group]
326
+ return torch.cat(cycled)
327
+ else:
328
+ return torch.cat([q, k, v], dim=0)