ai-edge-torch-nightly 0.3.0.dev20250108__py3-none-any.whl → 0.3.0.dev20250110__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma2.py +54 -24
  2. ai_edge_torch/generative/examples/llama/llama.py +29 -25
  3. ai_edge_torch/generative/examples/paligemma/decoder.py +4 -2
  4. ai_edge_torch/generative/examples/paligemma/decoder2.py +16 -11
  5. ai_edge_torch/generative/examples/paligemma/paligemma.py +3 -0
  6. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -1
  7. ai_edge_torch/generative/examples/phi/phi3.py +26 -23
  8. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  9. ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
  10. ai_edge_torch/generative/examples/smollm/verify.py +18 -2
  11. ai_edge_torch/generative/examples/test_models/toy_model.py +16 -5
  12. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -5
  13. ai_edge_torch/generative/layers/attention.py +4 -29
  14. ai_edge_torch/generative/layers/model_config.py +6 -2
  15. ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
  16. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  17. ai_edge_torch/generative/utilities/model_builder.py +20 -14
  18. ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
  19. ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
  20. ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
  21. ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
  22. ai_edge_torch/version.py +1 -1
  23. {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250110.dist-info}/METADATA +1 -1
  24. {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250110.dist-info}/RECORD +27 -26
  25. {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250110.dist-info}/LICENSE +0 -0
  26. {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250110.dist-info}/WHEEL +0 -0
  27. {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250110.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,14 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- from typing import Optional, Tuple
18
+ from typing import List, Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
25
26
  from ai_edge_torch.generative.utilities import model_builder
26
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
28
  import torch
@@ -103,17 +104,12 @@ class Gemma2(nn.Module):
103
104
  config.embedding_dim,
104
105
  config.final_norm_config,
105
106
  )
106
- # Gemma2 has same hyper parameters for each layer except for attention
107
- # types. Use the first layer.
108
- attn_config = config.block_config(0).attn_config
109
- self.rope_cache = attn_utils.build_rope_cache(
110
- size=config.kv_cache_max,
111
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
112
- base=attn_config.rotary_base,
113
- )
114
107
  self.mask_cache = attn_utils.build_causal_mask_cache(
115
108
  size=config.kv_cache_max,
116
109
  )
110
+ # Gemma2 has same hyper parameters for each layer except for attention
111
+ # types. Use the first layer.
112
+ attn_config = config.block_config(0).attn_config
117
113
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
118
114
  size=config.kv_cache_max,
119
115
  window_size=attn_config.sliding_window_size,
@@ -133,6 +129,7 @@ class Gemma2(nn.Module):
133
129
  tokens: torch.Tensor,
134
130
  input_pos: torch.Tensor,
135
131
  kv_cache: kv_utils.KVCache,
132
+ mask: Optional[torch.Tensor] = None,
136
133
  export_config: Optional[model_builder.ExportConfig] = None,
137
134
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
138
135
  _, seq_len = tokens.size()
@@ -140,29 +137,59 @@ class Gemma2(nn.Module):
140
137
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
141
138
  f" {self.config.max_seq_len}"
142
139
  )
140
+
141
+ # token embeddings of shape (b, t, n_embd)
142
+ input_embeds = self.tok_embedding(tokens)
143
+ # RoPE parameters are the same for all blocks. Use the first layer.
144
+ attn_config = self.config.block_config(0).attn_config
145
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
146
+ rope = rotary_pos_emb.build_rope(
147
+ input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
148
+ )
149
+ mask = [
150
+ self.get_attention_mask(
151
+ self.config.block_config(i).attn_config.attn_type, input_pos
152
+ )
153
+ for i in range(self.config.num_layers)
154
+ ]
155
+
156
+ return self._forward_with_embeds(
157
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
158
+ )
159
+
160
+ def _forward_with_embeds(
161
+ self,
162
+ input_embeds: torch.Tensor,
163
+ rope: Tuple[torch.Tensor, torch.Tensor],
164
+ mask: List[torch.Tensor],
165
+ input_pos: torch.Tensor,
166
+ kv_cache: kv_utils.KVCache,
167
+ export_config: Optional[model_builder.ExportConfig] = None,
168
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
169
+ """Forwards the model with input embeddings."""
143
170
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
144
171
  "The number of transformer blocks and the number of KV cache entries"
145
172
  " must be the same."
146
173
  )
147
174
 
148
- cos, sin = self.rope_cache
149
- cos = cos.index_select(0, input_pos)
150
- sin = sin.index_select(0, input_pos)
151
-
152
- # token embeddings of shape (b, t, n_embd)
153
- x = self.tok_embedding(tokens)
154
- x = x * (self.config.embedding_dim**0.5)
155
-
156
- updated_kv_entires = []
175
+ if self.config.embedding_scale is not None:
176
+ input_embeds = input_embeds * self.config.embedding_scale
177
+ x = input_embeds
178
+ updated_kv_entries = []
179
+ mask_input = mask is not None
157
180
  for i, block in enumerate(self.transformer_blocks):
158
- mask = self.get_attention_mask(
159
- block.config.attn_config.attn_type, input_pos
181
+ mask = (
182
+ mask
183
+ if mask_input
184
+ else self.get_attention_mask(
185
+ block.config.attn_config.attn_type, input_pos
186
+ )
160
187
  )
161
188
  kv_entry = kv_cache.caches[i] if kv_cache else None
162
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
189
+ x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
163
190
  if kv_entry:
164
- updated_kv_entires.append(kv_entry)
165
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
191
+ updated_kv_entries.append(kv_entry)
192
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
166
193
 
167
194
  if export_config is not None:
168
195
  if (
@@ -228,11 +255,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
228
255
  )
229
256
 
230
257
  num_layers = 26
258
+ embedding_dim = 2304
231
259
  config = cfg.ModelConfig(
232
260
  vocab_size=256000,
233
261
  num_layers=num_layers,
234
262
  max_seq_len=8192,
235
- embedding_dim=2304,
263
+ embedding_dim=embedding_dim,
264
+ embedding_scale=embedding_dim**0.5,
236
265
  kv_cache_max_len=kv_cache_max_len,
237
266
  block_configs=[get_block_config(i) for i in range(num_layers)],
238
267
  final_norm_config=norm_config,
@@ -249,6 +278,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
249
278
  config.num_layers = 2
250
279
  config.max_seq_len = 2 * kv_cache_max_len
251
280
  config.embedding_dim = 128
281
+ config.embedding_scale = config.embedding_dim**0.5
252
282
  config.block_configs = config.block_configs[: config.num_layers]
253
283
  for block_config in config.block_configs:
254
284
  block_config.attn_config.num_heads = 4
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Example of building Llama 3.2 models."""
17
17
 
18
+ from functools import partial
18
19
  import math
19
20
  from typing import Tuple
20
21
 
@@ -26,8 +27,8 @@ TENSOR_NAMES = model_builder.TENSOR_NAMES
26
27
 
27
28
 
28
29
  def _build_llama3_rope_cache(
29
- size: int,
30
- dim: int,
30
+ input_pos: torch.Tensor,
31
+ n_elem: int,
31
32
  base: int,
32
33
  condense_ratio: int,
33
34
  dtype: torch.dtype,
@@ -36,8 +37,9 @@ def _build_llama3_rope_cache(
36
37
  low_freq_factor: float,
37
38
  high_freq_factor: float,
38
39
  max_seq_len: int,
40
+ **kwargs,
39
41
  ) -> Tuple[torch.Tensor, torch.Tensor]:
40
- """Precomputes Rotary Positional Embeddings for Llama 3.2 model.
42
+ """Computes Rotary Positional Embeddings for Llama 3.2 model.
41
43
 
42
44
  It's a modified version of attn_utils.build_rope_cache with additional
43
45
  arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin
@@ -47,13 +49,12 @@ def _build_llama3_rope_cache(
47
49
  https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307
48
50
 
49
51
  Args:
50
- size (int): The size of the built cache.
51
- dim (int): Each sequence's dimmension.
52
- base (int, optional): Rope base value.
53
- condense_ratio (int, optional): The ratio by which sequence indicies are
54
- condensed.
55
- dtype (torch.dtype, optional): Output tensor's data type.
56
- device (torch.device, optional): Output tensor's data type.
52
+ input_pos (torch.Tensor): the given input sequence positions
53
+ n_elem (int): Each sequence's dimmension.
54
+ base (int): Rope base value.
55
+ condense_ratio (int): The ratio by which sequence indicies are condensed.
56
+ dtype (torch.dtype): Output tensor's data type.
57
+ device (torch.device): Output tensor's data type.
57
58
  factor (float): Factor to scale theta down for tokens in long range in the
58
59
  sequence.
59
60
  low_freq_factor (float): Factor to determine if tokens are in long range
@@ -66,7 +67,7 @@ def _build_llama3_rope_cache(
66
67
  Returns:
67
68
  Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
68
69
  """
69
- theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
70
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
70
71
  low_freq_wavelen = max_seq_len / low_freq_factor
71
72
  high_freq_wavelen = max_seq_len / high_freq_factor
72
73
  wavelen = 2 * math.pi / theta
@@ -81,7 +82,7 @@ def _build_llama3_rope_cache(
81
82
  is_medium = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
82
83
  theta = torch.where(is_medium, smoothed_theta, theta)
83
84
 
84
- seq_idx = torch.arange(size) / condense_ratio
85
+ seq_idx = input_pos / condense_ratio
85
86
  idx_theta = torch.outer(seq_idx, theta)
86
87
  cos = torch.cos(idx_theta).to(dtype=dtype, device=device)
87
88
  sin = torch.sin(idx_theta).to(dtype=dtype, device=device)
@@ -97,18 +98,6 @@ class Llama(model_builder.DecoderOnlyModel):
97
98
  def __init__(self, config: cfg.ModelConfig):
98
99
  super().__init__(config)
99
100
  attn_config = self.config.block_config(0).attn_config
100
- self.rope_cache = _build_llama3_rope_cache(
101
- size=self.config.kv_cache_max,
102
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
103
- base=attn_config.rotary_base,
104
- condense_ratio=1,
105
- dtype=torch.float32,
106
- device=torch.device("cpu"),
107
- factor=32.0,
108
- low_freq_factor=1.0,
109
- high_freq_factor=4.0,
110
- max_seq_len=self.config.max_seq_len,
111
- )
112
101
 
113
102
 
114
103
  def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -140,15 +129,30 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
140
129
  pre_attention_norm_config=norm_config,
141
130
  post_attention_norm_config=norm_config,
142
131
  )
132
+
133
+ max_seq_len = 8192
134
+ # Create the RoPE callable
135
+ build_rope = partial(
136
+ _build_llama3_rope_cache,
137
+ condense_ratio=1,
138
+ dtype=torch.float32,
139
+ device=torch.device("cpu"),
140
+ factor=32.0,
141
+ low_freq_factor=1.0,
142
+ high_freq_factor=4.0,
143
+ max_seq_len=max_seq_len,
144
+ )
145
+
143
146
  config = cfg.ModelConfig(
144
147
  vocab_size=128256,
145
148
  num_layers=16,
146
- max_seq_len=8192,
149
+ max_seq_len=max_seq_len,
147
150
  embedding_dim=2048,
148
151
  kv_cache_max_len=kv_cache_max_len,
149
152
  block_configs=block_config,
150
153
  final_norm_config=norm_config,
151
154
  enable_hlfb=True,
155
+ build_rope=build_rope,
152
156
  )
153
157
  return config
154
158
 
@@ -54,6 +54,7 @@ class Decoder(model_builder.DecoderOnlyModel):
54
54
  input_pos: torch.Tensor,
55
55
  kv_cache: kv_utils.KVCache,
56
56
  input_embeds: torch.Tensor = None,
57
+ mask: Optional[torch.Tensor] = None,
57
58
  export_config: Optional[model_builder.ExportConfig] = None,
58
59
  called_by_generate: bool = True,
59
60
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
@@ -73,8 +74,9 @@ class Decoder(model_builder.DecoderOnlyModel):
73
74
  # The first part of input_embeds are image embeddings. Diagonal causal mask
74
75
  # doesn't work here.
75
76
  embeds_len = input_embeds.shape[1]
76
- mask = torch.zeros(embeds_len, self.config.kv_cache_max)
77
- mask[:, embeds_len:] = float("-inf")
77
+ if mask is None:
78
+ mask = torch.zeros(embeds_len, self.config.kv_cache_max)
79
+ mask[:, embeds_len:] = float("-inf")
78
80
 
79
81
  return self._forward_with_embeds(
80
82
  input_embeds, rope, mask, input_pos, kv_cache
@@ -57,6 +57,7 @@ class Decoder2(gemma2.Gemma2):
57
57
  input_pos: torch.Tensor,
58
58
  kv_cache: kv_utils.KVCache,
59
59
  input_embeds: torch.Tensor = None,
60
+ mask: Optional[torch.Tensor] = None,
60
61
  export_config: Optional[model_builder.ExportConfig] = None,
61
62
  called_by_generate: bool = True,
62
63
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
@@ -73,17 +74,21 @@ class Decoder2(gemma2.Gemma2):
73
74
  repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
74
75
  )
75
76
 
76
- if called_by_generate:
77
- # PaliGemma2 generate() use a diagonal causal mask even with image embeds.
78
- mask = [self.get_attention_mask(
79
- self.config.block_config(i).attn_config.attn_type, input_pos
80
- ) for i in range(self.config.num_layers)]
81
- else:
82
- # By default, don't mask image embeds with a diagonal causal mask.
83
- embeds_len = input_embeds.shape[1]
84
- mask = torch.zeros(embeds_len, self.config.kv_cache_max)
85
- mask[:, embeds_len:] = float("-inf")
86
- mask = [mask] * self.config.num_layers
77
+ if mask is None:
78
+ if called_by_generate:
79
+ # PaliGemma2 generate() use a diagonal causal mask even with image embeds.
80
+ mask = [
81
+ self.get_attention_mask(
82
+ self.config.block_config(i).attn_config.attn_type, input_pos
83
+ )
84
+ for i in range(self.config.num_layers)
85
+ ]
86
+ else:
87
+ # By default, don't mask image embeds with a diagonal causal mask.
88
+ embeds_len = input_embeds.shape[1]
89
+ mask = torch.zeros(embeds_len, self.config.kv_cache_max)
90
+ mask[:, embeds_len:] = float("-inf")
91
+ mask = [mask] * self.config.num_layers
87
92
 
88
93
  return self._forward_with_embeds(
89
94
  input_embeds, rope, mask, input_pos, kv_cache, export_config
@@ -70,6 +70,7 @@ class PaliGemma(nn.Module):
70
70
  tokens: torch.Tensor,
71
71
  input_pos: torch.Tensor,
72
72
  kv_cache: kv_utils.KVCache,
73
+ mask: Optional[torch.Tensor] = None,
73
74
  pixel_values: torch.Tensor = None,
74
75
  export_config: Optional[model_builder.ExportConfig] = None,
75
76
  called_by_generate: bool = True,
@@ -79,6 +80,7 @@ class PaliGemma(nn.Module):
79
80
  tokens=tokens,
80
81
  input_pos=input_pos,
81
82
  kv_cache=kv_cache,
83
+ mask=mask,
82
84
  input_embeds=None,
83
85
  export_config=export_config,
84
86
  called_by_generate=called_by_generate,
@@ -111,6 +113,7 @@ class PaliGemma(nn.Module):
111
113
  tokens=None,
112
114
  input_pos=input_pos,
113
115
  kv_cache=kv_cache,
116
+ mask=mask,
114
117
  input_embeds=input_embeds,
115
118
  export_config=export_config,
116
119
  called_by_generate=called_by_generate,
@@ -26,7 +26,7 @@ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
26
 
27
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
28
28
  'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
32
  _OUTPUT_PATH = flags.DEFINE_string(
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
17
17
 
18
+ from functools import partial
18
19
  import math
19
20
  from typing import Tuple
20
21
 
@@ -93,40 +94,41 @@ ROPE_SHORT_FACTOR = [
93
94
  ]
94
95
 
95
96
 
96
- def _build_rope_cache(
97
- size: int,
98
- dim: int,
97
+ def _build_phi3_rope(
98
+ input_pos: int,
99
+ n_elem: int,
99
100
  base: int,
100
101
  condense_ratio: int,
101
102
  dtype: torch.dtype,
102
103
  device: torch.device,
103
104
  theta_factors: torch.Tensor,
104
105
  scale: float,
106
+ **kwargs,
105
107
  ) -> Tuple[torch.Tensor, torch.Tensor]:
106
- """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
108
+ """Computes Rotary Positional Embeddings for Phi-3.5 model.
107
109
 
108
110
  It's a modified version of attn_utils.build_rope_cache with additional
109
111
  arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
110
112
  Cos values with scaling factors for quick lookup during the inference.
111
113
 
112
114
  Args:
113
- size (int): The size of the built cache.
114
- dim (int): Each sequence's dimmension.
115
+ input_pos (torch.Tensor): the given input sequence positions
116
+ n_elem (int): Each sequence's dimmension.
115
117
  base (int, optional): Rope base value.
116
118
  condense_ratio (int, optional): The ratio by which sequence indicies are
117
119
  condensed.
118
120
  dtype (torch.dtype, optional): Output tensor's data type.
119
121
  device (torch.device, optional): Output tensor's data type.
120
- theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
121
- scale the theta values.
122
+ theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used
123
+ to scale the theta values.
122
124
  scale (float, optional): A float used to scale the rope values.
123
125
 
124
126
  Returns:
125
127
  Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
126
128
  """
127
- theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
129
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
128
130
  theta = theta / theta_factors
129
- seq_idx = torch.arange(size) / condense_ratio
131
+ seq_idx = input_pos / condense_ratio
130
132
  idx_theta = torch.outer(seq_idx, theta)
131
133
  cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
132
134
  sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
@@ -139,18 +141,6 @@ class Phi3_5Mini(model_builder.DecoderOnlyModel):
139
141
  def __init__(self, config: cfg.ModelConfig):
140
142
  super().__init__(config)
141
143
  attn_config = self.config.block_config(0).attn_config
142
- self.rope_cache = _build_rope_cache(
143
- size=self.config.kv_cache_max,
144
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
145
- base=attn_config.rotary_base,
146
- condense_ratio=1,
147
- dtype=torch.float32,
148
- device=torch.device("cpu"),
149
- theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
150
- scale=math.sqrt(
151
- 1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
152
- ),
153
- )
154
144
 
155
145
 
156
146
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -183,16 +173,29 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
183
173
  pre_attention_norm_config=norm_config,
184
174
  post_attention_norm_config=norm_config,
185
175
  )
176
+ max_seq_len = 4096
177
+ # Create the RoPE callable
178
+ build_rope = partial(
179
+ _build_phi3_rope,
180
+ condense_ratio=1,
181
+ dtype=torch.float32,
182
+ device=torch.device("cpu"),
183
+ theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
184
+ scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)),
185
+ max_seq_len=max_seq_len,
186
+ )
187
+
186
188
  config = cfg.ModelConfig(
187
189
  vocab_size=32064,
188
190
  num_layers=32,
189
- max_seq_len=4096,
191
+ max_seq_len=max_seq_len,
190
192
  kv_cache_max_len=kv_cache_max_len,
191
193
  embedding_dim=3072,
192
194
  block_configs=block_config,
193
195
  final_norm_config=norm_config,
194
196
  lm_head_share_weight_with_embedding=False,
195
197
  enable_hlfb=True,
198
+ build_rope=build_rope,
196
199
  )
197
200
  return config
198
201
 
@@ -0,0 +1,71 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting SmolLM2 model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.smollm import smollm
24
+ from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
+
27
+ _CHECKPOINT_PATH = flags.DEFINE_string(
28
+ 'checkpoint_path',
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm2'),
30
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
31
+ )
32
+ _TFLITE_PATH = flags.DEFINE_string(
33
+ 'tflite_path',
34
+ '/tmp/',
35
+ 'The tflite file path to export.',
36
+ )
37
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
+ 'prefill_seq_lens',
39
+ (8, 64, 128, 256, 512, 1024),
40
+ 'List of the maximum sizes of prefill input tensors.',
41
+ )
42
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
43
+ 'kv_cache_max_len',
44
+ 1280,
45
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
46
+ )
47
+ _QUANTIZE = flags.DEFINE_bool(
48
+ 'quantize',
49
+ True,
50
+ 'Whether the model should be quantized.',
51
+ )
52
+
53
+
54
+ def main(_):
55
+ pytorch_model = smollm.build_model_v2(
56
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
+ )
58
+
59
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
60
+ output_filename = f'smollm2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
+ converter.convert_to_tflite(
62
+ pytorch_model,
63
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
64
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
65
+ quantize=_QUANTIZE.value,
66
+ export_config=ExportConfig(),
67
+ )
68
+
69
+
70
+ if __name__ == '__main__':
71
+ app.run(main)
@@ -85,3 +85,41 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
85
85
  tensor_names=TENSOR_NAMES,
86
86
  model_class=SmolLM,
87
87
  )
88
+
89
+
90
+ class SmolLM2(model_builder.DecoderOnlyModel):
91
+ """A SmolLM2 model built from the Edge Generative API layers."""
92
+ pass
93
+
94
+
95
+ def get_model_config_v2(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
96
+ """Returns the model config for a SmolLM2 135M model.
97
+
98
+ Args:
99
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
100
+ is 1024.
101
+
102
+ Returns:
103
+ The model config for a SmolLM2 model.
104
+ """
105
+ config = get_model_config(kv_cache_max_len)
106
+ config.block_config(0).attn_config.rotary_base = 100000
107
+ return config
108
+
109
+
110
+ def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
111
+ config = get_model_config_v2(**kwargs)
112
+ config.vocab_size = 128
113
+ config.num_layers = 2
114
+ # SmolLM2 has only one block config.
115
+ config.block_config(0).ff_config.intermediate_size = 64
116
+ return config
117
+
118
+
119
+ def build_model_v2(checkpoint_path: str, **kwargs) -> nn.Module:
120
+ return model_builder.build_decoder_only_model(
121
+ checkpoint_path=checkpoint_path,
122
+ config=get_model_config_v2(**kwargs),
123
+ tensor_names=TENSOR_NAMES,
124
+ model_class=SmolLM2,
125
+ )
@@ -36,10 +36,26 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
36
36
  30,
37
37
  "The maximum size of the generated tokens.",
38
38
  )
39
+ _MODEL_VERSION = flags.DEFINE_enum(
40
+ "model_version",
41
+ "v1",
42
+ ["v1", "v2"],
43
+ "The version of SmolLm to verify.",
44
+ )
45
+ _CHECKPOINT = {
46
+ "v1": "HuggingFaceTB/SmolLM-135M",
47
+ "v2": "HuggingFaceTB/SmolLM2-135M",
48
+ }
49
+
50
+ _BUILDER = {
51
+ "v1": smollm.build_model,
52
+ "v2": smollm.build_model_v2,
53
+ }
39
54
 
40
55
 
41
56
  def main(_):
42
- checkpoint = "HuggingFaceTB/SmolLM-135M"
57
+ checkpoint = _CHECKPOINT[_MODEL_VERSION.value]
58
+ builder = _BUILDER[_MODEL_VERSION.value]
43
59
  logging.info("Loading the original model from: %s", checkpoint)
44
60
  original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
61
 
@@ -49,7 +65,7 @@ def main(_):
49
65
  )
50
66
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
67
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
- reauthored_model = smollm.build_model(reauthored_checkpoint)
68
+ reauthored_model = builder(reauthored_checkpoint)
53
69
 
54
70
  logging.info("Loading the tokenizer from: %s", checkpoint)
55
71
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  # A toy example which has a single-layer transformer block.
16
- from typing import Tuple
16
+ from typing import Optional, Tuple
17
17
 
18
18
  from ai_edge_torch.generative.layers import builder
19
19
  from ai_edge_torch.generative.layers.attention import TransformerBlock
@@ -52,14 +52,20 @@ class ToySingleLayerModel(torch.nn.Module):
52
52
  self.config = config
53
53
 
54
54
  @torch.inference_mode
55
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
55
+ def forward(
56
+ self,
57
+ idx: torch.Tensor,
58
+ input_pos: torch.Tensor,
59
+ mask: Optional[torch.Tensor] = None,
60
+ ) -> torch.Tensor:
56
61
  x = self.tok_embedding(idx)
57
62
  cos, sin = self.rope_cache
58
63
 
59
64
  cos = cos.index_select(0, input_pos)
60
65
  sin = sin.index_select(0, input_pos)
61
- mask = self.mask_cache.index_select(2, input_pos)
62
- mask = mask[:, :, :, : self.config.max_seq_len]
66
+ if mask is None:
67
+ mask = self.mask_cache.index_select(2, input_pos)
68
+ mask = mask[:, :, :, : self.config.max_seq_len]
63
69
 
64
70
  x = self.transformer_block(x, (cos, sin), mask, input_pos)
65
71
  x = self.final_norm(x)
@@ -98,7 +104,12 @@ class ToySingleLayerModelWeightSharing(torch.nn.Module):
98
104
  self.config = config
99
105
 
100
106
  @torch.inference_mode
101
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
107
+ def forward(
108
+ self,
109
+ idx: torch.Tensor,
110
+ input_pos: torch.Tensor,
111
+ mask: Optional[torch.Tensor] = None,
112
+ ) -> torch.Tensor:
102
113
  x = self.tok_embedding(idx)
103
114
  cos, sin = self.rope_cache
104
115
 
@@ -63,23 +63,25 @@ class ToyModelWithKVCache(torch.nn.Module):
63
63
  tokens: torch.Tensor,
64
64
  input_pos: torch.Tensor,
65
65
  kv_cache: kv_utils.KVCache,
66
+ mask: Optional[torch.Tensor] = None,
66
67
  export_config: Optional[ExportConfig] = None,
67
68
  ) -> Tuple[torch.Tensor, kv_utils.KVCache]:
68
69
  x = self.tok_embedding(tokens)
69
70
  cos, sin = self.rope_cache
70
71
  cos = cos.index_select(0, input_pos)
71
72
  sin = sin.index_select(0, input_pos)
72
- mask = self.mask_cache.index_select(2, input_pos)
73
- mask = mask[:, :, :, : self.config.max_seq_len]
73
+ if mask is None:
74
+ mask = self.mask_cache.index_select(2, input_pos)
75
+ mask = mask[:, :, :, : self.config.max_seq_len]
74
76
 
75
- updated_kv_entires = []
77
+ updated_kv_entries = []
76
78
  for i, block in enumerate(self.transformer_blocks):
77
79
  kv_entry = kv_cache.caches[i] if kv_cache else None
78
80
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
79
81
  if kv_entry:
80
- updated_kv_entires.append(kv_entry)
82
+ updated_kv_entries.append(kv_entry)
81
83
 
82
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
84
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
83
85
 
84
86
  if export_config is not None:
85
87
  if (
@@ -27,33 +27,6 @@ import torch
27
27
  from torch import nn
28
28
 
29
29
 
30
- def _embed_rope(
31
- q: torch.Tensor,
32
- k: torch.Tensor,
33
- n_elem: int,
34
- rope: Tuple[torch.Tensor, torch.Tensor],
35
- ) -> Tuple[torch.Tensor, torch.Tensor]:
36
- """Embed rotary positional embedding for query and key.
37
-
38
- Args:
39
- q (torch.Tensor): query tensor.
40
- k (torch.Tensor): key tensor.
41
- n_elem (int): number of elements to embed rotarty positional embedding.
42
- rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
43
- """
44
- if n_elem > 0:
45
- cos, sin = rope
46
- q_roped = rotary_pos_emb.apply_rope(
47
- q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
48
- )
49
- k_roped = rotary_pos_emb.apply_rope(
50
- k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
51
- )
52
- q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
53
- k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
54
- return q, k
55
-
56
-
57
30
  class TransformerBlock(nn.Module):
58
31
 
59
32
  def __init__(
@@ -252,7 +225,8 @@ class CausalSelfAttention(nn.Module):
252
225
  if rope is not None:
253
226
  # Compute rotary positional embedding for query and key.
254
227
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
255
- q, k = _embed_rope(q, k, n_elem, rope)
228
+ cos, sin = rope
229
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
256
230
 
257
231
  if kv_cache is not None:
258
232
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -404,7 +378,8 @@ class CrossAttention(nn.Module):
404
378
  if rope is not None:
405
379
  # Compute rotary positional embedding for query and key.
406
380
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
407
- q, k = _embed_rope(q, k, n_elem, rope)
381
+ cos, sin = rope
382
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
408
383
 
409
384
  if kv_cache is not None:
410
385
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -17,8 +17,8 @@
17
17
 
18
18
  import dataclasses
19
19
  import enum
20
- from typing import Optional, Sequence, Union
21
-
20
+ from typing import Callable, Optional, Sequence, Union
21
+ from ai_edge_torch.generative.layers import rotary_position_embedding
22
22
 
23
23
  @enum.unique
24
24
  class ActivationType(enum.Enum):
@@ -218,6 +218,10 @@ class ModelConfig:
218
218
  # Softcap on the model output logits.
219
219
  final_logit_softcap: Optional[float] = None
220
220
 
221
+ # The function to call to create the RoPE sin and cos vectors during the
222
+ # forward pass. Defaults to a standard implementation.
223
+ build_rope: Callable = rotary_position_embedding.build_rope
224
+
221
225
  @property
222
226
  def kv_cache_max(self) -> int:
223
227
  if self.kv_cache_max_len > 0:
@@ -32,57 +32,63 @@ def apply_rope(
32
32
  """
33
33
  x = x.transpose(1, 2)
34
34
  head_size = x.size(-1)
35
- x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
36
- x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
37
- rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
38
- roped = (x * cos) + (rotated * sin)
35
+ x1, x2 = torch.split(x, head_size // 2, dim=-1)
36
+ left = x1 * cos - x2 * sin
37
+ right = x2 * cos + x1 * sin
38
+ roped = torch.cat([left, right], dim=-1)
39
39
  return roped.transpose(1, 2).type_as(x)
40
40
 
41
41
 
42
- def apply_rope_inline(
43
- q: torch.Tensor,
44
- k: torch.Tensor,
42
+ def build_rope(
45
43
  input_pos: torch.Tensor,
46
44
  n_elem: int,
45
+ head_dim: int,
47
46
  base: int = 10_000,
48
47
  ) -> Tuple[torch.Tensor, torch.Tensor]:
49
- """Computes rotary positional embedding inline for a query and key.
48
+ """Computes rotary positional embedding cosine and sine tensors.
50
49
 
51
50
  Args:
52
- q: the query tensor.
53
- k: the key tensor.
54
51
  input_pos: the sequence indices for the query and key
55
52
  n_elem: number of elements of the head dimension for RoPE computation
53
+ base: the base of the exponentiated value for RoPE.
56
54
 
57
55
  Returns:
58
- output the RoPE'd query and key.
56
+ cos, sin tensors
59
57
  """
60
58
 
61
59
  if n_elem <= 0:
62
- return q, k
60
+ return None, None
63
61
 
64
- theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
65
62
  freq_exponents = (2.0 / n_elem) * torch.arange(
66
- q.shape[-1] // 2, dtype=torch.float32
63
+ head_dim // 2, dtype=torch.float32
67
64
  )
68
65
  timescale = float(base) ** freq_exponents
69
66
  radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
70
67
  0
71
68
  ).unsqueeze(0)
72
- cos = torch.cos(radians).type_as(q)
73
- sin = torch.sin(radians).type_as(q)
69
+ cos = torch.cos(radians)
70
+ sin = torch.sin(radians)
71
+ return cos, sin
72
+
74
73
 
75
- def apply(x, sin, cos):
76
- x = x.transpose(1, 2)
77
- b, h, s, d = x.shape
78
- ans = torch.split(x, d // 2, dim=-1)
79
- x1, x2 = ans
80
- left = x1 * cos - x2 * sin
81
- right = x2 * cos + x1 * sin
82
- res = torch.cat([left, right], dim=-1)
83
- res = res.transpose(1, 2)
84
- return res
74
+ def apply_rope_inline(
75
+ q: torch.Tensor,
76
+ k: torch.Tensor,
77
+ cos: torch.Tensor,
78
+ sin: torch.Tensor,
79
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
80
+ """Computes rotary positional embedding inline for a query and key.
81
+
82
+ Args:
83
+ q: the query tensor.
84
+ k: the key tensor.
85
+ cos: the cosine tensor.
86
+ sin: the sine tensor.
87
+
88
+ Returns:
89
+ output the RoPE'd query and key.
90
+ """
85
91
 
86
- q_roped = apply(q, sin, cos)
87
- k_roped = apply(k, sin, cos)
92
+ q_roped = apply_rope(q, cos, sin)
93
+ k_roped = apply_rope(k, cos, sin)
88
94
  return q_roped, k_roped
@@ -150,6 +150,16 @@ class TestModelConversion(googletest.TestCase):
150
150
  ai_edge_torch.config.in_oss,
151
151
  reason="tests with custom ops are not supported in oss",
152
152
  )
153
+
154
+ def test_smollm2(self):
155
+ config = smollm.get_fake_model_config_v2()
156
+ pytorch_model = smollm.SmolLM2(config).eval()
157
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
158
+ @googletest.skipIf(
159
+ ai_edge_torch.config.in_oss,
160
+ reason="tests with custom ops are not supported in oss",
161
+ )
162
+
153
163
  def test_openelm(self):
154
164
  config = openelm.get_fake_model_config()
155
165
  pytorch_model = openelm.OpenELM(config).eval()
@@ -25,6 +25,7 @@ from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
25
  from ai_edge_torch.generative.layers import lora as lora_utils
26
26
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
27
27
  import ai_edge_torch.generative.layers.model_config as cfg
28
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
28
29
  import ai_edge_torch.generative.utilities.loader as loading_utils
29
30
  import torch
30
31
  from torch import nn
@@ -87,13 +88,6 @@ class DecoderOnlyModel(nn.Module):
87
88
  config.embedding_dim,
88
89
  config.final_norm_config,
89
90
  )
90
- # ROPE parameters for all attn_configs are the same. Take the first one.
91
- attn_config = config.block_config(0).attn_config
92
- self.rope_cache = attn_utils.build_rope_cache(
93
- size=config.kv_cache_max,
94
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
95
- base=attn_config.rotary_base,
96
- )
97
91
  self.mask_cache = attn_utils.build_causal_mask_cache(
98
92
  size=config.kv_cache_max,
99
93
  )
@@ -105,6 +99,7 @@ class DecoderOnlyModel(nn.Module):
105
99
  tokens: torch.Tensor,
106
100
  input_pos: torch.Tensor,
107
101
  kv_cache: kv_utils.KVCache,
102
+ mask: Optional[torch.Tensor] = None,
108
103
  lora: Optional[lora_utils.LoRA] = None,
109
104
  export_config: Optional[ExportConfig] = None,
110
105
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
@@ -116,10 +111,21 @@ class DecoderOnlyModel(nn.Module):
116
111
 
117
112
  # token embeddings of shape (b, t, n_embd)
118
113
  input_embeds = self.tok_embedding(tokens)
119
- cos, sin = self.rope_cache
120
- rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
121
- mask = self.mask_cache.index_select(2, input_pos)
122
- mask = mask[:, :, :, : self.config.kv_cache_max]
114
+
115
+ # ROPE parameters for all attn_configs are the same. Take the first one.
116
+ attn_config = self.config.block_config(0).attn_config
117
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
118
+ rope = self.config.build_rope(
119
+ input_pos=input_pos,
120
+ n_elem=n_elem,
121
+ base=attn_config.rotary_base,
122
+ head_dim=attn_config.head_dim,
123
+ # input_pos=input_pos, n_elem=n_elem, base=attn_config.rotary_base
124
+ )
125
+
126
+ if mask is None:
127
+ mask = self.mask_cache.index_select(2, input_pos)
128
+ mask = mask[:, :, :, : self.config.kv_cache_max]
123
129
 
124
130
  return self.forward_with_embeds(
125
131
  input_embeds, rope, mask, input_pos, kv_cache, lora, export_config
@@ -145,14 +151,14 @@ class DecoderOnlyModel(nn.Module):
145
151
  if self.config.embedding_scale is not None:
146
152
  x = x * self.config.embedding_scale
147
153
 
148
- updated_kv_entires = []
154
+ updated_kv_entries = []
149
155
  for i, block in enumerate(self.transformer_blocks):
150
156
  kv_entry = kv_cache.caches[i] if kv_cache else None
151
157
  lora_adapter = lora.adapters[i] if lora else None
152
158
  x, kv_entry = block(x, rope, mask, input_pos, kv_entry, lora_adapter)
153
159
  if kv_entry:
154
- updated_kv_entires.append(kv_entry)
155
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
160
+ updated_kv_entries.append(kv_entry)
161
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
156
162
 
157
163
  if export_config is not None:
158
164
  if (
@@ -17,7 +17,7 @@ from typing import Any
17
17
  import uuid
18
18
 
19
19
  from ai_edge_torch import lowertools
20
- from ai_edge_torch.hlfb.mark_pattern import passes
20
+ from ai_edge_torch.hlfb.mark_pattern import fx_utils
21
21
  from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
22
22
  import torch
23
23
 
@@ -87,7 +87,7 @@ def mark_pattern(
87
87
  m.meta["ORIGINAL_NODE"] = n
88
88
 
89
89
  # Sanitize graph_module to match in the same way as pattern's graph_module.
90
- graph_module_to_match = passes.remove_clone_ops(graph_module_to_match)
90
+ graph_module_to_match = fx_utils.remove_clone_ops(graph_module_to_match)
91
91
 
92
92
  match_with_attrs = pattern.match(graph_module_to_match)
93
93
 
@@ -111,13 +111,25 @@ def mark_pattern(
111
111
  is_input=True,
112
112
  )
113
113
 
114
- # Only replace input by the marker node for those nodes used in the pattern.
114
+ # Only replace input by the marker node for those nodes used in the
115
+ # pattern.
115
116
  in_pattern_nodes = set(match.nodes_map.values())
116
117
  for user in input_node.users.keys():
117
- if user in in_pattern_nodes:
118
- user.meta["ORIGINAL_NODE"].replace_input_with(
119
- input_node.meta["ORIGINAL_NODE"], new_input_node
120
- )
118
+ if user not in in_pattern_nodes:
119
+ continue
120
+
121
+ user.meta["ORIGINAL_NODE"].replace_input_with(
122
+ input_node.meta["ORIGINAL_NODE"], new_input_node
123
+ )
124
+ # Pattern matching graph sanitization may remove clone ops, which means
125
+ # the user's input in the original graph may be a clone op. When
126
+ # replacing the input with the marker node, we need to further try
127
+ # replacing the input of the clone op that connects to the user.
128
+ for original_user_input in user.meta["ORIGINAL_NODE"].all_input_nodes:
129
+ if fx_utils.is_clone_op(original_user_input):
130
+ original_user_input.replace_input_with(
131
+ input_node.meta["ORIGINAL_NODE"], new_input_node
132
+ )
121
133
 
122
134
  for i, pattern_output_node in enumerate(pattern.output_nodes):
123
135
  output_node = match.nodes_map[pattern_output_node]
@@ -12,11 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Passes to clean up the model graph for pattern matching."""
15
+ """FX graph utilities for pattern matching clean ups."""
16
16
 
17
17
  import torch
18
18
 
19
19
 
20
+ def is_clone_op(node: torch.fx.Node) -> bool:
21
+ """Checks if the node is a clone op."""
22
+ return (
23
+ node.op == "call_function" and node.target == torch.ops.aten.clone.default
24
+ )
25
+
26
+
20
27
  def remove_clone_ops(gm: torch.fx.GraphModule):
21
28
  """Removes clone ops from the graph.
22
29
 
@@ -32,7 +39,7 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
32
39
  The graph module with clone ops removed.
33
40
  """
34
41
  for node in gm.graph.nodes:
35
- if node.op == "call_function" and node.name.startswith("clone"):
42
+ if is_clone_op(node):
36
43
  node.replace_all_uses_with(node.args[0])
37
44
  gm.graph.erase_node(node)
38
45
 
@@ -18,13 +18,14 @@ import dataclasses
18
18
  from typing import Any, Callable, Optional, Union
19
19
 
20
20
  from ai_edge_torch import fx_pass_base
21
- from ai_edge_torch.hlfb.mark_pattern import passes
21
+ from ai_edge_torch.hlfb.mark_pattern import fx_utils
22
22
  import torch
23
- from torch.export.graph_signature import TensorArgument
24
- from torch.fx import Graph
25
- from torch.fx import GraphModule
26
- from torch.fx.passes.utils.matcher_utils import InternalMatch
27
- from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
23
+
24
+ Graph = torch.fx.Graph
25
+ GraphModule = torch.fx.GraphModule
26
+ TensorArgument = torch.export.graph_signature.TensorArgument
27
+ InternalMatch = torch.fx.passes.utils.matcher_utils.InternalMatch
28
+ SubgraphMatcher = torch.fx.passes.utils.matcher_utils.SubgraphMatcher
28
29
 
29
30
 
30
31
  def _are_equal(x: Any, y: Any) -> bool:
@@ -219,8 +220,8 @@ class Pattern:
219
220
  # Sanitize graph_module for more precise pattern matching.
220
221
  # The graph_module to match against this pattern should apply equivalent
221
222
  # sanitization.
222
- self.graph_module = passes.remove_clone_ops(self.graph_module)
223
- self.graph_module = passes.remove_dangling_args(self.graph_module)
223
+ self.graph_module = fx_utils.remove_clone_ops(self.graph_module)
224
+ self.graph_module = fx_utils.remove_dangling_args(self.graph_module)
224
225
 
225
226
  # Builds list of ordered input and output nodes.
226
227
  self.graph_nodes_map = {}
@@ -58,6 +58,32 @@ class TestMarkPattern(googletest.TestCase):
58
58
  {"stablehlo.custom_call @mark_tensor": 6},
59
59
  )
60
60
 
61
+ def test_mark_pattern_with_clone_inputs(self):
62
+
63
+ class TestModel(torch.nn.Module):
64
+
65
+ def forward(self, x):
66
+ return torch.ops.aten.clone.default(x * x) + x
67
+
68
+ pattern = pattern_module.Pattern(
69
+ "test.add",
70
+ lambda a, b: a + b,
71
+ export_args=(torch.rand(2, 2), torch.rand(2, 2)),
72
+ )
73
+
74
+ model = TestModel().eval()
75
+ args = (torch.rand(20, 20),)
76
+ exported_program = torch.export.export(model, args)
77
+ mark_pattern.mark_pattern(exported_program.graph_module, pattern)
78
+ mlir = _export_stablehlo_mlir(exported_program)
79
+
80
+ lowertools.assert_string_count(
81
+ self,
82
+ mlir,
83
+ {'stablehlo.composite "test.add"': 1},
84
+ {"stablehlo.custom_call @mark_tensor": 3},
85
+ )
86
+
61
87
  def test_mark_pattern_with_attr_builder(self):
62
88
  class TestModel(torch.nn.Module):
63
89
 
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250108"
16
+ __version__ = "0.3.0.dev20250110"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250108
3
+ Version: 0.3.0.dev20250110
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=NOhiFx3WkuX_tsxWzAZcCmPr0n5wuIu79KHGbDtrbb8,706
6
+ ai_edge_torch/version.py,sha256=VHqAyYw4u6BgyQ6v7Xp08Jqb0cnzIVGsulfnclxgY5c,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -47,13 +47,13 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
47
47
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
48
48
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
49
49
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
50
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=whQ6DEnmhmj9hd5OyaoEI-FUNJ4m302vY3Swo_IqQcA,9285
50
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=e9HfiHr4FkQZwVBYdDUZGzOjB5TqY2LqtVTHEzwVkQY,10428
51
51
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
52
52
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
53
53
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
54
54
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
55
55
  ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=tMSsqg7LU3LR-PHtKvlWtLCqlk71mfcO9hANU4vnvDM,2734
56
- ai_edge_torch/generative/examples/llama/llama.py,sha256=BMjpdw6oOXmtqXCAfW9o7Iewaj-Hxd57xVrvSLBuHTk,6656
56
+ ai_edge_torch/generative/examples/llama/llama.py,sha256=kWy6-V4bFtE1yguCROLJS5XB0GOJD1-acJWp2dFjB5Q,6606
57
57
  ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
58
58
  ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
59
59
  ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
@@ -64,19 +64,19 @@ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2u
64
64
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
65
65
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
66
66
  ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=scLsguzzuHfKYDWUd2uZkKYVRzdAbQHLd-kPam8QwvM,3004
67
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=amN96oBMTPolOFvGa47vG92AZ-BNLm8j0bBYd-IrMvI,5407
68
- ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=0V_CX0Pn5Fj_-koOGjc_Av2KMSAaVjAlD-G8P6FBGyY,6385
67
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=NJGhfPxVQjHDqea_lYGffjihOBdIYiXftiFTM6ccrwM,5475
68
+ ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=L6F6KWHqxdnGQTOp9P3c8r_K1Uxet0ZCcbdvmjWtIos,6513
69
69
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
70
- ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=FwGlFHl9zktGDxnoOpEtbS6NYN5RyzcOXH7lvNUCwEU,6257
70
+ ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=CEMG9gh51ev1KXPew927a6nfampiXX9bL6m-25tNYN8,6340
71
71
  ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
72
72
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
73
73
  ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
74
74
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
75
75
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
76
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=P2K6G7bNespSJLk72qxuCLaCcR_xAPs0Mn1dBZoByhE,2518
76
+ ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=CaI_-Vtd0j9FoWIDd8q5z4CFsGYUhTwEWGvMGaXICuU,2514
77
77
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=g-MvEibJT_iIhkec2VGtFFA_iP54VCq9mY4KxwAYF08,2512
78
78
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=c6PYCky7yJn6MVIYOCTx8S_CH27kOPmJbRZcI95nbZs,3477
79
- ai_edge_torch/generative/examples/phi/phi3.py,sha256=7Y1E4XpRuZOiSbeZJ-C2uJjmlnDtWv6L0XvPRE8oEQs,7112
79
+ ai_edge_torch/generative/examples/phi/phi3.py,sha256=SHvJjmi5eIch5cYIWORt6YFmSQx_oCiOk1UbKKGibtk,7119
80
80
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
81
81
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
82
82
  ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -85,8 +85,9 @@ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-L
85
85
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
86
86
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
87
87
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
88
- ai_edge_torch/generative/examples/smollm/smollm.py,sha256=kk3cB_qaCzbFOhHtJlLb7qvSEBQTsILnoAcSFE3AkpE,2711
89
- ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
88
+ ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=CjY1i0iCYxFSjhCpQZwxkmVxILgeo0zu1m0oBrHqyDU,2311
89
+ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=3uUltb6D3Q1aHpndcYTJrsWM_RBwLAraKDniH8ZZous,3779
90
+ ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68n0owZMvty5Rr3W7ZNWWSw,2702
90
91
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
91
92
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
92
93
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
@@ -108,8 +109,8 @@ ai_edge_torch/generative/examples/t5/t5.py,sha256=gFTmPi-xB8pcPRgoF3DJxvH_fT-KWT
108
109
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqXquaFQPvCFBFF5zOnmGVb3Hg,8731
109
110
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
111
  ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
111
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
112
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=WMl1iuCE8So9FDnxPV0OTMzuPngQUTO61g8rfnBLyB4,4664
112
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=Crpj-vOwSViHpblXOrRJmsIn4DrHyuB3XZ8kHifb7LA,5203
113
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=Ab_N9xc-4DImA-Pvevr-nnnslBXScXVo4Pw7L3_OlhI,4732
113
114
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
114
115
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=VU0c5pgvrUtaTboT1xuDBGjpKOM85aqtaB_hYfSBuEk,2544
115
116
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
@@ -117,15 +118,15 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
117
118
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
118
119
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
119
120
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
120
- ai_edge_torch/generative/layers/attention.py,sha256=03YlYLYCD8kxkxlGwRcmw4rFEA2bI8BP6_o5gflnaXQ,14522
121
+ ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8kQ4U3YANfSiTJKn8,13776
121
122
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
122
123
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
123
124
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
124
125
  ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
125
126
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
126
- ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
127
+ ai_edge_torch/generative/layers/model_config.py,sha256=9yPEmWNw3-_2wXBmPmZ7RUKcPXHF2ZbJwksyQoXTA6M,7784
127
128
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
128
- ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
129
+ ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=1L1MEGPYbDELi0zy2OKl7yXyk9FXdBjcXwRZbfiJriU,2619
129
130
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
130
131
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
131
132
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
@@ -144,25 +145,25 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
144
145
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
145
146
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
146
147
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
147
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=NctnggTSFh0XEQbTu55diZ35rFD2QIARO-8PzLktRWg,12165
148
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bBcey-aD4L_TwKRrrM81bN2VQoJjPPC84Rv4o3WOc34,12491
148
149
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
149
150
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
150
151
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
151
152
  ai_edge_torch/generative/utilities/converter.py,sha256=MY8BK29yD-W4v45Xdl_ErbNilipsTlD-4-y9MyBxR5g,7620
152
153
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
153
154
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
154
- ai_edge_torch/generative/utilities/model_builder.py,sha256=3XhB3fJXXIJEiqw9eqtIRY86lbED1BpjVGOdt7z5kpE,6611
155
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=6OBKyOmbg5Sap_np1wnajpCQ1fh8P0eONqNls9eHAX4,6778
155
156
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
156
157
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
157
158
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
158
159
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
159
160
  ai_edge_torch/generative/utilities/verifier.py,sha256=6lnBU9Cy5GanB8JWK3-2_VU3PxqunDWGe-SgSLba5Yw,12065
160
161
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
161
- ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
162
- ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
163
- ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNghuFKKJOQskUJFI,10041
162
+ ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=-BYE7MGMxr-VfBy8tAiiOaCqYv8ytJ0w5l2P8B7h3eM,5387
163
+ ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=taWLpF5IVglxlsF9HM2dIoKDXuQREaCRAXtJeG5gKzs,2073
164
+ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=7bv9XqRkm1pjxiVL4Cm1cArExnolId8hQKFHtvlkCI8,10061
164
165
  ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
165
- ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
166
+ ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=-5UqJyk__1YbUNGuxi4b2sn0CED0W-G337AXwxPGdEs,5567
166
167
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
167
168
  ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
168
169
  ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
@@ -205,8 +206,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
205
206
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
206
207
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
207
208
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
208
- ai_edge_torch_nightly-0.3.0.dev20250108.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
209
- ai_edge_torch_nightly-0.3.0.dev20250108.dist-info/METADATA,sha256=npbaQVRzcYi1A0B6ylSeKLzLGHmmcr6ELSuOU2vXo_0,1966
210
- ai_edge_torch_nightly-0.3.0.dev20250108.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
211
- ai_edge_torch_nightly-0.3.0.dev20250108.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
212
- ai_edge_torch_nightly-0.3.0.dev20250108.dist-info/RECORD,,
209
+ ai_edge_torch_nightly-0.3.0.dev20250110.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
210
+ ai_edge_torch_nightly-0.3.0.dev20250110.dist-info/METADATA,sha256=D_Vexo_GTTaYsb6IqB5rLrD-mos2YWze1Xcj3IFDgKE,1966
211
+ ai_edge_torch_nightly-0.3.0.dev20250110.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
212
+ ai_edge_torch_nightly-0.3.0.dev20250110.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
213
+ ai_edge_torch_nightly-0.3.0.dev20250110.dist-info/RECORD,,