ai-edge-torch-nightly 0.5.0.dev20250516__py3-none-any.whl → 0.5.0.dev20250517__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. ai_edge_torch/_convert/conversion.py +1 -0
  2. ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
  3. ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py +40 -0
  4. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +1 -4
  5. ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -4
  6. ai_edge_torch/generative/examples/gemma/gemma1.py +1 -5
  7. ai_edge_torch/generative/examples/gemma/gemma2.py +1 -5
  8. ai_edge_torch/generative/examples/gemma3/decoder.py +10 -10
  9. ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -3
  10. ai_edge_torch/generative/examples/gemma3/image_encoder.py +1 -4
  11. ai_edge_torch/generative/examples/hammer/hammer.py +1 -4
  12. ai_edge_torch/generative/examples/llama/llama.py +1 -4
  13. ai_edge_torch/generative/examples/openelm/openelm.py +1 -2
  14. ai_edge_torch/generative/examples/paligemma/decoder.py +1 -5
  15. ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -5
  16. ai_edge_torch/generative/examples/paligemma/image_encoder.py +1 -4
  17. ai_edge_torch/generative/examples/phi/phi2.py +1 -4
  18. ai_edge_torch/generative/examples/phi/phi3.py +1 -4
  19. ai_edge_torch/generative/examples/phi/phi4.py +1 -4
  20. ai_edge_torch/generative/examples/qwen/qwen.py +1 -4
  21. ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -2
  22. ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +1 -3
  23. ai_edge_torch/generative/examples/smollm/smollm.py +1 -4
  24. ai_edge_torch/generative/examples/stable_diffusion/clip.py +6 -4
  25. ai_edge_torch/generative/examples/t5/t5.py +1 -3
  26. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  27. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -2
  28. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +1 -4
  29. ai_edge_torch/generative/layers/model_config.py +2 -2
  30. ai_edge_torch/generative/utilities/converter.py +2 -1
  31. ai_edge_torch/odml_torch/lowerings/utils.py +13 -0
  32. ai_edge_torch/version.py +1 -1
  33. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/METADATA +1 -1
  34. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/RECORD +37 -36
  35. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/LICENSE +0 -0
  36. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/WHEEL +0 -0
  37. {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,7 @@ def _run_convert_passes(
38
38
  )
39
39
 
40
40
  passes = [
41
+ fx_passes.EliminateDeadCodePass(),
41
42
  fx_passes.OptimizeLayoutTransposesPass(),
42
43
  fx_passes.CanonicalizePass(),
43
44
  fx_passes.BuildAtenCompositePass(),
@@ -17,6 +17,7 @@ from typing import Sequence, Union
17
17
 
18
18
  from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
19
19
  from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
20
+ from ai_edge_torch._convert.fx_passes.eliminate_dead_code_pass import EliminateDeadCodePass
20
21
  from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
22
23
  from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
@@ -0,0 +1,40 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Pass to eliminate dead code for ai-edge-torch conversion."""
16
+
17
+
18
+ from ai_edge_torch import fx_infra
19
+ import torch
20
+
21
+
22
+ class EliminateDeadCodePass(fx_infra.PassBase):
23
+ """Eliminates dead code with dedicated rules for ai-edge-torch conversion."""
24
+
25
+ def call(self, graph_module: torch.fx.GraphModule):
26
+ def is_impure_node(node: torch.fx.Node):
27
+ # Starting from torch 2.7.0, random torch ops with
28
+ # _nondeterministic_seeded set are no longer considered pure. However,
29
+ # for conversion, unused random ops/tensors should still be removed.
30
+ if getattr(node.target, "_nondeterministic_seeded", False):
31
+ return False
32
+ return node.is_impure()
33
+
34
+ try:
35
+ graph_module.graph.eliminate_dead_code(is_impure_node)
36
+ except TypeError:
37
+ # eliminate_dead_code has no is_impure_node input in old torch versions.
38
+ pass
39
+
40
+ return fx_infra.PassResult(graph_module, True)
@@ -51,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
51
51
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
52
52
  intermediate_size=2048,
53
53
  )
54
- norm_config = cfg.NormalizationConfig(
55
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
56
- )
54
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
57
55
  block_config = cfg.TransformerBlockConfig(
58
56
  attn_config=attn_config,
59
57
  ff_config=ff_config,
@@ -69,7 +67,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
69
67
  block_configs=block_config,
70
68
  final_norm_config=norm_config,
71
69
  lm_head_share_weight_with_embedding=False,
72
- enable_hlfb=True,
73
70
  )
74
71
  return config
75
72
 
@@ -53,9 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
53
53
  intermediate_size=8960,
54
54
  )
55
55
  norm_config = cfg.NormalizationConfig(
56
- type=cfg.NormalizationType.RMS_NORM,
57
- epsilon=1e-06,
58
- enable_hlfb=True,
56
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
59
57
  )
60
58
  block_config = cfg.TransformerBlockConfig(
61
59
  attn_config=attn_config,
@@ -72,7 +70,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
72
70
  block_configs=block_config,
73
71
  final_norm_config=norm_config,
74
72
  lm_head_share_weight_with_embedding=False,
75
- enable_hlfb=True,
76
73
  )
77
74
  return config
78
75
 
@@ -65,10 +65,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
65
65
  intermediate_size=16384,
66
66
  )
67
67
  norm_config = cfg.NormalizationConfig(
68
- type=cfg.NormalizationType.RMS_NORM,
69
- epsilon=1e-6,
70
- zero_centered=True,
71
- enable_hlfb=True,
68
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
72
69
  )
73
70
  block_config = cfg.TransformerBlockConfig(
74
71
  attn_config=attn_config,
@@ -87,7 +84,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
87
84
  block_configs=block_config,
88
85
  final_norm_config=norm_config,
89
86
  lm_head_use_bias=False,
90
- enable_hlfb=True,
91
87
  )
92
88
  return config
93
89
 
@@ -233,10 +233,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
233
233
  The model config for a Gemma 2B model.
234
234
  """
235
235
  norm_config = cfg.NormalizationConfig(
236
- type=cfg.NormalizationType.RMS_NORM,
237
- epsilon=1e-6,
238
- zero_centered=True,
239
- enable_hlfb=True,
236
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
240
237
  )
241
238
  ff_config = cfg.FeedForwardConfig(
242
239
  type=cfg.FeedForwardType.GATED,
@@ -284,7 +281,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
284
281
  block_configs=[get_block_config(i) for i in range(num_layers)],
285
282
  final_norm_config=norm_config,
286
283
  lm_head_use_bias=False,
287
- enable_hlfb=True,
288
284
  final_logit_softcap=30.0,
289
285
  )
290
286
  return config
@@ -149,8 +149,12 @@ class Decoder(nn.Module):
149
149
  cache_len=attention_mask.shape[-1],
150
150
  sliding_window_size=sliding_window_size,
151
151
  )
152
- # Combine masks using logical AND (min in this case).
153
- combined_mask = torch.min(attention_mask, sliding_mask)
152
+ # Expand sliding_mask to match attention_mask's dimensions
153
+ # (e.g., [B, 1, seq_len, cache_len]).
154
+ # Assuming the head dimension is dim 1 for attention_mask.
155
+ expanded_sliding_mask = sliding_mask.unsqueeze(1)
156
+ # Combine masks using logical AND (min ensures -inf propagates).
157
+ combined_mask = torch.min(attention_mask, expanded_sliding_mask)
154
158
  return combined_mask
155
159
  return attention_mask
156
160
 
@@ -161,9 +165,9 @@ class Decoder(nn.Module):
161
165
  sliding_window_size: int,
162
166
  ) -> torch.Tensor:
163
167
  """Creates mask for sliding window attention (PyTorch)."""
164
- cache_positions = torch.tensor(
165
- [i for i in range(cache_len)], dtype=torch.int32
166
- )
168
+ # Use torch.arange to create a tensor with a range of integers in a
169
+ # Dynamo-friendly way.
170
+ cache_positions = torch.arange(cache_len, dtype=torch.int32)
167
171
  cache_positions = cache_positions.view(1, 1, -1) # [1, 1, cache_len]
168
172
  segment_pos_expanded = segment_pos.clone().unsqueeze(-1) # [B, seq_len, 1]
169
173
 
@@ -329,10 +333,7 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
329
333
  The model config for a Gemma 1B model.
330
334
  """
331
335
  norm_config = cfg.NormalizationConfig(
332
- type=cfg.NormalizationType.RMS_NORM,
333
- epsilon=1e-6,
334
- zero_centered=True,
335
- enable_hlfb=True,
336
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True,
336
337
  )
337
338
  ff_config = cfg.FeedForwardConfig(
338
339
  type=cfg.FeedForwardType.GATED,
@@ -379,7 +380,6 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
379
380
  block_configs=[get_block_config(i) for i in range(num_layers)],
380
381
  final_norm_config=norm_config,
381
382
  lm_head_use_bias=False,
382
- enable_hlfb=True,
383
383
  final_logit_softcap=None,
384
384
  )
385
385
  return config
@@ -158,9 +158,7 @@ def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
158
158
  image_projection_scale=128**0.5,
159
159
  image_projection_use_bias=False,
160
160
  mm_norm_config=cfg.NormalizationConfig(
161
- type=cfg.NormalizationType.LAYER_NORM,
162
- epsilon=1e-6,
163
- enable_hlfb=True,
161
+ type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
164
162
  ),
165
163
  mm_extra_tokens=32,
166
164
  )
@@ -98,9 +98,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
98
98
  output_proj_use_bias=True,
99
99
  )
100
100
  norm_config = cfg.NormalizationConfig(
101
- type=cfg.NormalizationType.LAYER_NORM,
102
- epsilon=1e-6,
103
- enable_hlfb=True,
101
+ type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
104
102
  )
105
103
  ff_config = cfg.FeedForwardConfig(
106
104
  type=cfg.FeedForwardType.SEQUENTIAL,
@@ -123,7 +121,6 @@ def get_image_encoder_config() -> cfg.ModelConfig:
123
121
  image_embedding=image_embedding_config,
124
122
  block_configs=block_config,
125
123
  final_norm_config=norm_config,
126
- enable_hlfb=True,
127
124
  num_mm_tokens_per_image=256,
128
125
  )
129
126
  return config
@@ -45,9 +45,7 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
45
45
  intermediate_size=8960,
46
46
  )
47
47
  norm_config = cfg.NormalizationConfig(
48
- type=cfg.NormalizationType.RMS_NORM,
49
- epsilon=1e-06,
50
- enable_hlfb=True,
48
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
51
49
  )
52
50
  block_config = cfg.TransformerBlockConfig(
53
51
  attn_config=attn_config,
@@ -63,7 +61,6 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
63
61
  kv_cache_max_len=kv_cache_max_len,
64
62
  block_configs=block_config,
65
63
  final_norm_config=norm_config,
66
- enable_hlfb=True,
67
64
  )
68
65
  return config
69
66
 
@@ -121,9 +121,7 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
121
121
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
122
122
  intermediate_size=8192,
123
123
  )
124
- norm_config = cfg.NormalizationConfig(
125
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
126
- )
124
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
127
125
  block_config = cfg.TransformerBlockConfig(
128
126
  attn_config=attn_config,
129
127
  ff_config=ff_config,
@@ -152,7 +150,6 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
152
150
  kv_cache_max_len=kv_cache_max_len,
153
151
  block_configs=block_config,
154
152
  final_norm_config=norm_config,
155
- enable_hlfb=True,
156
153
  build_rope=build_rope,
157
154
  )
158
155
  return config
@@ -53,7 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
53
53
  The model config for an OpenELM model.
54
54
  """
55
55
  norm_config = cfg.NormalizationConfig(
56
- type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, enable_hlfb=True
56
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
57
57
  )
58
58
  num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
59
59
  num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
@@ -101,7 +101,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
101
101
  kv_cache_max_len=kv_cache_max_len,
102
102
  block_configs=[get_block_config(i) for i in range(num_layers)],
103
103
  final_norm_config=norm_config,
104
- enable_hlfb=True,
105
104
  )
106
105
  return config
107
106
 
@@ -110,10 +110,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
110
110
  intermediate_size=16384,
111
111
  )
112
112
  norm_config = cfg.NormalizationConfig(
113
- type=cfg.NormalizationType.RMS_NORM,
114
- epsilon=1e-6,
115
- zero_centered=True,
116
- enable_hlfb=True,
113
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
117
114
  )
118
115
  block_config = cfg.TransformerBlockConfig(
119
116
  attn_config=attn_config,
@@ -132,7 +129,6 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
132
129
  block_configs=block_config,
133
130
  final_norm_config=norm_config,
134
131
  lm_head_use_bias=False,
135
- enable_hlfb=True,
136
132
  )
137
133
  return config
138
134
 
@@ -93,10 +93,7 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
93
93
  The model config for the decoder of a PaliGemma 3B model.
94
94
  """
95
95
  norm_config = cfg.NormalizationConfig(
96
- type=cfg.NormalizationType.RMS_NORM,
97
- epsilon=1e-6,
98
- zero_centered=True,
99
- enable_hlfb=True,
96
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
100
97
  )
101
98
  ff_config = cfg.FeedForwardConfig(
102
99
  type=cfg.FeedForwardType.GATED,
@@ -140,7 +137,6 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
140
137
  block_configs=[get_block_config(i) for i in range(num_layers)],
141
138
  final_norm_config=norm_config,
142
139
  lm_head_use_bias=False,
143
- enable_hlfb=True,
144
140
  final_logit_softcap=30.0,
145
141
  )
146
142
  return config
@@ -118,9 +118,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
118
118
  use_bias=True,
119
119
  )
120
120
  norm_config = cfg.NormalizationConfig(
121
- type=cfg.NormalizationType.LAYER_NORM,
122
- epsilon=1e-6,
123
- enable_hlfb=True,
121
+ type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
124
122
  )
125
123
  block_config = cfg.TransformerBlockConfig(
126
124
  attn_config=attn_config,
@@ -137,7 +135,6 @@ def get_image_encoder_config() -> cfg.ModelConfig:
137
135
  image_embedding=image_embedding_config,
138
136
  block_configs=block_config,
139
137
  final_norm_config=norm_config,
140
- enable_hlfb=True,
141
138
  )
142
139
  return config
143
140
 
@@ -66,9 +66,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
66
66
  intermediate_size=10240,
67
67
  use_bias=True,
68
68
  )
69
- norm_config = cfg.NormalizationConfig(
70
- type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=True
71
- )
69
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
72
70
  block_config = cfg.TransformerBlockConfig(
73
71
  attn_config=attn_config,
74
72
  ff_config=ff_config,
@@ -85,7 +83,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
85
83
  final_norm_config=norm_config,
86
84
  lm_head_use_bias=True,
87
85
  lm_head_share_weight_with_embedding=False,
88
- enable_hlfb=True,
89
86
  )
90
87
  return config
91
88
 
@@ -162,9 +162,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
162
162
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
163
163
  intermediate_size=8192,
164
164
  )
165
- norm_config = cfg.NormalizationConfig(
166
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
167
- )
165
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
168
166
  block_config = cfg.TransformerBlockConfig(
169
167
  attn_config=attn_config,
170
168
  ff_config=ff_config,
@@ -192,7 +190,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
192
190
  block_configs=block_config,
193
191
  final_norm_config=norm_config,
194
192
  lm_head_share_weight_with_embedding=False,
195
- enable_hlfb=True,
196
193
  build_rope=build_rope,
197
194
  )
198
195
  return config
@@ -112,9 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
112
112
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
113
113
  intermediate_size=8192,
114
114
  )
115
- norm_config = cfg.NormalizationConfig(
116
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
117
- )
115
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
118
116
  block_config = cfg.TransformerBlockConfig(
119
117
  attn_config=attn_config,
120
118
  ff_config=ff_config,
@@ -141,7 +139,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
141
139
  embedding_dim=3072,
142
140
  block_configs=block_config,
143
141
  final_norm_config=norm_config,
144
- enable_hlfb=True,
145
142
  build_rope=build_rope,
146
143
  )
147
144
  return config
@@ -53,9 +53,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
53
53
  intermediate_size=11008,
54
54
  )
55
55
  norm_config = cfg.NormalizationConfig(
56
- type=cfg.NormalizationType.RMS_NORM,
57
- epsilon=1e-06,
58
- enable_hlfb=True,
56
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
59
57
  )
60
58
  block_config = cfg.TransformerBlockConfig(
61
59
  attn_config=attn_config,
@@ -71,7 +69,6 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
71
69
  kv_cache_max_len=kv_cache_max_len,
72
70
  block_configs=block_config,
73
71
  final_norm_config=norm_config,
74
- enable_hlfb=True,
75
72
  )
76
73
  return config
77
74
 
@@ -97,7 +97,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
97
97
  intermediate_size=11008,
98
98
  )
99
99
  norm_config = cfg.NormalizationConfig(
100
- type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06, enable_hlfb=True
100
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
101
101
  )
102
102
  block_config = cfg.TransformerBlockConfig(
103
103
  attn_config=attn_config,
@@ -113,7 +113,6 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
113
113
  kv_cache_max_len=kv_cache_max_len,
114
114
  block_configs=block_config,
115
115
  final_norm_config=norm_config,
116
- enable_hlfb=True,
117
116
  )
118
117
  return config
119
118
 
@@ -332,8 +332,7 @@ def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
332
332
  use_bias=True,
333
333
  )
334
334
  norm_config = cfg.NormalizationConfig(
335
- type=cfg.NormalizationType.RMS_NORM,
336
- epsilon=1e-6,
335
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
337
336
  )
338
337
  block_config = cfg.TransformerBlockConfig(
339
338
  attn_config=attn_config,
@@ -359,7 +358,6 @@ def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
359
358
  window_size=112,
360
359
  spatial_merge_size=2,
361
360
  full_atten_block_indexes=[7, 15, 23, 31],
362
- enable_hlfb=True,
363
361
  )
364
362
  return config
365
363
 
@@ -51,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
51
51
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
52
52
  intermediate_size=1536,
53
53
  )
54
- norm_config = cfg.NormalizationConfig(
55
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
56
- )
54
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
57
55
  block_config = cfg.TransformerBlockConfig(
58
56
  attn_config=attn_config,
59
57
  ff_config=ff_config,
@@ -68,7 +66,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
68
66
  kv_cache_max_len=kv_cache_max_len,
69
67
  block_configs=block_config,
70
68
  final_norm_config=norm_config,
71
- enable_hlfb=True,
72
69
  )
73
70
  return config
74
71
 
@@ -113,7 +113,9 @@ def get_model_config() -> cfg.ModelConfig:
113
113
  use_bias=True,
114
114
  )
115
115
 
116
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
116
+ norm_config = cfg.NormalizationConfig(
117
+ type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=False
118
+ )
117
119
 
118
120
  block_config = cfg.TransformerBlockConfig(
119
121
  attn_config=attn_config,
@@ -129,7 +131,6 @@ def get_model_config() -> cfg.ModelConfig:
129
131
  embedding_dim=embedding_dim,
130
132
  block_configs=block_config,
131
133
  final_norm_config=norm_config,
132
- enable_hlfb=True,
133
134
  )
134
135
 
135
136
  return config
@@ -164,7 +165,9 @@ def get_fake_model_config() -> cfg.ModelConfig:
164
165
  use_bias=True,
165
166
  )
166
167
 
167
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
168
+ norm_config = cfg.NormalizationConfig(
169
+ type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=False
170
+ )
168
171
 
169
172
  block_config = cfg.TransformerBlockConfig(
170
173
  attn_config=attn_config,
@@ -180,7 +183,6 @@ def get_fake_model_config() -> cfg.ModelConfig:
180
183
  embedding_dim=embedding_dim,
181
184
  block_configs=block_config,
182
185
  final_norm_config=norm_config,
183
- enable_hlfb=True,
184
186
  )
185
187
 
186
188
  return config
@@ -393,8 +393,7 @@ def get_model_config_t5() -> cfg.ModelConfig:
393
393
  )
394
394
  # T5 Confirmed as RMS Norm and eps = 1e-6 TJA.
395
395
  norm_config = cfg.NormalizationConfig(
396
- type=cfg.NormalizationType.RMS_NORM,
397
- epsilon=1e-6,
396
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, enable_hlfb=False
398
397
  )
399
398
  block_config = cfg.TransformerBlockConfig(
400
399
  attn_config=attn_config,
@@ -411,7 +410,6 @@ def get_model_config_t5() -> cfg.ModelConfig:
411
410
  block_configs=block_config,
412
411
  final_norm_config=norm_config,
413
412
  lm_head_use_bias=False,
414
- enable_hlfb=True,
415
413
  )
416
414
  return config
417
415
 
@@ -138,7 +138,9 @@ def get_model_config() -> cfg.ModelConfig:
138
138
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
139
139
  intermediate_size=256,
140
140
  )
141
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
141
+ norm_config = cfg.NormalizationConfig(
142
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=False
143
+ )
142
144
  block_config = cfg.TransformerBlockConfig(
143
145
  attn_config=attn_config,
144
146
  ff_config=ff_config,
@@ -152,5 +154,6 @@ def get_model_config() -> cfg.ModelConfig:
152
154
  embedding_dim=128,
153
155
  block_configs=block_config,
154
156
  final_norm_config=norm_config,
157
+ enable_hlfb=False,
155
158
  )
156
159
  return config
@@ -108,7 +108,9 @@ def get_model_config() -> cfg.ModelConfig:
108
108
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
109
109
  intermediate_size=256,
110
110
  )
111
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
111
+ norm_config = cfg.NormalizationConfig(
112
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=False
113
+ )
112
114
  block_config = cfg.TransformerBlockConfig(
113
115
  attn_config=attn_config,
114
116
  ff_config=ff_config,
@@ -122,7 +124,6 @@ def get_model_config() -> cfg.ModelConfig:
122
124
  embedding_dim=128,
123
125
  block_configs=block_config,
124
126
  final_norm_config=norm_config,
125
- enable_hlfb=True,
126
127
  )
127
128
  return config
128
129
 
@@ -51,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
51
51
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
52
52
  intermediate_size=5632,
53
53
  )
54
- norm_config = cfg.NormalizationConfig(
55
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
56
- )
54
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
57
55
  block_config = cfg.TransformerBlockConfig(
58
56
  attn_config=attn_config,
59
57
  ff_config=ff_config,
@@ -69,7 +67,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
69
67
  block_configs=block_config,
70
68
  final_norm_config=norm_config,
71
69
  lm_head_share_weight_with_embedding=False,
72
- enable_hlfb=True,
73
70
  )
74
71
  return config
75
72
 
@@ -66,7 +66,7 @@ class NormalizationConfig:
66
66
  """Normalizater parameters."""
67
67
 
68
68
  type: NormalizationType = NormalizationType.NONE
69
- enable_hlfb: bool = False
69
+ enable_hlfb: bool = True
70
70
  epsilon: float = 1e-5
71
71
  zero_centered: bool = False
72
72
  # Number of groups used in group normalization.
@@ -218,7 +218,7 @@ class ModelConfig:
218
218
  lm_head_share_weight_with_embedding: bool = True
219
219
 
220
220
  # Whether to turn on high-level function boundary.
221
- enable_hlfb: bool = False
221
+ enable_hlfb: bool = True
222
222
 
223
223
  # The maximum sequence length of the KV cache. Should not exceed max_seq_len.
224
224
  kv_cache_max_len: int = 0
@@ -100,7 +100,8 @@ def define_conversion_flags(
100
100
  flags.DEFINE_string(
101
101
  'quantize',
102
102
  'dynamic_int8',
103
- 'How the model should be quantized.',
103
+ 'How the model should be quantized. Set to "none" to disable'
104
+ ' quantization. See `QuantizationName` for supported quantization types.',
104
105
  )
105
106
  flags.DEFINE_multi_integer(
106
107
  'lora_ranks',
@@ -268,3 +268,16 @@ def numpy_array_constant(x: np.ndarray | np.generic) -> IrValues:
268
268
  x = np.ascontiguousarray(x)
269
269
  attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore
270
270
  return stablehlo.constant(attr)
271
+
272
+
273
+ def convert_to_ir_value(
274
+ value: ir.Value | int | float | np.ndarray | np.generic,
275
+ ) -> ir.Value:
276
+ if isinstance(value, (np.ndarray, np.generic)):
277
+ return numpy_array_constant(value)
278
+ if isinstance(value, (int, float)):
279
+ dtype = np.float32 if isinstance(value, float) else np.int32
280
+ return numpy_array_constant(np.array([value], dtype=dtype))
281
+ if isinstance(value, ir.Value):
282
+ return value
283
+ raise TypeError(f"Unsupported type for conversion to ir.Value: {type(value)}")
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250516"
16
+ __version__ = "0.5.0.dev20250517"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250516
3
+ Version: 0.5.0.dev20250517
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,16 +2,17 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
5
- ai_edge_torch/version.py,sha256=qsmmOMVNJ3QYndWFHn1wZqGlFpjk3G1-KHlQvjpBSFg,706
5
+ ai_edge_torch/version.py,sha256=1nWIqrcLl_lGq0WnfPfnBgtlAECzCufGTdwLlpIpp_Y,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
- ai_edge_torch/_convert/conversion.py,sha256=jidl5IOb3MhUPqhMLBNFRSzkqQyi3Y0R0ua-vOSahm0,6082
7
+ ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
9
9
  ai_edge_torch/_convert/converter.py,sha256=6MLKELzAwFoiXv-b7KRYi7gc7Z57XOeowcz9ArIl9TM,12100
10
10
  ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
11
11
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
12
- ai_edge_torch/_convert/fx_passes/__init__.py,sha256=jbRCZmSduG_1qmngaEEtbofAyL1PKZ8P1uxzzsXQhsw,1253
12
+ ai_edge_torch/_convert/fx_passes/__init__.py,sha256=uHek7k9KIW3kaEM_lcygbukJ69JLjm-xnYUWzAEIZII,1345
13
13
  ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=dgUO-lI9Id9hIOHP5XmegVlu5Fl79GR4_b-lDUehzoo,11428
14
14
  ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py,sha256=90YxLVAAkiA3qKr4Um__JmPeC1bTeA2PxBCj0GETq1Q,1748
15
+ ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py,sha256=jMl9YHIbx08KQHbp9UgDnxviUUWiN-FSsiUgR2HCT5s,1576
15
16
  ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
16
17
  ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
17
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=UKC-wM93-oe8spxyFqgybJ0TwnSRw8f-SOA2glCh2FA,890
@@ -52,48 +53,48 @@ ai_edge_torch/generative/custom_ops/bmm_4d.py,sha256=JmVbZCujG_wuBchma8QF3DSBfVc
52
53
  ai_edge_torch/generative/custom_ops/dynamic_update_slice.py,sha256=ZGAq2CfWZsfef5mHulsWmyUx0dDWJX6J6xPjhBrjQdM,2097
53
54
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
54
55
  ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
55
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=NyBlyUUk-3ksS5M2jFPeor6_1vSa8W_CofO8-lQ_4gE,2962
56
+ ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=PTKcl-CHQnzExQSfrwG9YC0KPc8zomG7WlPabXtZLx4,2910
56
57
  ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=s2f5TJos6rSgogqeFk0qsOpI30qsR04umk9hAAZ5918,1782
57
58
  ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=o13NkFlBgawBsjdJup05VMUjAPvDRAmig6VyEkX8q6U,2426
58
59
  ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
59
60
  ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=xTPfT3Mt_4bMfGkrqDKatLecZOuaE0WhxXs3uAsO_uU,1749
60
- ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=afKPeEjRUkLf5uhImvxtOdHrK2edfJ_R4lx92etEQpQ,3069
61
+ ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=9gUnK1IOifQyYpm03f64Mzg-afwbYY9kVWz6-ynq8zY,3014
61
62
  ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
62
63
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
63
64
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=t2qZTjyM2imPenb14fzbQ-CHj5Cejw4M5xfEZpgX6Uc,1748
64
65
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=Yj-b4S9BNxArnGjruRIymCiWrlf7ZvwiG6keTVGldk4,1816
65
- ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=HqpNgJYL3X91Bpl9dAQsWEmaXJjDXGuGBVeyqK5hGTk,3682
66
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=zynxoe_9ESvTIsznpp44HUS3gVDaEltkapmjzoNOaqA,11691
66
+ ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=wV_tr51UIwiKki8u5i-Q2YFvKYpTIIXEdUKFbdbMhRo,3621
67
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=b3zCFOjeU6T7K2PLUBABurpf7UjRIsGKkOym1wRuJOg,11630
67
68
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
68
69
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=jhiyinOqPt5ZZjEadDRZt_wY5fiLSCpMo54PcxFaL_Q,1789
69
70
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=n7f2nF6Lin_tDvPs0JVldsuaBzo7pAwi5YAHAhlIxQg,6139
70
71
  ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
71
72
  ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=wOrOV_jxCnjrhjC8X0-uIi0D-4aQjOfXw6XaxTSrM9k,2048
72
- ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=shdgLzKDUi0vyNOAsrIVAEFb3Adltsri6Rx1-wxzVf4,15089
73
- ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=ZorRtnbElWsctcA0nEbfwjx0C578voF7fjFEvWSR5Ck,6582
74
- ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
73
+ ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=GC22bZRTtO8IczccYpqh5nSE0FHJK3I0M9oaofrr-Ss,15344
74
+ ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=CPk3VJUobga0MVVIVRyWhdsrlCBWdQgF5kdSw7Yo--Y,6543
75
+ ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=tUOI99kdny33qcDM7-z0R6F-1aU1lZ24kG5zeLVdwow,5129
75
76
  ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
76
77
  ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=1vfAtayH_I_qTpqhzu6n9xnCuvhgTzhS8IzZviW2dJQ,9418
77
78
  ai_edge_torch/generative/examples/hammer/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
78
79
  ai_edge_torch/generative/examples/hammer/convert_to_tflite.py,sha256=XLmPuJCBJjKzMTG-mRmBX92juep2zl5yYeMrEhdqQQk,1975
79
- ai_edge_torch/generative/examples/hammer/hammer.py,sha256=s8arcxjETiyuERrFOvyQe_o8Lvr82gxmOIJO1hr2Dcs,3704
80
+ ai_edge_torch/generative/examples/hammer/hammer.py,sha256=aiGRdmJbtcePRde7l_Vte61rPh_4F-zcxNuGtg_ceTY,3649
80
81
  ai_edge_torch/generative/examples/hammer/verify.py,sha256=MkzAGkbPy4LKRhyCDm1cw-9jUt4VUxLPdwK_25fCGSE,2705
81
82
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
82
83
  ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=4qnMyvJHqhqf9k01wEsO23BKo6tSy2KD7sHdTGimKGg,1957
83
- ai_edge_torch/generative/examples/llama/llama.py,sha256=TJXU9yZwxPCnuT2uwlcXVLrs5pg1P-Csv4xY5WTcf8U,7005
84
+ ai_edge_torch/generative/examples/llama/llama.py,sha256=eWPFnuSxhjuk5XZmvtndu_Z1-e9NlZg7-uFfiOqJXfw,6952
84
85
  ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
85
86
  ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
86
87
  ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=_GkaSkregS3NWN38UGXxj4pED5gtQGaaPZx5_CZ0TVM,1657
87
88
  ai_edge_torch/generative/examples/moonshine/moonshine.py,sha256=nZ2b8u4TmsB5sgdClgAuH8E78bcTv9RCnF9666HqP2M,3394
88
89
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
89
90
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=S7OP8PJcOQbm8AHvi_Tc3qnQuVOtjMFNlwaZQ_oirUM,1747
90
- ai_edge_torch/generative/examples/openelm/openelm.py,sha256=2jkIbj_G0IuFi5nXz_yAIY4qRxgWGD5rKQDTSweRV9M,4734
91
+ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=R_E0hXsg6l8ANEgBBy0R8egz3p4ONJvBmPWs6sXx63M,4692
91
92
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=4W26ZtPF5Cb9mpHYuRM4b2QB_4W76zf4WV36KzexVjs,2446
92
93
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
93
94
  ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=Fl4k-lcpiUaJS0A1E7HVVUW7iTcZAU4FbA4KcSkO5SQ,2212
94
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=ruUTonTErvuinWsJ3pnSbvKhCnDUlupT1MW4TUwcrMY,5551
95
- ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=C377j2ULpPvmY5SsNLUC8jskTNNHVDH8uYOLH5W7fOU,6100
96
- ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=IbneN2J-ASdUg7OHVRkrUBiZ0UXyCVRJXhnDAxjozl8,5644
95
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=GLlfbJr3ZIzmH643IwXyrG54qKEYMPRsvhU6gXXi7yg,5490
96
+ ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=mSqB-E7JHjwhaEf1p2STxc5DWLKAGE47GTAwtM0EyDU,6039
97
+ ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=V0RrkocOe-y2EDvcg8DMcSpWzzHUruQAEofHn20Jw7M,5589
97
98
  ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=nxvcurGkFJcCjjgVkK59SJgp8mZ71D56bEnrjvGgPs4,6264
98
99
  ai_edge_torch/generative/examples/paligemma/verify.py,sha256=myHdeIAtVTOqb915h661CnvjvFkwmihy3Vp4UrKHb5I,6195
99
100
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
@@ -103,20 +104,20 @@ ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_
103
104
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=kYgZAIHXolUhOyDAYDuEK7RZ5ExL1YzpqtlcZjo622c,1736
104
105
  ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py,sha256=3y3vYlJjLjRmz4Vsq-B8YKyp0LnC2fj1LAACW3pQivI,1734
105
106
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=tY5uwRu-4Jxro7Z9jsDqZR9SUDWB8PR6JKfswvsUSxM,1735
106
- ai_edge_torch/generative/examples/phi/phi2.py,sha256=nu18YKF95yg9Mo7TjpkgjA_br5fSYqaHmw0o86b5hDQ,3654
107
- ai_edge_torch/generative/examples/phi/phi3.py,sha256=c2h17Gmo9zLSEEdA7BzG_Jd8p4-3JmO6ZSEWLWXDGFU,7107
108
- ai_edge_torch/generative/examples/phi/phi4.py,sha256=TgoRbaW27X2tYAUi_z2GCb3j6uze5POhKGchRf-5eZw,5889
107
+ ai_edge_torch/generative/examples/phi/phi2.py,sha256=RprdXLbdG5lDOAgN4xolZZwp74vbRHHuf_-CzjnI5cA,3602
108
+ ai_edge_torch/generative/examples/phi/phi3.py,sha256=LW1E0C7A3IIyB5CLbVt914YJB4Fx6mbXh4xXibDHA2w,7054
109
+ ai_edge_torch/generative/examples/phi/phi4.py,sha256=ZHA0Rq7ifgxiHC_8PJf-y7WCA7i_2SlsiGibyOMBP4s,5837
109
110
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
110
111
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
111
112
  ai_edge_torch/generative/examples/phi/verify_phi4.py,sha256=BoCa5kUBRHtMQ-5ql6yD4pG4xHJMyUiQlpMOWVx-JgY,2356
112
113
  ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
113
114
  ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=TnzyARHQgmWeOdYsV9WpRj5vhKGBH0kAbp3tMj8ZCYw,1998
114
- ai_edge_torch/generative/examples/qwen/qwen.py,sha256=XOLq1yTbW6nyAVrYYG3qu_8Cl0A74M2hkpjOT_UhyVs,4609
115
+ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=EcIHVeBcJLc290TiPkPfE7jdG_VXZYKlVGf0XQXzqo8,4554
115
116
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
116
117
  ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
117
118
  ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=BM-ed7KrmPwzI3MvDs2R7P-kJgE1SK_cNVqIfXhtJjs,2411
118
- ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=yt3pO0x9t39dS2RWCM-0NRLl2ImcyWRIfL3E06bDg8k,4485
119
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=vMZ6v6iVrps_NSFwycgG4OPG_RVQAxa80lKrbneMkaM,15023
119
+ ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=plOi-3LltxReW_HVxhxwee_rYCQq-gsOwbGZtRsM8N8,4443
120
+ ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=nuAHb-RXqTffpwjwCHOd_2mCrSMwL6Q1z_yjsU64gmI,14992
120
121
  ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=1Ac28olo0OJExZRyxqm7vxcf7GtXdkUwEbHvhiCHi0o,7908
121
122
  ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=4WKgAFQNQzwmeJhC8ayI5vjGj9ko6VcU2HA3VAkhHug,5812
122
123
  ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
@@ -124,11 +125,11 @@ ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=PZ392nD
124
125
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
125
126
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=QVRX_ovqBQi8fKAG6PezaO1qoRvMGpVxNH-_sds0pf8,1997
126
127
  ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=rOVYSaS68_otJcGewQSconBCPD4GhDEIIyquD4dSUWc,1979
127
- ai_edge_torch/generative/examples/smollm/smollm.py,sha256=OXSN0Vu1MXnWb_H-aW9acgjpeLIhPIXGq2fx7RaojcM,4080
128
+ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=nQRiq6phJbtl3GAEEsJ_bPP_zrpQmiPumNEWCRrECn0,4028
128
129
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=sH3rn1TbaCusPiUD5XlECiHY0rvoHIXALbk7ECOiinI,2720
129
130
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
130
131
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
131
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=fPSg217F9xBvqMZwujCAQvYq5MRZzXTYOxjiPLqD7ZU,6102
132
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=lSCRZsoLjH_kqasRMwCy5IogkhyJdwcHKsPEfyxsXCQ,6112
132
133
  ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=_yk6wVoZm1_FRMFJF5URaPZNNdmMR89fwmKz81BEyao,5601
133
134
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=afyHXc86h-ij5zTULmZnM1h313N9VWCyIVriH6pqeSo,16368
134
135
  ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=ylqXOZhYc6XFCaNBKQw0jAnYrCtRFFQKzQzEsFIntvo,34890
@@ -143,15 +144,15 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=ZE6H
143
144
  ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=RxR5rw0wFFm_5CfAY-3-EIz83vhM9EKye8Bb5zBb0Ok,1341
144
145
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
145
146
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=HHtZTtUh3QgE4F74-ru_8n1pt6cqfbObw12xoaMJ7NQ,4596
146
- ai_edge_torch/generative/examples/t5/t5.py,sha256=gFTmPi-xB8pcPRgoF3DJxvH_fT-KWTb8ii77P5UbKR0,21263
147
+ ai_edge_torch/generative/examples/t5/t5.py,sha256=aL-wDJDea9qbIe__oyKbK3g1p1xHq9-_88QsF95-JUs,21251
147
148
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqXquaFQPvCFBFF5zOnmGVb3Hg,8731
148
149
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
149
150
  ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
150
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=Crpj-vOwSViHpblXOrRJmsIn4DrHyuB3XZ8kHifb7LA,5203
151
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=-z5tkQzGHbo37eAl9sDAJuT1Egxm8xI9CZmYLcmqIfU,4761
151
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=s3l__g3l0DpgXEPs1ikcJqSS7OfWzgYdkOLnEwdjcUo,5257
152
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=rVTaNQxlf3-mv9e6s33V9CDd_cmVwPti1A3ARUAwSD4,4766
152
153
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
153
154
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=urWkWjOaGzV2gwMXoGEs1mfHNEXfEKgwuXmQ0lrWcbM,1761
154
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=HRyq5nzoljWEWGYw0kCHAZH-GNiNHxh7E2qNoupjA-4,2988
155
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=cVNP_a_3UBML0j9ITtcITeVXqCdcC7U1JoYwir09Dk8,2936
155
156
  ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1tO2i0nUCqe-VkRgboA10VZ7KNg,2431
156
157
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx_bLdWjN01iyjPM5jE,1094
157
158
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
@@ -165,7 +166,7 @@ ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-m
165
166
  ai_edge_torch/generative/layers/feed_forward_test.py,sha256=8ZGy79BBpsyS6yKKDEKrDt249G5Mz-8VKWW7_WHx0u4,1655
166
167
  ai_edge_torch/generative/layers/kv_cache.py,sha256=b-7shzDaKexmvQF7P3SiAmIz4ZofjYWv3m5u71GojsA,10460
167
168
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
168
- ai_edge_torch/generative/layers/model_config.py,sha256=X_gjN5524DCDBNXsX5GrOBlkKM4UHzj_RfdCD0-VOxQ,8572
169
+ ai_edge_torch/generative/layers/model_config.py,sha256=H1MpjP1Ij1r4DEcE4cQ_6A8h0QvUjCkuGATXMkIMIWg,8570
169
170
  ai_edge_torch/generative/layers/normalization.py,sha256=ijwCpi22NLX-Sygwy5sK9l9WjGvbPIhZvVwoBAonWAo,7014
170
171
  ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
171
172
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
@@ -193,7 +194,7 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hP
193
194
  ai_edge_torch/generative/test/test_quantize.py,sha256=kKJ01wscTC2t_Ylr7huO5gNKES01gm3dT1gx52z15PA,7356
194
195
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
195
196
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
196
- ai_edge_torch/generative/utilities/converter.py,sha256=mM8Vgd6zWkOrGt4-waa8cNjJwfhhTp-VNJ306NhXrV8,15425
197
+ ai_edge_torch/generative/utilities/converter.py,sha256=VRI960xo86g6lGLc_II3vDovFMa2DGIxnAZgE2GfSiM,15530
197
198
  ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
198
199
  ai_edge_torch/generative/utilities/loader.py,sha256=y1uSkUBiR0b9U4aoCQQk9qk7ctya_vEeY28Wc0A5e2s,15504
199
200
  ai_edge_torch/generative/utilities/model_builder.py,sha256=tBfOcsI_NcneggHqkCSydYN3ZgmkzPc6nW0AJrA81wI,6461
@@ -242,7 +243,7 @@ ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQD
242
243
  ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
243
244
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
244
245
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
245
- ai_edge_torch/odml_torch/lowerings/utils.py,sha256=uJaFbbgvYMI4-VFpFcMpaObNfBQl6nV0x8Yo8LaSAOE,8974
246
+ ai_edge_torch/odml_torch/lowerings/utils.py,sha256=wHIhaKGWxf_x4W750_nzKoYMBDlq2Fd6b05-XOEVyVQ,9465
246
247
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
247
248
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
248
249
  ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
@@ -252,8 +253,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
252
253
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
253
254
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
254
255
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
255
- ai_edge_torch_nightly-0.5.0.dev20250516.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
256
- ai_edge_torch_nightly-0.5.0.dev20250516.dist-info/METADATA,sha256=669y6k49WKfsyVCxQ-N-xiyLc5U2lR90qfnNDoPpedA,2074
257
- ai_edge_torch_nightly-0.5.0.dev20250516.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
258
- ai_edge_torch_nightly-0.5.0.dev20250516.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
259
- ai_edge_torch_nightly-0.5.0.dev20250516.dist-info/RECORD,,
256
+ ai_edge_torch_nightly-0.5.0.dev20250517.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
257
+ ai_edge_torch_nightly-0.5.0.dev20250517.dist-info/METADATA,sha256=L8usKutHqIHjPswWNp7b3ynJzJiS5cTC9YWTnP82Qm8,2074
258
+ ai_edge_torch_nightly-0.5.0.dev20250517.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
259
+ ai_edge_torch_nightly-0.5.0.dev20250517.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
260
+ ai_edge_torch_nightly-0.5.0.dev20250517.dist-info/RECORD,,