ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (91) hide show
  1. ai_edge_torch/__init__.py +30 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +330 -0
  5. ai_edge_torch/convert/converter.py +171 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
  9. ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +273 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +171 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/debug/__init__.py +16 -0
  27. ai_edge_torch/debug/culprit.py +423 -0
  28. ai_edge_torch/debug/test/__init__.py +14 -0
  29. ai_edge_torch/debug/test/test_culprit.py +133 -0
  30. ai_edge_torch/debug/utils.py +48 -0
  31. ai_edge_torch/experimental/__init__.py +14 -0
  32. ai_edge_torch/generative/__init__.py +14 -0
  33. ai_edge_torch/generative/examples/__init__.py +14 -0
  34. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  35. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  36. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  37. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  39. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  40. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  42. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  43. ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
  44. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
  46. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  47. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  48. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  49. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  50. ai_edge_torch/generative/layers/__init__.py +14 -0
  51. ai_edge_torch/generative/layers/attention.py +288 -0
  52. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  53. ai_edge_torch/generative/layers/builder.py +103 -0
  54. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  55. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  56. ai_edge_torch/generative/layers/model_config.py +135 -0
  57. ai_edge_torch/generative/layers/normalization.py +62 -0
  58. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  59. ai_edge_torch/generative/quantize/__init__.py +14 -0
  60. ai_edge_torch/generative/quantize/example.py +45 -0
  61. ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
  62. ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
  63. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  64. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  65. ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
  66. ai_edge_torch/generative/test/__init__.py +14 -0
  67. ai_edge_torch/generative/test/test_model_conversion.py +201 -0
  68. ai_edge_torch/generative/test/test_quantize.py +109 -0
  69. ai_edge_torch/generative/utilities/__init__.py +15 -0
  70. ai_edge_torch/generative/utilities/loader.py +290 -0
  71. ai_edge_torch/generative/utilities/t5_loader.py +467 -0
  72. ai_edge_torch/hlfb/__init__.py +16 -0
  73. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  74. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  75. ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
  76. ai_edge_torch/hlfb/test/__init__.py +14 -0
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  79. ai_edge_torch/model.py +134 -0
  80. ai_edge_torch/quantize/__init__.py +16 -0
  81. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  83. ai_edge_torch/quantize/quant_config.py +85 -0
  84. ai_edge_torch/testing/__init__.py +14 -0
  85. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  86. ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
  87. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
  88. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
  89. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
  90. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
  91. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
@@ -0,0 +1,467 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Common utility functions for data loading etc.
16
+ from dataclasses import dataclass
17
+ import glob
18
+ import os
19
+ from typing import Callable, Dict
20
+
21
+ from safetensors import safe_open
22
+ import torch
23
+
24
+ from ai_edge_torch.generative.layers import model_config
25
+
26
+
27
+ def load_safetensors(full_path: str):
28
+ """Loads safetensors into a single state dictionary.
29
+
30
+ Args:
31
+ full_path (string): the safetensor filename or directory that contains the
32
+ safetensor files.
33
+
34
+ Returns:
35
+ A state dictionary contating loaded tensors.
36
+
37
+ Raises:
38
+ ValueError: If no tensors are loaded from the provided directory or file.
39
+ """
40
+ pattern = (
41
+ os.path.join(full_path, "*.safetensors")
42
+ if os.path.isdir(full_path)
43
+ else full_path
44
+ )
45
+ files = []
46
+ for file in glob.glob(pattern):
47
+ files.append(file)
48
+
49
+ tensors = {}
50
+ for file in files:
51
+ with safe_open(file, framework="pt") as fp:
52
+ for k in fp.keys():
53
+ assert k not in tensors
54
+ tensors[k] = fp.get_tensor(k)
55
+
56
+ if not tensors:
57
+ raise ValueError("Failed to load SafeTensors.")
58
+ return tensors
59
+
60
+
61
+ def load_pytorch_statedict(full_path: str):
62
+ """Loads state dictionary binaries into a single state dictionary.
63
+
64
+ Args:
65
+ full_path (string): the bin filename or directory that contains the bin
66
+ files.
67
+
68
+ Returns:
69
+ A state dictionary contating loaded tensors.
70
+
71
+ Raises:
72
+ ValueError: If no tensors are loaded from the provided directory or file.
73
+ """
74
+ pattern = os.path.join(full_path, "*.bin") if os.path.isdir(full_path) else full_path
75
+ files = []
76
+ for file in glob.glob(pattern):
77
+ files.append(file)
78
+
79
+ tensors = {}
80
+ for file in files:
81
+ this_file_tensors = torch.load(file, map_location=torch.device("cpu"))
82
+ for k in this_file_tensors:
83
+ assert k not in tensors
84
+ tensors.update(this_file_tensors)
85
+
86
+ if not tensors:
87
+ raise ValueError("Failed to load torch bin files.")
88
+ return tensors
89
+
90
+
91
+ class ModelLoader:
92
+ """A utility class for loading and converting model checkpoints to ODML
93
+ transformer layer format.
94
+ """
95
+
96
+ @dataclass
97
+ class TensorNames:
98
+ attn_query_proj: str = None
99
+ attn_key_proj: str = None
100
+ attn_value_proj: str = None
101
+ attn_output_proj: str = None
102
+ relative_attn_bias: str = None
103
+
104
+ cross_attn_query_proj: str = None
105
+ cross_attn_key_proj: str = None
106
+ cross_attn_value_proj: str = None
107
+ cross_attn_output_proj: str = None
108
+
109
+ ff_up_proj: str = None
110
+ ff_down_proj: str = None
111
+ ff_gate_proj: str = None
112
+
113
+ pre_attn_norm: str = None
114
+ pre_cross_attn_norm: str = None
115
+ pre_ff_norm: str = None
116
+ embedding: str = None
117
+ final_norm: str = None
118
+ lm_head: str = None
119
+
120
+ def __init__(self, file_name: str, names: TensorNames) -> None:
121
+ """ModelLoader constructor. Can be used to load multiple models of the same
122
+ type.
123
+
124
+ Args:
125
+ file_name (str): Path to the checkpoint. Can be a directory or an
126
+ exact file.
127
+ names (TensorNames): An instance of `TensorNames` to determine mappings.
128
+ """
129
+ self._file_name = file_name
130
+ self._names = names
131
+ self._loader = self._get_loader()
132
+
133
+ def load(
134
+ self, model: torch.nn.Module, strict: bool = True, fuse_attention: bool = True
135
+ ):
136
+ """Load the model from the checkpoint
137
+
138
+ Args:
139
+ model (torch.nn.Module): The pytorch model that needs to be loaded.
140
+ strict (bool, optional): Whether the converted keys are strictly
141
+ matched. Defaults to True.
142
+
143
+ Raises:
144
+ ValueError: If conversion results in unmapped tensors and strict mode is
145
+ enabled.
146
+ """
147
+ state = self._loader(self._file_name)
148
+
149
+ if isinstance(self._names, ModelLoader.TensorNames):
150
+ converted_state = self._do_load(
151
+ model, state, self._names, fuse_attention=fuse_attention
152
+ )
153
+ elif isinstance(self._names, dict):
154
+ converted_state = {}
155
+ for additional_prefix, names in self._names.items():
156
+ local_converted_state = self._do_load(
157
+ model,
158
+ state,
159
+ self._names[additional_prefix],
160
+ additional_prefix,
161
+ fuse_attention=fuse_attention,
162
+ )
163
+ converted_state.update(local_converted_state)
164
+ else:
165
+ raise ValueError(f"Unkown type for names: {type(self._names)}")
166
+
167
+ if strict and state:
168
+ raise ValueError(
169
+ f"Failed to map all tensor. Remaining tensor are: {list(state.keys())}"
170
+ )
171
+ model.load_state_dict(converted_state, strict=strict)
172
+
173
+ def _do_load(self, model, state, names, additional_prefix="", fuse_attention=True):
174
+ """Load the model from the checkpoint
175
+
176
+ Args:
177
+ model (torch.nn.Module): The pytorch model that needs to be loaded.
178
+ state (Dict[str, torch.Tensor]): The pytorch state dictionary
179
+ names (TensorNames]): The TensorNames for the model we are loading.
180
+
181
+ Returns:
182
+ Dict[str, torch.Tensor]: Map of name to tensor for loading.
183
+ """
184
+ converted_state = dict()
185
+ if names.embedding is not None:
186
+ converted_state["tok_embedding.weight"] = state.pop(f"{names.embedding}.weight")
187
+ if names.lm_head is not None:
188
+ converted_state["lm_head.weight"] = state.pop(f"{names.lm_head}.weight")
189
+ if model.config.lm_head_use_bias:
190
+ converted_state["lm_head.bias"] = state.pop(f"{names.lm_head}.bias")
191
+ if names.final_norm is not None:
192
+ final_norm_name = names.final_norm
193
+ prefix = additional_prefix
194
+ converted_state[f"{prefix}final_norm.weight"] = state.pop(
195
+ f"{final_norm_name}.weight"
196
+ )
197
+ if f"{final_norm_name}.bias" in state:
198
+ converted_state["final_norm.bias"] = state.pop(f"{final_norm_name}.bias")
199
+
200
+ if names.relative_attn_bias:
201
+ rel_attn_name = names.relative_attn_bias
202
+ prefix = additional_prefix + f"transformer_blocks.0"
203
+ converted_state[f"{prefix}.atten_func.relative_attention_bias.weight"] = (
204
+ state.pop(f"{rel_attn_name}.weight")
205
+ )
206
+
207
+ for i in range(model.config.num_layers):
208
+ self._map_norm(i, model.config, state, converted_state, names, additional_prefix)
209
+ self._map_feedforward(
210
+ i, model.config, state, converted_state, names, additional_prefix
211
+ )
212
+ self._map_attention(
213
+ i,
214
+ model.config,
215
+ state,
216
+ converted_state,
217
+ names,
218
+ additional_prefix,
219
+ fuse_attention,
220
+ )
221
+ self._map_cross_attention(
222
+ i,
223
+ model.config,
224
+ state,
225
+ converted_state,
226
+ names,
227
+ additional_prefix,
228
+ fuse_attention,
229
+ )
230
+
231
+ return converted_state
232
+
233
+ def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]:
234
+ """A best effort method for finding appropriate state loader.
235
+
236
+ Raises:
237
+ ValueError: If it fails to find an appropriate loader.
238
+
239
+ Returns:
240
+ Callable[[str], Dict[str, torch.Tensor]]: State loader to be used.
241
+ """
242
+ if os.path.isdir(self._file_name):
243
+ if glob.glob(os.path.join(self._file_name, "*.safetensors")):
244
+ return load_safetensors
245
+ if glob.glob(os.path.join(self._file_name, "*.bin")):
246
+ return load_pytorch_statedict
247
+
248
+ if self._file_name.endswith(".safetensors"):
249
+ return load_safetensors
250
+
251
+ if self._file_name.endswith(".bin"):
252
+ return load_pytorch_statedict
253
+
254
+ raise ValueError(f"File format not supported.")
255
+
256
+ def _map_feedforward(
257
+ self,
258
+ idx: int,
259
+ config: model_config.ModelConfig,
260
+ state: Dict[str, torch.Tensor],
261
+ converted_state: Dict[str, torch.Tensor],
262
+ names: TensorNames,
263
+ additional_prefix: str = "",
264
+ ):
265
+ prefix = additional_prefix + f"transformer_blocks.{idx}"
266
+ if names.ff_up_proj is None or names.ff_down_proj is None:
267
+ return
268
+ if config.ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
269
+ ff_up_proj_name = names.ff_up_proj.format(idx)
270
+ ff_down_proj_name = names.ff_down_proj.format(idx)
271
+ converted_state[f"{prefix}.ff.w1.weight"] = state.pop(f"{ff_up_proj_name}.weight")
272
+ converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
273
+ f"{ff_down_proj_name}.weight"
274
+ )
275
+ if config.ff_config.use_bias:
276
+ converted_state[f"{prefix}.ff.w1.bias"] = state.pop(f"{ff_up_proj_name}.bias")
277
+ converted_state[f"{prefix}.ff.w2.bias"] = state.pop(f"{ff_down_proj_name}.bias")
278
+ else:
279
+ if names.ff_gate_proj is not None:
280
+ ff_up_proj_name = names.ff_up_proj.format(idx)
281
+ ff_down_proj_name = names.ff_down_proj.format(idx)
282
+ ff_gate_proj_name = names.ff_gate_proj.format(idx)
283
+ converted_state[f"{prefix}.ff.w3.weight"] = state.pop(
284
+ f"{ff_up_proj_name}.weight"
285
+ )
286
+ converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
287
+ f"{ff_down_proj_name}.weight"
288
+ )
289
+ converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
290
+ f"{ff_gate_proj_name}.weight"
291
+ )
292
+ if config.ff_config.use_bias:
293
+ converted_state[f"{prefix}.ff.w3.bias"] = state.pop(f"{ff_up_proj_name}.bias")
294
+ converted_state[f"{prefix}.ff.w2.bias"] = state.pop(
295
+ f"{ff_down_proj_name}.bias"
296
+ )
297
+ converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
298
+ f"{ff_gate_proj_name}.bias"
299
+ )
300
+
301
+ def _map_attention(
302
+ self,
303
+ idx: int,
304
+ config: model_config.ModelConfig,
305
+ state: Dict[str, torch.Tensor],
306
+ converted_state: Dict[str, torch.Tensor],
307
+ names: TensorNames,
308
+ additional_prefix: str = "",
309
+ fuse_attention: bool = True,
310
+ ):
311
+ if (
312
+ names.attn_query_proj is None
313
+ or names.attn_key_proj is None
314
+ or names.attn_value_proj is None
315
+ ):
316
+ return
317
+ prefix = additional_prefix + f"transformer_blocks.{idx}"
318
+ q_name = names.attn_query_proj.format(idx)
319
+ k_name = names.attn_key_proj.format(idx)
320
+ v_name = names.attn_value_proj.format(idx)
321
+ # model.encoder.transformer_blocks[0].atten_func.q.weight
322
+ if fuse_attention:
323
+ converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
324
+ config,
325
+ state.pop(f"{q_name}.weight"),
326
+ state.pop(f"{k_name}.weight"),
327
+ state.pop(f"{v_name}.weight"),
328
+ )
329
+ if config.attn_config.qkv_use_bias:
330
+ converted_state[f"{prefix}.atten_func.attn.bias"] = self._fuse_qkv(
331
+ config,
332
+ state.pop(f"{q_name}.bias"),
333
+ state.pop(f"{k_name}.bias"),
334
+ state.pop(f"{v_name}.bias"),
335
+ )
336
+ else:
337
+ converted_state[f"{prefix}.atten_func.q.weight"] = state.pop(f"{q_name}.weight")
338
+ converted_state[f"{prefix}.atten_func.k.weight"] = state.pop(f"{k_name}.weight")
339
+ converted_state[f"{prefix}.atten_func.v.weight"] = state.pop(f"{v_name}.weight")
340
+ if config.attn_config.qkv_use_bias:
341
+ converted_state[f"{prefix}.atten_func.q.bias"] = state.pop(f"{q_name}.bias")
342
+ converted_state[f"{prefix}.atten_func.k.bias"] = state.pop(f"{k_name}.bias")
343
+ converted_state[f"{prefix}.atten_func.v.bias"] = state.pop(f"{v_name}.bias")
344
+
345
+ o_name = names.attn_output_proj.format(idx)
346
+ converted_state[f"{prefix}.atten_func.proj.weight"] = state.pop(f"{o_name}.weight")
347
+ if config.attn_config.output_proj_use_bias:
348
+ converted_state[f"{prefix}.atten_func.proj.bias"] = state.pop(f"{o_name}.bias")
349
+
350
+ def _map_cross_attention(
351
+ self,
352
+ idx: int,
353
+ config: model_config.ModelConfig,
354
+ state: Dict[str, torch.Tensor],
355
+ converted_state: Dict[str, torch.Tensor],
356
+ names: TensorNames,
357
+ additional_prefix: str = "",
358
+ fuse_attention: bool = True,
359
+ ):
360
+ if (
361
+ names.cross_attn_query_proj is None
362
+ or names.cross_attn_key_proj is None
363
+ or names.cross_attn_value_proj is None
364
+ ):
365
+ return
366
+ prefix = additional_prefix + f"transformer_blocks.{idx}"
367
+ q_name = names.cross_attn_query_proj.format(idx)
368
+ k_name = names.cross_attn_key_proj.format(idx)
369
+ v_name = names.cross_attn_value_proj.format(idx)
370
+
371
+ if fuse_attention:
372
+ converted_state[f"{prefix}.cross_atten_func.attn.weight"] = self._fuse_qkv(
373
+ config,
374
+ state.pop(f"{q_name}.weight"),
375
+ state.pop(f"{k_name}.weight"),
376
+ state.pop(f"{v_name}.weight"),
377
+ )
378
+ if config.attn_config.qkv_use_bias:
379
+ converted_state[f"{prefix}.cross_atten_func.attn.bias"] = self._fuse_qkv(
380
+ config,
381
+ state.pop(f"{q_name}.bias"),
382
+ state.pop(f"{k_name}.bias"),
383
+ state.pop(f"{v_name}.bias"),
384
+ )
385
+ else:
386
+ converted_state[f"{prefix}.cross_atten_func.q.weight"] = state.pop(
387
+ f"{q_name}.weight"
388
+ )
389
+ converted_state[f"{prefix}.cross_atten_func.k.weight"] = state.pop(
390
+ f"{k_name}.weight"
391
+ )
392
+ converted_state[f"{prefix}.cross_atten_func.v.weight"] = state.pop(
393
+ f"{v_name}.weight"
394
+ )
395
+ if config.attn_config.qkv_use_bias:
396
+ converted_state[f"{prefix}.cross_atten_func.q.bias"] = state.pop(
397
+ f"{q_name}.bias"
398
+ )
399
+ converted_state[f"{prefix}.cross_atten_func.k.bias"] = state.pop(
400
+ f"{k_name}.bias"
401
+ )
402
+ converted_state[f"{prefix}.cross_atten_func.v.bias"] = state.pop(
403
+ f"{v_name}.bias"
404
+ )
405
+
406
+ o_name = names.cross_attn_output_proj.format(idx)
407
+ converted_state[f"{prefix}.cross_atten_func.proj.weight"] = state.pop(
408
+ f"{o_name}.weight"
409
+ )
410
+ if config.attn_config.output_proj_use_bias:
411
+ converted_state[f"{prefix}.cross_atten_func.proj.bias"] = state.pop(
412
+ f"{o_name}.bias"
413
+ )
414
+
415
+ def _map_norm(
416
+ self,
417
+ idx: int,
418
+ config: model_config.ModelConfig,
419
+ state: Dict[str, torch.Tensor],
420
+ converted_state: Dict[str, torch.Tensor],
421
+ names: TensorNames,
422
+ additional_prefix: str = "",
423
+ ):
424
+ prefix = additional_prefix + f"transformer_blocks.{idx}"
425
+ if names.pre_attn_norm is not None:
426
+ pre_attn_norm_name = names.pre_attn_norm.format(idx)
427
+ converted_state[f"{prefix}.atten_func.pre_atten_norm.weight"] = state.pop(
428
+ f"{pre_attn_norm_name}.weight"
429
+ )
430
+ if f"{pre_attn_norm_name}.bias" in state:
431
+ converted_state[f"{prefix}.atten_func.pre_atten_norm.bias"] = state.pop(
432
+ f"{pre_attn_norm_name}.bias"
433
+ )
434
+
435
+ if names.pre_cross_attn_norm:
436
+ pre_cross_attn_norm_name = names.pre_cross_attn_norm.format(idx)
437
+ converted_state[f"{prefix}.cross_atten_func.pre_atten_norm.weight"] = state.pop(
438
+ f"{pre_cross_attn_norm_name}.weight"
439
+ )
440
+ if f"{pre_cross_attn_norm_name}.bias" in state:
441
+ converted_state[f"{prefix}.cross_atten_func.pre_atten_norm.bias"] = state.pop(
442
+ f"{pre_cross_attn_norm_name}.bias"
443
+ )
444
+
445
+ if names.pre_ff_norm is not None:
446
+ pre_ff_norm_name = names.pre_ff_norm.format(idx)
447
+ converted_state[f"{prefix}.pre_ff_norm.weight"] = state.pop(
448
+ f"{pre_ff_norm_name}.weight"
449
+ )
450
+ if f"{pre_ff_norm_name}.bias" in state:
451
+ converted_state[f"{prefix}.pre_ff_norm.bias"] = state.pop(
452
+ f"{pre_ff_norm_name}.bias"
453
+ )
454
+
455
+ def _fuse_qkv(
456
+ self,
457
+ config: model_config.ModelConfig,
458
+ q: torch.Tensor,
459
+ k: torch.Tensor,
460
+ v: torch.Tensor,
461
+ ) -> torch.Tensor:
462
+ q_per_kv = config.attn_config.num_heads // config.attn_config.num_query_groups
463
+ qs = torch.split(q, config.head_dim * q_per_kv)
464
+ ks = torch.split(k, config.head_dim)
465
+ vs = torch.split(v, config.head_dim)
466
+ cycled = [t for group in zip(qs, ks, vs) for t in group]
467
+ return torch.cat(cycled)
@@ -0,0 +1,16 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
@@ -0,0 +1,139 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import copy
16
+ from typing import Any
17
+ import uuid
18
+
19
+ import torch
20
+ from torch_xla.experimental import xla_marker
21
+
22
+ from ai_edge_torch.hlfb.mark_pattern.pattern import Pattern
23
+ from ai_edge_torch.hlfb.mark_pattern.pattern import ScalarAttrTracker # NOQA
24
+
25
+
26
+ @torch._dynamo.assume_constant_result
27
+ def _get_uuid() -> str:
28
+ return uuid.uuid4().hex
29
+
30
+
31
+ # TODO: Move to a general fx utils file.
32
+ def _prepose_placeholder_nodes(graph: torch.fx.Graph):
33
+ nodes = [node for node in graph.nodes if node.op == "placeholder"] + [
34
+ node for node in graph.nodes if node.op != "placeholder"
35
+ ]
36
+
37
+ for a, b in zip(nodes, nodes[1:]):
38
+ if a.next is not b:
39
+ a.append(b)
40
+ return graph
41
+
42
+
43
+ def _insert_marker(
44
+ graph_module: torch.fx.GraphModule,
45
+ node: torch.fx.Node,
46
+ name: str,
47
+ pos: int,
48
+ id: str,
49
+ is_input: bool,
50
+ attr: dict[str, Any] = None,
51
+ ):
52
+ attr = xla_marker.serialize_composite_attr(attr) if attr else None
53
+ with graph_module.graph.inserting_after(node):
54
+ new_node = graph_module.graph.call_function(
55
+ torch.ops.xla.mark_tensor,
56
+ args=(node,),
57
+ kwargs={
58
+ "name": name,
59
+ "pos": pos,
60
+ "id": id,
61
+ "is_input": is_input,
62
+ "attr": attr,
63
+ },
64
+ )
65
+
66
+ new_node.meta = node.meta
67
+ return new_node
68
+
69
+
70
+ def mark_pattern(
71
+ graph_module: torch.fx.GraphModule,
72
+ pattern: Pattern,
73
+ ) -> torch.fx.GraphModule:
74
+ """Mark all existences of pattern graph in the GraphModule with fx pattern matching.
75
+ The marked subgraphs will be lowered in StableHLO composite ops.
76
+ Args:
77
+ graph_module (torch.fx.GraphModule): GraphModule to be matched and marked.
78
+ pattern (ai_edge_torch.hlfb.mark_pattern.Pattern): Pattern to match.
79
+ Returns:
80
+ The modified graph_module with additional marker ops in graph.
81
+ """
82
+ # Create a copy of graph_module and sanitize it for pattern matching.
83
+ graph_module_to_match = copy.deepcopy(graph_module)
84
+ for n, m in zip(graph_module.graph.nodes, graph_module_to_match.graph.nodes):
85
+ m.meta["ORIGINAL_NODE"] = n
86
+
87
+ # Sanitize graph_module to match in the same way as pattern's graph_module.
88
+ graph_module_to_match = passes.remove_clone_ops(graph_module_to_match)
89
+
90
+ match_with_attrs = pattern.match(graph_module_to_match)
91
+
92
+ for match, attr in match_with_attrs:
93
+ match_id = _get_uuid()
94
+
95
+ # NOTE: Current graph rewriter (_insert_marker) does not work perfectly
96
+ # with continuous matches e.g. matching (a + b) on (w + x + y + z). The
97
+ # rewritten results may be undetermined with false negative - some
98
+ # matches may not be marked in the lowering, while the marked ones would
99
+ # always be correct.
100
+ # TODO(cnchan): completely support mark_pattern on continuous matches.
101
+ for i, pattern_input_node in enumerate(pattern.input_nodes):
102
+ input_node = match.nodes_map[pattern_input_node]
103
+ new_input_node = _insert_marker(
104
+ graph_module,
105
+ input_node.meta["ORIGINAL_NODE"],
106
+ name=pattern.name,
107
+ pos=i,
108
+ id=match_id,
109
+ is_input=True,
110
+ )
111
+
112
+ # Only replace input by the marker node for those nodes used in the pattern.
113
+ in_pattern_nodes = set(match.nodes_map.values())
114
+ for user in input_node.users.keys():
115
+ if user in in_pattern_nodes:
116
+ user.meta["ORIGINAL_NODE"].replace_input_with(
117
+ input_node.meta["ORIGINAL_NODE"], new_input_node
118
+ )
119
+
120
+ for i, pattern_output_node in enumerate(pattern.output_nodes):
121
+ output_node = match.nodes_map[pattern_output_node]
122
+ new_output_node = _insert_marker(
123
+ graph_module,
124
+ output_node.meta["ORIGINAL_NODE"],
125
+ name=pattern.name,
126
+ pos=i,
127
+ id=match_id,
128
+ is_input=False,
129
+ attr=attr, # torch_xla internal: only output marker needs attr.
130
+ )
131
+ output_node.meta["ORIGINAL_NODE"].replace_all_uses_with(new_output_node)
132
+ new_output_node.update_arg(0, output_node.meta["ORIGINAL_NODE"])
133
+
134
+ graph_module.graph.eliminate_dead_code()
135
+ _prepose_placeholder_nodes(graph_module.graph)
136
+
137
+ graph_module.graph.lint()
138
+ graph_module.recompile()
139
+ return graph_module
@@ -0,0 +1,42 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import torch
16
+
17
+
18
+ def remove_clone_ops(gm: torch.fx.GraphModule):
19
+ # torch export adds additional aten.clone nodes to produce contiguous in memory tensors
20
+ # depending on tensor sizes for runtime efficiency. However, these unpredictable clone
21
+ # nodes can break the pattern matching. Thus remove all clones in model and pattern graphs.
22
+ for node in gm.graph.nodes:
23
+ if node.op == "call_function" and node.name.startswith("clone"):
24
+ node.replace_all_uses_with(node.args[0])
25
+ gm.graph.erase_node(node)
26
+
27
+ gm.graph.lint()
28
+ gm.recompile()
29
+ return gm
30
+
31
+
32
+ def remove_dangling_args(gm: torch.fx.GraphModule):
33
+ nodes_to_erase = []
34
+ for node in gm.graph.nodes:
35
+ if node.op == "placeholder" and len(node.users) == 0:
36
+ nodes_to_erase.append(node)
37
+ for node in nodes_to_erase:
38
+ gm.graph.erase_node(node)
39
+
40
+ gm.graph.lint()
41
+ gm.recompile()
42
+ return gm