ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (38) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  3. ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
  4. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  5. ai_edge_torch/generative/examples/llama/llama.py +29 -25
  6. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  7. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  8. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  9. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  10. ai_edge_torch/generative/examples/phi/phi3.py +26 -23
  11. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  12. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  13. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  14. ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
  15. ai_edge_torch/generative/examples/smollm/verify.py +18 -2
  16. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  17. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  18. ai_edge_torch/generative/layers/attention.py +45 -37
  19. ai_edge_torch/generative/layers/lora.py +557 -0
  20. ai_edge_torch/generative/layers/model_config.py +6 -2
  21. ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
  22. ai_edge_torch/generative/test/test_lora.py +147 -0
  23. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  24. ai_edge_torch/generative/utilities/converter.py +100 -47
  25. ai_edge_torch/generative/utilities/model_builder.py +23 -14
  26. ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
  27. ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
  28. ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
  29. ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
  30. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  31. ai_edge_torch/odml_torch/export.py +6 -2
  32. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  33. ai_edge_torch/version.py +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
  35. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +38 -35
  36. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
  37. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
  38. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,147 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """A suite of tests to validate LoRA utilities."""
17
+
18
+ from ai_edge_torch.generative.layers import lora as lora_utils
19
+ import ai_edge_torch.generative.layers.model_config as cfg
20
+ import torch
21
+ from absl.testing import absltest as googletest
22
+ from tensorflow.python.platform import resource_loader # pylint: disable=g-direct-tensorflow-import
23
+
24
+
25
+ class TestLora(googletest.TestCase):
26
+ """Tests for LoRA utilities."""
27
+
28
+ def test_safetensors_builder(self):
29
+ """Converts a safetensors file to a LoRA module."""
30
+
31
+ tensor_names = lora_utils.LoRATensorNames(
32
+ attn_query_w_a=(
33
+ "base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight"
34
+ ),
35
+ attn_query_w_b=(
36
+ "base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight"
37
+ ),
38
+ attn_key_w_a=(
39
+ "base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight"
40
+ ),
41
+ attn_key_w_b=(
42
+ "base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight"
43
+ ),
44
+ attn_value_w_a=(
45
+ "base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight"
46
+ ),
47
+ attn_value_w_b=(
48
+ "base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight"
49
+ ),
50
+ attn_output_w_a=(
51
+ "base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight"
52
+ ),
53
+ attn_output_w_b=(
54
+ "base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight"
55
+ ),
56
+ )
57
+
58
+ safetensors_file = resource_loader.get_path_to_datafile(
59
+ "fixtures/test_lora_rank16.safetensors"
60
+ )
61
+ config = self._get_test_config(
62
+ num_layers=1,
63
+ head_dim=8,
64
+ num_query_groups=1,
65
+ kv_cache_max_len=16,
66
+ )
67
+ lora = lora_utils.LoRA.from_safetensors(
68
+ safetensors_file,
69
+ scale=1.0,
70
+ lora_tensor_names=tensor_names,
71
+ config=config,
72
+ )
73
+ self.assertEqual(lora.get_rank(), 16)
74
+
75
+ def test_torch_export(self):
76
+ """Tests the export of the LoRA module."""
77
+
78
+ class TestModel(torch.nn.Module):
79
+
80
+ def forward(self, x: torch.Tensor, lora: lora_utils.LoRA) -> torch.Tensor:
81
+ x += lora_utils.apply_lora(x, lora.adapters[0].attention.query)
82
+ return x
83
+
84
+ n = 1
85
+ head_dim = 2
86
+ num_query_groups = 1
87
+ key_length = 4
88
+ config = self._get_test_config(
89
+ num_layers=n,
90
+ head_dim=head_dim,
91
+ num_query_groups=num_query_groups,
92
+ kv_cache_max_len=key_length,
93
+ )
94
+ inputs = torch.zeros((n, 1, head_dim))
95
+ lora = lora_utils.LoRA.zeros(rank=16, config=config)
96
+ model = TestModel()
97
+ exported_program = torch.export.export(model, (inputs, lora))
98
+ input_specs = exported_program.graph_signature.input_specs
99
+ # 9 inputs: 1 for x, 2 for query lora, 2 for key lora, 2 for value lora,
100
+ # 2 for output lora.
101
+ self.assertLen(input_specs, 9)
102
+ self.assertEqual(input_specs[0].arg.name, "x")
103
+ self.assertEqual(input_specs[1].arg.name, "lora_atten_q_a_prime_weight_0")
104
+ self.assertEqual(input_specs[2].arg.name, "lora_atten_q_b_prime_weight_0")
105
+ self.assertEqual(input_specs[3].arg.name, "lora_atten_k_a_prime_weight_0")
106
+ self.assertEqual(input_specs[4].arg.name, "lora_atten_k_b_prime_weight_0")
107
+ self.assertEqual(input_specs[5].arg.name, "lora_atten_v_a_prime_weight_0")
108
+ self.assertEqual(input_specs[6].arg.name, "lora_atten_v_b_prime_weight_0")
109
+ self.assertEqual(input_specs[7].arg.name, "lora_atten_o_a_prime_weight_0")
110
+ self.assertEqual(input_specs[8].arg.name, "lora_atten_o_b_prime_weight_0")
111
+
112
+ def test_lora_tflite_serialization(self):
113
+ """Tests the serialization of the LoRA module."""
114
+ config = self._get_test_config(
115
+ num_layers=2,
116
+ head_dim=8,
117
+ num_query_groups=1,
118
+ kv_cache_max_len=16,
119
+ )
120
+ lora = lora_utils.LoRA.random(rank=16, config=config)
121
+ flatbuffer_model = lora.to_tflite()
122
+ recovered_lora = lora_utils.LoRA.from_flatbuffers(flatbuffer_model)
123
+ self.assertEqual(lora, recovered_lora)
124
+
125
+ def _get_test_config(
126
+ self, num_layers, head_dim, num_query_groups, kv_cache_max_len
127
+ ):
128
+ """Returns a test model config."""
129
+ attn_config = cfg.AttentionConfig(
130
+ num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
131
+ )
132
+ block_config = cfg.TransformerBlockConfig(
133
+ attn_config=attn_config, ff_config=None
134
+ )
135
+ config = cfg.ModelConfig(
136
+ kv_cache_max_len=kv_cache_max_len,
137
+ embedding_dim=head_dim,
138
+ block_configs=block_config,
139
+ num_layers=num_layers,
140
+ max_seq_len=None,
141
+ vocab_size=None,
142
+ )
143
+ return config
144
+
145
+
146
+ if __name__ == "__main__":
147
+ googletest.main()
@@ -150,6 +150,16 @@ class TestModelConversion(googletest.TestCase):
150
150
  ai_edge_torch.config.in_oss,
151
151
  reason="tests with custom ops are not supported in oss",
152
152
  )
153
+
154
+ def test_smollm2(self):
155
+ config = smollm.get_fake_model_config_v2()
156
+ pytorch_model = smollm.SmolLM2(config).eval()
157
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
158
+ @googletest.skipIf(
159
+ ai_edge_torch.config.in_oss,
160
+ reason="tests with custom ops are not supported in oss",
161
+ )
162
+
153
163
  def test_openelm(self):
154
164
  config = openelm.get_fake_model_config()
155
165
  pytorch_model = openelm.OpenELM(config).eval()
@@ -15,16 +15,15 @@
15
15
 
16
16
  """Common utility functions for model conversion."""
17
17
 
18
- from functools import partial
19
- from typing import Any, Union
20
-
18
+ import os
19
+ from typing import Optional, Union
21
20
  from ai_edge_torch._convert import converter as converter_utils
21
+ from ai_edge_torch.generative.layers import lora as lora_utils
22
22
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.model_config as cfg
24
24
  from ai_edge_torch.generative.quantize import quant_recipes
25
25
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
26
  import torch
27
- import torch.nn as nn
28
27
 
29
28
 
30
29
  class ExportableModule(torch.nn.Module):
@@ -41,11 +40,13 @@ class ExportableModule(torch.nn.Module):
41
40
 
42
41
  def convert_to_tflite(
43
42
  pytorch_model: torch.nn.Module,
44
- tflite_path: str,
43
+ output_path: str,
44
+ output_name_prefix: str,
45
45
  prefill_seq_len: Union[int, list[int]],
46
46
  pixel_values_size: torch.Size = None,
47
47
  quantize: bool = True,
48
48
  config: cfg.ModelConfig = None,
49
+ lora_ranks: Optional[list[int]] = None,
49
50
  export_config: ExportConfig = None,
50
51
  ):
51
52
  """Converts a nn.Module model to multi-signature tflite model.
@@ -79,21 +80,65 @@ def convert_to_tflite(
79
80
 
80
81
  Args:
81
82
  pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
82
- tflite_path (str): The tflite file path to export.
83
- prefill_seq_len (Union[int, list[int]]): A list of prefill lengths to
84
- export.
83
+ output_path (str): The path to export the tflite model.
84
+ output_name_prefix (str): The prefix of the tflite model name.
85
+ prefill_seq_len (Union[int, list[int]]): The prefill sequence length to
86
+ use. If a list, the model will have multiple prefill signatures.
85
87
  pixel_values_size (torch.Size, optional): The size of pixel values to pass
86
88
  to the model. If None, the model is not expected to take pixel values.
87
89
  quantize (bool, optional): Whether the model should be quanized. Defaults
88
90
  to True.
89
91
  config (cfg.ModelConfig, optional): The model config used to configure KV
90
92
  cache. If None, it uses the config of the pytorch_model.
93
+ lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
94
+ no LoRA signatures will be added.
91
95
  """
96
+ # pylint: disable=protected-access
97
+ torch._dynamo.config.cache_size_limit = 64
98
+
99
+ config = config if config else pytorch_model.config
92
100
  prefill_seq_lens = (
93
101
  [prefill_seq_len] if isinstance(prefill_seq_len, int) else prefill_seq_len
94
102
  )
103
+ loras = [None]
104
+ if lora_ranks is not None:
105
+ for rank in lora_ranks:
106
+ lora = lora_utils.LoRA.zeros(rank, config)
107
+ loras.append(lora)
108
+
109
+ quant_suffix = 'q8' if quantize else 'f32'
110
+ kv_size = config.kv_cache_max_len
111
+ lora_suffix = (
112
+ '' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
113
+ )
114
+ output_filename = (
115
+ f'{output_name_prefix}_{quant_suffix}_ekv{kv_size}{lora_suffix}.tflite'
116
+ )
117
+ output_file = os.path.join(output_path, output_filename)
118
+
119
+ _export_helper(
120
+ pytorch_model,
121
+ output_file,
122
+ prefill_seq_lens,
123
+ pixel_values_size,
124
+ quantize,
125
+ config,
126
+ loras,
127
+ export_config,
128
+ )
95
129
 
96
- # Tensors used to trace the model graph during conversion.
130
+
131
+ def _export_helper(
132
+ pytorch_model: torch.nn.Module,
133
+ output_file: str,
134
+ prefill_seq_lens: list[int],
135
+ pixel_values_size: torch.Size,
136
+ quantize: bool,
137
+ config: cfg.ModelConfig,
138
+ loras: list[None | lora_utils.LoRA],
139
+ export_config: ExportConfig,
140
+ ):
141
+ """Helper function to export a model to tflite."""
97
142
  prefill_tokens_list = []
98
143
  prefill_input_pos_list = []
99
144
  for seq_len in prefill_seq_lens:
@@ -108,9 +153,7 @@ def convert_to_tflite(
108
153
 
109
154
  decode_token = torch.tensor([[0]], dtype=torch.int)
110
155
  decode_input_pos = torch.tensor([0], dtype=torch.int)
111
- kv = kv_utils.KVCache.from_model_config(
112
- config if config else pytorch_model.config
113
- )
156
+ kv = kv_utils.KVCache.from_model_config(config)
114
157
 
115
158
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
116
159
 
@@ -119,44 +162,54 @@ def convert_to_tflite(
119
162
  mod = ExportableModule(pytorch_model, export_config=export_config)
120
163
 
121
164
  converter = converter_utils.Converter()
122
- for i in range(len(prefill_seq_lens)):
123
- prefill_seq_len = prefill_seq_lens[i]
124
- prefill_tokens = prefill_tokens_list[i]
125
- prefill_input_pos = prefill_input_pos_list[i]
126
- if i == 0 and len(prefill_seq_lens) == 1:
127
- prefill_signature_name = 'prefill'
128
- else:
129
- prefill_signature_name = f'prefill_{prefill_seq_len}'
130
- converter.add_signature(
131
- prefill_signature_name,
132
- mod,
133
- sample_kwargs={
134
- 'tokens': prefill_tokens,
135
- 'input_pos': prefill_input_pos,
136
- 'kv_cache': kv,
137
- },
138
- )
139
- if prefill_pixel_values is not None:
165
+ for lora in loras:
166
+ for i in range(len(prefill_seq_lens)):
167
+ prefill_seq_len = prefill_seq_lens[i]
168
+ prefill_tokens = prefill_tokens_list[i]
169
+ prefill_input_pos = prefill_input_pos_list[i]
170
+ if i == 0 and len(prefill_seq_lens) == 1:
171
+ prefill_signature_name = 'prefill'
172
+ else:
173
+ prefill_signature_name = f'prefill_{prefill_seq_len}'
174
+
175
+ sample_kwargs = {
176
+ 'tokens': prefill_tokens,
177
+ 'input_pos': prefill_input_pos,
178
+ 'kv_cache': kv,
179
+ }
180
+ if lora is not None:
181
+ prefill_signature_name += f'_lora_r{lora.get_rank()}'
182
+ sample_kwargs['lora'] = lora
183
+
140
184
  converter.add_signature(
141
- prefill_signature_name + '_pixel',
185
+ prefill_signature_name,
142
186
  mod,
143
- sample_kwargs={
144
- 'tokens': prefill_tokens,
145
- 'input_pos': prefill_input_pos,
146
- 'kv_cache': kv,
147
- 'pixel_values': prefill_pixel_values,
148
- },
187
+ sample_kwargs=sample_kwargs,
149
188
  )
150
189
 
151
- converter.add_signature(
152
- 'decode',
153
- mod,
154
- sample_kwargs={
155
- 'tokens': decode_token,
156
- 'input_pos': decode_input_pos,
157
- 'kv_cache': kv,
158
- },
159
- )
190
+ if prefill_pixel_values is not None:
191
+ converter.add_signature(
192
+ prefill_signature_name + '_pixel',
193
+ mod,
194
+ sample_kwargs={
195
+ **sample_kwargs,
196
+ 'pixel_values': prefill_pixel_values,
197
+ },
198
+ )
199
+
200
+ sample_kwargs = {
201
+ 'tokens': decode_token,
202
+ 'input_pos': decode_input_pos,
203
+ 'kv_cache': kv,
204
+ }
205
+ if lora is not None:
206
+ sample_kwargs['lora'] = lora
207
+
208
+ converter.add_signature(
209
+ 'decode' if lora is None else f'decode_lora_r{lora.get_rank()}',
210
+ mod,
211
+ sample_kwargs=sample_kwargs,
212
+ )
160
213
 
161
214
  edge_model = converter.convert(quant_config=quant_config)
162
- edge_model.export(tflite_path)
215
+ edge_model.export(output_file)
@@ -22,12 +22,15 @@ from typing import Optional, Tuple
22
22
  from ai_edge_torch.generative.layers import attention
23
23
  from ai_edge_torch.generative.layers import builder
24
24
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
+ from ai_edge_torch.generative.layers import lora as lora_utils
25
26
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
27
  import ai_edge_torch.generative.layers.model_config as cfg
28
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
27
29
  import ai_edge_torch.generative.utilities.loader as loading_utils
28
30
  import torch
29
31
  from torch import nn
30
32
 
33
+
31
34
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
32
35
  ff_up_proj="model.layers.{}.mlp.up_proj",
33
36
  ff_down_proj="model.layers.{}.mlp.down_proj",
@@ -85,13 +88,6 @@ class DecoderOnlyModel(nn.Module):
85
88
  config.embedding_dim,
86
89
  config.final_norm_config,
87
90
  )
88
- # ROPE parameters for all attn_configs are the same. Take the first one.
89
- attn_config = config.block_config(0).attn_config
90
- self.rope_cache = attn_utils.build_rope_cache(
91
- size=config.kv_cache_max,
92
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
93
- base=attn_config.rotary_base,
94
- )
95
91
  self.mask_cache = attn_utils.build_causal_mask_cache(
96
92
  size=config.kv_cache_max,
97
93
  )
@@ -103,6 +99,7 @@ class DecoderOnlyModel(nn.Module):
103
99
  tokens: torch.Tensor,
104
100
  input_pos: torch.Tensor,
105
101
  kv_cache: kv_utils.KVCache,
102
+ lora: Optional[lora_utils.LoRA] = None,
106
103
  export_config: Optional[ExportConfig] = None,
107
104
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
108
105
  _, seq_len = tokens.size()
@@ -113,13 +110,23 @@ class DecoderOnlyModel(nn.Module):
113
110
 
114
111
  # token embeddings of shape (b, t, n_embd)
115
112
  input_embeds = self.tok_embedding(tokens)
116
- cos, sin = self.rope_cache
117
- rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
113
+
114
+ # ROPE parameters for all attn_configs are the same. Take the first one.
115
+ attn_config = self.config.block_config(0).attn_config
116
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
117
+ rope = self.config.build_rope(
118
+ input_pos=input_pos,
119
+ n_elem=n_elem,
120
+ base=attn_config.rotary_base,
121
+ head_dim=attn_config.head_dim,
122
+ # input_pos=input_pos, n_elem=n_elem, base=attn_config.rotary_base
123
+ )
124
+
118
125
  mask = self.mask_cache.index_select(2, input_pos)
119
126
  mask = mask[:, :, :, : self.config.kv_cache_max]
120
127
 
121
128
  return self.forward_with_embeds(
122
- input_embeds, rope, mask, input_pos, kv_cache, export_config
129
+ input_embeds, rope, mask, input_pos, kv_cache, lora, export_config
123
130
  )
124
131
 
125
132
  def forward_with_embeds(
@@ -129,6 +136,7 @@ class DecoderOnlyModel(nn.Module):
129
136
  mask: torch.Tensor,
130
137
  input_pos: torch.Tensor,
131
138
  kv_cache: kv_utils.KVCache,
139
+ lora: Optional[lora_utils.LoRA] = None,
132
140
  export_config: Optional[ExportConfig] = None,
133
141
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
134
142
  """Forwards the model with input embeddings."""
@@ -141,13 +149,14 @@ class DecoderOnlyModel(nn.Module):
141
149
  if self.config.embedding_scale is not None:
142
150
  x = x * self.config.embedding_scale
143
151
 
144
- updated_kv_entires = []
152
+ updated_kv_entries = []
145
153
  for i, block in enumerate(self.transformer_blocks):
146
154
  kv_entry = kv_cache.caches[i] if kv_cache else None
147
- x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
155
+ lora_adapter = lora.adapters[i] if lora else None
156
+ x, kv_entry = block(x, rope, mask, input_pos, kv_entry, lora_adapter)
148
157
  if kv_entry:
149
- updated_kv_entires.append(kv_entry)
150
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
158
+ updated_kv_entries.append(kv_entry)
159
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
151
160
 
152
161
  if export_config is not None:
153
162
  if (
@@ -17,7 +17,7 @@ from typing import Any
17
17
  import uuid
18
18
 
19
19
  from ai_edge_torch import lowertools
20
- from ai_edge_torch.hlfb.mark_pattern import passes
20
+ from ai_edge_torch.hlfb.mark_pattern import fx_utils
21
21
  from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
22
22
  import torch
23
23
 
@@ -87,7 +87,7 @@ def mark_pattern(
87
87
  m.meta["ORIGINAL_NODE"] = n
88
88
 
89
89
  # Sanitize graph_module to match in the same way as pattern's graph_module.
90
- graph_module_to_match = passes.remove_clone_ops(graph_module_to_match)
90
+ graph_module_to_match = fx_utils.remove_clone_ops(graph_module_to_match)
91
91
 
92
92
  match_with_attrs = pattern.match(graph_module_to_match)
93
93
 
@@ -111,13 +111,25 @@ def mark_pattern(
111
111
  is_input=True,
112
112
  )
113
113
 
114
- # Only replace input by the marker node for those nodes used in the pattern.
114
+ # Only replace input by the marker node for those nodes used in the
115
+ # pattern.
115
116
  in_pattern_nodes = set(match.nodes_map.values())
116
117
  for user in input_node.users.keys():
117
- if user in in_pattern_nodes:
118
- user.meta["ORIGINAL_NODE"].replace_input_with(
119
- input_node.meta["ORIGINAL_NODE"], new_input_node
120
- )
118
+ if user not in in_pattern_nodes:
119
+ continue
120
+
121
+ user.meta["ORIGINAL_NODE"].replace_input_with(
122
+ input_node.meta["ORIGINAL_NODE"], new_input_node
123
+ )
124
+ # Pattern matching graph sanitization may remove clone ops, which means
125
+ # the user's input in the original graph may be a clone op. When
126
+ # replacing the input with the marker node, we need to further try
127
+ # replacing the input of the clone op that connects to the user.
128
+ for original_user_input in user.meta["ORIGINAL_NODE"].all_input_nodes:
129
+ if fx_utils.is_clone_op(original_user_input):
130
+ original_user_input.replace_input_with(
131
+ input_node.meta["ORIGINAL_NODE"], new_input_node
132
+ )
121
133
 
122
134
  for i, pattern_output_node in enumerate(pattern.output_nodes):
123
135
  output_node = match.nodes_map[pattern_output_node]
@@ -12,11 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Passes to clean up the model graph for pattern matching."""
15
+ """FX graph utilities for pattern matching clean ups."""
16
16
 
17
17
  import torch
18
18
 
19
19
 
20
+ def is_clone_op(node: torch.fx.Node) -> bool:
21
+ """Checks if the node is a clone op."""
22
+ return (
23
+ node.op == "call_function" and node.target == torch.ops.aten.clone.default
24
+ )
25
+
26
+
20
27
  def remove_clone_ops(gm: torch.fx.GraphModule):
21
28
  """Removes clone ops from the graph.
22
29
 
@@ -32,7 +39,7 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
32
39
  The graph module with clone ops removed.
33
40
  """
34
41
  for node in gm.graph.nodes:
35
- if node.op == "call_function" and node.name.startswith("clone"):
42
+ if is_clone_op(node):
36
43
  node.replace_all_uses_with(node.args[0])
37
44
  gm.graph.erase_node(node)
38
45
 
@@ -18,13 +18,14 @@ import dataclasses
18
18
  from typing import Any, Callable, Optional, Union
19
19
 
20
20
  from ai_edge_torch import fx_pass_base
21
- from ai_edge_torch.hlfb.mark_pattern import passes
21
+ from ai_edge_torch.hlfb.mark_pattern import fx_utils
22
22
  import torch
23
- from torch.export.graph_signature import TensorArgument
24
- from torch.fx import Graph
25
- from torch.fx import GraphModule
26
- from torch.fx.passes.utils.matcher_utils import InternalMatch
27
- from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
23
+
24
+ Graph = torch.fx.Graph
25
+ GraphModule = torch.fx.GraphModule
26
+ TensorArgument = torch.export.graph_signature.TensorArgument
27
+ InternalMatch = torch.fx.passes.utils.matcher_utils.InternalMatch
28
+ SubgraphMatcher = torch.fx.passes.utils.matcher_utils.SubgraphMatcher
28
29
 
29
30
 
30
31
  def _are_equal(x: Any, y: Any) -> bool:
@@ -219,8 +220,8 @@ class Pattern:
219
220
  # Sanitize graph_module for more precise pattern matching.
220
221
  # The graph_module to match against this pattern should apply equivalent
221
222
  # sanitization.
222
- self.graph_module = passes.remove_clone_ops(self.graph_module)
223
- self.graph_module = passes.remove_dangling_args(self.graph_module)
223
+ self.graph_module = fx_utils.remove_clone_ops(self.graph_module)
224
+ self.graph_module = fx_utils.remove_dangling_args(self.graph_module)
224
225
 
225
226
  # Builds list of ordered input and output nodes.
226
227
  self.graph_nodes_map = {}
@@ -58,6 +58,32 @@ class TestMarkPattern(googletest.TestCase):
58
58
  {"stablehlo.custom_call @mark_tensor": 6},
59
59
  )
60
60
 
61
+ def test_mark_pattern_with_clone_inputs(self):
62
+
63
+ class TestModel(torch.nn.Module):
64
+
65
+ def forward(self, x):
66
+ return torch.ops.aten.clone.default(x * x) + x
67
+
68
+ pattern = pattern_module.Pattern(
69
+ "test.add",
70
+ lambda a, b: a + b,
71
+ export_args=(torch.rand(2, 2), torch.rand(2, 2)),
72
+ )
73
+
74
+ model = TestModel().eval()
75
+ args = (torch.rand(20, 20),)
76
+ exported_program = torch.export.export(model, args)
77
+ mark_pattern.mark_pattern(exported_program.graph_module, pattern)
78
+ mlir = _export_stablehlo_mlir(exported_program)
79
+
80
+ lowertools.assert_string_count(
81
+ self,
82
+ mlir,
83
+ {'stablehlo.composite "test.add"': 1},
84
+ {"stablehlo.custom_call @mark_tensor": 3},
85
+ )
86
+
61
87
  def test_mark_pattern_with_attr_builder(self):
62
88
  class TestModel(torch.nn.Module):
63
89
 
@@ -73,3 +73,16 @@ def safe_run_decompositions(exported_program, decomp_table=None):
73
73
  node.target = lambda self, size: torch.reshape(self.contiguous(), size)
74
74
 
75
75
  return exported_program.run_decompositions(decomp_table)
76
+
77
+
78
+ def dummy_decomp_table():
79
+ """Build dummy decomp table for run_decompositions without any decompositions.
80
+
81
+ Compatible for torch<=2.5.
82
+
83
+ Returns:
84
+ Decomp table for ExportedProgram.run_decompositions.
85
+ """
86
+ return {
87
+ torch._ops.OperatorBase(): lambda: None,
88
+ }
@@ -238,6 +238,9 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
238
238
  def in_i32(x: int):
239
239
  return -2147483648 <= x <= 2147483647
240
240
 
241
+ def to_int32(x: torch.Tensor):
242
+ return torch.ops.aten._to_copy.default(x, dtype=torch.int32)
243
+
241
244
  def rewrite_arange(node: torch.fx.Node):
242
245
  tensor_meta = node.meta.get("tensor_meta", None)
243
246
  if not tensor_meta:
@@ -249,7 +252,7 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
249
252
  if not (in_i32(start) and in_i32(end)):
250
253
  return
251
254
  op = node.target
252
- node.target = lambda *args, **kwargs: op(*args, **kwargs).type(torch.int32)
255
+ node.target = lambda *args, **kwargs: to_int32(op(*args, **kwargs))
253
256
 
254
257
  graph_module = exported_program.graph_module
255
258
  for node in graph_module.graph.nodes:
@@ -305,8 +308,9 @@ def exported_program_to_mlir(
305
308
 
306
309
  _convert_i64_to_i32(exported_program)
307
310
 
311
+ # No decompositions but just retracing/cananicalization.
308
312
  exported_program = _torch_future.safe_run_decompositions(
309
- exported_program, lowerings.decompositions()
313
+ exported_program, _torch_future.dummy_decomp_table()
310
314
  )
311
315
 
312
316
  # Passes below mutate the exported program to a state not executable by torch.
@@ -55,6 +55,10 @@ def decompositions():
55
55
  ],
56
56
  )
57
57
 
58
+ # Override noop aten op decompositions for faster run_decompositions.
59
+ decompositions[torch.ops.aten.alias.default] = lambda x: x
60
+ decompositions[torch.ops.aten.detach.default] = lambda x: x
61
+
58
62
  # Override _safe_softmax decompositions with regular softmax.
59
63
  # _safe_softmax introduces additional check-select ops to guard extreme
60
64
  # input values to softmax, which could make the converted model inefficient
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250107"
16
+ __version__ = "0.3.0.dev20250109"