ai-edge-torch-nightly 0.3.0.dev20250108__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
  2. ai_edge_torch/generative/examples/llama/llama.py +29 -25
  3. ai_edge_torch/generative/examples/phi/phi3.py +26 -23
  4. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  5. ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
  6. ai_edge_torch/generative/examples/smollm/verify.py +18 -2
  7. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  8. ai_edge_torch/generative/layers/attention.py +4 -29
  9. ai_edge_torch/generative/layers/model_config.py +6 -2
  10. ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
  11. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  12. ai_edge_torch/generative/utilities/model_builder.py +16 -12
  13. ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
  14. ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
  15. ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
  16. ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
  17. ai_edge_torch/version.py +1 -1
  18. {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
  19. {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +22 -21
  20. {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
  21. {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
  22. {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,14 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- from typing import Optional, Tuple
18
+ from typing import List, Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
25
26
  from ai_edge_torch.generative.utilities import model_builder
26
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
28
  import torch
@@ -103,17 +104,12 @@ class Gemma2(nn.Module):
103
104
  config.embedding_dim,
104
105
  config.final_norm_config,
105
106
  )
106
- # Gemma2 has same hyper parameters for each layer except for attention
107
- # types. Use the first layer.
108
- attn_config = config.block_config(0).attn_config
109
- self.rope_cache = attn_utils.build_rope_cache(
110
- size=config.kv_cache_max,
111
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
112
- base=attn_config.rotary_base,
113
- )
114
107
  self.mask_cache = attn_utils.build_causal_mask_cache(
115
108
  size=config.kv_cache_max,
116
109
  )
110
+ # Gemma2 has same hyper parameters for each layer except for attention
111
+ # types. Use the first layer.
112
+ attn_config = config.block_config(0).attn_config
117
113
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
118
114
  size=config.kv_cache_max,
119
115
  window_size=attn_config.sliding_window_size,
@@ -140,29 +136,51 @@ class Gemma2(nn.Module):
140
136
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
141
137
  f" {self.config.max_seq_len}"
142
138
  )
139
+
140
+ # token embeddings of shape (b, t, n_embd)
141
+ input_embeds = self.tok_embedding(tokens)
142
+ # RoPE parameters are the same for all blocks. Use the first layer.
143
+ attn_config = self.config.block_config(0).attn_config
144
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
145
+ rope = rotary_pos_emb.build_rope(
146
+ input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
147
+ )
148
+ mask = [
149
+ self.get_attention_mask(
150
+ self.config.block_config(i).attn_config.attn_type, input_pos
151
+ )
152
+ for i in range(self.config.num_layers)
153
+ ]
154
+
155
+ return self._forward_with_embeds(
156
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
157
+ )
158
+
159
+ def _forward_with_embeds(
160
+ self,
161
+ input_embeds: torch.Tensor,
162
+ rope: Tuple[torch.Tensor, torch.Tensor],
163
+ mask: List[torch.Tensor],
164
+ input_pos: torch.Tensor,
165
+ kv_cache: kv_utils.KVCache,
166
+ export_config: Optional[model_builder.ExportConfig] = None,
167
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
168
+ """Forwards the model with input embeddings."""
143
169
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
144
170
  "The number of transformer blocks and the number of KV cache entries"
145
171
  " must be the same."
146
172
  )
147
173
 
148
- cos, sin = self.rope_cache
149
- cos = cos.index_select(0, input_pos)
150
- sin = sin.index_select(0, input_pos)
151
-
152
- # token embeddings of shape (b, t, n_embd)
153
- x = self.tok_embedding(tokens)
154
- x = x * (self.config.embedding_dim**0.5)
155
-
156
- updated_kv_entires = []
174
+ if self.config.embedding_scale is not None:
175
+ input_embeds = input_embeds * self.config.embedding_scale
176
+ x = input_embeds
177
+ updated_kv_entries = []
157
178
  for i, block in enumerate(self.transformer_blocks):
158
- mask = self.get_attention_mask(
159
- block.config.attn_config.attn_type, input_pos
160
- )
161
179
  kv_entry = kv_cache.caches[i] if kv_cache else None
162
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
180
+ x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
163
181
  if kv_entry:
164
- updated_kv_entires.append(kv_entry)
165
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
182
+ updated_kv_entries.append(kv_entry)
183
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
166
184
 
167
185
  if export_config is not None:
168
186
  if (
@@ -228,11 +246,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
228
246
  )
229
247
 
230
248
  num_layers = 26
249
+ embedding_dim = 2304
231
250
  config = cfg.ModelConfig(
232
251
  vocab_size=256000,
233
252
  num_layers=num_layers,
234
253
  max_seq_len=8192,
235
- embedding_dim=2304,
254
+ embedding_dim=embedding_dim,
255
+ embedding_scale=embedding_dim**0.5,
236
256
  kv_cache_max_len=kv_cache_max_len,
237
257
  block_configs=[get_block_config(i) for i in range(num_layers)],
238
258
  final_norm_config=norm_config,
@@ -249,6 +269,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
249
269
  config.num_layers = 2
250
270
  config.max_seq_len = 2 * kv_cache_max_len
251
271
  config.embedding_dim = 128
272
+ config.embedding_scale = config.embedding_dim**0.5
252
273
  config.block_configs = config.block_configs[: config.num_layers]
253
274
  for block_config in config.block_configs:
254
275
  block_config.attn_config.num_heads = 4
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Example of building Llama 3.2 models."""
17
17
 
18
+ from functools import partial
18
19
  import math
19
20
  from typing import Tuple
20
21
 
@@ -26,8 +27,8 @@ TENSOR_NAMES = model_builder.TENSOR_NAMES
26
27
 
27
28
 
28
29
  def _build_llama3_rope_cache(
29
- size: int,
30
- dim: int,
30
+ input_pos: torch.Tensor,
31
+ n_elem: int,
31
32
  base: int,
32
33
  condense_ratio: int,
33
34
  dtype: torch.dtype,
@@ -36,8 +37,9 @@ def _build_llama3_rope_cache(
36
37
  low_freq_factor: float,
37
38
  high_freq_factor: float,
38
39
  max_seq_len: int,
40
+ **kwargs,
39
41
  ) -> Tuple[torch.Tensor, torch.Tensor]:
40
- """Precomputes Rotary Positional Embeddings for Llama 3.2 model.
42
+ """Computes Rotary Positional Embeddings for Llama 3.2 model.
41
43
 
42
44
  It's a modified version of attn_utils.build_rope_cache with additional
43
45
  arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin
@@ -47,13 +49,12 @@ def _build_llama3_rope_cache(
47
49
  https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307
48
50
 
49
51
  Args:
50
- size (int): The size of the built cache.
51
- dim (int): Each sequence's dimmension.
52
- base (int, optional): Rope base value.
53
- condense_ratio (int, optional): The ratio by which sequence indicies are
54
- condensed.
55
- dtype (torch.dtype, optional): Output tensor's data type.
56
- device (torch.device, optional): Output tensor's data type.
52
+ input_pos (torch.Tensor): the given input sequence positions
53
+ n_elem (int): Each sequence's dimmension.
54
+ base (int): Rope base value.
55
+ condense_ratio (int): The ratio by which sequence indicies are condensed.
56
+ dtype (torch.dtype): Output tensor's data type.
57
+ device (torch.device): Output tensor's data type.
57
58
  factor (float): Factor to scale theta down for tokens in long range in the
58
59
  sequence.
59
60
  low_freq_factor (float): Factor to determine if tokens are in long range
@@ -66,7 +67,7 @@ def _build_llama3_rope_cache(
66
67
  Returns:
67
68
  Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
68
69
  """
69
- theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
70
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
70
71
  low_freq_wavelen = max_seq_len / low_freq_factor
71
72
  high_freq_wavelen = max_seq_len / high_freq_factor
72
73
  wavelen = 2 * math.pi / theta
@@ -81,7 +82,7 @@ def _build_llama3_rope_cache(
81
82
  is_medium = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
82
83
  theta = torch.where(is_medium, smoothed_theta, theta)
83
84
 
84
- seq_idx = torch.arange(size) / condense_ratio
85
+ seq_idx = input_pos / condense_ratio
85
86
  idx_theta = torch.outer(seq_idx, theta)
86
87
  cos = torch.cos(idx_theta).to(dtype=dtype, device=device)
87
88
  sin = torch.sin(idx_theta).to(dtype=dtype, device=device)
@@ -97,18 +98,6 @@ class Llama(model_builder.DecoderOnlyModel):
97
98
  def __init__(self, config: cfg.ModelConfig):
98
99
  super().__init__(config)
99
100
  attn_config = self.config.block_config(0).attn_config
100
- self.rope_cache = _build_llama3_rope_cache(
101
- size=self.config.kv_cache_max,
102
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
103
- base=attn_config.rotary_base,
104
- condense_ratio=1,
105
- dtype=torch.float32,
106
- device=torch.device("cpu"),
107
- factor=32.0,
108
- low_freq_factor=1.0,
109
- high_freq_factor=4.0,
110
- max_seq_len=self.config.max_seq_len,
111
- )
112
101
 
113
102
 
114
103
  def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -140,15 +129,30 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
140
129
  pre_attention_norm_config=norm_config,
141
130
  post_attention_norm_config=norm_config,
142
131
  )
132
+
133
+ max_seq_len = 8192
134
+ # Create the RoPE callable
135
+ build_rope = partial(
136
+ _build_llama3_rope_cache,
137
+ condense_ratio=1,
138
+ dtype=torch.float32,
139
+ device=torch.device("cpu"),
140
+ factor=32.0,
141
+ low_freq_factor=1.0,
142
+ high_freq_factor=4.0,
143
+ max_seq_len=max_seq_len,
144
+ )
145
+
143
146
  config = cfg.ModelConfig(
144
147
  vocab_size=128256,
145
148
  num_layers=16,
146
- max_seq_len=8192,
149
+ max_seq_len=max_seq_len,
147
150
  embedding_dim=2048,
148
151
  kv_cache_max_len=kv_cache_max_len,
149
152
  block_configs=block_config,
150
153
  final_norm_config=norm_config,
151
154
  enable_hlfb=True,
155
+ build_rope=build_rope,
152
156
  )
153
157
  return config
154
158
 
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
17
17
 
18
+ from functools import partial
18
19
  import math
19
20
  from typing import Tuple
20
21
 
@@ -93,40 +94,41 @@ ROPE_SHORT_FACTOR = [
93
94
  ]
94
95
 
95
96
 
96
- def _build_rope_cache(
97
- size: int,
98
- dim: int,
97
+ def _build_phi3_rope(
98
+ input_pos: int,
99
+ n_elem: int,
99
100
  base: int,
100
101
  condense_ratio: int,
101
102
  dtype: torch.dtype,
102
103
  device: torch.device,
103
104
  theta_factors: torch.Tensor,
104
105
  scale: float,
106
+ **kwargs,
105
107
  ) -> Tuple[torch.Tensor, torch.Tensor]:
106
- """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
108
+ """Computes Rotary Positional Embeddings for Phi-3.5 model.
107
109
 
108
110
  It's a modified version of attn_utils.build_rope_cache with additional
109
111
  arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
110
112
  Cos values with scaling factors for quick lookup during the inference.
111
113
 
112
114
  Args:
113
- size (int): The size of the built cache.
114
- dim (int): Each sequence's dimmension.
115
+ input_pos (torch.Tensor): the given input sequence positions
116
+ n_elem (int): Each sequence's dimmension.
115
117
  base (int, optional): Rope base value.
116
118
  condense_ratio (int, optional): The ratio by which sequence indicies are
117
119
  condensed.
118
120
  dtype (torch.dtype, optional): Output tensor's data type.
119
121
  device (torch.device, optional): Output tensor's data type.
120
- theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
121
- scale the theta values.
122
+ theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used
123
+ to scale the theta values.
122
124
  scale (float, optional): A float used to scale the rope values.
123
125
 
124
126
  Returns:
125
127
  Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
126
128
  """
127
- theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
129
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
128
130
  theta = theta / theta_factors
129
- seq_idx = torch.arange(size) / condense_ratio
131
+ seq_idx = input_pos / condense_ratio
130
132
  idx_theta = torch.outer(seq_idx, theta)
131
133
  cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
132
134
  sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
@@ -139,18 +141,6 @@ class Phi3_5Mini(model_builder.DecoderOnlyModel):
139
141
  def __init__(self, config: cfg.ModelConfig):
140
142
  super().__init__(config)
141
143
  attn_config = self.config.block_config(0).attn_config
142
- self.rope_cache = _build_rope_cache(
143
- size=self.config.kv_cache_max,
144
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
145
- base=attn_config.rotary_base,
146
- condense_ratio=1,
147
- dtype=torch.float32,
148
- device=torch.device("cpu"),
149
- theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
150
- scale=math.sqrt(
151
- 1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
152
- ),
153
- )
154
144
 
155
145
 
156
146
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -183,16 +173,29 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
183
173
  pre_attention_norm_config=norm_config,
184
174
  post_attention_norm_config=norm_config,
185
175
  )
176
+ max_seq_len = 4096
177
+ # Create the RoPE callable
178
+ build_rope = partial(
179
+ _build_phi3_rope,
180
+ condense_ratio=1,
181
+ dtype=torch.float32,
182
+ device=torch.device("cpu"),
183
+ theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
184
+ scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)),
185
+ max_seq_len=max_seq_len,
186
+ )
187
+
186
188
  config = cfg.ModelConfig(
187
189
  vocab_size=32064,
188
190
  num_layers=32,
189
- max_seq_len=4096,
191
+ max_seq_len=max_seq_len,
190
192
  kv_cache_max_len=kv_cache_max_len,
191
193
  embedding_dim=3072,
192
194
  block_configs=block_config,
193
195
  final_norm_config=norm_config,
194
196
  lm_head_share_weight_with_embedding=False,
195
197
  enable_hlfb=True,
198
+ build_rope=build_rope,
196
199
  )
197
200
  return config
198
201
 
@@ -0,0 +1,71 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting SmolLM2 model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.smollm import smollm
24
+ from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
+
27
+ _CHECKPOINT_PATH = flags.DEFINE_string(
28
+ 'checkpoint_path',
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm2'),
30
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
31
+ )
32
+ _TFLITE_PATH = flags.DEFINE_string(
33
+ 'tflite_path',
34
+ '/tmp/',
35
+ 'The tflite file path to export.',
36
+ )
37
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
+ 'prefill_seq_lens',
39
+ (8, 64, 128, 256, 512, 1024),
40
+ 'List of the maximum sizes of prefill input tensors.',
41
+ )
42
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
43
+ 'kv_cache_max_len',
44
+ 1280,
45
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
46
+ )
47
+ _QUANTIZE = flags.DEFINE_bool(
48
+ 'quantize',
49
+ True,
50
+ 'Whether the model should be quantized.',
51
+ )
52
+
53
+
54
+ def main(_):
55
+ pytorch_model = smollm.build_model_v2(
56
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
+ )
58
+
59
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
60
+ output_filename = f'smollm2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
+ converter.convert_to_tflite(
62
+ pytorch_model,
63
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
64
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
65
+ quantize=_QUANTIZE.value,
66
+ export_config=ExportConfig(),
67
+ )
68
+
69
+
70
+ if __name__ == '__main__':
71
+ app.run(main)
@@ -85,3 +85,41 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
85
85
  tensor_names=TENSOR_NAMES,
86
86
  model_class=SmolLM,
87
87
  )
88
+
89
+
90
+ class SmolLM2(model_builder.DecoderOnlyModel):
91
+ """A SmolLM2 model built from the Edge Generative API layers."""
92
+ pass
93
+
94
+
95
+ def get_model_config_v2(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
96
+ """Returns the model config for a SmolLM2 135M model.
97
+
98
+ Args:
99
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
100
+ is 1024.
101
+
102
+ Returns:
103
+ The model config for a SmolLM2 model.
104
+ """
105
+ config = get_model_config(kv_cache_max_len)
106
+ config.block_config(0).attn_config.rotary_base = 100000
107
+ return config
108
+
109
+
110
+ def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
111
+ config = get_model_config_v2(**kwargs)
112
+ config.vocab_size = 128
113
+ config.num_layers = 2
114
+ # SmolLM2 has only one block config.
115
+ config.block_config(0).ff_config.intermediate_size = 64
116
+ return config
117
+
118
+
119
+ def build_model_v2(checkpoint_path: str, **kwargs) -> nn.Module:
120
+ return model_builder.build_decoder_only_model(
121
+ checkpoint_path=checkpoint_path,
122
+ config=get_model_config_v2(**kwargs),
123
+ tensor_names=TENSOR_NAMES,
124
+ model_class=SmolLM2,
125
+ )
@@ -36,10 +36,26 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
36
36
  30,
37
37
  "The maximum size of the generated tokens.",
38
38
  )
39
+ _MODEL_VERSION = flags.DEFINE_enum(
40
+ "model_version",
41
+ "v1",
42
+ ["v1", "v2"],
43
+ "The version of SmolLm to verify.",
44
+ )
45
+ _CHECKPOINT = {
46
+ "v1": "HuggingFaceTB/SmolLM-135M",
47
+ "v2": "HuggingFaceTB/SmolLM2-135M",
48
+ }
49
+
50
+ _BUILDER = {
51
+ "v1": smollm.build_model,
52
+ "v2": smollm.build_model_v2,
53
+ }
39
54
 
40
55
 
41
56
  def main(_):
42
- checkpoint = "HuggingFaceTB/SmolLM-135M"
57
+ checkpoint = _CHECKPOINT[_MODEL_VERSION.value]
58
+ builder = _BUILDER[_MODEL_VERSION.value]
43
59
  logging.info("Loading the original model from: %s", checkpoint)
44
60
  original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
61
 
@@ -49,7 +65,7 @@ def main(_):
49
65
  )
50
66
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
67
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
- reauthored_model = smollm.build_model(reauthored_checkpoint)
68
+ reauthored_model = builder(reauthored_checkpoint)
53
69
 
54
70
  logging.info("Loading the tokenizer from: %s", checkpoint)
55
71
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
72
72
  mask = self.mask_cache.index_select(2, input_pos)
73
73
  mask = mask[:, :, :, : self.config.max_seq_len]
74
74
 
75
- updated_kv_entires = []
75
+ updated_kv_entries = []
76
76
  for i, block in enumerate(self.transformer_blocks):
77
77
  kv_entry = kv_cache.caches[i] if kv_cache else None
78
78
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
79
79
  if kv_entry:
80
- updated_kv_entires.append(kv_entry)
80
+ updated_kv_entries.append(kv_entry)
81
81
 
82
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
82
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
83
83
 
84
84
  if export_config is not None:
85
85
  if (
@@ -27,33 +27,6 @@ import torch
27
27
  from torch import nn
28
28
 
29
29
 
30
- def _embed_rope(
31
- q: torch.Tensor,
32
- k: torch.Tensor,
33
- n_elem: int,
34
- rope: Tuple[torch.Tensor, torch.Tensor],
35
- ) -> Tuple[torch.Tensor, torch.Tensor]:
36
- """Embed rotary positional embedding for query and key.
37
-
38
- Args:
39
- q (torch.Tensor): query tensor.
40
- k (torch.Tensor): key tensor.
41
- n_elem (int): number of elements to embed rotarty positional embedding.
42
- rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
43
- """
44
- if n_elem > 0:
45
- cos, sin = rope
46
- q_roped = rotary_pos_emb.apply_rope(
47
- q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
48
- )
49
- k_roped = rotary_pos_emb.apply_rope(
50
- k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
51
- )
52
- q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
53
- k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
54
- return q, k
55
-
56
-
57
30
  class TransformerBlock(nn.Module):
58
31
 
59
32
  def __init__(
@@ -252,7 +225,8 @@ class CausalSelfAttention(nn.Module):
252
225
  if rope is not None:
253
226
  # Compute rotary positional embedding for query and key.
254
227
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
255
- q, k = _embed_rope(q, k, n_elem, rope)
228
+ cos, sin = rope
229
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
256
230
 
257
231
  if kv_cache is not None:
258
232
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -404,7 +378,8 @@ class CrossAttention(nn.Module):
404
378
  if rope is not None:
405
379
  # Compute rotary positional embedding for query and key.
406
380
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
407
- q, k = _embed_rope(q, k, n_elem, rope)
381
+ cos, sin = rope
382
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
408
383
 
409
384
  if kv_cache is not None:
410
385
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -17,8 +17,8 @@
17
17
 
18
18
  import dataclasses
19
19
  import enum
20
- from typing import Optional, Sequence, Union
21
-
20
+ from typing import Callable, Optional, Sequence, Union
21
+ from ai_edge_torch.generative.layers import rotary_position_embedding
22
22
 
23
23
  @enum.unique
24
24
  class ActivationType(enum.Enum):
@@ -218,6 +218,10 @@ class ModelConfig:
218
218
  # Softcap on the model output logits.
219
219
  final_logit_softcap: Optional[float] = None
220
220
 
221
+ # The function to call to create the RoPE sin and cos vectors during the
222
+ # forward pass. Defaults to a standard implementation.
223
+ build_rope: Callable = rotary_position_embedding.build_rope
224
+
221
225
  @property
222
226
  def kv_cache_max(self) -> int:
223
227
  if self.kv_cache_max_len > 0:
@@ -32,57 +32,63 @@ def apply_rope(
32
32
  """
33
33
  x = x.transpose(1, 2)
34
34
  head_size = x.size(-1)
35
- x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
36
- x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
37
- rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
38
- roped = (x * cos) + (rotated * sin)
35
+ x1, x2 = torch.split(x, head_size // 2, dim=-1)
36
+ left = x1 * cos - x2 * sin
37
+ right = x2 * cos + x1 * sin
38
+ roped = torch.cat([left, right], dim=-1)
39
39
  return roped.transpose(1, 2).type_as(x)
40
40
 
41
41
 
42
- def apply_rope_inline(
43
- q: torch.Tensor,
44
- k: torch.Tensor,
42
+ def build_rope(
45
43
  input_pos: torch.Tensor,
46
44
  n_elem: int,
45
+ head_dim: int,
47
46
  base: int = 10_000,
48
47
  ) -> Tuple[torch.Tensor, torch.Tensor]:
49
- """Computes rotary positional embedding inline for a query and key.
48
+ """Computes rotary positional embedding cosine and sine tensors.
50
49
 
51
50
  Args:
52
- q: the query tensor.
53
- k: the key tensor.
54
51
  input_pos: the sequence indices for the query and key
55
52
  n_elem: number of elements of the head dimension for RoPE computation
53
+ base: the base of the exponentiated value for RoPE.
56
54
 
57
55
  Returns:
58
- output the RoPE'd query and key.
56
+ cos, sin tensors
59
57
  """
60
58
 
61
59
  if n_elem <= 0:
62
- return q, k
60
+ return None, None
63
61
 
64
- theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
65
62
  freq_exponents = (2.0 / n_elem) * torch.arange(
66
- q.shape[-1] // 2, dtype=torch.float32
63
+ head_dim // 2, dtype=torch.float32
67
64
  )
68
65
  timescale = float(base) ** freq_exponents
69
66
  radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
70
67
  0
71
68
  ).unsqueeze(0)
72
- cos = torch.cos(radians).type_as(q)
73
- sin = torch.sin(radians).type_as(q)
69
+ cos = torch.cos(radians)
70
+ sin = torch.sin(radians)
71
+ return cos, sin
72
+
74
73
 
75
- def apply(x, sin, cos):
76
- x = x.transpose(1, 2)
77
- b, h, s, d = x.shape
78
- ans = torch.split(x, d // 2, dim=-1)
79
- x1, x2 = ans
80
- left = x1 * cos - x2 * sin
81
- right = x2 * cos + x1 * sin
82
- res = torch.cat([left, right], dim=-1)
83
- res = res.transpose(1, 2)
84
- return res
74
+ def apply_rope_inline(
75
+ q: torch.Tensor,
76
+ k: torch.Tensor,
77
+ cos: torch.Tensor,
78
+ sin: torch.Tensor,
79
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
80
+ """Computes rotary positional embedding inline for a query and key.
81
+
82
+ Args:
83
+ q: the query tensor.
84
+ k: the key tensor.
85
+ cos: the cosine tensor.
86
+ sin: the sine tensor.
87
+
88
+ Returns:
89
+ output the RoPE'd query and key.
90
+ """
85
91
 
86
- q_roped = apply(q, sin, cos)
87
- k_roped = apply(k, sin, cos)
92
+ q_roped = apply_rope(q, cos, sin)
93
+ k_roped = apply_rope(k, cos, sin)
88
94
  return q_roped, k_roped
@@ -150,6 +150,16 @@ class TestModelConversion(googletest.TestCase):
150
150
  ai_edge_torch.config.in_oss,
151
151
  reason="tests with custom ops are not supported in oss",
152
152
  )
153
+
154
+ def test_smollm2(self):
155
+ config = smollm.get_fake_model_config_v2()
156
+ pytorch_model = smollm.SmolLM2(config).eval()
157
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
158
+ @googletest.skipIf(
159
+ ai_edge_torch.config.in_oss,
160
+ reason="tests with custom ops are not supported in oss",
161
+ )
162
+
153
163
  def test_openelm(self):
154
164
  config = openelm.get_fake_model_config()
155
165
  pytorch_model = openelm.OpenELM(config).eval()
@@ -25,6 +25,7 @@ from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
25
  from ai_edge_torch.generative.layers import lora as lora_utils
26
26
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
27
27
  import ai_edge_torch.generative.layers.model_config as cfg
28
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
28
29
  import ai_edge_torch.generative.utilities.loader as loading_utils
29
30
  import torch
30
31
  from torch import nn
@@ -87,13 +88,6 @@ class DecoderOnlyModel(nn.Module):
87
88
  config.embedding_dim,
88
89
  config.final_norm_config,
89
90
  )
90
- # ROPE parameters for all attn_configs are the same. Take the first one.
91
- attn_config = config.block_config(0).attn_config
92
- self.rope_cache = attn_utils.build_rope_cache(
93
- size=config.kv_cache_max,
94
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
95
- base=attn_config.rotary_base,
96
- )
97
91
  self.mask_cache = attn_utils.build_causal_mask_cache(
98
92
  size=config.kv_cache_max,
99
93
  )
@@ -116,8 +110,18 @@ class DecoderOnlyModel(nn.Module):
116
110
 
117
111
  # token embeddings of shape (b, t, n_embd)
118
112
  input_embeds = self.tok_embedding(tokens)
119
- cos, sin = self.rope_cache
120
- rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
113
+
114
+ # ROPE parameters for all attn_configs are the same. Take the first one.
115
+ attn_config = self.config.block_config(0).attn_config
116
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
117
+ rope = self.config.build_rope(
118
+ input_pos=input_pos,
119
+ n_elem=n_elem,
120
+ base=attn_config.rotary_base,
121
+ head_dim=attn_config.head_dim,
122
+ # input_pos=input_pos, n_elem=n_elem, base=attn_config.rotary_base
123
+ )
124
+
121
125
  mask = self.mask_cache.index_select(2, input_pos)
122
126
  mask = mask[:, :, :, : self.config.kv_cache_max]
123
127
 
@@ -145,14 +149,14 @@ class DecoderOnlyModel(nn.Module):
145
149
  if self.config.embedding_scale is not None:
146
150
  x = x * self.config.embedding_scale
147
151
 
148
- updated_kv_entires = []
152
+ updated_kv_entries = []
149
153
  for i, block in enumerate(self.transformer_blocks):
150
154
  kv_entry = kv_cache.caches[i] if kv_cache else None
151
155
  lora_adapter = lora.adapters[i] if lora else None
152
156
  x, kv_entry = block(x, rope, mask, input_pos, kv_entry, lora_adapter)
153
157
  if kv_entry:
154
- updated_kv_entires.append(kv_entry)
155
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
158
+ updated_kv_entries.append(kv_entry)
159
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
156
160
 
157
161
  if export_config is not None:
158
162
  if (
@@ -17,7 +17,7 @@ from typing import Any
17
17
  import uuid
18
18
 
19
19
  from ai_edge_torch import lowertools
20
- from ai_edge_torch.hlfb.mark_pattern import passes
20
+ from ai_edge_torch.hlfb.mark_pattern import fx_utils
21
21
  from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
22
22
  import torch
23
23
 
@@ -87,7 +87,7 @@ def mark_pattern(
87
87
  m.meta["ORIGINAL_NODE"] = n
88
88
 
89
89
  # Sanitize graph_module to match in the same way as pattern's graph_module.
90
- graph_module_to_match = passes.remove_clone_ops(graph_module_to_match)
90
+ graph_module_to_match = fx_utils.remove_clone_ops(graph_module_to_match)
91
91
 
92
92
  match_with_attrs = pattern.match(graph_module_to_match)
93
93
 
@@ -111,13 +111,25 @@ def mark_pattern(
111
111
  is_input=True,
112
112
  )
113
113
 
114
- # Only replace input by the marker node for those nodes used in the pattern.
114
+ # Only replace input by the marker node for those nodes used in the
115
+ # pattern.
115
116
  in_pattern_nodes = set(match.nodes_map.values())
116
117
  for user in input_node.users.keys():
117
- if user in in_pattern_nodes:
118
- user.meta["ORIGINAL_NODE"].replace_input_with(
119
- input_node.meta["ORIGINAL_NODE"], new_input_node
120
- )
118
+ if user not in in_pattern_nodes:
119
+ continue
120
+
121
+ user.meta["ORIGINAL_NODE"].replace_input_with(
122
+ input_node.meta["ORIGINAL_NODE"], new_input_node
123
+ )
124
+ # Pattern matching graph sanitization may remove clone ops, which means
125
+ # the user's input in the original graph may be a clone op. When
126
+ # replacing the input with the marker node, we need to further try
127
+ # replacing the input of the clone op that connects to the user.
128
+ for original_user_input in user.meta["ORIGINAL_NODE"].all_input_nodes:
129
+ if fx_utils.is_clone_op(original_user_input):
130
+ original_user_input.replace_input_with(
131
+ input_node.meta["ORIGINAL_NODE"], new_input_node
132
+ )
121
133
 
122
134
  for i, pattern_output_node in enumerate(pattern.output_nodes):
123
135
  output_node = match.nodes_map[pattern_output_node]
@@ -12,11 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Passes to clean up the model graph for pattern matching."""
15
+ """FX graph utilities for pattern matching clean ups."""
16
16
 
17
17
  import torch
18
18
 
19
19
 
20
+ def is_clone_op(node: torch.fx.Node) -> bool:
21
+ """Checks if the node is a clone op."""
22
+ return (
23
+ node.op == "call_function" and node.target == torch.ops.aten.clone.default
24
+ )
25
+
26
+
20
27
  def remove_clone_ops(gm: torch.fx.GraphModule):
21
28
  """Removes clone ops from the graph.
22
29
 
@@ -32,7 +39,7 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
32
39
  The graph module with clone ops removed.
33
40
  """
34
41
  for node in gm.graph.nodes:
35
- if node.op == "call_function" and node.name.startswith("clone"):
42
+ if is_clone_op(node):
36
43
  node.replace_all_uses_with(node.args[0])
37
44
  gm.graph.erase_node(node)
38
45
 
@@ -18,13 +18,14 @@ import dataclasses
18
18
  from typing import Any, Callable, Optional, Union
19
19
 
20
20
  from ai_edge_torch import fx_pass_base
21
- from ai_edge_torch.hlfb.mark_pattern import passes
21
+ from ai_edge_torch.hlfb.mark_pattern import fx_utils
22
22
  import torch
23
- from torch.export.graph_signature import TensorArgument
24
- from torch.fx import Graph
25
- from torch.fx import GraphModule
26
- from torch.fx.passes.utils.matcher_utils import InternalMatch
27
- from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
23
+
24
+ Graph = torch.fx.Graph
25
+ GraphModule = torch.fx.GraphModule
26
+ TensorArgument = torch.export.graph_signature.TensorArgument
27
+ InternalMatch = torch.fx.passes.utils.matcher_utils.InternalMatch
28
+ SubgraphMatcher = torch.fx.passes.utils.matcher_utils.SubgraphMatcher
28
29
 
29
30
 
30
31
  def _are_equal(x: Any, y: Any) -> bool:
@@ -219,8 +220,8 @@ class Pattern:
219
220
  # Sanitize graph_module for more precise pattern matching.
220
221
  # The graph_module to match against this pattern should apply equivalent
221
222
  # sanitization.
222
- self.graph_module = passes.remove_clone_ops(self.graph_module)
223
- self.graph_module = passes.remove_dangling_args(self.graph_module)
223
+ self.graph_module = fx_utils.remove_clone_ops(self.graph_module)
224
+ self.graph_module = fx_utils.remove_dangling_args(self.graph_module)
224
225
 
225
226
  # Builds list of ordered input and output nodes.
226
227
  self.graph_nodes_map = {}
@@ -58,6 +58,32 @@ class TestMarkPattern(googletest.TestCase):
58
58
  {"stablehlo.custom_call @mark_tensor": 6},
59
59
  )
60
60
 
61
+ def test_mark_pattern_with_clone_inputs(self):
62
+
63
+ class TestModel(torch.nn.Module):
64
+
65
+ def forward(self, x):
66
+ return torch.ops.aten.clone.default(x * x) + x
67
+
68
+ pattern = pattern_module.Pattern(
69
+ "test.add",
70
+ lambda a, b: a + b,
71
+ export_args=(torch.rand(2, 2), torch.rand(2, 2)),
72
+ )
73
+
74
+ model = TestModel().eval()
75
+ args = (torch.rand(20, 20),)
76
+ exported_program = torch.export.export(model, args)
77
+ mark_pattern.mark_pattern(exported_program.graph_module, pattern)
78
+ mlir = _export_stablehlo_mlir(exported_program)
79
+
80
+ lowertools.assert_string_count(
81
+ self,
82
+ mlir,
83
+ {'stablehlo.composite "test.add"': 1},
84
+ {"stablehlo.custom_call @mark_tensor": 3},
85
+ )
86
+
61
87
  def test_mark_pattern_with_attr_builder(self):
62
88
  class TestModel(torch.nn.Module):
63
89
 
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250108"
16
+ __version__ = "0.3.0.dev20250109"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250108
3
+ Version: 0.3.0.dev20250109
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=NOhiFx3WkuX_tsxWzAZcCmPr0n5wuIu79KHGbDtrbb8,706
6
+ ai_edge_torch/version.py,sha256=kM89dmK5VqznvQQJTvtq94oCbRtajNvkLPCCWSJxFSY,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -47,13 +47,13 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
47
47
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
48
48
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
49
49
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
50
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=whQ6DEnmhmj9hd5OyaoEI-FUNJ4m302vY3Swo_IqQcA,9285
50
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=pXilP6DHqVdcFH1TpIAtcwAQZH2_jZ6Tz41ddlXZXMs,10177
51
51
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
52
52
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
53
53
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
54
54
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
55
55
  ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=tMSsqg7LU3LR-PHtKvlWtLCqlk71mfcO9hANU4vnvDM,2734
56
- ai_edge_torch/generative/examples/llama/llama.py,sha256=BMjpdw6oOXmtqXCAfW9o7Iewaj-Hxd57xVrvSLBuHTk,6656
56
+ ai_edge_torch/generative/examples/llama/llama.py,sha256=kWy6-V4bFtE1yguCROLJS5XB0GOJD1-acJWp2dFjB5Q,6606
57
57
  ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
58
58
  ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
59
59
  ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
@@ -76,7 +76,7 @@ ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_
76
76
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=P2K6G7bNespSJLk72qxuCLaCcR_xAPs0Mn1dBZoByhE,2518
77
77
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=g-MvEibJT_iIhkec2VGtFFA_iP54VCq9mY4KxwAYF08,2512
78
78
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=c6PYCky7yJn6MVIYOCTx8S_CH27kOPmJbRZcI95nbZs,3477
79
- ai_edge_torch/generative/examples/phi/phi3.py,sha256=7Y1E4XpRuZOiSbeZJ-C2uJjmlnDtWv6L0XvPRE8oEQs,7112
79
+ ai_edge_torch/generative/examples/phi/phi3.py,sha256=SHvJjmi5eIch5cYIWORt6YFmSQx_oCiOk1UbKKGibtk,7119
80
80
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
81
81
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
82
82
  ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -85,8 +85,9 @@ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-L
85
85
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
86
86
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
87
87
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
88
- ai_edge_torch/generative/examples/smollm/smollm.py,sha256=kk3cB_qaCzbFOhHtJlLb7qvSEBQTsILnoAcSFE3AkpE,2711
89
- ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
88
+ ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=CjY1i0iCYxFSjhCpQZwxkmVxILgeo0zu1m0oBrHqyDU,2311
89
+ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=3uUltb6D3Q1aHpndcYTJrsWM_RBwLAraKDniH8ZZous,3779
90
+ ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68n0owZMvty5Rr3W7ZNWWSw,2702
90
91
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
91
92
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
92
93
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
@@ -109,7 +110,7 @@ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqX
109
110
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
111
  ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
111
112
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
112
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=WMl1iuCE8So9FDnxPV0OTMzuPngQUTO61g8rfnBLyB4,4664
113
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=C9dzJFK3TybxKpM1vSdLjOKftkJ72DGjr8YR4H7vCe8,4664
113
114
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
114
115
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=VU0c5pgvrUtaTboT1xuDBGjpKOM85aqtaB_hYfSBuEk,2544
115
116
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
@@ -117,15 +118,15 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
117
118
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
118
119
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
119
120
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
120
- ai_edge_torch/generative/layers/attention.py,sha256=03YlYLYCD8kxkxlGwRcmw4rFEA2bI8BP6_o5gflnaXQ,14522
121
+ ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8kQ4U3YANfSiTJKn8,13776
121
122
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
122
123
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
123
124
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
124
125
  ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
125
126
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
126
- ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
127
+ ai_edge_torch/generative/layers/model_config.py,sha256=9yPEmWNw3-_2wXBmPmZ7RUKcPXHF2ZbJwksyQoXTA6M,7784
127
128
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
128
- ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
129
+ ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=1L1MEGPYbDELi0zy2OKl7yXyk9FXdBjcXwRZbfiJriU,2619
129
130
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
130
131
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
131
132
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
@@ -144,25 +145,25 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
144
145
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
145
146
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
146
147
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
147
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=NctnggTSFh0XEQbTu55diZ35rFD2QIARO-8PzLktRWg,12165
148
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bBcey-aD4L_TwKRrrM81bN2VQoJjPPC84Rv4o3WOc34,12491
148
149
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
149
150
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
150
151
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
151
152
  ai_edge_torch/generative/utilities/converter.py,sha256=MY8BK29yD-W4v45Xdl_ErbNilipsTlD-4-y9MyBxR5g,7620
152
153
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
153
154
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
154
- ai_edge_torch/generative/utilities/model_builder.py,sha256=3XhB3fJXXIJEiqw9eqtIRY86lbED1BpjVGOdt7z5kpE,6611
155
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=yAO4VcYex21fDpuApewr0cNqgmxJljxonMd6450kblg,6710
155
156
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
156
157
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
157
158
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
158
159
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
159
160
  ai_edge_torch/generative/utilities/verifier.py,sha256=6lnBU9Cy5GanB8JWK3-2_VU3PxqunDWGe-SgSLba5Yw,12065
160
161
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
161
- ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
162
- ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
163
- ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNghuFKKJOQskUJFI,10041
162
+ ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=-BYE7MGMxr-VfBy8tAiiOaCqYv8ytJ0w5l2P8B7h3eM,5387
163
+ ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=taWLpF5IVglxlsF9HM2dIoKDXuQREaCRAXtJeG5gKzs,2073
164
+ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=7bv9XqRkm1pjxiVL4Cm1cArExnolId8hQKFHtvlkCI8,10061
164
165
  ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
165
- ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
166
+ ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=-5UqJyk__1YbUNGuxi4b2sn0CED0W-G337AXwxPGdEs,5567
166
167
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
167
168
  ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
168
169
  ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
@@ -205,8 +206,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
205
206
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
206
207
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
207
208
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
208
- ai_edge_torch_nightly-0.3.0.dev20250108.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
209
- ai_edge_torch_nightly-0.3.0.dev20250108.dist-info/METADATA,sha256=npbaQVRzcYi1A0B6ylSeKLzLGHmmcr6ELSuOU2vXo_0,1966
210
- ai_edge_torch_nightly-0.3.0.dev20250108.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
211
- ai_edge_torch_nightly-0.3.0.dev20250108.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
212
- ai_edge_torch_nightly-0.3.0.dev20250108.dist-info/RECORD,,
209
+ ai_edge_torch_nightly-0.3.0.dev20250109.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
210
+ ai_edge_torch_nightly-0.3.0.dev20250109.dist-info/METADATA,sha256=bkCouLqAI9GXCpiduHyj21ZElW42bdt0w6K5gWw1fOE,1966
211
+ ai_edge_torch_nightly-0.3.0.dev20250109.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
212
+ ai_edge_torch_nightly-0.3.0.dev20250109.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
213
+ ai_edge_torch_nightly-0.3.0.dev20250109.dist-info/RECORD,,