ai-edge-torch-nightly 0.3.0.dev20250105__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (32) hide show
  1. ai_edge_torch/_config.py +26 -9
  2. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +13 -0
  3. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +36 -0
  4. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  5. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  6. ai_edge_torch/generative/examples/gemma/gemma2.py +25 -43
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  8. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  9. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  10. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  11. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  12. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  13. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  14. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  15. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  16. ai_edge_torch/generative/layers/attention.py +70 -12
  17. ai_edge_torch/generative/layers/lora.py +557 -0
  18. ai_edge_torch/generative/layers/normalization.py +2 -50
  19. ai_edge_torch/generative/layers/rotary_position_embedding.py +27 -34
  20. ai_edge_torch/generative/test/test_lora.py +147 -0
  21. ai_edge_torch/generative/utilities/converter.py +100 -47
  22. ai_edge_torch/generative/utilities/model_builder.py +21 -16
  23. ai_edge_torch/generative/utilities/verifier.py +4 -4
  24. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  25. ai_edge_torch/odml_torch/export.py +6 -2
  26. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  27. ai_edge_torch/version.py +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
  29. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +32 -30
  30. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
  31. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
  32. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
@@ -32,64 +32,57 @@ def apply_rope(
32
32
  """
33
33
  x = x.transpose(1, 2)
34
34
  head_size = x.size(-1)
35
- x1, x2 = torch.split(x, head_size // 2, dim=-1)
36
- left = x1 * cos - x2 * sin
37
- right = x2 * cos + x1 * sin
38
- roped = torch.cat([left, right], dim=-1)
35
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
36
+ x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
37
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
38
+ roped = (x * cos) + (rotated * sin)
39
39
  return roped.transpose(1, 2).type_as(x)
40
40
 
41
41
 
42
- def build_rope(
42
+ def apply_rope_inline(
43
+ q: torch.Tensor,
44
+ k: torch.Tensor,
43
45
  input_pos: torch.Tensor,
44
46
  n_elem: int,
45
- head_dim: int,
46
47
  base: int = 10_000,
47
48
  ) -> Tuple[torch.Tensor, torch.Tensor]:
48
- """Computes rotary positional embedding cosine and sine tensors.
49
+ """Computes rotary positional embedding inline for a query and key.
49
50
 
50
51
  Args:
52
+ q: the query tensor.
53
+ k: the key tensor.
51
54
  input_pos: the sequence indices for the query and key
52
55
  n_elem: number of elements of the head dimension for RoPE computation
53
- base: the base of the exponentiated value for RoPE.
54
56
 
55
57
  Returns:
56
- cos, sin tensors
58
+ output the RoPE'd query and key.
57
59
  """
58
60
 
59
61
  if n_elem <= 0:
60
- return None, None
62
+ return q, k
61
63
 
62
64
  theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
63
65
  freq_exponents = (2.0 / n_elem) * torch.arange(
64
- head_dim // 2, dtype=torch.float32
66
+ q.shape[-1] // 2, dtype=torch.float32
65
67
  )
66
68
  timescale = float(base) ** freq_exponents
67
69
  radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
68
70
  0
69
71
  ).unsqueeze(0)
70
- cos = torch.cos(radians)
71
- sin = torch.sin(radians)
72
- return cos, sin
73
-
72
+ cos = torch.cos(radians).type_as(q)
73
+ sin = torch.sin(radians).type_as(q)
74
74
 
75
- def apply_rope_inline(
76
- q: torch.Tensor,
77
- k: torch.Tensor,
78
- cos: torch.Tensor,
79
- sin: torch.Tensor,
80
- ) -> Tuple[torch.Tensor, torch.Tensor]:
81
- """Computes rotary positional embedding inline for a query and key.
82
-
83
- Args:
84
- q: the query tensor.
85
- k: the key tensor.
86
- cos: the cosine tensor.
87
- sin: the sine tensor.
88
-
89
- Returns:
90
- output the RoPE'd query and key.
91
- """
75
+ def apply(x, sin, cos):
76
+ x = x.transpose(1, 2)
77
+ b, h, s, d = x.shape
78
+ ans = torch.split(x, d // 2, dim=-1)
79
+ x1, x2 = ans
80
+ left = x1 * cos - x2 * sin
81
+ right = x2 * cos + x1 * sin
82
+ res = torch.cat([left, right], dim=-1)
83
+ res = res.transpose(1, 2)
84
+ return res
92
85
 
93
- q_roped = apply_rope(q, cos, sin)
94
- k_roped = apply_rope(k, cos, sin)
86
+ q_roped = apply(q, sin, cos)
87
+ k_roped = apply(k, sin, cos)
95
88
  return q_roped, k_roped
@@ -0,0 +1,147 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """A suite of tests to validate LoRA utilities."""
17
+
18
+ from ai_edge_torch.generative.layers import lora as lora_utils
19
+ import ai_edge_torch.generative.layers.model_config as cfg
20
+ import torch
21
+ from absl.testing import absltest as googletest
22
+ from tensorflow.python.platform import resource_loader # pylint: disable=g-direct-tensorflow-import
23
+
24
+
25
+ class TestLora(googletest.TestCase):
26
+ """Tests for LoRA utilities."""
27
+
28
+ def test_safetensors_builder(self):
29
+ """Converts a safetensors file to a LoRA module."""
30
+
31
+ tensor_names = lora_utils.LoRATensorNames(
32
+ attn_query_w_a=(
33
+ "base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight"
34
+ ),
35
+ attn_query_w_b=(
36
+ "base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight"
37
+ ),
38
+ attn_key_w_a=(
39
+ "base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight"
40
+ ),
41
+ attn_key_w_b=(
42
+ "base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight"
43
+ ),
44
+ attn_value_w_a=(
45
+ "base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight"
46
+ ),
47
+ attn_value_w_b=(
48
+ "base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight"
49
+ ),
50
+ attn_output_w_a=(
51
+ "base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight"
52
+ ),
53
+ attn_output_w_b=(
54
+ "base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight"
55
+ ),
56
+ )
57
+
58
+ safetensors_file = resource_loader.get_path_to_datafile(
59
+ "fixtures/test_lora_rank16.safetensors"
60
+ )
61
+ config = self._get_test_config(
62
+ num_layers=1,
63
+ head_dim=8,
64
+ num_query_groups=1,
65
+ kv_cache_max_len=16,
66
+ )
67
+ lora = lora_utils.LoRA.from_safetensors(
68
+ safetensors_file,
69
+ scale=1.0,
70
+ lora_tensor_names=tensor_names,
71
+ config=config,
72
+ )
73
+ self.assertEqual(lora.get_rank(), 16)
74
+
75
+ def test_torch_export(self):
76
+ """Tests the export of the LoRA module."""
77
+
78
+ class TestModel(torch.nn.Module):
79
+
80
+ def forward(self, x: torch.Tensor, lora: lora_utils.LoRA) -> torch.Tensor:
81
+ x += lora_utils.apply_lora(x, lora.adapters[0].attention.query)
82
+ return x
83
+
84
+ n = 1
85
+ head_dim = 2
86
+ num_query_groups = 1
87
+ key_length = 4
88
+ config = self._get_test_config(
89
+ num_layers=n,
90
+ head_dim=head_dim,
91
+ num_query_groups=num_query_groups,
92
+ kv_cache_max_len=key_length,
93
+ )
94
+ inputs = torch.zeros((n, 1, head_dim))
95
+ lora = lora_utils.LoRA.zeros(rank=16, config=config)
96
+ model = TestModel()
97
+ exported_program = torch.export.export(model, (inputs, lora))
98
+ input_specs = exported_program.graph_signature.input_specs
99
+ # 9 inputs: 1 for x, 2 for query lora, 2 for key lora, 2 for value lora,
100
+ # 2 for output lora.
101
+ self.assertLen(input_specs, 9)
102
+ self.assertEqual(input_specs[0].arg.name, "x")
103
+ self.assertEqual(input_specs[1].arg.name, "lora_atten_q_a_prime_weight_0")
104
+ self.assertEqual(input_specs[2].arg.name, "lora_atten_q_b_prime_weight_0")
105
+ self.assertEqual(input_specs[3].arg.name, "lora_atten_k_a_prime_weight_0")
106
+ self.assertEqual(input_specs[4].arg.name, "lora_atten_k_b_prime_weight_0")
107
+ self.assertEqual(input_specs[5].arg.name, "lora_atten_v_a_prime_weight_0")
108
+ self.assertEqual(input_specs[6].arg.name, "lora_atten_v_b_prime_weight_0")
109
+ self.assertEqual(input_specs[7].arg.name, "lora_atten_o_a_prime_weight_0")
110
+ self.assertEqual(input_specs[8].arg.name, "lora_atten_o_b_prime_weight_0")
111
+
112
+ def test_lora_tflite_serialization(self):
113
+ """Tests the serialization of the LoRA module."""
114
+ config = self._get_test_config(
115
+ num_layers=2,
116
+ head_dim=8,
117
+ num_query_groups=1,
118
+ kv_cache_max_len=16,
119
+ )
120
+ lora = lora_utils.LoRA.random(rank=16, config=config)
121
+ flatbuffer_model = lora.to_tflite()
122
+ recovered_lora = lora_utils.LoRA.from_flatbuffers(flatbuffer_model)
123
+ self.assertEqual(lora, recovered_lora)
124
+
125
+ def _get_test_config(
126
+ self, num_layers, head_dim, num_query_groups, kv_cache_max_len
127
+ ):
128
+ """Returns a test model config."""
129
+ attn_config = cfg.AttentionConfig(
130
+ num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
131
+ )
132
+ block_config = cfg.TransformerBlockConfig(
133
+ attn_config=attn_config, ff_config=None
134
+ )
135
+ config = cfg.ModelConfig(
136
+ kv_cache_max_len=kv_cache_max_len,
137
+ embedding_dim=head_dim,
138
+ block_configs=block_config,
139
+ num_layers=num_layers,
140
+ max_seq_len=None,
141
+ vocab_size=None,
142
+ )
143
+ return config
144
+
145
+
146
+ if __name__ == "__main__":
147
+ googletest.main()
@@ -15,16 +15,15 @@
15
15
 
16
16
  """Common utility functions for model conversion."""
17
17
 
18
- from functools import partial
19
- from typing import Any, Union
20
-
18
+ import os
19
+ from typing import Optional, Union
21
20
  from ai_edge_torch._convert import converter as converter_utils
21
+ from ai_edge_torch.generative.layers import lora as lora_utils
22
22
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.model_config as cfg
24
24
  from ai_edge_torch.generative.quantize import quant_recipes
25
25
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
26
  import torch
27
- import torch.nn as nn
28
27
 
29
28
 
30
29
  class ExportableModule(torch.nn.Module):
@@ -41,11 +40,13 @@ class ExportableModule(torch.nn.Module):
41
40
 
42
41
  def convert_to_tflite(
43
42
  pytorch_model: torch.nn.Module,
44
- tflite_path: str,
43
+ output_path: str,
44
+ output_name_prefix: str,
45
45
  prefill_seq_len: Union[int, list[int]],
46
46
  pixel_values_size: torch.Size = None,
47
47
  quantize: bool = True,
48
48
  config: cfg.ModelConfig = None,
49
+ lora_ranks: Optional[list[int]] = None,
49
50
  export_config: ExportConfig = None,
50
51
  ):
51
52
  """Converts a nn.Module model to multi-signature tflite model.
@@ -79,21 +80,65 @@ def convert_to_tflite(
79
80
 
80
81
  Args:
81
82
  pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
82
- tflite_path (str): The tflite file path to export.
83
- prefill_seq_len (Union[int, list[int]]): A list of prefill lengths to
84
- export.
83
+ output_path (str): The path to export the tflite model.
84
+ output_name_prefix (str): The prefix of the tflite model name.
85
+ prefill_seq_len (Union[int, list[int]]): The prefill sequence length to
86
+ use. If a list, the model will have multiple prefill signatures.
85
87
  pixel_values_size (torch.Size, optional): The size of pixel values to pass
86
88
  to the model. If None, the model is not expected to take pixel values.
87
89
  quantize (bool, optional): Whether the model should be quanized. Defaults
88
90
  to True.
89
91
  config (cfg.ModelConfig, optional): The model config used to configure KV
90
92
  cache. If None, it uses the config of the pytorch_model.
93
+ lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
94
+ no LoRA signatures will be added.
91
95
  """
96
+ # pylint: disable=protected-access
97
+ torch._dynamo.config.cache_size_limit = 64
98
+
99
+ config = config if config else pytorch_model.config
92
100
  prefill_seq_lens = (
93
101
  [prefill_seq_len] if isinstance(prefill_seq_len, int) else prefill_seq_len
94
102
  )
103
+ loras = [None]
104
+ if lora_ranks is not None:
105
+ for rank in lora_ranks:
106
+ lora = lora_utils.LoRA.zeros(rank, config)
107
+ loras.append(lora)
108
+
109
+ quant_suffix = 'q8' if quantize else 'f32'
110
+ kv_size = config.kv_cache_max_len
111
+ lora_suffix = (
112
+ '' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
113
+ )
114
+ output_filename = (
115
+ f'{output_name_prefix}_{quant_suffix}_ekv{kv_size}{lora_suffix}.tflite'
116
+ )
117
+ output_file = os.path.join(output_path, output_filename)
118
+
119
+ _export_helper(
120
+ pytorch_model,
121
+ output_file,
122
+ prefill_seq_lens,
123
+ pixel_values_size,
124
+ quantize,
125
+ config,
126
+ loras,
127
+ export_config,
128
+ )
95
129
 
96
- # Tensors used to trace the model graph during conversion.
130
+
131
+ def _export_helper(
132
+ pytorch_model: torch.nn.Module,
133
+ output_file: str,
134
+ prefill_seq_lens: list[int],
135
+ pixel_values_size: torch.Size,
136
+ quantize: bool,
137
+ config: cfg.ModelConfig,
138
+ loras: list[None | lora_utils.LoRA],
139
+ export_config: ExportConfig,
140
+ ):
141
+ """Helper function to export a model to tflite."""
97
142
  prefill_tokens_list = []
98
143
  prefill_input_pos_list = []
99
144
  for seq_len in prefill_seq_lens:
@@ -108,9 +153,7 @@ def convert_to_tflite(
108
153
 
109
154
  decode_token = torch.tensor([[0]], dtype=torch.int)
110
155
  decode_input_pos = torch.tensor([0], dtype=torch.int)
111
- kv = kv_utils.KVCache.from_model_config(
112
- config if config else pytorch_model.config
113
- )
156
+ kv = kv_utils.KVCache.from_model_config(config)
114
157
 
115
158
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
116
159
 
@@ -119,44 +162,54 @@ def convert_to_tflite(
119
162
  mod = ExportableModule(pytorch_model, export_config=export_config)
120
163
 
121
164
  converter = converter_utils.Converter()
122
- for i in range(len(prefill_seq_lens)):
123
- prefill_seq_len = prefill_seq_lens[i]
124
- prefill_tokens = prefill_tokens_list[i]
125
- prefill_input_pos = prefill_input_pos_list[i]
126
- if i == 0 and len(prefill_seq_lens) == 1:
127
- prefill_signature_name = 'prefill'
128
- else:
129
- prefill_signature_name = f'prefill_{prefill_seq_len}'
130
- converter.add_signature(
131
- prefill_signature_name,
132
- mod,
133
- sample_kwargs={
134
- 'tokens': prefill_tokens,
135
- 'input_pos': prefill_input_pos,
136
- 'kv_cache': kv,
137
- },
138
- )
139
- if prefill_pixel_values is not None:
165
+ for lora in loras:
166
+ for i in range(len(prefill_seq_lens)):
167
+ prefill_seq_len = prefill_seq_lens[i]
168
+ prefill_tokens = prefill_tokens_list[i]
169
+ prefill_input_pos = prefill_input_pos_list[i]
170
+ if i == 0 and len(prefill_seq_lens) == 1:
171
+ prefill_signature_name = 'prefill'
172
+ else:
173
+ prefill_signature_name = f'prefill_{prefill_seq_len}'
174
+
175
+ sample_kwargs = {
176
+ 'tokens': prefill_tokens,
177
+ 'input_pos': prefill_input_pos,
178
+ 'kv_cache': kv,
179
+ }
180
+ if lora is not None:
181
+ prefill_signature_name += f'_lora_r{lora.get_rank()}'
182
+ sample_kwargs['lora'] = lora
183
+
140
184
  converter.add_signature(
141
- prefill_signature_name + '_pixel',
185
+ prefill_signature_name,
142
186
  mod,
143
- sample_kwargs={
144
- 'tokens': prefill_tokens,
145
- 'input_pos': prefill_input_pos,
146
- 'kv_cache': kv,
147
- 'pixel_values': prefill_pixel_values,
148
- },
187
+ sample_kwargs=sample_kwargs,
149
188
  )
150
189
 
151
- converter.add_signature(
152
- 'decode',
153
- mod,
154
- sample_kwargs={
155
- 'tokens': decode_token,
156
- 'input_pos': decode_input_pos,
157
- 'kv_cache': kv,
158
- },
159
- )
190
+ if prefill_pixel_values is not None:
191
+ converter.add_signature(
192
+ prefill_signature_name + '_pixel',
193
+ mod,
194
+ sample_kwargs={
195
+ **sample_kwargs,
196
+ 'pixel_values': prefill_pixel_values,
197
+ },
198
+ )
199
+
200
+ sample_kwargs = {
201
+ 'tokens': decode_token,
202
+ 'input_pos': decode_input_pos,
203
+ 'kv_cache': kv,
204
+ }
205
+ if lora is not None:
206
+ sample_kwargs['lora'] = lora
207
+
208
+ converter.add_signature(
209
+ 'decode' if lora is None else f'decode_lora_r{lora.get_rank()}',
210
+ mod,
211
+ sample_kwargs=sample_kwargs,
212
+ )
160
213
 
161
214
  edge_model = converter.convert(quant_config=quant_config)
162
- edge_model.export(tflite_path)
215
+ edge_model.export(output_file)
@@ -22,13 +22,14 @@ from typing import Optional, Tuple
22
22
  from ai_edge_torch.generative.layers import attention
23
23
  from ai_edge_torch.generative.layers import builder
24
24
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
+ from ai_edge_torch.generative.layers import lora as lora_utils
25
26
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
27
  import ai_edge_torch.generative.layers.model_config as cfg
27
- import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
28
28
  import ai_edge_torch.generative.utilities.loader as loading_utils
29
29
  import torch
30
30
  from torch import nn
31
31
 
32
+
32
33
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
33
34
  ff_up_proj="model.layers.{}.mlp.up_proj",
34
35
  ff_down_proj="model.layers.{}.mlp.down_proj",
@@ -86,6 +87,13 @@ class DecoderOnlyModel(nn.Module):
86
87
  config.embedding_dim,
87
88
  config.final_norm_config,
88
89
  )
90
+ # ROPE parameters for all attn_configs are the same. Take the first one.
91
+ attn_config = config.block_config(0).attn_config
92
+ self.rope_cache = attn_utils.build_rope_cache(
93
+ size=config.kv_cache_max,
94
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
95
+ base=attn_config.rotary_base,
96
+ )
89
97
  self.mask_cache = attn_utils.build_causal_mask_cache(
90
98
  size=config.kv_cache_max,
91
99
  )
@@ -97,6 +105,7 @@ class DecoderOnlyModel(nn.Module):
97
105
  tokens: torch.Tensor,
98
106
  input_pos: torch.Tensor,
99
107
  kv_cache: kv_utils.KVCache,
108
+ lora: Optional[lora_utils.LoRA] = None,
100
109
  export_config: Optional[ExportConfig] = None,
101
110
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
102
111
  _, seq_len = tokens.size()
@@ -107,28 +116,23 @@ class DecoderOnlyModel(nn.Module):
107
116
 
108
117
  # token embeddings of shape (b, t, n_embd)
109
118
  input_embeds = self.tok_embedding(tokens)
110
-
111
- # ROPE parameters for all attn_configs are the same. Take the first one.
112
- attn_config = self.config.block_config(0).attn_config
113
- n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
114
- rope = rotary_pos_emb.build_rope(
115
- input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
116
- )
117
-
119
+ cos, sin = self.rope_cache
120
+ rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
118
121
  mask = self.mask_cache.index_select(2, input_pos)
119
122
  mask = mask[:, :, :, : self.config.kv_cache_max]
120
123
 
121
- return self._forward_with_embeds(
122
- input_embeds, rope, mask, input_pos, kv_cache, export_config
124
+ return self.forward_with_embeds(
125
+ input_embeds, rope, mask, input_pos, kv_cache, lora, export_config
123
126
  )
124
127
 
125
- def _forward_with_embeds(
128
+ def forward_with_embeds(
126
129
  self,
127
130
  input_embeds: torch.Tensor,
128
131
  rope: Tuple[torch.Tensor, torch.Tensor],
129
132
  mask: torch.Tensor,
130
133
  input_pos: torch.Tensor,
131
134
  kv_cache: kv_utils.KVCache,
135
+ lora: Optional[lora_utils.LoRA] = None,
132
136
  export_config: Optional[ExportConfig] = None,
133
137
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
134
138
  """Forwards the model with input embeddings."""
@@ -141,13 +145,14 @@ class DecoderOnlyModel(nn.Module):
141
145
  if self.config.embedding_scale is not None:
142
146
  x = x * self.config.embedding_scale
143
147
 
144
- updated_kv_entries = []
148
+ updated_kv_entires = []
145
149
  for i, block in enumerate(self.transformer_blocks):
146
150
  kv_entry = kv_cache.caches[i] if kv_cache else None
147
- x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
151
+ lora_adapter = lora.adapters[i] if lora else None
152
+ x, kv_entry = block(x, rope, mask, input_pos, kv_entry, lora_adapter)
148
153
  if kv_entry:
149
- updated_kv_entries.append(kv_entry)
150
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
154
+ updated_kv_entires.append(kv_entry)
155
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
151
156
 
152
157
  if export_config is not None:
153
158
  if (
@@ -16,7 +16,7 @@
16
16
  """Common utility functions to verify the reauthored models."""
17
17
 
18
18
  import logging
19
- from typing import Any,List
19
+ from typing import Any, List, Optional
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
22
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
@@ -134,7 +134,7 @@ class ReauthoredModelWrapper(ModelWrapper):
134
134
  prompts: torch.Tensor,
135
135
  max_new_tokens: int,
136
136
  pixel_values: torch.Tensor = None,
137
- eos_token_id: int = 1,
137
+ eos_token_id: Optional[int] = None,
138
138
  ) -> torch.IntTensor:
139
139
  input_ids = prompts[0].int().tolist()
140
140
  tokens = torch.tensor([input_ids])
@@ -146,7 +146,7 @@ class ReauthoredModelWrapper(ModelWrapper):
146
146
  )
147
147
  generated_token = logits[0][-1].argmax().item()
148
148
  input_ids.append(generated_token)
149
- if generated_token == eos_token_id:
149
+ if eos_token_id is not None and generated_token == eos_token_id:
150
150
  break
151
151
  tokens = torch.tensor([[generated_token]])
152
152
  input_pos = torch.tensor([len(input_ids) - 1])
@@ -253,7 +253,7 @@ def verify_model_with_prompts(
253
253
  outputs_reauthored = reauthored_model.generate(
254
254
  prompt_tokens,
255
255
  max_new_tokens,
256
- eos_token_id=tokenizer.tokenizer.eos_token_id,
256
+ eos_token_id=getattr(tokenizer.tokenizer, "eos_token_id", None),
257
257
  )
258
258
  response_reauthored = tokenizer.decode(outputs_reauthored[0])
259
259
  logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
@@ -73,3 +73,16 @@ def safe_run_decompositions(exported_program, decomp_table=None):
73
73
  node.target = lambda self, size: torch.reshape(self.contiguous(), size)
74
74
 
75
75
  return exported_program.run_decompositions(decomp_table)
76
+
77
+
78
+ def dummy_decomp_table():
79
+ """Build dummy decomp table for run_decompositions without any decompositions.
80
+
81
+ Compatible for torch<=2.5.
82
+
83
+ Returns:
84
+ Decomp table for ExportedProgram.run_decompositions.
85
+ """
86
+ return {
87
+ torch._ops.OperatorBase(): lambda: None,
88
+ }
@@ -238,6 +238,9 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
238
238
  def in_i32(x: int):
239
239
  return -2147483648 <= x <= 2147483647
240
240
 
241
+ def to_int32(x: torch.Tensor):
242
+ return torch.ops.aten._to_copy.default(x, dtype=torch.int32)
243
+
241
244
  def rewrite_arange(node: torch.fx.Node):
242
245
  tensor_meta = node.meta.get("tensor_meta", None)
243
246
  if not tensor_meta:
@@ -249,7 +252,7 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
249
252
  if not (in_i32(start) and in_i32(end)):
250
253
  return
251
254
  op = node.target
252
- node.target = lambda *args, **kwargs: op(*args, **kwargs).type(torch.int32)
255
+ node.target = lambda *args, **kwargs: to_int32(op(*args, **kwargs))
253
256
 
254
257
  graph_module = exported_program.graph_module
255
258
  for node in graph_module.graph.nodes:
@@ -305,8 +308,9 @@ def exported_program_to_mlir(
305
308
 
306
309
  _convert_i64_to_i32(exported_program)
307
310
 
311
+ # No decompositions but just retracing/cananicalization.
308
312
  exported_program = _torch_future.safe_run_decompositions(
309
- exported_program, lowerings.decompositions()
313
+ exported_program, _torch_future.dummy_decomp_table()
310
314
  )
311
315
 
312
316
  # Passes below mutate the exported program to a state not executable by torch.
@@ -55,6 +55,10 @@ def decompositions():
55
55
  ],
56
56
  )
57
57
 
58
+ # Override noop aten op decompositions for faster run_decompositions.
59
+ decompositions[torch.ops.aten.alias.default] = lambda x: x
60
+ decompositions[torch.ops.aten.detach.default] = lambda x: x
61
+
58
62
  # Override _safe_softmax decompositions with regular softmax.
59
63
  # _safe_softmax introduces additional check-select ops to guard extreme
60
64
  # input values to softmax, which could make the converted model inefficient
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250105"
16
+ __version__ = "0.3.0.dev20250108"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250105
3
+ Version: 0.3.0.dev20250108
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI