ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  3. ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
  4. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  5. ai_edge_torch/generative/examples/llama/llama.py +29 -25
  6. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  7. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  8. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  9. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  10. ai_edge_torch/generative/examples/phi/phi3.py +26 -23
  11. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  12. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  13. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  14. ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
  15. ai_edge_torch/generative/examples/smollm/verify.py +18 -2
  16. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  17. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  18. ai_edge_torch/generative/layers/attention.py +45 -37
  19. ai_edge_torch/generative/layers/lora.py +557 -0
  20. ai_edge_torch/generative/layers/model_config.py +6 -2
  21. ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
  22. ai_edge_torch/generative/test/test_lora.py +147 -0
  23. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  24. ai_edge_torch/generative/utilities/converter.py +100 -47
  25. ai_edge_torch/generative/utilities/model_builder.py +23 -14
  26. ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
  27. ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
  28. ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
  29. ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
  30. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  31. ai_edge_torch/odml_torch/export.py +6 -2
  32. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  33. ai_edge_torch/version.py +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
  35. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +38 -35
  36. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
  37. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
  38. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,147 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """A suite of tests to validate LoRA utilities."""
17
+
18
+ from ai_edge_torch.generative.layers import lora as lora_utils
19
+ import ai_edge_torch.generative.layers.model_config as cfg
20
+ import torch
21
+ from absl.testing import absltest as googletest
22
+ from tensorflow.python.platform import resource_loader # pylint: disable=g-direct-tensorflow-import
23
+
24
+
25
+ class TestLora(googletest.TestCase):
26
+ """Tests for LoRA utilities."""
27
+
28
+ def test_safetensors_builder(self):
29
+ """Converts a safetensors file to a LoRA module."""
30
+
31
+ tensor_names = lora_utils.LoRATensorNames(
32
+ attn_query_w_a=(
33
+ "base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight"
34
+ ),
35
+ attn_query_w_b=(
36
+ "base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight"
37
+ ),
38
+ attn_key_w_a=(
39
+ "base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight"
40
+ ),
41
+ attn_key_w_b=(
42
+ "base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight"
43
+ ),
44
+ attn_value_w_a=(
45
+ "base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight"
46
+ ),
47
+ attn_value_w_b=(
48
+ "base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight"
49
+ ),
50
+ attn_output_w_a=(
51
+ "base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight"
52
+ ),
53
+ attn_output_w_b=(
54
+ "base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight"
55
+ ),
56
+ )
57
+
58
+ safetensors_file = resource_loader.get_path_to_datafile(
59
+ "fixtures/test_lora_rank16.safetensors"
60
+ )
61
+ config = self._get_test_config(
62
+ num_layers=1,
63
+ head_dim=8,
64
+ num_query_groups=1,
65
+ kv_cache_max_len=16,
66
+ )
67
+ lora = lora_utils.LoRA.from_safetensors(
68
+ safetensors_file,
69
+ scale=1.0,
70
+ lora_tensor_names=tensor_names,
71
+ config=config,
72
+ )
73
+ self.assertEqual(lora.get_rank(), 16)
74
+
75
+ def test_torch_export(self):
76
+ """Tests the export of the LoRA module."""
77
+
78
+ class TestModel(torch.nn.Module):
79
+
80
+ def forward(self, x: torch.Tensor, lora: lora_utils.LoRA) -> torch.Tensor:
81
+ x += lora_utils.apply_lora(x, lora.adapters[0].attention.query)
82
+ return x
83
+
84
+ n = 1
85
+ head_dim = 2
86
+ num_query_groups = 1
87
+ key_length = 4
88
+ config = self._get_test_config(
89
+ num_layers=n,
90
+ head_dim=head_dim,
91
+ num_query_groups=num_query_groups,
92
+ kv_cache_max_len=key_length,
93
+ )
94
+ inputs = torch.zeros((n, 1, head_dim))
95
+ lora = lora_utils.LoRA.zeros(rank=16, config=config)
96
+ model = TestModel()
97
+ exported_program = torch.export.export(model, (inputs, lora))
98
+ input_specs = exported_program.graph_signature.input_specs
99
+ # 9 inputs: 1 for x, 2 for query lora, 2 for key lora, 2 for value lora,
100
+ # 2 for output lora.
101
+ self.assertLen(input_specs, 9)
102
+ self.assertEqual(input_specs[0].arg.name, "x")
103
+ self.assertEqual(input_specs[1].arg.name, "lora_atten_q_a_prime_weight_0")
104
+ self.assertEqual(input_specs[2].arg.name, "lora_atten_q_b_prime_weight_0")
105
+ self.assertEqual(input_specs[3].arg.name, "lora_atten_k_a_prime_weight_0")
106
+ self.assertEqual(input_specs[4].arg.name, "lora_atten_k_b_prime_weight_0")
107
+ self.assertEqual(input_specs[5].arg.name, "lora_atten_v_a_prime_weight_0")
108
+ self.assertEqual(input_specs[6].arg.name, "lora_atten_v_b_prime_weight_0")
109
+ self.assertEqual(input_specs[7].arg.name, "lora_atten_o_a_prime_weight_0")
110
+ self.assertEqual(input_specs[8].arg.name, "lora_atten_o_b_prime_weight_0")
111
+
112
+ def test_lora_tflite_serialization(self):
113
+ """Tests the serialization of the LoRA module."""
114
+ config = self._get_test_config(
115
+ num_layers=2,
116
+ head_dim=8,
117
+ num_query_groups=1,
118
+ kv_cache_max_len=16,
119
+ )
120
+ lora = lora_utils.LoRA.random(rank=16, config=config)
121
+ flatbuffer_model = lora.to_tflite()
122
+ recovered_lora = lora_utils.LoRA.from_flatbuffers(flatbuffer_model)
123
+ self.assertEqual(lora, recovered_lora)
124
+
125
+ def _get_test_config(
126
+ self, num_layers, head_dim, num_query_groups, kv_cache_max_len
127
+ ):
128
+ """Returns a test model config."""
129
+ attn_config = cfg.AttentionConfig(
130
+ num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
131
+ )
132
+ block_config = cfg.TransformerBlockConfig(
133
+ attn_config=attn_config, ff_config=None
134
+ )
135
+ config = cfg.ModelConfig(
136
+ kv_cache_max_len=kv_cache_max_len,
137
+ embedding_dim=head_dim,
138
+ block_configs=block_config,
139
+ num_layers=num_layers,
140
+ max_seq_len=None,
141
+ vocab_size=None,
142
+ )
143
+ return config
144
+
145
+
146
+ if __name__ == "__main__":
147
+ googletest.main()
@@ -150,6 +150,16 @@ class TestModelConversion(googletest.TestCase):
150
150
  ai_edge_torch.config.in_oss,
151
151
  reason="tests with custom ops are not supported in oss",
152
152
  )
153
+
154
+ def test_smollm2(self):
155
+ config = smollm.get_fake_model_config_v2()
156
+ pytorch_model = smollm.SmolLM2(config).eval()
157
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
158
+ @googletest.skipIf(
159
+ ai_edge_torch.config.in_oss,
160
+ reason="tests with custom ops are not supported in oss",
161
+ )
162
+
153
163
  def test_openelm(self):
154
164
  config = openelm.get_fake_model_config()
155
165
  pytorch_model = openelm.OpenELM(config).eval()
@@ -15,16 +15,15 @@
15
15
 
16
16
  """Common utility functions for model conversion."""
17
17
 
18
- from functools import partial
19
- from typing import Any, Union
20
-
18
+ import os
19
+ from typing import Optional, Union
21
20
  from ai_edge_torch._convert import converter as converter_utils
21
+ from ai_edge_torch.generative.layers import lora as lora_utils
22
22
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.model_config as cfg
24
24
  from ai_edge_torch.generative.quantize import quant_recipes
25
25
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
26
  import torch
27
- import torch.nn as nn
28
27
 
29
28
 
30
29
  class ExportableModule(torch.nn.Module):
@@ -41,11 +40,13 @@ class ExportableModule(torch.nn.Module):
41
40
 
42
41
  def convert_to_tflite(
43
42
  pytorch_model: torch.nn.Module,
44
- tflite_path: str,
43
+ output_path: str,
44
+ output_name_prefix: str,
45
45
  prefill_seq_len: Union[int, list[int]],
46
46
  pixel_values_size: torch.Size = None,
47
47
  quantize: bool = True,
48
48
  config: cfg.ModelConfig = None,
49
+ lora_ranks: Optional[list[int]] = None,
49
50
  export_config: ExportConfig = None,
50
51
  ):
51
52
  """Converts a nn.Module model to multi-signature tflite model.
@@ -79,21 +80,65 @@ def convert_to_tflite(
79
80
 
80
81
  Args:
81
82
  pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
82
- tflite_path (str): The tflite file path to export.
83
- prefill_seq_len (Union[int, list[int]]): A list of prefill lengths to
84
- export.
83
+ output_path (str): The path to export the tflite model.
84
+ output_name_prefix (str): The prefix of the tflite model name.
85
+ prefill_seq_len (Union[int, list[int]]): The prefill sequence length to
86
+ use. If a list, the model will have multiple prefill signatures.
85
87
  pixel_values_size (torch.Size, optional): The size of pixel values to pass
86
88
  to the model. If None, the model is not expected to take pixel values.
87
89
  quantize (bool, optional): Whether the model should be quanized. Defaults
88
90
  to True.
89
91
  config (cfg.ModelConfig, optional): The model config used to configure KV
90
92
  cache. If None, it uses the config of the pytorch_model.
93
+ lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
94
+ no LoRA signatures will be added.
91
95
  """
96
+ # pylint: disable=protected-access
97
+ torch._dynamo.config.cache_size_limit = 64
98
+
99
+ config = config if config else pytorch_model.config
92
100
  prefill_seq_lens = (
93
101
  [prefill_seq_len] if isinstance(prefill_seq_len, int) else prefill_seq_len
94
102
  )
103
+ loras = [None]
104
+ if lora_ranks is not None:
105
+ for rank in lora_ranks:
106
+ lora = lora_utils.LoRA.zeros(rank, config)
107
+ loras.append(lora)
108
+
109
+ quant_suffix = 'q8' if quantize else 'f32'
110
+ kv_size = config.kv_cache_max_len
111
+ lora_suffix = (
112
+ '' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
113
+ )
114
+ output_filename = (
115
+ f'{output_name_prefix}_{quant_suffix}_ekv{kv_size}{lora_suffix}.tflite'
116
+ )
117
+ output_file = os.path.join(output_path, output_filename)
118
+
119
+ _export_helper(
120
+ pytorch_model,
121
+ output_file,
122
+ prefill_seq_lens,
123
+ pixel_values_size,
124
+ quantize,
125
+ config,
126
+ loras,
127
+ export_config,
128
+ )
95
129
 
96
- # Tensors used to trace the model graph during conversion.
130
+
131
+ def _export_helper(
132
+ pytorch_model: torch.nn.Module,
133
+ output_file: str,
134
+ prefill_seq_lens: list[int],
135
+ pixel_values_size: torch.Size,
136
+ quantize: bool,
137
+ config: cfg.ModelConfig,
138
+ loras: list[None | lora_utils.LoRA],
139
+ export_config: ExportConfig,
140
+ ):
141
+ """Helper function to export a model to tflite."""
97
142
  prefill_tokens_list = []
98
143
  prefill_input_pos_list = []
99
144
  for seq_len in prefill_seq_lens:
@@ -108,9 +153,7 @@ def convert_to_tflite(
108
153
 
109
154
  decode_token = torch.tensor([[0]], dtype=torch.int)
110
155
  decode_input_pos = torch.tensor([0], dtype=torch.int)
111
- kv = kv_utils.KVCache.from_model_config(
112
- config if config else pytorch_model.config
113
- )
156
+ kv = kv_utils.KVCache.from_model_config(config)
114
157
 
115
158
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
116
159
 
@@ -119,44 +162,54 @@ def convert_to_tflite(
119
162
  mod = ExportableModule(pytorch_model, export_config=export_config)
120
163
 
121
164
  converter = converter_utils.Converter()
122
- for i in range(len(prefill_seq_lens)):
123
- prefill_seq_len = prefill_seq_lens[i]
124
- prefill_tokens = prefill_tokens_list[i]
125
- prefill_input_pos = prefill_input_pos_list[i]
126
- if i == 0 and len(prefill_seq_lens) == 1:
127
- prefill_signature_name = 'prefill'
128
- else:
129
- prefill_signature_name = f'prefill_{prefill_seq_len}'
130
- converter.add_signature(
131
- prefill_signature_name,
132
- mod,
133
- sample_kwargs={
134
- 'tokens': prefill_tokens,
135
- 'input_pos': prefill_input_pos,
136
- 'kv_cache': kv,
137
- },
138
- )
139
- if prefill_pixel_values is not None:
165
+ for lora in loras:
166
+ for i in range(len(prefill_seq_lens)):
167
+ prefill_seq_len = prefill_seq_lens[i]
168
+ prefill_tokens = prefill_tokens_list[i]
169
+ prefill_input_pos = prefill_input_pos_list[i]
170
+ if i == 0 and len(prefill_seq_lens) == 1:
171
+ prefill_signature_name = 'prefill'
172
+ else:
173
+ prefill_signature_name = f'prefill_{prefill_seq_len}'
174
+
175
+ sample_kwargs = {
176
+ 'tokens': prefill_tokens,
177
+ 'input_pos': prefill_input_pos,
178
+ 'kv_cache': kv,
179
+ }
180
+ if lora is not None:
181
+ prefill_signature_name += f'_lora_r{lora.get_rank()}'
182
+ sample_kwargs['lora'] = lora
183
+
140
184
  converter.add_signature(
141
- prefill_signature_name + '_pixel',
185
+ prefill_signature_name,
142
186
  mod,
143
- sample_kwargs={
144
- 'tokens': prefill_tokens,
145
- 'input_pos': prefill_input_pos,
146
- 'kv_cache': kv,
147
- 'pixel_values': prefill_pixel_values,
148
- },
187
+ sample_kwargs=sample_kwargs,
149
188
  )
150
189
 
151
- converter.add_signature(
152
- 'decode',
153
- mod,
154
- sample_kwargs={
155
- 'tokens': decode_token,
156
- 'input_pos': decode_input_pos,
157
- 'kv_cache': kv,
158
- },
159
- )
190
+ if prefill_pixel_values is not None:
191
+ converter.add_signature(
192
+ prefill_signature_name + '_pixel',
193
+ mod,
194
+ sample_kwargs={
195
+ **sample_kwargs,
196
+ 'pixel_values': prefill_pixel_values,
197
+ },
198
+ )
199
+
200
+ sample_kwargs = {
201
+ 'tokens': decode_token,
202
+ 'input_pos': decode_input_pos,
203
+ 'kv_cache': kv,
204
+ }
205
+ if lora is not None:
206
+ sample_kwargs['lora'] = lora
207
+
208
+ converter.add_signature(
209
+ 'decode' if lora is None else f'decode_lora_r{lora.get_rank()}',
210
+ mod,
211
+ sample_kwargs=sample_kwargs,
212
+ )
160
213
 
161
214
  edge_model = converter.convert(quant_config=quant_config)
162
- edge_model.export(tflite_path)
215
+ edge_model.export(output_file)
@@ -22,12 +22,15 @@ from typing import Optional, Tuple
22
22
  from ai_edge_torch.generative.layers import attention
23
23
  from ai_edge_torch.generative.layers import builder
24
24
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
+ from ai_edge_torch.generative.layers import lora as lora_utils
25
26
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
27
  import ai_edge_torch.generative.layers.model_config as cfg
28
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
27
29
  import ai_edge_torch.generative.utilities.loader as loading_utils
28
30
  import torch
29
31
  from torch import nn
30
32
 
33
+
31
34
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
32
35
  ff_up_proj="model.layers.{}.mlp.up_proj",
33
36
  ff_down_proj="model.layers.{}.mlp.down_proj",
@@ -85,13 +88,6 @@ class DecoderOnlyModel(nn.Module):
85
88
  config.embedding_dim,
86
89
  config.final_norm_config,
87
90
  )
88
- # ROPE parameters for all attn_configs are the same. Take the first one.
89
- attn_config = config.block_config(0).attn_config
90
- self.rope_cache = attn_utils.build_rope_cache(
91
- size=config.kv_cache_max,
92
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
93
- base=attn_config.rotary_base,
94
- )
95
91
  self.mask_cache = attn_utils.build_causal_mask_cache(
96
92
  size=config.kv_cache_max,
97
93
  )
@@ -103,6 +99,7 @@ class DecoderOnlyModel(nn.Module):
103
99
  tokens: torch.Tensor,
104
100
  input_pos: torch.Tensor,
105
101
  kv_cache: kv_utils.KVCache,
102
+ lora: Optional[lora_utils.LoRA] = None,
106
103
  export_config: Optional[ExportConfig] = None,
107
104
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
108
105
  _, seq_len = tokens.size()
@@ -113,13 +110,23 @@ class DecoderOnlyModel(nn.Module):
113
110
 
114
111
  # token embeddings of shape (b, t, n_embd)
115
112
  input_embeds = self.tok_embedding(tokens)
116
- cos, sin = self.rope_cache
117
- rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
113
+
114
+ # ROPE parameters for all attn_configs are the same. Take the first one.
115
+ attn_config = self.config.block_config(0).attn_config
116
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
117
+ rope = self.config.build_rope(
118
+ input_pos=input_pos,
119
+ n_elem=n_elem,
120
+ base=attn_config.rotary_base,
121
+ head_dim=attn_config.head_dim,
122
+ # input_pos=input_pos, n_elem=n_elem, base=attn_config.rotary_base
123
+ )
124
+
118
125
  mask = self.mask_cache.index_select(2, input_pos)
119
126
  mask = mask[:, :, :, : self.config.kv_cache_max]
120
127
 
121
128
  return self.forward_with_embeds(
122
- input_embeds, rope, mask, input_pos, kv_cache, export_config
129
+ input_embeds, rope, mask, input_pos, kv_cache, lora, export_config
123
130
  )
124
131
 
125
132
  def forward_with_embeds(
@@ -129,6 +136,7 @@ class DecoderOnlyModel(nn.Module):
129
136
  mask: torch.Tensor,
130
137
  input_pos: torch.Tensor,
131
138
  kv_cache: kv_utils.KVCache,
139
+ lora: Optional[lora_utils.LoRA] = None,
132
140
  export_config: Optional[ExportConfig] = None,
133
141
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
134
142
  """Forwards the model with input embeddings."""
@@ -141,13 +149,14 @@ class DecoderOnlyModel(nn.Module):
141
149
  if self.config.embedding_scale is not None:
142
150
  x = x * self.config.embedding_scale
143
151
 
144
- updated_kv_entires = []
152
+ updated_kv_entries = []
145
153
  for i, block in enumerate(self.transformer_blocks):
146
154
  kv_entry = kv_cache.caches[i] if kv_cache else None
147
- x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
155
+ lora_adapter = lora.adapters[i] if lora else None
156
+ x, kv_entry = block(x, rope, mask, input_pos, kv_entry, lora_adapter)
148
157
  if kv_entry:
149
- updated_kv_entires.append(kv_entry)
150
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
158
+ updated_kv_entries.append(kv_entry)
159
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
151
160
 
152
161
  if export_config is not None:
153
162
  if (
@@ -17,7 +17,7 @@ from typing import Any
17
17
  import uuid
18
18
 
19
19
  from ai_edge_torch import lowertools
20
- from ai_edge_torch.hlfb.mark_pattern import passes
20
+ from ai_edge_torch.hlfb.mark_pattern import fx_utils
21
21
  from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
22
22
  import torch
23
23
 
@@ -87,7 +87,7 @@ def mark_pattern(
87
87
  m.meta["ORIGINAL_NODE"] = n
88
88
 
89
89
  # Sanitize graph_module to match in the same way as pattern's graph_module.
90
- graph_module_to_match = passes.remove_clone_ops(graph_module_to_match)
90
+ graph_module_to_match = fx_utils.remove_clone_ops(graph_module_to_match)
91
91
 
92
92
  match_with_attrs = pattern.match(graph_module_to_match)
93
93
 
@@ -111,13 +111,25 @@ def mark_pattern(
111
111
  is_input=True,
112
112
  )
113
113
 
114
- # Only replace input by the marker node for those nodes used in the pattern.
114
+ # Only replace input by the marker node for those nodes used in the
115
+ # pattern.
115
116
  in_pattern_nodes = set(match.nodes_map.values())
116
117
  for user in input_node.users.keys():
117
- if user in in_pattern_nodes:
118
- user.meta["ORIGINAL_NODE"].replace_input_with(
119
- input_node.meta["ORIGINAL_NODE"], new_input_node
120
- )
118
+ if user not in in_pattern_nodes:
119
+ continue
120
+
121
+ user.meta["ORIGINAL_NODE"].replace_input_with(
122
+ input_node.meta["ORIGINAL_NODE"], new_input_node
123
+ )
124
+ # Pattern matching graph sanitization may remove clone ops, which means
125
+ # the user's input in the original graph may be a clone op. When
126
+ # replacing the input with the marker node, we need to further try
127
+ # replacing the input of the clone op that connects to the user.
128
+ for original_user_input in user.meta["ORIGINAL_NODE"].all_input_nodes:
129
+ if fx_utils.is_clone_op(original_user_input):
130
+ original_user_input.replace_input_with(
131
+ input_node.meta["ORIGINAL_NODE"], new_input_node
132
+ )
121
133
 
122
134
  for i, pattern_output_node in enumerate(pattern.output_nodes):
123
135
  output_node = match.nodes_map[pattern_output_node]
@@ -12,11 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Passes to clean up the model graph for pattern matching."""
15
+ """FX graph utilities for pattern matching clean ups."""
16
16
 
17
17
  import torch
18
18
 
19
19
 
20
+ def is_clone_op(node: torch.fx.Node) -> bool:
21
+ """Checks if the node is a clone op."""
22
+ return (
23
+ node.op == "call_function" and node.target == torch.ops.aten.clone.default
24
+ )
25
+
26
+
20
27
  def remove_clone_ops(gm: torch.fx.GraphModule):
21
28
  """Removes clone ops from the graph.
22
29
 
@@ -32,7 +39,7 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
32
39
  The graph module with clone ops removed.
33
40
  """
34
41
  for node in gm.graph.nodes:
35
- if node.op == "call_function" and node.name.startswith("clone"):
42
+ if is_clone_op(node):
36
43
  node.replace_all_uses_with(node.args[0])
37
44
  gm.graph.erase_node(node)
38
45
 
@@ -18,13 +18,14 @@ import dataclasses
18
18
  from typing import Any, Callable, Optional, Union
19
19
 
20
20
  from ai_edge_torch import fx_pass_base
21
- from ai_edge_torch.hlfb.mark_pattern import passes
21
+ from ai_edge_torch.hlfb.mark_pattern import fx_utils
22
22
  import torch
23
- from torch.export.graph_signature import TensorArgument
24
- from torch.fx import Graph
25
- from torch.fx import GraphModule
26
- from torch.fx.passes.utils.matcher_utils import InternalMatch
27
- from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
23
+
24
+ Graph = torch.fx.Graph
25
+ GraphModule = torch.fx.GraphModule
26
+ TensorArgument = torch.export.graph_signature.TensorArgument
27
+ InternalMatch = torch.fx.passes.utils.matcher_utils.InternalMatch
28
+ SubgraphMatcher = torch.fx.passes.utils.matcher_utils.SubgraphMatcher
28
29
 
29
30
 
30
31
  def _are_equal(x: Any, y: Any) -> bool:
@@ -219,8 +220,8 @@ class Pattern:
219
220
  # Sanitize graph_module for more precise pattern matching.
220
221
  # The graph_module to match against this pattern should apply equivalent
221
222
  # sanitization.
222
- self.graph_module = passes.remove_clone_ops(self.graph_module)
223
- self.graph_module = passes.remove_dangling_args(self.graph_module)
223
+ self.graph_module = fx_utils.remove_clone_ops(self.graph_module)
224
+ self.graph_module = fx_utils.remove_dangling_args(self.graph_module)
224
225
 
225
226
  # Builds list of ordered input and output nodes.
226
227
  self.graph_nodes_map = {}
@@ -58,6 +58,32 @@ class TestMarkPattern(googletest.TestCase):
58
58
  {"stablehlo.custom_call @mark_tensor": 6},
59
59
  )
60
60
 
61
+ def test_mark_pattern_with_clone_inputs(self):
62
+
63
+ class TestModel(torch.nn.Module):
64
+
65
+ def forward(self, x):
66
+ return torch.ops.aten.clone.default(x * x) + x
67
+
68
+ pattern = pattern_module.Pattern(
69
+ "test.add",
70
+ lambda a, b: a + b,
71
+ export_args=(torch.rand(2, 2), torch.rand(2, 2)),
72
+ )
73
+
74
+ model = TestModel().eval()
75
+ args = (torch.rand(20, 20),)
76
+ exported_program = torch.export.export(model, args)
77
+ mark_pattern.mark_pattern(exported_program.graph_module, pattern)
78
+ mlir = _export_stablehlo_mlir(exported_program)
79
+
80
+ lowertools.assert_string_count(
81
+ self,
82
+ mlir,
83
+ {'stablehlo.composite "test.add"': 1},
84
+ {"stablehlo.custom_call @mark_tensor": 3},
85
+ )
86
+
61
87
  def test_mark_pattern_with_attr_builder(self):
62
88
  class TestModel(torch.nn.Module):
63
89
 
@@ -73,3 +73,16 @@ def safe_run_decompositions(exported_program, decomp_table=None):
73
73
  node.target = lambda self, size: torch.reshape(self.contiguous(), size)
74
74
 
75
75
  return exported_program.run_decompositions(decomp_table)
76
+
77
+
78
+ def dummy_decomp_table():
79
+ """Build dummy decomp table for run_decompositions without any decompositions.
80
+
81
+ Compatible for torch<=2.5.
82
+
83
+ Returns:
84
+ Decomp table for ExportedProgram.run_decompositions.
85
+ """
86
+ return {
87
+ torch._ops.OperatorBase(): lambda: None,
88
+ }
@@ -238,6 +238,9 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
238
238
  def in_i32(x: int):
239
239
  return -2147483648 <= x <= 2147483647
240
240
 
241
+ def to_int32(x: torch.Tensor):
242
+ return torch.ops.aten._to_copy.default(x, dtype=torch.int32)
243
+
241
244
  def rewrite_arange(node: torch.fx.Node):
242
245
  tensor_meta = node.meta.get("tensor_meta", None)
243
246
  if not tensor_meta:
@@ -249,7 +252,7 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
249
252
  if not (in_i32(start) and in_i32(end)):
250
253
  return
251
254
  op = node.target
252
- node.target = lambda *args, **kwargs: op(*args, **kwargs).type(torch.int32)
255
+ node.target = lambda *args, **kwargs: to_int32(op(*args, **kwargs))
253
256
 
254
257
  graph_module = exported_program.graph_module
255
258
  for node in graph_module.graph.nodes:
@@ -305,8 +308,9 @@ def exported_program_to_mlir(
305
308
 
306
309
  _convert_i64_to_i32(exported_program)
307
310
 
311
+ # No decompositions but just retracing/cananicalization.
308
312
  exported_program = _torch_future.safe_run_decompositions(
309
- exported_program, lowerings.decompositions()
313
+ exported_program, _torch_future.dummy_decomp_table()
310
314
  )
311
315
 
312
316
  # Passes below mutate the exported program to a state not executable by torch.
@@ -55,6 +55,10 @@ def decompositions():
55
55
  ],
56
56
  )
57
57
 
58
+ # Override noop aten op decompositions for faster run_decompositions.
59
+ decompositions[torch.ops.aten.alias.default] = lambda x: x
60
+ decompositions[torch.ops.aten.detach.default] = lambda x: x
61
+
58
62
  # Override _safe_softmax decompositions with regular softmax.
59
63
  # _safe_softmax introduces additional check-select ops to guard extreme
60
64
  # input values to softmax, which could make the converted model inefficient
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250107"
16
+ __version__ = "0.3.0.dev20250109"