ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +31 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +400 -0
- ai_edge_torch/convert/converter.py +202 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +311 -0
- ai_edge_torch/convert/test/test_convert_composites.py +192 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/convert/to_channel_last_io.py +85 -0
- ai_edge_torch/debug/__init__.py +17 -0
- ai_edge_torch/debug/culprit.py +464 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/test/test_search_model.py +50 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/fx_passes/__init__.py +31 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +354 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +131 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +158 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
- ai_edge_torch/generative/layers/unet/builder.py +47 -0
- ai_edge_torch/generative/layers/unet/model_config.py +269 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/loader_test.py +80 -0
- ai_edge_torch/generative/test/test_model_conversion.py +235 -0
- ai_edge_torch/generative/test/test_quantize.py +162 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +328 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
- ai_edge_torch/generative/utilities/t5_loader.py +483 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +142 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +81 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,483 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Common utility functions for data loading etc.
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
import glob
|
|
18
|
+
import os
|
|
19
|
+
from typing import Callable, Dict
|
|
20
|
+
|
|
21
|
+
from safetensors import safe_open
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
from ai_edge_torch.generative.layers import model_config
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def load_safetensors(full_path: str):
|
|
28
|
+
"""Loads safetensors into a single state dictionary.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
full_path (string): the safetensor filename or directory that contains the
|
|
32
|
+
safetensor files.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
A state dictionary contating loaded tensors.
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
ValueError: If no tensors are loaded from the provided directory or file.
|
|
39
|
+
"""
|
|
40
|
+
pattern = (
|
|
41
|
+
os.path.join(full_path, "*.safetensors")
|
|
42
|
+
if os.path.isdir(full_path)
|
|
43
|
+
else full_path
|
|
44
|
+
)
|
|
45
|
+
files = []
|
|
46
|
+
for file in glob.glob(pattern):
|
|
47
|
+
files.append(file)
|
|
48
|
+
|
|
49
|
+
tensors = {}
|
|
50
|
+
for file in files:
|
|
51
|
+
with safe_open(file, framework="pt") as fp:
|
|
52
|
+
for k in fp.keys():
|
|
53
|
+
assert k not in tensors
|
|
54
|
+
tensors[k] = fp.get_tensor(k)
|
|
55
|
+
|
|
56
|
+
if not tensors:
|
|
57
|
+
raise ValueError("Failed to load SafeTensors.")
|
|
58
|
+
return tensors
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def load_pytorch_statedict(full_path: str):
|
|
62
|
+
"""Loads state dictionary binaries into a single state dictionary.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
full_path (string): the bin filename or directory that contains the bin
|
|
66
|
+
files.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
A state dictionary contating loaded tensors.
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
ValueError: If no tensors are loaded from the provided directory or file.
|
|
73
|
+
"""
|
|
74
|
+
pattern = os.path.join(full_path, "*.bin") if os.path.isdir(full_path) else full_path
|
|
75
|
+
files = []
|
|
76
|
+
for file in glob.glob(pattern):
|
|
77
|
+
files.append(file)
|
|
78
|
+
|
|
79
|
+
tensors = {}
|
|
80
|
+
for file in files:
|
|
81
|
+
this_file_tensors = torch.load(file, map_location=torch.device("cpu"))
|
|
82
|
+
for k in this_file_tensors:
|
|
83
|
+
assert k not in tensors
|
|
84
|
+
tensors.update(this_file_tensors)
|
|
85
|
+
|
|
86
|
+
if not tensors:
|
|
87
|
+
raise ValueError("Failed to load torch bin files.")
|
|
88
|
+
return tensors
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class ModelLoader:
|
|
92
|
+
"""A utility class for loading and converting model checkpoints to ODML
|
|
93
|
+
transformer layer format.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
@dataclass
|
|
97
|
+
class TensorNames:
|
|
98
|
+
attn_query_proj: str = None
|
|
99
|
+
attn_key_proj: str = None
|
|
100
|
+
attn_value_proj: str = None
|
|
101
|
+
attn_output_proj: str = None
|
|
102
|
+
relative_attn_bias: str = None
|
|
103
|
+
|
|
104
|
+
cross_attn_query_proj: str = None
|
|
105
|
+
cross_attn_key_proj: str = None
|
|
106
|
+
cross_attn_value_proj: str = None
|
|
107
|
+
cross_attn_output_proj: str = None
|
|
108
|
+
|
|
109
|
+
ff_up_proj: str = None
|
|
110
|
+
ff_down_proj: str = None
|
|
111
|
+
ff_gate_proj: str = None
|
|
112
|
+
|
|
113
|
+
pre_attn_norm: str = None
|
|
114
|
+
pre_cross_attn_norm: str = None
|
|
115
|
+
pre_ff_norm: str = None
|
|
116
|
+
embedding: str = None
|
|
117
|
+
final_norm: str = None
|
|
118
|
+
lm_head: str = None
|
|
119
|
+
|
|
120
|
+
def __init__(self, file_name: str, names: TensorNames) -> None:
|
|
121
|
+
"""ModelLoader constructor. Can be used to load multiple models of the same
|
|
122
|
+
type.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
file_name (str): Path to the checkpoint. Can be a directory or an
|
|
126
|
+
exact file.
|
|
127
|
+
names (TensorNames): An instance of `TensorNames` to determine mappings.
|
|
128
|
+
"""
|
|
129
|
+
self._file_name = file_name
|
|
130
|
+
self._names = names
|
|
131
|
+
self._loader = self._get_loader()
|
|
132
|
+
|
|
133
|
+
def load(
|
|
134
|
+
self, model: torch.nn.Module, strict: bool = True, fuse_attention: bool = True
|
|
135
|
+
):
|
|
136
|
+
"""Load the model from the checkpoint
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
model (torch.nn.Module): The pytorch model that needs to be loaded.
|
|
140
|
+
strict (bool, optional): Whether the converted keys are strictly
|
|
141
|
+
matched. Defaults to True.
|
|
142
|
+
|
|
143
|
+
Raises:
|
|
144
|
+
ValueError: If conversion results in unmapped tensors and strict mode is
|
|
145
|
+
enabled.
|
|
146
|
+
"""
|
|
147
|
+
state = self._loader(self._file_name)
|
|
148
|
+
|
|
149
|
+
if isinstance(self._names, ModelLoader.TensorNames):
|
|
150
|
+
converted_state = self._do_load(
|
|
151
|
+
model, state, self._names, fuse_attention=fuse_attention
|
|
152
|
+
)
|
|
153
|
+
elif isinstance(self._names, dict):
|
|
154
|
+
converted_state = {}
|
|
155
|
+
for additional_prefix, names in self._names.items():
|
|
156
|
+
local_converted_state = self._do_load(
|
|
157
|
+
model,
|
|
158
|
+
state,
|
|
159
|
+
self._names[additional_prefix],
|
|
160
|
+
additional_prefix,
|
|
161
|
+
fuse_attention=fuse_attention,
|
|
162
|
+
)
|
|
163
|
+
converted_state.update(local_converted_state)
|
|
164
|
+
else:
|
|
165
|
+
raise ValueError(f"Unkown type for names: {type(self._names)}")
|
|
166
|
+
|
|
167
|
+
if strict and state:
|
|
168
|
+
raise ValueError(
|
|
169
|
+
f"Failed to map all tensor. Remaining tensor are: {list(state.keys())}"
|
|
170
|
+
)
|
|
171
|
+
model.load_state_dict(converted_state, strict=strict)
|
|
172
|
+
|
|
173
|
+
def _do_load(self, model, state, names, additional_prefix="", fuse_attention=True):
|
|
174
|
+
"""Load the model from the checkpoint
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
model (torch.nn.Module): The pytorch model that needs to be loaded.
|
|
178
|
+
state (Dict[str, torch.Tensor]): The pytorch state dictionary
|
|
179
|
+
names (TensorNames]): The TensorNames for the model we are loading.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Dict[str, torch.Tensor]: Map of name to tensor for loading.
|
|
183
|
+
"""
|
|
184
|
+
converted_state = dict()
|
|
185
|
+
if names.embedding is not None:
|
|
186
|
+
converted_state["tok_embedding.weight"] = state.pop(f"{names.embedding}.weight")
|
|
187
|
+
if names.lm_head is not None:
|
|
188
|
+
converted_state["lm_head.weight"] = state.pop(f"{names.lm_head}.weight")
|
|
189
|
+
if model.config.lm_head_use_bias:
|
|
190
|
+
converted_state["lm_head.bias"] = state.pop(f"{names.lm_head}.bias")
|
|
191
|
+
if names.final_norm is not None:
|
|
192
|
+
final_norm_name = names.final_norm
|
|
193
|
+
prefix = additional_prefix
|
|
194
|
+
converted_state[f"{prefix}final_norm.weight"] = state.pop(
|
|
195
|
+
f"{final_norm_name}.weight"
|
|
196
|
+
)
|
|
197
|
+
if f"{final_norm_name}.bias" in state:
|
|
198
|
+
converted_state["final_norm.bias"] = state.pop(f"{final_norm_name}.bias")
|
|
199
|
+
|
|
200
|
+
if names.relative_attn_bias:
|
|
201
|
+
rel_attn_name = names.relative_attn_bias
|
|
202
|
+
prefix = additional_prefix + f"transformer_blocks.0"
|
|
203
|
+
converted_state[f"{prefix}.atten_func.relative_attention_bias.weight"] = (
|
|
204
|
+
state.pop(f"{rel_attn_name}.weight")
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
for i in range(model.config.num_layers):
|
|
208
|
+
self._map_norm(i, model.config, state, converted_state, names, additional_prefix)
|
|
209
|
+
self._map_feedforward(
|
|
210
|
+
i, model.config, state, converted_state, names, additional_prefix
|
|
211
|
+
)
|
|
212
|
+
self._map_attention(
|
|
213
|
+
i,
|
|
214
|
+
model.config,
|
|
215
|
+
state,
|
|
216
|
+
converted_state,
|
|
217
|
+
names,
|
|
218
|
+
additional_prefix,
|
|
219
|
+
fuse_attention,
|
|
220
|
+
)
|
|
221
|
+
self._map_cross_attention(
|
|
222
|
+
i,
|
|
223
|
+
model.config,
|
|
224
|
+
state,
|
|
225
|
+
converted_state,
|
|
226
|
+
names,
|
|
227
|
+
additional_prefix,
|
|
228
|
+
fuse_attention,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
return converted_state
|
|
232
|
+
|
|
233
|
+
def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]:
|
|
234
|
+
"""A best effort method for finding appropriate state loader.
|
|
235
|
+
|
|
236
|
+
Raises:
|
|
237
|
+
ValueError: If it fails to find an appropriate loader.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
Callable[[str], Dict[str, torch.Tensor]]: State loader to be used.
|
|
241
|
+
"""
|
|
242
|
+
if os.path.isdir(self._file_name):
|
|
243
|
+
if glob.glob(os.path.join(self._file_name, "*.safetensors")):
|
|
244
|
+
return load_safetensors
|
|
245
|
+
if glob.glob(os.path.join(self._file_name, "*.bin")):
|
|
246
|
+
return load_pytorch_statedict
|
|
247
|
+
|
|
248
|
+
if self._file_name.endswith(".safetensors"):
|
|
249
|
+
return load_safetensors
|
|
250
|
+
|
|
251
|
+
if self._file_name.endswith(".bin"):
|
|
252
|
+
return load_pytorch_statedict
|
|
253
|
+
|
|
254
|
+
raise ValueError(f"File format not supported.")
|
|
255
|
+
|
|
256
|
+
def _map_feedforward(
|
|
257
|
+
self,
|
|
258
|
+
idx: int,
|
|
259
|
+
config: model_config.ModelConfig,
|
|
260
|
+
state: Dict[str, torch.Tensor],
|
|
261
|
+
converted_state: Dict[str, torch.Tensor],
|
|
262
|
+
names: TensorNames,
|
|
263
|
+
additional_prefix: str = "",
|
|
264
|
+
):
|
|
265
|
+
prefix = additional_prefix + f"transformer_blocks.{idx}"
|
|
266
|
+
if names.ff_up_proj is None or names.ff_down_proj is None:
|
|
267
|
+
return
|
|
268
|
+
if config.ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
|
|
269
|
+
ff_up_proj_name = names.ff_up_proj.format(idx)
|
|
270
|
+
ff_down_proj_name = names.ff_down_proj.format(idx)
|
|
271
|
+
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(f"{ff_up_proj_name}.weight")
|
|
272
|
+
converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
|
|
273
|
+
f"{ff_down_proj_name}.weight"
|
|
274
|
+
)
|
|
275
|
+
if config.ff_config.use_bias:
|
|
276
|
+
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(f"{ff_up_proj_name}.bias")
|
|
277
|
+
converted_state[f"{prefix}.ff.w2.bias"] = state.pop(f"{ff_down_proj_name}.bias")
|
|
278
|
+
else:
|
|
279
|
+
if names.ff_gate_proj is not None:
|
|
280
|
+
ff_up_proj_name = names.ff_up_proj.format(idx)
|
|
281
|
+
ff_down_proj_name = names.ff_down_proj.format(idx)
|
|
282
|
+
ff_gate_proj_name = names.ff_gate_proj.format(idx)
|
|
283
|
+
converted_state[f"{prefix}.ff.w3.weight"] = state.pop(
|
|
284
|
+
f"{ff_up_proj_name}.weight"
|
|
285
|
+
)
|
|
286
|
+
converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
|
|
287
|
+
f"{ff_down_proj_name}.weight"
|
|
288
|
+
)
|
|
289
|
+
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
|
290
|
+
f"{ff_gate_proj_name}.weight"
|
|
291
|
+
)
|
|
292
|
+
if config.ff_config.use_bias:
|
|
293
|
+
converted_state[f"{prefix}.ff.w3.bias"] = state.pop(f"{ff_up_proj_name}.bias")
|
|
294
|
+
converted_state[f"{prefix}.ff.w2.bias"] = state.pop(
|
|
295
|
+
f"{ff_down_proj_name}.bias"
|
|
296
|
+
)
|
|
297
|
+
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
|
|
298
|
+
f"{ff_gate_proj_name}.bias"
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
def _map_attention(
|
|
302
|
+
self,
|
|
303
|
+
idx: int,
|
|
304
|
+
config: model_config.ModelConfig,
|
|
305
|
+
state: Dict[str, torch.Tensor],
|
|
306
|
+
converted_state: Dict[str, torch.Tensor],
|
|
307
|
+
names: TensorNames,
|
|
308
|
+
additional_prefix: str = "",
|
|
309
|
+
fuse_attention: bool = True,
|
|
310
|
+
):
|
|
311
|
+
if (
|
|
312
|
+
names.attn_query_proj is None
|
|
313
|
+
or names.attn_key_proj is None
|
|
314
|
+
or names.attn_value_proj is None
|
|
315
|
+
):
|
|
316
|
+
return
|
|
317
|
+
prefix = additional_prefix + f"transformer_blocks.{idx}"
|
|
318
|
+
q_name = names.attn_query_proj.format(idx)
|
|
319
|
+
k_name = names.attn_key_proj.format(idx)
|
|
320
|
+
v_name = names.attn_value_proj.format(idx)
|
|
321
|
+
# model.encoder.transformer_blocks[0].atten_func.q_projection.weight
|
|
322
|
+
if fuse_attention:
|
|
323
|
+
converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
|
|
324
|
+
config,
|
|
325
|
+
state.pop(f"{q_name}.weight"),
|
|
326
|
+
state.pop(f"{k_name}.weight"),
|
|
327
|
+
state.pop(f"{v_name}.weight"),
|
|
328
|
+
)
|
|
329
|
+
if config.attn_config.qkv_use_bias:
|
|
330
|
+
converted_state[f"{prefix}.atten_func.attn.bias"] = self._fuse_qkv(
|
|
331
|
+
config,
|
|
332
|
+
state.pop(f"{q_name}.bias"),
|
|
333
|
+
state.pop(f"{k_name}.bias"),
|
|
334
|
+
state.pop(f"{v_name}.bias"),
|
|
335
|
+
)
|
|
336
|
+
else:
|
|
337
|
+
converted_state[f"{prefix}.atten_func.q_projection.weight"] = state.pop(
|
|
338
|
+
f"{q_name}.weight"
|
|
339
|
+
)
|
|
340
|
+
converted_state[f"{prefix}.atten_func.k_projection.weight"] = state.pop(
|
|
341
|
+
f"{k_name}.weight"
|
|
342
|
+
)
|
|
343
|
+
converted_state[f"{prefix}.atten_func.v_projection.weight"] = state.pop(
|
|
344
|
+
f"{v_name}.weight"
|
|
345
|
+
)
|
|
346
|
+
if config.attn_config.qkv_use_bias:
|
|
347
|
+
converted_state[f"{prefix}.atten_func.q_projection.bias"] = state.pop(
|
|
348
|
+
f"{q_name}.bias"
|
|
349
|
+
)
|
|
350
|
+
converted_state[f"{prefix}.atten_func.k_projection.bias"] = state.pop(
|
|
351
|
+
f"{k_name}.bias"
|
|
352
|
+
)
|
|
353
|
+
converted_state[f"{prefix}.atten_func.v_projection.bias"] = state.pop(
|
|
354
|
+
f"{v_name}.bias"
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
o_name = names.attn_output_proj.format(idx)
|
|
358
|
+
converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
|
|
359
|
+
f"{o_name}.weight"
|
|
360
|
+
)
|
|
361
|
+
if config.attn_config.output_proj_use_bias:
|
|
362
|
+
converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
|
|
363
|
+
f"{o_name}.bias"
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
def _map_cross_attention(
|
|
367
|
+
self,
|
|
368
|
+
idx: int,
|
|
369
|
+
config: model_config.ModelConfig,
|
|
370
|
+
state: Dict[str, torch.Tensor],
|
|
371
|
+
converted_state: Dict[str, torch.Tensor],
|
|
372
|
+
names: TensorNames,
|
|
373
|
+
additional_prefix: str = "",
|
|
374
|
+
fuse_attention: bool = True,
|
|
375
|
+
):
|
|
376
|
+
if (
|
|
377
|
+
names.cross_attn_query_proj is None
|
|
378
|
+
or names.cross_attn_key_proj is None
|
|
379
|
+
or names.cross_attn_value_proj is None
|
|
380
|
+
):
|
|
381
|
+
return
|
|
382
|
+
prefix = additional_prefix + f"transformer_blocks.{idx}"
|
|
383
|
+
q_name = names.cross_attn_query_proj.format(idx)
|
|
384
|
+
k_name = names.cross_attn_key_proj.format(idx)
|
|
385
|
+
v_name = names.cross_attn_value_proj.format(idx)
|
|
386
|
+
|
|
387
|
+
if fuse_attention:
|
|
388
|
+
converted_state[f"{prefix}.cross_atten_func.attn.weight"] = self._fuse_qkv(
|
|
389
|
+
config,
|
|
390
|
+
state.pop(f"{q_name}.weight"),
|
|
391
|
+
state.pop(f"{k_name}.weight"),
|
|
392
|
+
state.pop(f"{v_name}.weight"),
|
|
393
|
+
)
|
|
394
|
+
if config.attn_config.qkv_use_bias:
|
|
395
|
+
converted_state[f"{prefix}.cross_atten_func.attn.bias"] = self._fuse_qkv(
|
|
396
|
+
config,
|
|
397
|
+
state.pop(f"{q_name}.bias"),
|
|
398
|
+
state.pop(f"{k_name}.bias"),
|
|
399
|
+
state.pop(f"{v_name}.bias"),
|
|
400
|
+
)
|
|
401
|
+
else:
|
|
402
|
+
converted_state[f"{prefix}.cross_atten_func.q_projection.weight"] = state.pop(
|
|
403
|
+
f"{q_name}.weight"
|
|
404
|
+
)
|
|
405
|
+
converted_state[f"{prefix}.cross_atten_func.k_projection.weight"] = state.pop(
|
|
406
|
+
f"{k_name}.weight"
|
|
407
|
+
)
|
|
408
|
+
converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = state.pop(
|
|
409
|
+
f"{v_name}.weight"
|
|
410
|
+
)
|
|
411
|
+
if config.attn_config.qkv_use_bias:
|
|
412
|
+
converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = state.pop(
|
|
413
|
+
f"{q_name}.bias"
|
|
414
|
+
)
|
|
415
|
+
converted_state[f"{prefix}.cross_atten_func.k_projection.bias"] = state.pop(
|
|
416
|
+
f"{k_name}.bias"
|
|
417
|
+
)
|
|
418
|
+
converted_state[f"{prefix}.cross_atten_func.v_projection.bias"] = state.pop(
|
|
419
|
+
f"{v_name}.bias"
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
o_name = names.cross_attn_output_proj.format(idx)
|
|
423
|
+
converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = state.pop(
|
|
424
|
+
f"{o_name}.weight"
|
|
425
|
+
)
|
|
426
|
+
if config.attn_config.output_proj_use_bias:
|
|
427
|
+
converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = state.pop(
|
|
428
|
+
f"{o_name}.bias"
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
def _map_norm(
|
|
432
|
+
self,
|
|
433
|
+
idx: int,
|
|
434
|
+
config: model_config.ModelConfig,
|
|
435
|
+
state: Dict[str, torch.Tensor],
|
|
436
|
+
converted_state: Dict[str, torch.Tensor],
|
|
437
|
+
names: TensorNames,
|
|
438
|
+
additional_prefix: str = "",
|
|
439
|
+
):
|
|
440
|
+
prefix = additional_prefix + f"transformer_blocks.{idx}"
|
|
441
|
+
if names.pre_attn_norm is not None:
|
|
442
|
+
pre_attn_norm_name = names.pre_attn_norm.format(idx)
|
|
443
|
+
converted_state[f"{prefix}.atten_func.pre_atten_norm.weight"] = state.pop(
|
|
444
|
+
f"{pre_attn_norm_name}.weight"
|
|
445
|
+
)
|
|
446
|
+
if f"{pre_attn_norm_name}.bias" in state:
|
|
447
|
+
converted_state[f"{prefix}.atten_func.pre_atten_norm.bias"] = state.pop(
|
|
448
|
+
f"{pre_attn_norm_name}.bias"
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
if names.pre_cross_attn_norm:
|
|
452
|
+
pre_cross_attn_norm_name = names.pre_cross_attn_norm.format(idx)
|
|
453
|
+
converted_state[f"{prefix}.cross_atten_func.pre_atten_norm.weight"] = state.pop(
|
|
454
|
+
f"{pre_cross_attn_norm_name}.weight"
|
|
455
|
+
)
|
|
456
|
+
if f"{pre_cross_attn_norm_name}.bias" in state:
|
|
457
|
+
converted_state[f"{prefix}.cross_atten_func.pre_atten_norm.bias"] = state.pop(
|
|
458
|
+
f"{pre_cross_attn_norm_name}.bias"
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
if names.pre_ff_norm is not None:
|
|
462
|
+
pre_ff_norm_name = names.pre_ff_norm.format(idx)
|
|
463
|
+
converted_state[f"{prefix}.pre_ff_norm.weight"] = state.pop(
|
|
464
|
+
f"{pre_ff_norm_name}.weight"
|
|
465
|
+
)
|
|
466
|
+
if f"{pre_ff_norm_name}.bias" in state:
|
|
467
|
+
converted_state[f"{prefix}.pre_ff_norm.bias"] = state.pop(
|
|
468
|
+
f"{pre_ff_norm_name}.bias"
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
def _fuse_qkv(
|
|
472
|
+
self,
|
|
473
|
+
config: model_config.ModelConfig,
|
|
474
|
+
q: torch.Tensor,
|
|
475
|
+
k: torch.Tensor,
|
|
476
|
+
v: torch.Tensor,
|
|
477
|
+
) -> torch.Tensor:
|
|
478
|
+
q_per_kv = config.attn_config.num_heads // config.attn_config.num_query_groups
|
|
479
|
+
qs = torch.split(q, config.head_dim * q_per_kv)
|
|
480
|
+
ks = torch.split(k, config.head_dim)
|
|
481
|
+
vs = torch.split(v, config.head_dim)
|
|
482
|
+
cycled = [t for group in zip(qs, ks, vs) for t in group]
|
|
483
|
+
return torch.cat(cycled)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import copy
|
|
16
|
+
from typing import Any
|
|
17
|
+
import uuid
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from torch_xla.experimental import xla_marker
|
|
21
|
+
|
|
22
|
+
from ai_edge_torch.hlfb.mark_pattern.pattern import Pattern
|
|
23
|
+
from ai_edge_torch.hlfb.mark_pattern.pattern import ScalarAttrTracker # NOQA
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@torch._dynamo.assume_constant_result
|
|
27
|
+
def _get_uuid() -> str:
|
|
28
|
+
return uuid.uuid4().hex
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# TODO: Move to a general fx utils file.
|
|
32
|
+
def _prepose_placeholder_nodes(graph: torch.fx.Graph):
|
|
33
|
+
nodes = [node for node in graph.nodes if node.op == "placeholder"] + [
|
|
34
|
+
node for node in graph.nodes if node.op != "placeholder"
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
for a, b in zip(nodes, nodes[1:]):
|
|
38
|
+
if a.next is not b:
|
|
39
|
+
a.append(b)
|
|
40
|
+
return graph
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _insert_marker(
|
|
44
|
+
graph_module: torch.fx.GraphModule,
|
|
45
|
+
node: torch.fx.Node,
|
|
46
|
+
name: str,
|
|
47
|
+
pos: int,
|
|
48
|
+
id: str,
|
|
49
|
+
is_input: bool,
|
|
50
|
+
attr: dict[str, Any] = None,
|
|
51
|
+
):
|
|
52
|
+
attr = xla_marker.serialize_composite_attr(attr) if attr else None
|
|
53
|
+
with graph_module.graph.inserting_after(node):
|
|
54
|
+
new_node = graph_module.graph.call_function(
|
|
55
|
+
torch.ops.xla.mark_tensor,
|
|
56
|
+
args=(node,),
|
|
57
|
+
kwargs={
|
|
58
|
+
"name": name,
|
|
59
|
+
"pos": pos,
|
|
60
|
+
"id": id,
|
|
61
|
+
"is_input": is_input,
|
|
62
|
+
"attr": attr,
|
|
63
|
+
},
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
new_node.meta = node.meta
|
|
67
|
+
return new_node
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def mark_pattern(
|
|
71
|
+
graph_module: torch.fx.GraphModule,
|
|
72
|
+
pattern: Pattern,
|
|
73
|
+
) -> torch.fx.GraphModule:
|
|
74
|
+
"""Mark all existences of pattern graph in the GraphModule with fx pattern matching.
|
|
75
|
+
The marked subgraphs will be lowered in StableHLO composite ops.
|
|
76
|
+
Args:
|
|
77
|
+
graph_module (torch.fx.GraphModule): GraphModule to be matched and marked.
|
|
78
|
+
pattern (ai_edge_torch.hlfb.mark_pattern.Pattern): Pattern to match.
|
|
79
|
+
Returns:
|
|
80
|
+
The modified graph_module with additional marker ops in graph.
|
|
81
|
+
"""
|
|
82
|
+
# Create a copy of graph_module and sanitize it for pattern matching.
|
|
83
|
+
graph_module_to_match = copy.deepcopy(graph_module)
|
|
84
|
+
for n, m in zip(graph_module.graph.nodes, graph_module_to_match.graph.nodes):
|
|
85
|
+
m.meta["ORIGINAL_NODE"] = n
|
|
86
|
+
|
|
87
|
+
# Sanitize graph_module to match in the same way as pattern's graph_module.
|
|
88
|
+
graph_module_to_match = passes.remove_clone_ops(graph_module_to_match)
|
|
89
|
+
|
|
90
|
+
match_with_attrs = pattern.match(graph_module_to_match)
|
|
91
|
+
|
|
92
|
+
for match, attr in match_with_attrs:
|
|
93
|
+
match_id = _get_uuid()
|
|
94
|
+
|
|
95
|
+
# NOTE: Current graph rewriter (_insert_marker) does not work perfectly
|
|
96
|
+
# with continuous matches e.g. matching (a + b) on (w + x + y + z). The
|
|
97
|
+
# rewritten results may be undetermined with false negative - some
|
|
98
|
+
# matches may not be marked in the lowering, while the marked ones would
|
|
99
|
+
# always be correct.
|
|
100
|
+
# TODO(cnchan): completely support mark_pattern on continuous matches.
|
|
101
|
+
for i, pattern_input_node in enumerate(pattern.input_nodes):
|
|
102
|
+
input_node = match.nodes_map[pattern_input_node]
|
|
103
|
+
new_input_node = _insert_marker(
|
|
104
|
+
graph_module,
|
|
105
|
+
input_node.meta["ORIGINAL_NODE"],
|
|
106
|
+
name=pattern.name,
|
|
107
|
+
pos=i,
|
|
108
|
+
id=match_id,
|
|
109
|
+
is_input=True,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Only replace input by the marker node for those nodes used in the pattern.
|
|
113
|
+
in_pattern_nodes = set(match.nodes_map.values())
|
|
114
|
+
for user in input_node.users.keys():
|
|
115
|
+
if user in in_pattern_nodes:
|
|
116
|
+
user.meta["ORIGINAL_NODE"].replace_input_with(
|
|
117
|
+
input_node.meta["ORIGINAL_NODE"], new_input_node
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
for i, pattern_output_node in enumerate(pattern.output_nodes):
|
|
121
|
+
output_node = match.nodes_map[pattern_output_node]
|
|
122
|
+
new_output_node = _insert_marker(
|
|
123
|
+
graph_module,
|
|
124
|
+
output_node.meta["ORIGINAL_NODE"],
|
|
125
|
+
name=pattern.name,
|
|
126
|
+
pos=i,
|
|
127
|
+
id=match_id,
|
|
128
|
+
is_input=False,
|
|
129
|
+
attr=attr, # torch_xla internal: only output marker needs attr.
|
|
130
|
+
)
|
|
131
|
+
output_node.meta["ORIGINAL_NODE"].replace_all_uses_with(new_output_node)
|
|
132
|
+
new_output_node.update_arg(0, output_node.meta["ORIGINAL_NODE"])
|
|
133
|
+
|
|
134
|
+
graph_module.graph.eliminate_dead_code()
|
|
135
|
+
_prepose_placeholder_nodes(graph_module.graph)
|
|
136
|
+
|
|
137
|
+
graph_module.graph.lint()
|
|
138
|
+
graph_module.recompile()
|
|
139
|
+
return graph_module
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def remove_clone_ops(gm: torch.fx.GraphModule):
|
|
19
|
+
# torch export adds additional aten.clone nodes to produce contiguous in memory tensors
|
|
20
|
+
# depending on tensor sizes for runtime efficiency. However, these unpredictable clone
|
|
21
|
+
# nodes can break the pattern matching. Thus remove all clones in model and pattern graphs.
|
|
22
|
+
for node in gm.graph.nodes:
|
|
23
|
+
if node.op == "call_function" and node.name.startswith("clone"):
|
|
24
|
+
node.replace_all_uses_with(node.args[0])
|
|
25
|
+
gm.graph.erase_node(node)
|
|
26
|
+
|
|
27
|
+
gm.graph.lint()
|
|
28
|
+
gm.recompile()
|
|
29
|
+
return gm
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def remove_dangling_args(gm: torch.fx.GraphModule):
|
|
33
|
+
nodes_to_erase = []
|
|
34
|
+
for node in gm.graph.nodes:
|
|
35
|
+
if node.op == "placeholder" and len(node.users) == 0:
|
|
36
|
+
nodes_to_erase.append(node)
|
|
37
|
+
for node in nodes_to_erase:
|
|
38
|
+
gm.graph.erase_node(node)
|
|
39
|
+
|
|
40
|
+
gm.graph.lint()
|
|
41
|
+
gm.recompile()
|
|
42
|
+
return gm
|