ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/llama/llama.py +29 -25
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/phi/phi3.py +26 -23
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
- ai_edge_torch/generative/examples/smollm/verify.py +18 -2
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
- ai_edge_torch/generative/layers/attention.py +45 -37
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
- ai_edge_torch/generative/utilities/converter.py +100 -47
- ai_edge_torch/generative/utilities/model_builder.py +23 -14
- ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
- ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
- ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
- ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
- ai_edge_torch/odml_torch/_torch_future.py +13 -0
- ai_edge_torch/odml_torch/export.py +6 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +38 -35
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,147 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""A suite of tests to validate LoRA utilities."""
|
17
|
+
|
18
|
+
from ai_edge_torch.generative.layers import lora as lora_utils
|
19
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
20
|
+
import torch
|
21
|
+
from absl.testing import absltest as googletest
|
22
|
+
from tensorflow.python.platform import resource_loader # pylint: disable=g-direct-tensorflow-import
|
23
|
+
|
24
|
+
|
25
|
+
class TestLora(googletest.TestCase):
|
26
|
+
"""Tests for LoRA utilities."""
|
27
|
+
|
28
|
+
def test_safetensors_builder(self):
|
29
|
+
"""Converts a safetensors file to a LoRA module."""
|
30
|
+
|
31
|
+
tensor_names = lora_utils.LoRATensorNames(
|
32
|
+
attn_query_w_a=(
|
33
|
+
"base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight"
|
34
|
+
),
|
35
|
+
attn_query_w_b=(
|
36
|
+
"base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight"
|
37
|
+
),
|
38
|
+
attn_key_w_a=(
|
39
|
+
"base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight"
|
40
|
+
),
|
41
|
+
attn_key_w_b=(
|
42
|
+
"base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight"
|
43
|
+
),
|
44
|
+
attn_value_w_a=(
|
45
|
+
"base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight"
|
46
|
+
),
|
47
|
+
attn_value_w_b=(
|
48
|
+
"base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight"
|
49
|
+
),
|
50
|
+
attn_output_w_a=(
|
51
|
+
"base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight"
|
52
|
+
),
|
53
|
+
attn_output_w_b=(
|
54
|
+
"base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight"
|
55
|
+
),
|
56
|
+
)
|
57
|
+
|
58
|
+
safetensors_file = resource_loader.get_path_to_datafile(
|
59
|
+
"fixtures/test_lora_rank16.safetensors"
|
60
|
+
)
|
61
|
+
config = self._get_test_config(
|
62
|
+
num_layers=1,
|
63
|
+
head_dim=8,
|
64
|
+
num_query_groups=1,
|
65
|
+
kv_cache_max_len=16,
|
66
|
+
)
|
67
|
+
lora = lora_utils.LoRA.from_safetensors(
|
68
|
+
safetensors_file,
|
69
|
+
scale=1.0,
|
70
|
+
lora_tensor_names=tensor_names,
|
71
|
+
config=config,
|
72
|
+
)
|
73
|
+
self.assertEqual(lora.get_rank(), 16)
|
74
|
+
|
75
|
+
def test_torch_export(self):
|
76
|
+
"""Tests the export of the LoRA module."""
|
77
|
+
|
78
|
+
class TestModel(torch.nn.Module):
|
79
|
+
|
80
|
+
def forward(self, x: torch.Tensor, lora: lora_utils.LoRA) -> torch.Tensor:
|
81
|
+
x += lora_utils.apply_lora(x, lora.adapters[0].attention.query)
|
82
|
+
return x
|
83
|
+
|
84
|
+
n = 1
|
85
|
+
head_dim = 2
|
86
|
+
num_query_groups = 1
|
87
|
+
key_length = 4
|
88
|
+
config = self._get_test_config(
|
89
|
+
num_layers=n,
|
90
|
+
head_dim=head_dim,
|
91
|
+
num_query_groups=num_query_groups,
|
92
|
+
kv_cache_max_len=key_length,
|
93
|
+
)
|
94
|
+
inputs = torch.zeros((n, 1, head_dim))
|
95
|
+
lora = lora_utils.LoRA.zeros(rank=16, config=config)
|
96
|
+
model = TestModel()
|
97
|
+
exported_program = torch.export.export(model, (inputs, lora))
|
98
|
+
input_specs = exported_program.graph_signature.input_specs
|
99
|
+
# 9 inputs: 1 for x, 2 for query lora, 2 for key lora, 2 for value lora,
|
100
|
+
# 2 for output lora.
|
101
|
+
self.assertLen(input_specs, 9)
|
102
|
+
self.assertEqual(input_specs[0].arg.name, "x")
|
103
|
+
self.assertEqual(input_specs[1].arg.name, "lora_atten_q_a_prime_weight_0")
|
104
|
+
self.assertEqual(input_specs[2].arg.name, "lora_atten_q_b_prime_weight_0")
|
105
|
+
self.assertEqual(input_specs[3].arg.name, "lora_atten_k_a_prime_weight_0")
|
106
|
+
self.assertEqual(input_specs[4].arg.name, "lora_atten_k_b_prime_weight_0")
|
107
|
+
self.assertEqual(input_specs[5].arg.name, "lora_atten_v_a_prime_weight_0")
|
108
|
+
self.assertEqual(input_specs[6].arg.name, "lora_atten_v_b_prime_weight_0")
|
109
|
+
self.assertEqual(input_specs[7].arg.name, "lora_atten_o_a_prime_weight_0")
|
110
|
+
self.assertEqual(input_specs[8].arg.name, "lora_atten_o_b_prime_weight_0")
|
111
|
+
|
112
|
+
def test_lora_tflite_serialization(self):
|
113
|
+
"""Tests the serialization of the LoRA module."""
|
114
|
+
config = self._get_test_config(
|
115
|
+
num_layers=2,
|
116
|
+
head_dim=8,
|
117
|
+
num_query_groups=1,
|
118
|
+
kv_cache_max_len=16,
|
119
|
+
)
|
120
|
+
lora = lora_utils.LoRA.random(rank=16, config=config)
|
121
|
+
flatbuffer_model = lora.to_tflite()
|
122
|
+
recovered_lora = lora_utils.LoRA.from_flatbuffers(flatbuffer_model)
|
123
|
+
self.assertEqual(lora, recovered_lora)
|
124
|
+
|
125
|
+
def _get_test_config(
|
126
|
+
self, num_layers, head_dim, num_query_groups, kv_cache_max_len
|
127
|
+
):
|
128
|
+
"""Returns a test model config."""
|
129
|
+
attn_config = cfg.AttentionConfig(
|
130
|
+
num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
|
131
|
+
)
|
132
|
+
block_config = cfg.TransformerBlockConfig(
|
133
|
+
attn_config=attn_config, ff_config=None
|
134
|
+
)
|
135
|
+
config = cfg.ModelConfig(
|
136
|
+
kv_cache_max_len=kv_cache_max_len,
|
137
|
+
embedding_dim=head_dim,
|
138
|
+
block_configs=block_config,
|
139
|
+
num_layers=num_layers,
|
140
|
+
max_seq_len=None,
|
141
|
+
vocab_size=None,
|
142
|
+
)
|
143
|
+
return config
|
144
|
+
|
145
|
+
|
146
|
+
if __name__ == "__main__":
|
147
|
+
googletest.main()
|
@@ -150,6 +150,16 @@ class TestModelConversion(googletest.TestCase):
|
|
150
150
|
ai_edge_torch.config.in_oss,
|
151
151
|
reason="tests with custom ops are not supported in oss",
|
152
152
|
)
|
153
|
+
|
154
|
+
def test_smollm2(self):
|
155
|
+
config = smollm.get_fake_model_config_v2()
|
156
|
+
pytorch_model = smollm.SmolLM2(config).eval()
|
157
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
158
|
+
@googletest.skipIf(
|
159
|
+
ai_edge_torch.config.in_oss,
|
160
|
+
reason="tests with custom ops are not supported in oss",
|
161
|
+
)
|
162
|
+
|
153
163
|
def test_openelm(self):
|
154
164
|
config = openelm.get_fake_model_config()
|
155
165
|
pytorch_model = openelm.OpenELM(config).eval()
|
@@ -15,16 +15,15 @@
|
|
15
15
|
|
16
16
|
"""Common utility functions for model conversion."""
|
17
17
|
|
18
|
-
|
19
|
-
from typing import
|
20
|
-
|
18
|
+
import os
|
19
|
+
from typing import Optional, Union
|
21
20
|
from ai_edge_torch._convert import converter as converter_utils
|
21
|
+
from ai_edge_torch.generative.layers import lora as lora_utils
|
22
22
|
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
25
25
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
26
|
import torch
|
27
|
-
import torch.nn as nn
|
28
27
|
|
29
28
|
|
30
29
|
class ExportableModule(torch.nn.Module):
|
@@ -41,11 +40,13 @@ class ExportableModule(torch.nn.Module):
|
|
41
40
|
|
42
41
|
def convert_to_tflite(
|
43
42
|
pytorch_model: torch.nn.Module,
|
44
|
-
|
43
|
+
output_path: str,
|
44
|
+
output_name_prefix: str,
|
45
45
|
prefill_seq_len: Union[int, list[int]],
|
46
46
|
pixel_values_size: torch.Size = None,
|
47
47
|
quantize: bool = True,
|
48
48
|
config: cfg.ModelConfig = None,
|
49
|
+
lora_ranks: Optional[list[int]] = None,
|
49
50
|
export_config: ExportConfig = None,
|
50
51
|
):
|
51
52
|
"""Converts a nn.Module model to multi-signature tflite model.
|
@@ -79,21 +80,65 @@ def convert_to_tflite(
|
|
79
80
|
|
80
81
|
Args:
|
81
82
|
pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
|
82
|
-
|
83
|
-
|
84
|
-
|
83
|
+
output_path (str): The path to export the tflite model.
|
84
|
+
output_name_prefix (str): The prefix of the tflite model name.
|
85
|
+
prefill_seq_len (Union[int, list[int]]): The prefill sequence length to
|
86
|
+
use. If a list, the model will have multiple prefill signatures.
|
85
87
|
pixel_values_size (torch.Size, optional): The size of pixel values to pass
|
86
88
|
to the model. If None, the model is not expected to take pixel values.
|
87
89
|
quantize (bool, optional): Whether the model should be quanized. Defaults
|
88
90
|
to True.
|
89
91
|
config (cfg.ModelConfig, optional): The model config used to configure KV
|
90
92
|
cache. If None, it uses the config of the pytorch_model.
|
93
|
+
lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
|
94
|
+
no LoRA signatures will be added.
|
91
95
|
"""
|
96
|
+
# pylint: disable=protected-access
|
97
|
+
torch._dynamo.config.cache_size_limit = 64
|
98
|
+
|
99
|
+
config = config if config else pytorch_model.config
|
92
100
|
prefill_seq_lens = (
|
93
101
|
[prefill_seq_len] if isinstance(prefill_seq_len, int) else prefill_seq_len
|
94
102
|
)
|
103
|
+
loras = [None]
|
104
|
+
if lora_ranks is not None:
|
105
|
+
for rank in lora_ranks:
|
106
|
+
lora = lora_utils.LoRA.zeros(rank, config)
|
107
|
+
loras.append(lora)
|
108
|
+
|
109
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
110
|
+
kv_size = config.kv_cache_max_len
|
111
|
+
lora_suffix = (
|
112
|
+
'' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
|
113
|
+
)
|
114
|
+
output_filename = (
|
115
|
+
f'{output_name_prefix}_{quant_suffix}_ekv{kv_size}{lora_suffix}.tflite'
|
116
|
+
)
|
117
|
+
output_file = os.path.join(output_path, output_filename)
|
118
|
+
|
119
|
+
_export_helper(
|
120
|
+
pytorch_model,
|
121
|
+
output_file,
|
122
|
+
prefill_seq_lens,
|
123
|
+
pixel_values_size,
|
124
|
+
quantize,
|
125
|
+
config,
|
126
|
+
loras,
|
127
|
+
export_config,
|
128
|
+
)
|
95
129
|
|
96
|
-
|
130
|
+
|
131
|
+
def _export_helper(
|
132
|
+
pytorch_model: torch.nn.Module,
|
133
|
+
output_file: str,
|
134
|
+
prefill_seq_lens: list[int],
|
135
|
+
pixel_values_size: torch.Size,
|
136
|
+
quantize: bool,
|
137
|
+
config: cfg.ModelConfig,
|
138
|
+
loras: list[None | lora_utils.LoRA],
|
139
|
+
export_config: ExportConfig,
|
140
|
+
):
|
141
|
+
"""Helper function to export a model to tflite."""
|
97
142
|
prefill_tokens_list = []
|
98
143
|
prefill_input_pos_list = []
|
99
144
|
for seq_len in prefill_seq_lens:
|
@@ -108,9 +153,7 @@ def convert_to_tflite(
|
|
108
153
|
|
109
154
|
decode_token = torch.tensor([[0]], dtype=torch.int)
|
110
155
|
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
111
|
-
kv = kv_utils.KVCache.from_model_config(
|
112
|
-
config if config else pytorch_model.config
|
113
|
-
)
|
156
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
114
157
|
|
115
158
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
116
159
|
|
@@ -119,44 +162,54 @@ def convert_to_tflite(
|
|
119
162
|
mod = ExportableModule(pytorch_model, export_config=export_config)
|
120
163
|
|
121
164
|
converter = converter_utils.Converter()
|
122
|
-
for
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
165
|
+
for lora in loras:
|
166
|
+
for i in range(len(prefill_seq_lens)):
|
167
|
+
prefill_seq_len = prefill_seq_lens[i]
|
168
|
+
prefill_tokens = prefill_tokens_list[i]
|
169
|
+
prefill_input_pos = prefill_input_pos_list[i]
|
170
|
+
if i == 0 and len(prefill_seq_lens) == 1:
|
171
|
+
prefill_signature_name = 'prefill'
|
172
|
+
else:
|
173
|
+
prefill_signature_name = f'prefill_{prefill_seq_len}'
|
174
|
+
|
175
|
+
sample_kwargs = {
|
176
|
+
'tokens': prefill_tokens,
|
177
|
+
'input_pos': prefill_input_pos,
|
178
|
+
'kv_cache': kv,
|
179
|
+
}
|
180
|
+
if lora is not None:
|
181
|
+
prefill_signature_name += f'_lora_r{lora.get_rank()}'
|
182
|
+
sample_kwargs['lora'] = lora
|
183
|
+
|
140
184
|
converter.add_signature(
|
141
|
-
prefill_signature_name
|
185
|
+
prefill_signature_name,
|
142
186
|
mod,
|
143
|
-
sample_kwargs=
|
144
|
-
'tokens': prefill_tokens,
|
145
|
-
'input_pos': prefill_input_pos,
|
146
|
-
'kv_cache': kv,
|
147
|
-
'pixel_values': prefill_pixel_values,
|
148
|
-
},
|
187
|
+
sample_kwargs=sample_kwargs,
|
149
188
|
)
|
150
189
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
190
|
+
if prefill_pixel_values is not None:
|
191
|
+
converter.add_signature(
|
192
|
+
prefill_signature_name + '_pixel',
|
193
|
+
mod,
|
194
|
+
sample_kwargs={
|
195
|
+
**sample_kwargs,
|
196
|
+
'pixel_values': prefill_pixel_values,
|
197
|
+
},
|
198
|
+
)
|
199
|
+
|
200
|
+
sample_kwargs = {
|
201
|
+
'tokens': decode_token,
|
202
|
+
'input_pos': decode_input_pos,
|
203
|
+
'kv_cache': kv,
|
204
|
+
}
|
205
|
+
if lora is not None:
|
206
|
+
sample_kwargs['lora'] = lora
|
207
|
+
|
208
|
+
converter.add_signature(
|
209
|
+
'decode' if lora is None else f'decode_lora_r{lora.get_rank()}',
|
210
|
+
mod,
|
211
|
+
sample_kwargs=sample_kwargs,
|
212
|
+
)
|
160
213
|
|
161
214
|
edge_model = converter.convert(quant_config=quant_config)
|
162
|
-
edge_model.export(
|
215
|
+
edge_model.export(output_file)
|
@@ -22,12 +22,15 @@ from typing import Optional, Tuple
|
|
22
22
|
from ai_edge_torch.generative.layers import attention
|
23
23
|
from ai_edge_torch.generative.layers import builder
|
24
24
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
|
+
from ai_edge_torch.generative.layers import lora as lora_utils
|
25
26
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
27
|
import ai_edge_torch.generative.layers.model_config as cfg
|
28
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
27
29
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
28
30
|
import torch
|
29
31
|
from torch import nn
|
30
32
|
|
33
|
+
|
31
34
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
32
35
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
33
36
|
ff_down_proj="model.layers.{}.mlp.down_proj",
|
@@ -85,13 +88,6 @@ class DecoderOnlyModel(nn.Module):
|
|
85
88
|
config.embedding_dim,
|
86
89
|
config.final_norm_config,
|
87
90
|
)
|
88
|
-
# ROPE parameters for all attn_configs are the same. Take the first one.
|
89
|
-
attn_config = config.block_config(0).attn_config
|
90
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
91
|
-
size=config.kv_cache_max,
|
92
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
93
|
-
base=attn_config.rotary_base,
|
94
|
-
)
|
95
91
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
96
92
|
size=config.kv_cache_max,
|
97
93
|
)
|
@@ -103,6 +99,7 @@ class DecoderOnlyModel(nn.Module):
|
|
103
99
|
tokens: torch.Tensor,
|
104
100
|
input_pos: torch.Tensor,
|
105
101
|
kv_cache: kv_utils.KVCache,
|
102
|
+
lora: Optional[lora_utils.LoRA] = None,
|
106
103
|
export_config: Optional[ExportConfig] = None,
|
107
104
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
108
105
|
_, seq_len = tokens.size()
|
@@ -113,13 +110,23 @@ class DecoderOnlyModel(nn.Module):
|
|
113
110
|
|
114
111
|
# token embeddings of shape (b, t, n_embd)
|
115
112
|
input_embeds = self.tok_embedding(tokens)
|
116
|
-
|
117
|
-
|
113
|
+
|
114
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
115
|
+
attn_config = self.config.block_config(0).attn_config
|
116
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
117
|
+
rope = self.config.build_rope(
|
118
|
+
input_pos=input_pos,
|
119
|
+
n_elem=n_elem,
|
120
|
+
base=attn_config.rotary_base,
|
121
|
+
head_dim=attn_config.head_dim,
|
122
|
+
# input_pos=input_pos, n_elem=n_elem, base=attn_config.rotary_base
|
123
|
+
)
|
124
|
+
|
118
125
|
mask = self.mask_cache.index_select(2, input_pos)
|
119
126
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
120
127
|
|
121
128
|
return self.forward_with_embeds(
|
122
|
-
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
129
|
+
input_embeds, rope, mask, input_pos, kv_cache, lora, export_config
|
123
130
|
)
|
124
131
|
|
125
132
|
def forward_with_embeds(
|
@@ -129,6 +136,7 @@ class DecoderOnlyModel(nn.Module):
|
|
129
136
|
mask: torch.Tensor,
|
130
137
|
input_pos: torch.Tensor,
|
131
138
|
kv_cache: kv_utils.KVCache,
|
139
|
+
lora: Optional[lora_utils.LoRA] = None,
|
132
140
|
export_config: Optional[ExportConfig] = None,
|
133
141
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
134
142
|
"""Forwards the model with input embeddings."""
|
@@ -141,13 +149,14 @@ class DecoderOnlyModel(nn.Module):
|
|
141
149
|
if self.config.embedding_scale is not None:
|
142
150
|
x = x * self.config.embedding_scale
|
143
151
|
|
144
|
-
|
152
|
+
updated_kv_entries = []
|
145
153
|
for i, block in enumerate(self.transformer_blocks):
|
146
154
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
147
|
-
|
155
|
+
lora_adapter = lora.adapters[i] if lora else None
|
156
|
+
x, kv_entry = block(x, rope, mask, input_pos, kv_entry, lora_adapter)
|
148
157
|
if kv_entry:
|
149
|
-
|
150
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
158
|
+
updated_kv_entries.append(kv_entry)
|
159
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
151
160
|
|
152
161
|
if export_config is not None:
|
153
162
|
if (
|
@@ -17,7 +17,7 @@ from typing import Any
|
|
17
17
|
import uuid
|
18
18
|
|
19
19
|
from ai_edge_torch import lowertools
|
20
|
-
from ai_edge_torch.hlfb.mark_pattern import
|
20
|
+
from ai_edge_torch.hlfb.mark_pattern import fx_utils
|
21
21
|
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
22
22
|
import torch
|
23
23
|
|
@@ -87,7 +87,7 @@ def mark_pattern(
|
|
87
87
|
m.meta["ORIGINAL_NODE"] = n
|
88
88
|
|
89
89
|
# Sanitize graph_module to match in the same way as pattern's graph_module.
|
90
|
-
graph_module_to_match =
|
90
|
+
graph_module_to_match = fx_utils.remove_clone_ops(graph_module_to_match)
|
91
91
|
|
92
92
|
match_with_attrs = pattern.match(graph_module_to_match)
|
93
93
|
|
@@ -111,13 +111,25 @@ def mark_pattern(
|
|
111
111
|
is_input=True,
|
112
112
|
)
|
113
113
|
|
114
|
-
# Only replace input by the marker node for those nodes used in the
|
114
|
+
# Only replace input by the marker node for those nodes used in the
|
115
|
+
# pattern.
|
115
116
|
in_pattern_nodes = set(match.nodes_map.values())
|
116
117
|
for user in input_node.users.keys():
|
117
|
-
if user in in_pattern_nodes:
|
118
|
-
|
119
|
-
|
120
|
-
|
118
|
+
if user not in in_pattern_nodes:
|
119
|
+
continue
|
120
|
+
|
121
|
+
user.meta["ORIGINAL_NODE"].replace_input_with(
|
122
|
+
input_node.meta["ORIGINAL_NODE"], new_input_node
|
123
|
+
)
|
124
|
+
# Pattern matching graph sanitization may remove clone ops, which means
|
125
|
+
# the user's input in the original graph may be a clone op. When
|
126
|
+
# replacing the input with the marker node, we need to further try
|
127
|
+
# replacing the input of the clone op that connects to the user.
|
128
|
+
for original_user_input in user.meta["ORIGINAL_NODE"].all_input_nodes:
|
129
|
+
if fx_utils.is_clone_op(original_user_input):
|
130
|
+
original_user_input.replace_input_with(
|
131
|
+
input_node.meta["ORIGINAL_NODE"], new_input_node
|
132
|
+
)
|
121
133
|
|
122
134
|
for i, pattern_output_node in enumerate(pattern.output_nodes):
|
123
135
|
output_node = match.nodes_map[pattern_output_node]
|
@@ -12,11 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""
|
15
|
+
"""FX graph utilities for pattern matching clean ups."""
|
16
16
|
|
17
17
|
import torch
|
18
18
|
|
19
19
|
|
20
|
+
def is_clone_op(node: torch.fx.Node) -> bool:
|
21
|
+
"""Checks if the node is a clone op."""
|
22
|
+
return (
|
23
|
+
node.op == "call_function" and node.target == torch.ops.aten.clone.default
|
24
|
+
)
|
25
|
+
|
26
|
+
|
20
27
|
def remove_clone_ops(gm: torch.fx.GraphModule):
|
21
28
|
"""Removes clone ops from the graph.
|
22
29
|
|
@@ -32,7 +39,7 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
|
|
32
39
|
The graph module with clone ops removed.
|
33
40
|
"""
|
34
41
|
for node in gm.graph.nodes:
|
35
|
-
if node
|
42
|
+
if is_clone_op(node):
|
36
43
|
node.replace_all_uses_with(node.args[0])
|
37
44
|
gm.graph.erase_node(node)
|
38
45
|
|
@@ -18,13 +18,14 @@ import dataclasses
|
|
18
18
|
from typing import Any, Callable, Optional, Union
|
19
19
|
|
20
20
|
from ai_edge_torch import fx_pass_base
|
21
|
-
from ai_edge_torch.hlfb.mark_pattern import
|
21
|
+
from ai_edge_torch.hlfb.mark_pattern import fx_utils
|
22
22
|
import torch
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
23
|
+
|
24
|
+
Graph = torch.fx.Graph
|
25
|
+
GraphModule = torch.fx.GraphModule
|
26
|
+
TensorArgument = torch.export.graph_signature.TensorArgument
|
27
|
+
InternalMatch = torch.fx.passes.utils.matcher_utils.InternalMatch
|
28
|
+
SubgraphMatcher = torch.fx.passes.utils.matcher_utils.SubgraphMatcher
|
28
29
|
|
29
30
|
|
30
31
|
def _are_equal(x: Any, y: Any) -> bool:
|
@@ -219,8 +220,8 @@ class Pattern:
|
|
219
220
|
# Sanitize graph_module for more precise pattern matching.
|
220
221
|
# The graph_module to match against this pattern should apply equivalent
|
221
222
|
# sanitization.
|
222
|
-
self.graph_module =
|
223
|
-
self.graph_module =
|
223
|
+
self.graph_module = fx_utils.remove_clone_ops(self.graph_module)
|
224
|
+
self.graph_module = fx_utils.remove_dangling_args(self.graph_module)
|
224
225
|
|
225
226
|
# Builds list of ordered input and output nodes.
|
226
227
|
self.graph_nodes_map = {}
|
@@ -58,6 +58,32 @@ class TestMarkPattern(googletest.TestCase):
|
|
58
58
|
{"stablehlo.custom_call @mark_tensor": 6},
|
59
59
|
)
|
60
60
|
|
61
|
+
def test_mark_pattern_with_clone_inputs(self):
|
62
|
+
|
63
|
+
class TestModel(torch.nn.Module):
|
64
|
+
|
65
|
+
def forward(self, x):
|
66
|
+
return torch.ops.aten.clone.default(x * x) + x
|
67
|
+
|
68
|
+
pattern = pattern_module.Pattern(
|
69
|
+
"test.add",
|
70
|
+
lambda a, b: a + b,
|
71
|
+
export_args=(torch.rand(2, 2), torch.rand(2, 2)),
|
72
|
+
)
|
73
|
+
|
74
|
+
model = TestModel().eval()
|
75
|
+
args = (torch.rand(20, 20),)
|
76
|
+
exported_program = torch.export.export(model, args)
|
77
|
+
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
78
|
+
mlir = _export_stablehlo_mlir(exported_program)
|
79
|
+
|
80
|
+
lowertools.assert_string_count(
|
81
|
+
self,
|
82
|
+
mlir,
|
83
|
+
{'stablehlo.composite "test.add"': 1},
|
84
|
+
{"stablehlo.custom_call @mark_tensor": 3},
|
85
|
+
)
|
86
|
+
|
61
87
|
def test_mark_pattern_with_attr_builder(self):
|
62
88
|
class TestModel(torch.nn.Module):
|
63
89
|
|
@@ -73,3 +73,16 @@ def safe_run_decompositions(exported_program, decomp_table=None):
|
|
73
73
|
node.target = lambda self, size: torch.reshape(self.contiguous(), size)
|
74
74
|
|
75
75
|
return exported_program.run_decompositions(decomp_table)
|
76
|
+
|
77
|
+
|
78
|
+
def dummy_decomp_table():
|
79
|
+
"""Build dummy decomp table for run_decompositions without any decompositions.
|
80
|
+
|
81
|
+
Compatible for torch<=2.5.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
Decomp table for ExportedProgram.run_decompositions.
|
85
|
+
"""
|
86
|
+
return {
|
87
|
+
torch._ops.OperatorBase(): lambda: None,
|
88
|
+
}
|
@@ -238,6 +238,9 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
|
|
238
238
|
def in_i32(x: int):
|
239
239
|
return -2147483648 <= x <= 2147483647
|
240
240
|
|
241
|
+
def to_int32(x: torch.Tensor):
|
242
|
+
return torch.ops.aten._to_copy.default(x, dtype=torch.int32)
|
243
|
+
|
241
244
|
def rewrite_arange(node: torch.fx.Node):
|
242
245
|
tensor_meta = node.meta.get("tensor_meta", None)
|
243
246
|
if not tensor_meta:
|
@@ -249,7 +252,7 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
|
|
249
252
|
if not (in_i32(start) and in_i32(end)):
|
250
253
|
return
|
251
254
|
op = node.target
|
252
|
-
node.target = lambda *args, **kwargs: op(*args, **kwargs)
|
255
|
+
node.target = lambda *args, **kwargs: to_int32(op(*args, **kwargs))
|
253
256
|
|
254
257
|
graph_module = exported_program.graph_module
|
255
258
|
for node in graph_module.graph.nodes:
|
@@ -305,8 +308,9 @@ def exported_program_to_mlir(
|
|
305
308
|
|
306
309
|
_convert_i64_to_i32(exported_program)
|
307
310
|
|
311
|
+
# No decompositions but just retracing/cananicalization.
|
308
312
|
exported_program = _torch_future.safe_run_decompositions(
|
309
|
-
exported_program,
|
313
|
+
exported_program, _torch_future.dummy_decomp_table()
|
310
314
|
)
|
311
315
|
|
312
316
|
# Passes below mutate the exported program to a state not executable by torch.
|
@@ -55,6 +55,10 @@ def decompositions():
|
|
55
55
|
],
|
56
56
|
)
|
57
57
|
|
58
|
+
# Override noop aten op decompositions for faster run_decompositions.
|
59
|
+
decompositions[torch.ops.aten.alias.default] = lambda x: x
|
60
|
+
decompositions[torch.ops.aten.detach.default] = lambda x: x
|
61
|
+
|
58
62
|
# Override _safe_softmax decompositions with regular softmax.
|
59
63
|
# _safe_softmax introduces additional check-select ops to guard extreme
|
60
64
|
# input values to softmax, which could make the converted model inefficient
|
ai_edge_torch/version.py
CHANGED