ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (91) hide show
  1. ai_edge_torch/__init__.py +30 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +330 -0
  5. ai_edge_torch/convert/converter.py +171 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
  9. ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +273 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +171 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/debug/__init__.py +16 -0
  27. ai_edge_torch/debug/culprit.py +423 -0
  28. ai_edge_torch/debug/test/__init__.py +14 -0
  29. ai_edge_torch/debug/test/test_culprit.py +133 -0
  30. ai_edge_torch/debug/utils.py +48 -0
  31. ai_edge_torch/experimental/__init__.py +14 -0
  32. ai_edge_torch/generative/__init__.py +14 -0
  33. ai_edge_torch/generative/examples/__init__.py +14 -0
  34. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  35. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  36. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  37. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  39. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  40. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  42. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  43. ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
  44. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
  46. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  47. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  48. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  49. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  50. ai_edge_torch/generative/layers/__init__.py +14 -0
  51. ai_edge_torch/generative/layers/attention.py +288 -0
  52. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  53. ai_edge_torch/generative/layers/builder.py +103 -0
  54. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  55. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  56. ai_edge_torch/generative/layers/model_config.py +135 -0
  57. ai_edge_torch/generative/layers/normalization.py +62 -0
  58. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  59. ai_edge_torch/generative/quantize/__init__.py +14 -0
  60. ai_edge_torch/generative/quantize/example.py +45 -0
  61. ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
  62. ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
  63. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  64. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  65. ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
  66. ai_edge_torch/generative/test/__init__.py +14 -0
  67. ai_edge_torch/generative/test/test_model_conversion.py +201 -0
  68. ai_edge_torch/generative/test/test_quantize.py +109 -0
  69. ai_edge_torch/generative/utilities/__init__.py +15 -0
  70. ai_edge_torch/generative/utilities/loader.py +290 -0
  71. ai_edge_torch/generative/utilities/t5_loader.py +467 -0
  72. ai_edge_torch/hlfb/__init__.py +16 -0
  73. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  74. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  75. ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
  76. ai_edge_torch/hlfb/test/__init__.py +14 -0
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  79. ai_edge_torch/model.py +134 -0
  80. ai_edge_torch/quantize/__init__.py +16 -0
  81. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  83. ai_edge_torch/quantize/quant_config.py +85 -0
  84. ai_edge_torch/testing/__init__.py +14 -0
  85. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  86. ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
  87. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
  88. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
  89. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
  90. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
  91. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
@@ -0,0 +1,290 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Common utility functions for data loading etc.
16
+ from dataclasses import dataclass
17
+ import glob
18
+ import os
19
+ from typing import Callable, Dict
20
+
21
+ from safetensors import safe_open
22
+ import torch
23
+
24
+ from ai_edge_torch.generative.layers import model_config
25
+
26
+
27
+ def load_safetensors(full_path: str):
28
+ """Loads safetensors into a single state dictionary.
29
+
30
+ Args:
31
+ full_path (string): the directory that contains the safetensor files.
32
+
33
+ Returns:
34
+ A state dictionary contating loaded tensors.
35
+
36
+ Raises:
37
+ ValueError: If no tensors are loaded from the provided directory or file.
38
+ """
39
+ pattern = (
40
+ os.path.join(full_path, "*.safetensors")
41
+ if os.path.isdir(full_path)
42
+ else full_path
43
+ )
44
+ files = []
45
+ for file in glob.glob(pattern):
46
+ files.append(file)
47
+
48
+ tensors = {}
49
+ for file in files:
50
+ with safe_open(file, framework="pt") as fp:
51
+ for k in fp.keys():
52
+ assert k not in tensors
53
+ tensors[k] = fp.get_tensor(k)
54
+
55
+ if not tensors:
56
+ raise ValueError("Failed to load SafeTensors.")
57
+ return tensors
58
+
59
+
60
+ def load_pytorch_statedict(full_path: str):
61
+ """Loads state dictionary binaries into a single state dictionary.
62
+
63
+ Args:
64
+ full_path (string): the directory that contains the bin files.
65
+
66
+ Returns:
67
+ A state dictionary contating loaded tensors.
68
+
69
+ Raises:
70
+ ValueError: If no tensors are loaded from the provided directory or file.
71
+ """
72
+ pattern = os.path.join(full_path, "*.bin") if os.path.isdir(full_path) else full_path
73
+ files = []
74
+ for file in glob.glob(pattern):
75
+ files.append(file)
76
+
77
+ tensors = {}
78
+ for file in files:
79
+ this_file_tensors = torch.load(file)
80
+ for k in this_file_tensors:
81
+ assert k not in tensors
82
+ tensors.update(this_file_tensors)
83
+
84
+ if not tensors:
85
+ raise ValueError("Failed to load torch bin files.")
86
+ return tensors
87
+
88
+
89
+ class ModelLoader:
90
+ """A utility class for loading and converting model checkpoints to the
91
+ Edge Generative API layer format.
92
+ """
93
+
94
+ @dataclass
95
+ class TensorNames:
96
+ attn_query_proj: str
97
+ attn_key_proj: str
98
+ attn_value_proj: str
99
+ attn_output_proj: str
100
+
101
+ ff_up_proj: str
102
+ ff_down_proj: str
103
+ ff_gate_proj: str = None
104
+
105
+ pre_attn_norm: str = None
106
+ pre_ff_norm: str = None
107
+ embedding: str = None
108
+ final_norm: str = None
109
+ lm_head: str = None
110
+
111
+ def __init__(self, file_name: str, names: TensorNames) -> None:
112
+ """ModelLoader constructor. Can be used to load multiple models of the same
113
+ type.
114
+
115
+ Args:
116
+ file_name (str): Path to the checkpoint. Can be a directory or an
117
+ exact file.
118
+ names (TensorNames): An instance of `TensorNames` to determine mappings.
119
+ """
120
+ self._file_name = file_name
121
+ self._names = names
122
+ self._loader = self._get_loader()
123
+
124
+ def load(self, model: torch.nn.Module, strict: bool = True):
125
+ """Load the model from the checkpoint
126
+
127
+ Args:
128
+ model (torch.nn.Module): The pytorch model that needs to be loaded.
129
+ strict (bool, optional): Whether the converted keys are strictly
130
+ matched. Defaults to True.
131
+
132
+ Raises:
133
+ ValueError: If conversion results in unmapped tensors and strict mode is
134
+ enabled.
135
+ """
136
+ state = self._loader(self._file_name)
137
+ converted_state = dict()
138
+ if self._names.embedding is not None:
139
+ converted_state["tok_embedding.weight"] = state.pop(
140
+ f"{self._names.embedding}.weight"
141
+ )
142
+ if self._names.lm_head is not None:
143
+ converted_state["lm_head.weight"] = state.pop(f"{self._names.lm_head}.weight")
144
+ if model.config.lm_head_use_bias:
145
+ converted_state["lm_head.bias"] = state.pop(f"{self._names.lm_head}.bias")
146
+ if self._names.final_norm is not None:
147
+ final_norm_name = self._names.final_norm
148
+ converted_state["final_norm.weight"] = state.pop(f"{final_norm_name}.weight")
149
+ if f"{final_norm_name}.bias" in state:
150
+ converted_state["final_norm.bias"] = state.pop(f"{final_norm_name}.bias")
151
+
152
+ for i in range(model.config.num_layers):
153
+ self._map_norm(i, model.config, state, converted_state)
154
+ self._map_feedforward(i, model.config, state, converted_state)
155
+ self._map_attention(i, model.config, state, converted_state)
156
+
157
+ if strict and state:
158
+ raise ValueError(
159
+ f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
160
+ )
161
+ model.load_state_dict(converted_state, strict=strict)
162
+
163
+ def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]:
164
+ """A best effort method for finding appropriate state loader.
165
+
166
+ Raises:
167
+ ValueError: If it fails to find an appropriate loader.
168
+
169
+ Returns:
170
+ Callable[[str], Dict[str, torch.Tensor]]: State loader to be used.
171
+ """
172
+ if os.path.isdir(self._file_name):
173
+ if glob.glob(os.path.join(self._file_name, "*.safetensors")):
174
+ return load_safetensors
175
+ if glob.glob(os.path.join(self._file_name, "*.bin")):
176
+ return load_pytorch_statedict
177
+
178
+ if self._file_name.endswith(".safetensors"):
179
+ return load_safetensors
180
+
181
+ if self._file_name.endswith(".bin"):
182
+ return load_pytorch_statedict
183
+
184
+ raise ValueError(f"File format not supported.")
185
+
186
+ def _map_feedforward(
187
+ self,
188
+ idx: int,
189
+ config: model_config.ModelConfig,
190
+ state: Dict[str, torch.Tensor],
191
+ converted_state: Dict[str, torch.Tensor],
192
+ ):
193
+ prefix = f"transformer_blocks.{idx}"
194
+ if config.ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
195
+ ff_up_proj_name = self._names.ff_up_proj.format(idx)
196
+ ff_down_proj_name = self._names.ff_down_proj.format(idx)
197
+ converted_state[f"{prefix}.ff.w1.weight"] = state.pop(f"{ff_up_proj_name}.weight")
198
+ converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
199
+ f"{ff_down_proj_name}.weight"
200
+ )
201
+ if config.ff_config.use_bias:
202
+ converted_state[f"{prefix}.ff.w1.bias"] = state.pop(f"{ff_up_proj_name}.bias")
203
+ converted_state[f"{prefix}.ff.w2.bias"] = state.pop(f"{ff_down_proj_name}.bias")
204
+ else:
205
+ ff_up_proj_name = self._names.ff_up_proj.format(idx)
206
+ ff_down_proj_name = self._names.ff_down_proj.format(idx)
207
+ ff_gate_proj_name = self._names.ff_gate_proj.format(idx)
208
+ converted_state[f"{prefix}.ff.w3.weight"] = state.pop(f"{ff_up_proj_name}.weight")
209
+ converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
210
+ f"{ff_down_proj_name}.weight"
211
+ )
212
+ converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
213
+ f"{ff_gate_proj_name}.weight"
214
+ )
215
+ if config.ff_config.use_bias:
216
+ converted_state[f"{prefix}.ff.w3.bias"] = state.pop(f"{ff_up_proj_name}.bias")
217
+ converted_state[f"{prefix}.ff.w2.bias"] = state.pop(f"{ff_down_proj_name}.bias")
218
+ converted_state[f"{prefix}.ff.w1.bias"] = state.pop(f"{ff_gate_proj_name}.bias")
219
+
220
+ def _map_attention(
221
+ self,
222
+ idx: int,
223
+ config: model_config.ModelConfig,
224
+ state: Dict[str, torch.Tensor],
225
+ converted_state: Dict[str, torch.Tensor],
226
+ ):
227
+ prefix = f"transformer_blocks.{idx}"
228
+ q_name = self._names.attn_query_proj.format(idx)
229
+ k_name = self._names.attn_key_proj.format(idx)
230
+ v_name = self._names.attn_value_proj.format(idx)
231
+ converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
232
+ config,
233
+ state.pop(f"{q_name}.weight"),
234
+ state.pop(f"{k_name}.weight"),
235
+ state.pop(f"{v_name}.weight"),
236
+ )
237
+ if config.attn_config.qkv_use_bias:
238
+ converted_state[f"{prefix}.atten_func.attn.bias"] = self._fuse_qkv(
239
+ config,
240
+ state.pop(f"{q_name}.bias"),
241
+ state.pop(f"{k_name}.bias"),
242
+ state.pop(f"{v_name}.bias"),
243
+ )
244
+
245
+ o_name = self._names.attn_output_proj.format(idx)
246
+ converted_state[f"{prefix}.atten_func.proj.weight"] = state.pop(f"{o_name}.weight")
247
+ if config.attn_config.output_proj_use_bias:
248
+ converted_state[f"{prefix}.atten_func.proj.bias"] = state.pop(f"{o_name}.bias")
249
+
250
+ def _map_norm(
251
+ self,
252
+ idx: int,
253
+ config: model_config.ModelConfig,
254
+ state: Dict[str, torch.Tensor],
255
+ converted_state: Dict[str, torch.Tensor],
256
+ ):
257
+ prefix = f"transformer_blocks.{idx}"
258
+ if self._names.pre_attn_norm is not None:
259
+ pre_attn_norm_name = self._names.pre_attn_norm.format(idx)
260
+ converted_state[f"{prefix}.pre_atten_norm.weight"] = state.pop(
261
+ f"{pre_attn_norm_name}.weight"
262
+ )
263
+ if f"{pre_attn_norm_name}.bias" in state:
264
+ converted_state[f"{prefix}.pre_atten_norm.bias"] = state.pop(
265
+ f"{pre_attn_norm_name}.bias"
266
+ )
267
+
268
+ if self._names.pre_ff_norm is not None:
269
+ pre_ff_norm_name = self._names.pre_ff_norm.format(idx)
270
+ converted_state[f"{prefix}.pre_ff_norm.weight"] = state.pop(
271
+ f"{pre_ff_norm_name}.weight"
272
+ )
273
+ if f"{pre_ff_norm_name}.bias" in state:
274
+ converted_state[f"{prefix}.pre_ff_norm.bias"] = state.pop(
275
+ f"{pre_ff_norm_name}.bias"
276
+ )
277
+
278
+ def _fuse_qkv(
279
+ self,
280
+ config: model_config.ModelConfig,
281
+ q: torch.Tensor,
282
+ k: torch.Tensor,
283
+ v: torch.Tensor,
284
+ ) -> torch.Tensor:
285
+ q_per_kv = config.attn_config.num_heads // config.attn_config.num_query_groups
286
+ qs = torch.split(q, config.head_dim * q_per_kv)
287
+ ks = torch.split(k, config.head_dim)
288
+ vs = torch.split(v, config.head_dim)
289
+ cycled = [t for group in zip(qs, ks, vs) for t in group]
290
+ return torch.cat(cycled)