ai-edge-torch-nightly 0.5.0.dev20250516__py3-none-any.whl → 0.5.0.dev20250517__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +1 -0
- ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
- ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py +40 -0
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +1 -4
- ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -4
- ai_edge_torch/generative/examples/gemma/gemma1.py +1 -5
- ai_edge_torch/generative/examples/gemma/gemma2.py +1 -5
- ai_edge_torch/generative/examples/gemma3/decoder.py +10 -10
- ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -3
- ai_edge_torch/generative/examples/gemma3/image_encoder.py +1 -4
- ai_edge_torch/generative/examples/hammer/hammer.py +1 -4
- ai_edge_torch/generative/examples/llama/llama.py +1 -4
- ai_edge_torch/generative/examples/openelm/openelm.py +1 -2
- ai_edge_torch/generative/examples/paligemma/decoder.py +1 -5
- ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -5
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +1 -4
- ai_edge_torch/generative/examples/phi/phi2.py +1 -4
- ai_edge_torch/generative/examples/phi/phi3.py +1 -4
- ai_edge_torch/generative/examples/phi/phi4.py +1 -4
- ai_edge_torch/generative/examples/qwen/qwen.py +1 -4
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -2
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +1 -3
- ai_edge_torch/generative/examples/smollm/smollm.py +1 -4
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +6 -4
- ai_edge_torch/generative/examples/t5/t5.py +1 -3
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -2
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +1 -4
- ai_edge_torch/generative/layers/model_config.py +2 -2
- ai_edge_torch/generative/utilities/converter.py +2 -1
- ai_edge_torch/odml_torch/lowerings/utils.py +13 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/RECORD +37 -36
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250516.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from typing import Sequence, Union
|
|
17
17
|
|
18
18
|
from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
|
19
19
|
from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
|
20
|
+
from ai_edge_torch._convert.fx_passes.eliminate_dead_code_pass import EliminateDeadCodePass
|
20
21
|
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
|
22
23
|
from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
|
@@ -0,0 +1,40 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Pass to eliminate dead code for ai-edge-torch conversion."""
|
16
|
+
|
17
|
+
|
18
|
+
from ai_edge_torch import fx_infra
|
19
|
+
import torch
|
20
|
+
|
21
|
+
|
22
|
+
class EliminateDeadCodePass(fx_infra.PassBase):
|
23
|
+
"""Eliminates dead code with dedicated rules for ai-edge-torch conversion."""
|
24
|
+
|
25
|
+
def call(self, graph_module: torch.fx.GraphModule):
|
26
|
+
def is_impure_node(node: torch.fx.Node):
|
27
|
+
# Starting from torch 2.7.0, random torch ops with
|
28
|
+
# _nondeterministic_seeded set are no longer considered pure. However,
|
29
|
+
# for conversion, unused random ops/tensors should still be removed.
|
30
|
+
if getattr(node.target, "_nondeterministic_seeded", False):
|
31
|
+
return False
|
32
|
+
return node.is_impure()
|
33
|
+
|
34
|
+
try:
|
35
|
+
graph_module.graph.eliminate_dead_code(is_impure_node)
|
36
|
+
except TypeError:
|
37
|
+
# eliminate_dead_code has no is_impure_node input in old torch versions.
|
38
|
+
pass
|
39
|
+
|
40
|
+
return fx_infra.PassResult(graph_module, True)
|
@@ -51,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
51
51
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
52
52
|
intermediate_size=2048,
|
53
53
|
)
|
54
|
-
norm_config = cfg.NormalizationConfig(
|
55
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
56
|
-
)
|
54
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
57
55
|
block_config = cfg.TransformerBlockConfig(
|
58
56
|
attn_config=attn_config,
|
59
57
|
ff_config=ff_config,
|
@@ -69,7 +67,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
69
67
|
block_configs=block_config,
|
70
68
|
final_norm_config=norm_config,
|
71
69
|
lm_head_share_weight_with_embedding=False,
|
72
|
-
enable_hlfb=True,
|
73
70
|
)
|
74
71
|
return config
|
75
72
|
|
@@ -53,9 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
53
53
|
intermediate_size=8960,
|
54
54
|
)
|
55
55
|
norm_config = cfg.NormalizationConfig(
|
56
|
-
type=cfg.NormalizationType.RMS_NORM,
|
57
|
-
epsilon=1e-06,
|
58
|
-
enable_hlfb=True,
|
56
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
59
57
|
)
|
60
58
|
block_config = cfg.TransformerBlockConfig(
|
61
59
|
attn_config=attn_config,
|
@@ -72,7 +70,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
72
70
|
block_configs=block_config,
|
73
71
|
final_norm_config=norm_config,
|
74
72
|
lm_head_share_weight_with_embedding=False,
|
75
|
-
enable_hlfb=True,
|
76
73
|
)
|
77
74
|
return config
|
78
75
|
|
@@ -65,10 +65,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
65
65
|
intermediate_size=16384,
|
66
66
|
)
|
67
67
|
norm_config = cfg.NormalizationConfig(
|
68
|
-
type=cfg.NormalizationType.RMS_NORM,
|
69
|
-
epsilon=1e-6,
|
70
|
-
zero_centered=True,
|
71
|
-
enable_hlfb=True,
|
68
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
|
72
69
|
)
|
73
70
|
block_config = cfg.TransformerBlockConfig(
|
74
71
|
attn_config=attn_config,
|
@@ -87,7 +84,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
87
84
|
block_configs=block_config,
|
88
85
|
final_norm_config=norm_config,
|
89
86
|
lm_head_use_bias=False,
|
90
|
-
enable_hlfb=True,
|
91
87
|
)
|
92
88
|
return config
|
93
89
|
|
@@ -233,10 +233,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
233
233
|
The model config for a Gemma 2B model.
|
234
234
|
"""
|
235
235
|
norm_config = cfg.NormalizationConfig(
|
236
|
-
type=cfg.NormalizationType.RMS_NORM,
|
237
|
-
epsilon=1e-6,
|
238
|
-
zero_centered=True,
|
239
|
-
enable_hlfb=True,
|
236
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
|
240
237
|
)
|
241
238
|
ff_config = cfg.FeedForwardConfig(
|
242
239
|
type=cfg.FeedForwardType.GATED,
|
@@ -284,7 +281,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
284
281
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
285
282
|
final_norm_config=norm_config,
|
286
283
|
lm_head_use_bias=False,
|
287
|
-
enable_hlfb=True,
|
288
284
|
final_logit_softcap=30.0,
|
289
285
|
)
|
290
286
|
return config
|
@@ -149,8 +149,12 @@ class Decoder(nn.Module):
|
|
149
149
|
cache_len=attention_mask.shape[-1],
|
150
150
|
sliding_window_size=sliding_window_size,
|
151
151
|
)
|
152
|
-
#
|
153
|
-
|
152
|
+
# Expand sliding_mask to match attention_mask's dimensions
|
153
|
+
# (e.g., [B, 1, seq_len, cache_len]).
|
154
|
+
# Assuming the head dimension is dim 1 for attention_mask.
|
155
|
+
expanded_sliding_mask = sliding_mask.unsqueeze(1)
|
156
|
+
# Combine masks using logical AND (min ensures -inf propagates).
|
157
|
+
combined_mask = torch.min(attention_mask, expanded_sliding_mask)
|
154
158
|
return combined_mask
|
155
159
|
return attention_mask
|
156
160
|
|
@@ -161,9 +165,9 @@ class Decoder(nn.Module):
|
|
161
165
|
sliding_window_size: int,
|
162
166
|
) -> torch.Tensor:
|
163
167
|
"""Creates mask for sliding window attention (PyTorch)."""
|
164
|
-
|
165
|
-
|
166
|
-
)
|
168
|
+
# Use torch.arange to create a tensor with a range of integers in a
|
169
|
+
# Dynamo-friendly way.
|
170
|
+
cache_positions = torch.arange(cache_len, dtype=torch.int32)
|
167
171
|
cache_positions = cache_positions.view(1, 1, -1) # [1, 1, cache_len]
|
168
172
|
segment_pos_expanded = segment_pos.clone().unsqueeze(-1) # [B, seq_len, 1]
|
169
173
|
|
@@ -329,10 +333,7 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
|
|
329
333
|
The model config for a Gemma 1B model.
|
330
334
|
"""
|
331
335
|
norm_config = cfg.NormalizationConfig(
|
332
|
-
type=cfg.NormalizationType.RMS_NORM,
|
333
|
-
epsilon=1e-6,
|
334
|
-
zero_centered=True,
|
335
|
-
enable_hlfb=True,
|
336
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True,
|
336
337
|
)
|
337
338
|
ff_config = cfg.FeedForwardConfig(
|
338
339
|
type=cfg.FeedForwardType.GATED,
|
@@ -379,7 +380,6 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
|
|
379
380
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
380
381
|
final_norm_config=norm_config,
|
381
382
|
lm_head_use_bias=False,
|
382
|
-
enable_hlfb=True,
|
383
383
|
final_logit_softcap=None,
|
384
384
|
)
|
385
385
|
return config
|
@@ -158,9 +158,7 @@ def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
|
|
158
158
|
image_projection_scale=128**0.5,
|
159
159
|
image_projection_use_bias=False,
|
160
160
|
mm_norm_config=cfg.NormalizationConfig(
|
161
|
-
type=cfg.NormalizationType.LAYER_NORM,
|
162
|
-
epsilon=1e-6,
|
163
|
-
enable_hlfb=True,
|
161
|
+
type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
|
164
162
|
),
|
165
163
|
mm_extra_tokens=32,
|
166
164
|
)
|
@@ -98,9 +98,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
|
|
98
98
|
output_proj_use_bias=True,
|
99
99
|
)
|
100
100
|
norm_config = cfg.NormalizationConfig(
|
101
|
-
type=cfg.NormalizationType.LAYER_NORM,
|
102
|
-
epsilon=1e-6,
|
103
|
-
enable_hlfb=True,
|
101
|
+
type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
|
104
102
|
)
|
105
103
|
ff_config = cfg.FeedForwardConfig(
|
106
104
|
type=cfg.FeedForwardType.SEQUENTIAL,
|
@@ -123,7 +121,6 @@ def get_image_encoder_config() -> cfg.ModelConfig:
|
|
123
121
|
image_embedding=image_embedding_config,
|
124
122
|
block_configs=block_config,
|
125
123
|
final_norm_config=norm_config,
|
126
|
-
enable_hlfb=True,
|
127
124
|
num_mm_tokens_per_image=256,
|
128
125
|
)
|
129
126
|
return config
|
@@ -45,9 +45,7 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
45
45
|
intermediate_size=8960,
|
46
46
|
)
|
47
47
|
norm_config = cfg.NormalizationConfig(
|
48
|
-
type=cfg.NormalizationType.RMS_NORM,
|
49
|
-
epsilon=1e-06,
|
50
|
-
enable_hlfb=True,
|
48
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
51
49
|
)
|
52
50
|
block_config = cfg.TransformerBlockConfig(
|
53
51
|
attn_config=attn_config,
|
@@ -63,7 +61,6 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
63
61
|
kv_cache_max_len=kv_cache_max_len,
|
64
62
|
block_configs=block_config,
|
65
63
|
final_norm_config=norm_config,
|
66
|
-
enable_hlfb=True,
|
67
64
|
)
|
68
65
|
return config
|
69
66
|
|
@@ -121,9 +121,7 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
121
121
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
122
122
|
intermediate_size=8192,
|
123
123
|
)
|
124
|
-
norm_config = cfg.NormalizationConfig(
|
125
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
|
126
|
-
)
|
124
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
127
125
|
block_config = cfg.TransformerBlockConfig(
|
128
126
|
attn_config=attn_config,
|
129
127
|
ff_config=ff_config,
|
@@ -152,7 +150,6 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
152
150
|
kv_cache_max_len=kv_cache_max_len,
|
153
151
|
block_configs=block_config,
|
154
152
|
final_norm_config=norm_config,
|
155
|
-
enable_hlfb=True,
|
156
153
|
build_rope=build_rope,
|
157
154
|
)
|
158
155
|
return config
|
@@ -53,7 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
53
53
|
The model config for an OpenELM model.
|
54
54
|
"""
|
55
55
|
norm_config = cfg.NormalizationConfig(
|
56
|
-
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
|
56
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
|
57
57
|
)
|
58
58
|
num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
|
59
59
|
num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
|
@@ -101,7 +101,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
101
101
|
kv_cache_max_len=kv_cache_max_len,
|
102
102
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
103
103
|
final_norm_config=norm_config,
|
104
|
-
enable_hlfb=True,
|
105
104
|
)
|
106
105
|
return config
|
107
106
|
|
@@ -110,10 +110,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
110
110
|
intermediate_size=16384,
|
111
111
|
)
|
112
112
|
norm_config = cfg.NormalizationConfig(
|
113
|
-
type=cfg.NormalizationType.RMS_NORM,
|
114
|
-
epsilon=1e-6,
|
115
|
-
zero_centered=True,
|
116
|
-
enable_hlfb=True,
|
113
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
|
117
114
|
)
|
118
115
|
block_config = cfg.TransformerBlockConfig(
|
119
116
|
attn_config=attn_config,
|
@@ -132,7 +129,6 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
132
129
|
block_configs=block_config,
|
133
130
|
final_norm_config=norm_config,
|
134
131
|
lm_head_use_bias=False,
|
135
|
-
enable_hlfb=True,
|
136
132
|
)
|
137
133
|
return config
|
138
134
|
|
@@ -93,10 +93,7 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
93
93
|
The model config for the decoder of a PaliGemma 3B model.
|
94
94
|
"""
|
95
95
|
norm_config = cfg.NormalizationConfig(
|
96
|
-
type=cfg.NormalizationType.RMS_NORM,
|
97
|
-
epsilon=1e-6,
|
98
|
-
zero_centered=True,
|
99
|
-
enable_hlfb=True,
|
96
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
|
100
97
|
)
|
101
98
|
ff_config = cfg.FeedForwardConfig(
|
102
99
|
type=cfg.FeedForwardType.GATED,
|
@@ -140,7 +137,6 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
140
137
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
141
138
|
final_norm_config=norm_config,
|
142
139
|
lm_head_use_bias=False,
|
143
|
-
enable_hlfb=True,
|
144
140
|
final_logit_softcap=30.0,
|
145
141
|
)
|
146
142
|
return config
|
@@ -118,9 +118,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
|
|
118
118
|
use_bias=True,
|
119
119
|
)
|
120
120
|
norm_config = cfg.NormalizationConfig(
|
121
|
-
type=cfg.NormalizationType.LAYER_NORM,
|
122
|
-
epsilon=1e-6,
|
123
|
-
enable_hlfb=True,
|
121
|
+
type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
|
124
122
|
)
|
125
123
|
block_config = cfg.TransformerBlockConfig(
|
126
124
|
attn_config=attn_config,
|
@@ -137,7 +135,6 @@ def get_image_encoder_config() -> cfg.ModelConfig:
|
|
137
135
|
image_embedding=image_embedding_config,
|
138
136
|
block_configs=block_config,
|
139
137
|
final_norm_config=norm_config,
|
140
|
-
enable_hlfb=True,
|
141
138
|
)
|
142
139
|
return config
|
143
140
|
|
@@ -66,9 +66,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
66
66
|
intermediate_size=10240,
|
67
67
|
use_bias=True,
|
68
68
|
)
|
69
|
-
norm_config = cfg.NormalizationConfig(
|
70
|
-
type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=True
|
71
|
-
)
|
69
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
72
70
|
block_config = cfg.TransformerBlockConfig(
|
73
71
|
attn_config=attn_config,
|
74
72
|
ff_config=ff_config,
|
@@ -85,7 +83,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
85
83
|
final_norm_config=norm_config,
|
86
84
|
lm_head_use_bias=True,
|
87
85
|
lm_head_share_weight_with_embedding=False,
|
88
|
-
enable_hlfb=True,
|
89
86
|
)
|
90
87
|
return config
|
91
88
|
|
@@ -162,9 +162,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
162
162
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
163
163
|
intermediate_size=8192,
|
164
164
|
)
|
165
|
-
norm_config = cfg.NormalizationConfig(
|
166
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
|
167
|
-
)
|
165
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
168
166
|
block_config = cfg.TransformerBlockConfig(
|
169
167
|
attn_config=attn_config,
|
170
168
|
ff_config=ff_config,
|
@@ -192,7 +190,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
192
190
|
block_configs=block_config,
|
193
191
|
final_norm_config=norm_config,
|
194
192
|
lm_head_share_weight_with_embedding=False,
|
195
|
-
enable_hlfb=True,
|
196
193
|
build_rope=build_rope,
|
197
194
|
)
|
198
195
|
return config
|
@@ -112,9 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
112
112
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
113
113
|
intermediate_size=8192,
|
114
114
|
)
|
115
|
-
norm_config = cfg.NormalizationConfig(
|
116
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
117
|
-
)
|
115
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
118
116
|
block_config = cfg.TransformerBlockConfig(
|
119
117
|
attn_config=attn_config,
|
120
118
|
ff_config=ff_config,
|
@@ -141,7 +139,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
141
139
|
embedding_dim=3072,
|
142
140
|
block_configs=block_config,
|
143
141
|
final_norm_config=norm_config,
|
144
|
-
enable_hlfb=True,
|
145
142
|
build_rope=build_rope,
|
146
143
|
)
|
147
144
|
return config
|
@@ -53,9 +53,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
53
53
|
intermediate_size=11008,
|
54
54
|
)
|
55
55
|
norm_config = cfg.NormalizationConfig(
|
56
|
-
type=cfg.NormalizationType.RMS_NORM,
|
57
|
-
epsilon=1e-06,
|
58
|
-
enable_hlfb=True,
|
56
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
59
57
|
)
|
60
58
|
block_config = cfg.TransformerBlockConfig(
|
61
59
|
attn_config=attn_config,
|
@@ -71,7 +69,6 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
71
69
|
kv_cache_max_len=kv_cache_max_len,
|
72
70
|
block_configs=block_config,
|
73
71
|
final_norm_config=norm_config,
|
74
|
-
enable_hlfb=True,
|
75
72
|
)
|
76
73
|
return config
|
77
74
|
|
@@ -97,7 +97,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
97
97
|
intermediate_size=11008,
|
98
98
|
)
|
99
99
|
norm_config = cfg.NormalizationConfig(
|
100
|
-
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
100
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
101
101
|
)
|
102
102
|
block_config = cfg.TransformerBlockConfig(
|
103
103
|
attn_config=attn_config,
|
@@ -113,7 +113,6 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
113
113
|
kv_cache_max_len=kv_cache_max_len,
|
114
114
|
block_configs=block_config,
|
115
115
|
final_norm_config=norm_config,
|
116
|
-
enable_hlfb=True,
|
117
116
|
)
|
118
117
|
return config
|
119
118
|
|
@@ -332,8 +332,7 @@ def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
|
|
332
332
|
use_bias=True,
|
333
333
|
)
|
334
334
|
norm_config = cfg.NormalizationConfig(
|
335
|
-
type=cfg.NormalizationType.RMS_NORM,
|
336
|
-
epsilon=1e-6,
|
335
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
|
337
336
|
)
|
338
337
|
block_config = cfg.TransformerBlockConfig(
|
339
338
|
attn_config=attn_config,
|
@@ -359,7 +358,6 @@ def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
|
|
359
358
|
window_size=112,
|
360
359
|
spatial_merge_size=2,
|
361
360
|
full_atten_block_indexes=[7, 15, 23, 31],
|
362
|
-
enable_hlfb=True,
|
363
361
|
)
|
364
362
|
return config
|
365
363
|
|
@@ -51,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
51
51
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
52
52
|
intermediate_size=1536,
|
53
53
|
)
|
54
|
-
norm_config = cfg.NormalizationConfig(
|
55
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
56
|
-
)
|
54
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
57
55
|
block_config = cfg.TransformerBlockConfig(
|
58
56
|
attn_config=attn_config,
|
59
57
|
ff_config=ff_config,
|
@@ -68,7 +66,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
68
66
|
kv_cache_max_len=kv_cache_max_len,
|
69
67
|
block_configs=block_config,
|
70
68
|
final_norm_config=norm_config,
|
71
|
-
enable_hlfb=True,
|
72
69
|
)
|
73
70
|
return config
|
74
71
|
|
@@ -113,7 +113,9 @@ def get_model_config() -> cfg.ModelConfig:
|
|
113
113
|
use_bias=True,
|
114
114
|
)
|
115
115
|
|
116
|
-
norm_config = cfg.NormalizationConfig(
|
116
|
+
norm_config = cfg.NormalizationConfig(
|
117
|
+
type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=False
|
118
|
+
)
|
117
119
|
|
118
120
|
block_config = cfg.TransformerBlockConfig(
|
119
121
|
attn_config=attn_config,
|
@@ -129,7 +131,6 @@ def get_model_config() -> cfg.ModelConfig:
|
|
129
131
|
embedding_dim=embedding_dim,
|
130
132
|
block_configs=block_config,
|
131
133
|
final_norm_config=norm_config,
|
132
|
-
enable_hlfb=True,
|
133
134
|
)
|
134
135
|
|
135
136
|
return config
|
@@ -164,7 +165,9 @@ def get_fake_model_config() -> cfg.ModelConfig:
|
|
164
165
|
use_bias=True,
|
165
166
|
)
|
166
167
|
|
167
|
-
norm_config = cfg.NormalizationConfig(
|
168
|
+
norm_config = cfg.NormalizationConfig(
|
169
|
+
type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=False
|
170
|
+
)
|
168
171
|
|
169
172
|
block_config = cfg.TransformerBlockConfig(
|
170
173
|
attn_config=attn_config,
|
@@ -180,7 +183,6 @@ def get_fake_model_config() -> cfg.ModelConfig:
|
|
180
183
|
embedding_dim=embedding_dim,
|
181
184
|
block_configs=block_config,
|
182
185
|
final_norm_config=norm_config,
|
183
|
-
enable_hlfb=True,
|
184
186
|
)
|
185
187
|
|
186
188
|
return config
|
@@ -393,8 +393,7 @@ def get_model_config_t5() -> cfg.ModelConfig:
|
|
393
393
|
)
|
394
394
|
# T5 Confirmed as RMS Norm and eps = 1e-6 TJA.
|
395
395
|
norm_config = cfg.NormalizationConfig(
|
396
|
-
type=cfg.NormalizationType.RMS_NORM,
|
397
|
-
epsilon=1e-6,
|
396
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, enable_hlfb=False
|
398
397
|
)
|
399
398
|
block_config = cfg.TransformerBlockConfig(
|
400
399
|
attn_config=attn_config,
|
@@ -411,7 +410,6 @@ def get_model_config_t5() -> cfg.ModelConfig:
|
|
411
410
|
block_configs=block_config,
|
412
411
|
final_norm_config=norm_config,
|
413
412
|
lm_head_use_bias=False,
|
414
|
-
enable_hlfb=True,
|
415
413
|
)
|
416
414
|
return config
|
417
415
|
|
@@ -138,7 +138,9 @@ def get_model_config() -> cfg.ModelConfig:
|
|
138
138
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
139
139
|
intermediate_size=256,
|
140
140
|
)
|
141
|
-
norm_config = cfg.NormalizationConfig(
|
141
|
+
norm_config = cfg.NormalizationConfig(
|
142
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=False
|
143
|
+
)
|
142
144
|
block_config = cfg.TransformerBlockConfig(
|
143
145
|
attn_config=attn_config,
|
144
146
|
ff_config=ff_config,
|
@@ -152,5 +154,6 @@ def get_model_config() -> cfg.ModelConfig:
|
|
152
154
|
embedding_dim=128,
|
153
155
|
block_configs=block_config,
|
154
156
|
final_norm_config=norm_config,
|
157
|
+
enable_hlfb=False,
|
155
158
|
)
|
156
159
|
return config
|
@@ -108,7 +108,9 @@ def get_model_config() -> cfg.ModelConfig:
|
|
108
108
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
109
109
|
intermediate_size=256,
|
110
110
|
)
|
111
|
-
norm_config = cfg.NormalizationConfig(
|
111
|
+
norm_config = cfg.NormalizationConfig(
|
112
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=False
|
113
|
+
)
|
112
114
|
block_config = cfg.TransformerBlockConfig(
|
113
115
|
attn_config=attn_config,
|
114
116
|
ff_config=ff_config,
|
@@ -122,7 +124,6 @@ def get_model_config() -> cfg.ModelConfig:
|
|
122
124
|
embedding_dim=128,
|
123
125
|
block_configs=block_config,
|
124
126
|
final_norm_config=norm_config,
|
125
|
-
enable_hlfb=True,
|
126
127
|
)
|
127
128
|
return config
|
128
129
|
|
@@ -51,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
51
51
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
52
52
|
intermediate_size=5632,
|
53
53
|
)
|
54
|
-
norm_config = cfg.NormalizationConfig(
|
55
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
56
|
-
)
|
54
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
57
55
|
block_config = cfg.TransformerBlockConfig(
|
58
56
|
attn_config=attn_config,
|
59
57
|
ff_config=ff_config,
|
@@ -69,7 +67,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
69
67
|
block_configs=block_config,
|
70
68
|
final_norm_config=norm_config,
|
71
69
|
lm_head_share_weight_with_embedding=False,
|
72
|
-
enable_hlfb=True,
|
73
70
|
)
|
74
71
|
return config
|
75
72
|
|
@@ -66,7 +66,7 @@ class NormalizationConfig:
|
|
66
66
|
"""Normalizater parameters."""
|
67
67
|
|
68
68
|
type: NormalizationType = NormalizationType.NONE
|
69
|
-
enable_hlfb: bool =
|
69
|
+
enable_hlfb: bool = True
|
70
70
|
epsilon: float = 1e-5
|
71
71
|
zero_centered: bool = False
|
72
72
|
# Number of groups used in group normalization.
|
@@ -218,7 +218,7 @@ class ModelConfig:
|
|
218
218
|
lm_head_share_weight_with_embedding: bool = True
|
219
219
|
|
220
220
|
# Whether to turn on high-level function boundary.
|
221
|
-
enable_hlfb: bool =
|
221
|
+
enable_hlfb: bool = True
|
222
222
|
|
223
223
|
# The maximum sequence length of the KV cache. Should not exceed max_seq_len.
|
224
224
|
kv_cache_max_len: int = 0
|
@@ -100,7 +100,8 @@ def define_conversion_flags(
|
|
100
100
|
flags.DEFINE_string(
|
101
101
|
'quantize',
|
102
102
|
'dynamic_int8',
|
103
|
-
'How the model should be quantized.'
|
103
|
+
'How the model should be quantized. Set to "none" to disable'
|
104
|
+
' quantization. See `QuantizationName` for supported quantization types.',
|
104
105
|
)
|
105
106
|
flags.DEFINE_multi_integer(
|
106
107
|
'lora_ranks',
|
@@ -268,3 +268,16 @@ def numpy_array_constant(x: np.ndarray | np.generic) -> IrValues:
|
|
268
268
|
x = np.ascontiguousarray(x)
|
269
269
|
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore
|
270
270
|
return stablehlo.constant(attr)
|
271
|
+
|
272
|
+
|
273
|
+
def convert_to_ir_value(
|
274
|
+
value: ir.Value | int | float | np.ndarray | np.generic,
|
275
|
+
) -> ir.Value:
|
276
|
+
if isinstance(value, (np.ndarray, np.generic)):
|
277
|
+
return numpy_array_constant(value)
|
278
|
+
if isinstance(value, (int, float)):
|
279
|
+
dtype = np.float32 if isinstance(value, float) else np.int32
|
280
|
+
return numpy_array_constant(np.array([value], dtype=dtype))
|
281
|
+
if isinstance(value, ir.Value):
|
282
|
+
return value
|
283
|
+
raise TypeError(f"Unsupported type for conversion to ir.Value: {type(value)}")
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250517
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,16 +2,17 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=1nWIqrcLl_lGq0WnfPfnBgtlAECzCufGTdwLlpIpp_Y,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
7
|
+
ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
9
9
|
ai_edge_torch/_convert/converter.py,sha256=6MLKELzAwFoiXv-b7KRYi7gc7Z57XOeowcz9ArIl9TM,12100
|
10
10
|
ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
|
11
11
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
12
|
-
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=
|
12
|
+
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=uHek7k9KIW3kaEM_lcygbukJ69JLjm-xnYUWzAEIZII,1345
|
13
13
|
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=dgUO-lI9Id9hIOHP5XmegVlu5Fl79GR4_b-lDUehzoo,11428
|
14
14
|
ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py,sha256=90YxLVAAkiA3qKr4Um__JmPeC1bTeA2PxBCj0GETq1Q,1748
|
15
|
+
ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py,sha256=jMl9YHIbx08KQHbp9UgDnxviUUWiN-FSsiUgR2HCT5s,1576
|
15
16
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
|
16
17
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
|
17
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=UKC-wM93-oe8spxyFqgybJ0TwnSRw8f-SOA2glCh2FA,890
|
@@ -52,48 +53,48 @@ ai_edge_torch/generative/custom_ops/bmm_4d.py,sha256=JmVbZCujG_wuBchma8QF3DSBfVc
|
|
52
53
|
ai_edge_torch/generative/custom_ops/dynamic_update_slice.py,sha256=ZGAq2CfWZsfef5mHulsWmyUx0dDWJX6J6xPjhBrjQdM,2097
|
53
54
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
54
55
|
ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
55
|
-
ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=
|
56
|
+
ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=PTKcl-CHQnzExQSfrwG9YC0KPc8zomG7WlPabXtZLx4,2910
|
56
57
|
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=s2f5TJos6rSgogqeFk0qsOpI30qsR04umk9hAAZ5918,1782
|
57
58
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=o13NkFlBgawBsjdJup05VMUjAPvDRAmig6VyEkX8q6U,2426
|
58
59
|
ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
59
60
|
ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=xTPfT3Mt_4bMfGkrqDKatLecZOuaE0WhxXs3uAsO_uU,1749
|
60
|
-
ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=
|
61
|
+
ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=9gUnK1IOifQyYpm03f64Mzg-afwbYY9kVWz6-ynq8zY,3014
|
61
62
|
ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
|
62
63
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
63
64
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=t2qZTjyM2imPenb14fzbQ-CHj5Cejw4M5xfEZpgX6Uc,1748
|
64
65
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=Yj-b4S9BNxArnGjruRIymCiWrlf7ZvwiG6keTVGldk4,1816
|
65
|
-
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=
|
66
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
66
|
+
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=wV_tr51UIwiKki8u5i-Q2YFvKYpTIIXEdUKFbdbMhRo,3621
|
67
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=b3zCFOjeU6T7K2PLUBABurpf7UjRIsGKkOym1wRuJOg,11630
|
67
68
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
68
69
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=jhiyinOqPt5ZZjEadDRZt_wY5fiLSCpMo54PcxFaL_Q,1789
|
69
70
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=n7f2nF6Lin_tDvPs0JVldsuaBzo7pAwi5YAHAhlIxQg,6139
|
70
71
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
71
72
|
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=wOrOV_jxCnjrhjC8X0-uIi0D-4aQjOfXw6XaxTSrM9k,2048
|
72
|
-
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=
|
73
|
-
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=
|
74
|
-
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=
|
73
|
+
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=GC22bZRTtO8IczccYpqh5nSE0FHJK3I0M9oaofrr-Ss,15344
|
74
|
+
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=CPk3VJUobga0MVVIVRyWhdsrlCBWdQgF5kdSw7Yo--Y,6543
|
75
|
+
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=tUOI99kdny33qcDM7-z0R6F-1aU1lZ24kG5zeLVdwow,5129
|
75
76
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
76
77
|
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=1vfAtayH_I_qTpqhzu6n9xnCuvhgTzhS8IzZviW2dJQ,9418
|
77
78
|
ai_edge_torch/generative/examples/hammer/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
78
79
|
ai_edge_torch/generative/examples/hammer/convert_to_tflite.py,sha256=XLmPuJCBJjKzMTG-mRmBX92juep2zl5yYeMrEhdqQQk,1975
|
79
|
-
ai_edge_torch/generative/examples/hammer/hammer.py,sha256=
|
80
|
+
ai_edge_torch/generative/examples/hammer/hammer.py,sha256=aiGRdmJbtcePRde7l_Vte61rPh_4F-zcxNuGtg_ceTY,3649
|
80
81
|
ai_edge_torch/generative/examples/hammer/verify.py,sha256=MkzAGkbPy4LKRhyCDm1cw-9jUt4VUxLPdwK_25fCGSE,2705
|
81
82
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
82
83
|
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=4qnMyvJHqhqf9k01wEsO23BKo6tSy2KD7sHdTGimKGg,1957
|
83
|
-
ai_edge_torch/generative/examples/llama/llama.py,sha256=
|
84
|
+
ai_edge_torch/generative/examples/llama/llama.py,sha256=eWPFnuSxhjuk5XZmvtndu_Z1-e9NlZg7-uFfiOqJXfw,6952
|
84
85
|
ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
|
85
86
|
ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
86
87
|
ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=_GkaSkregS3NWN38UGXxj4pED5gtQGaaPZx5_CZ0TVM,1657
|
87
88
|
ai_edge_torch/generative/examples/moonshine/moonshine.py,sha256=nZ2b8u4TmsB5sgdClgAuH8E78bcTv9RCnF9666HqP2M,3394
|
88
89
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
89
90
|
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=S7OP8PJcOQbm8AHvi_Tc3qnQuVOtjMFNlwaZQ_oirUM,1747
|
90
|
-
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=
|
91
|
+
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=R_E0hXsg6l8ANEgBBy0R8egz3p4ONJvBmPWs6sXx63M,4692
|
91
92
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=4W26ZtPF5Cb9mpHYuRM4b2QB_4W76zf4WV36KzexVjs,2446
|
92
93
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
93
94
|
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=Fl4k-lcpiUaJS0A1E7HVVUW7iTcZAU4FbA4KcSkO5SQ,2212
|
94
|
-
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=
|
95
|
-
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=
|
96
|
-
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=
|
95
|
+
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=GLlfbJr3ZIzmH643IwXyrG54qKEYMPRsvhU6gXXi7yg,5490
|
96
|
+
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=mSqB-E7JHjwhaEf1p2STxc5DWLKAGE47GTAwtM0EyDU,6039
|
97
|
+
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=V0RrkocOe-y2EDvcg8DMcSpWzzHUruQAEofHn20Jw7M,5589
|
97
98
|
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=nxvcurGkFJcCjjgVkK59SJgp8mZ71D56bEnrjvGgPs4,6264
|
98
99
|
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=myHdeIAtVTOqb915h661CnvjvFkwmihy3Vp4UrKHb5I,6195
|
99
100
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
@@ -103,20 +104,20 @@ ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_
|
|
103
104
|
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=kYgZAIHXolUhOyDAYDuEK7RZ5ExL1YzpqtlcZjo622c,1736
|
104
105
|
ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py,sha256=3y3vYlJjLjRmz4Vsq-B8YKyp0LnC2fj1LAACW3pQivI,1734
|
105
106
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=tY5uwRu-4Jxro7Z9jsDqZR9SUDWB8PR6JKfswvsUSxM,1735
|
106
|
-
ai_edge_torch/generative/examples/phi/phi2.py,sha256=
|
107
|
-
ai_edge_torch/generative/examples/phi/phi3.py,sha256=
|
108
|
-
ai_edge_torch/generative/examples/phi/phi4.py,sha256=
|
107
|
+
ai_edge_torch/generative/examples/phi/phi2.py,sha256=RprdXLbdG5lDOAgN4xolZZwp74vbRHHuf_-CzjnI5cA,3602
|
108
|
+
ai_edge_torch/generative/examples/phi/phi3.py,sha256=LW1E0C7A3IIyB5CLbVt914YJB4Fx6mbXh4xXibDHA2w,7054
|
109
|
+
ai_edge_torch/generative/examples/phi/phi4.py,sha256=ZHA0Rq7ifgxiHC_8PJf-y7WCA7i_2SlsiGibyOMBP4s,5837
|
109
110
|
ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
|
110
111
|
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
|
111
112
|
ai_edge_torch/generative/examples/phi/verify_phi4.py,sha256=BoCa5kUBRHtMQ-5ql6yD4pG4xHJMyUiQlpMOWVx-JgY,2356
|
112
113
|
ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
113
114
|
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=TnzyARHQgmWeOdYsV9WpRj5vhKGBH0kAbp3tMj8ZCYw,1998
|
114
|
-
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=
|
115
|
+
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=EcIHVeBcJLc290TiPkPfE7jdG_VXZYKlVGf0XQXzqo8,4554
|
115
116
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
116
117
|
ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
117
118
|
ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=BM-ed7KrmPwzI3MvDs2R7P-kJgE1SK_cNVqIfXhtJjs,2411
|
118
|
-
ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=
|
119
|
-
ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=
|
119
|
+
ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=plOi-3LltxReW_HVxhxwee_rYCQq-gsOwbGZtRsM8N8,4443
|
120
|
+
ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=nuAHb-RXqTffpwjwCHOd_2mCrSMwL6Q1z_yjsU64gmI,14992
|
120
121
|
ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=1Ac28olo0OJExZRyxqm7vxcf7GtXdkUwEbHvhiCHi0o,7908
|
121
122
|
ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=4WKgAFQNQzwmeJhC8ayI5vjGj9ko6VcU2HA3VAkhHug,5812
|
122
123
|
ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
|
@@ -124,11 +125,11 @@ ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=PZ392nD
|
|
124
125
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
125
126
|
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=QVRX_ovqBQi8fKAG6PezaO1qoRvMGpVxNH-_sds0pf8,1997
|
126
127
|
ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=rOVYSaS68_otJcGewQSconBCPD4GhDEIIyquD4dSUWc,1979
|
127
|
-
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=
|
128
|
+
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=nQRiq6phJbtl3GAEEsJ_bPP_zrpQmiPumNEWCRrECn0,4028
|
128
129
|
ai_edge_torch/generative/examples/smollm/verify.py,sha256=sH3rn1TbaCusPiUD5XlECiHY0rvoHIXALbk7ECOiinI,2720
|
129
130
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
130
131
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
131
|
-
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=
|
132
|
+
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=lSCRZsoLjH_kqasRMwCy5IogkhyJdwcHKsPEfyxsXCQ,6112
|
132
133
|
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=_yk6wVoZm1_FRMFJF5URaPZNNdmMR89fwmKz81BEyao,5601
|
133
134
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=afyHXc86h-ij5zTULmZnM1h313N9VWCyIVriH6pqeSo,16368
|
134
135
|
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=ylqXOZhYc6XFCaNBKQw0jAnYrCtRFFQKzQzEsFIntvo,34890
|
@@ -143,15 +144,15 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=ZE6H
|
|
143
144
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=RxR5rw0wFFm_5CfAY-3-EIz83vhM9EKye8Bb5zBb0Ok,1341
|
144
145
|
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
145
146
|
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=HHtZTtUh3QgE4F74-ru_8n1pt6cqfbObw12xoaMJ7NQ,4596
|
146
|
-
ai_edge_torch/generative/examples/t5/t5.py,sha256=
|
147
|
+
ai_edge_torch/generative/examples/t5/t5.py,sha256=aL-wDJDea9qbIe__oyKbK3g1p1xHq9-_88QsF95-JUs,21251
|
147
148
|
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqXquaFQPvCFBFF5zOnmGVb3Hg,8731
|
148
149
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
149
150
|
ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
|
150
|
-
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=
|
151
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256
|
151
|
+
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=s3l__g3l0DpgXEPs1ikcJqSS7OfWzgYdkOLnEwdjcUo,5257
|
152
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=rVTaNQxlf3-mv9e6s33V9CDd_cmVwPti1A3ARUAwSD4,4766
|
152
153
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
153
154
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=urWkWjOaGzV2gwMXoGEs1mfHNEXfEKgwuXmQ0lrWcbM,1761
|
154
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
155
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=cVNP_a_3UBML0j9ITtcITeVXqCdcC7U1JoYwir09Dk8,2936
|
155
156
|
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1tO2i0nUCqe-VkRgboA10VZ7KNg,2431
|
156
157
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx_bLdWjN01iyjPM5jE,1094
|
157
158
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
|
@@ -165,7 +166,7 @@ ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-m
|
|
165
166
|
ai_edge_torch/generative/layers/feed_forward_test.py,sha256=8ZGy79BBpsyS6yKKDEKrDt249G5Mz-8VKWW7_WHx0u4,1655
|
166
167
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=b-7shzDaKexmvQF7P3SiAmIz4ZofjYWv3m5u71GojsA,10460
|
167
168
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
168
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
169
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=H1MpjP1Ij1r4DEcE4cQ_6A8h0QvUjCkuGATXMkIMIWg,8570
|
169
170
|
ai_edge_torch/generative/layers/normalization.py,sha256=ijwCpi22NLX-Sygwy5sK9l9WjGvbPIhZvVwoBAonWAo,7014
|
170
171
|
ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
|
171
172
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
@@ -193,7 +194,7 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hP
|
|
193
194
|
ai_edge_torch/generative/test/test_quantize.py,sha256=kKJ01wscTC2t_Ylr7huO5gNKES01gm3dT1gx52z15PA,7356
|
194
195
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
195
196
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
196
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
197
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=VRI960xo86g6lGLc_II3vDovFMa2DGIxnAZgE2GfSiM,15530
|
197
198
|
ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
|
198
199
|
ai_edge_torch/generative/utilities/loader.py,sha256=y1uSkUBiR0b9U4aoCQQk9qk7ctya_vEeY28Wc0A5e2s,15504
|
199
200
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=tBfOcsI_NcneggHqkCSydYN3ZgmkzPc6nW0AJrA81wI,6461
|
@@ -242,7 +243,7 @@ ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQD
|
|
242
243
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
243
244
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
244
245
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
245
|
-
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=
|
246
|
+
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=wHIhaKGWxf_x4W750_nzKoYMBDlq2Fd6b05-XOEVyVQ,9465
|
246
247
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
247
248
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
248
249
|
ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
|
@@ -252,8 +253,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
252
253
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
253
254
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
254
255
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
255
|
-
ai_edge_torch_nightly-0.5.0.
|
256
|
-
ai_edge_torch_nightly-0.5.0.
|
257
|
-
ai_edge_torch_nightly-0.5.0.
|
258
|
-
ai_edge_torch_nightly-0.5.0.
|
259
|
-
ai_edge_torch_nightly-0.5.0.
|
256
|
+
ai_edge_torch_nightly-0.5.0.dev20250517.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
257
|
+
ai_edge_torch_nightly-0.5.0.dev20250517.dist-info/METADATA,sha256=L8usKutHqIHjPswWNp7b3ynJzJiS5cTC9YWTnP82Qm8,2074
|
258
|
+
ai_edge_torch_nightly-0.5.0.dev20250517.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
259
|
+
ai_edge_torch_nightly-0.5.0.dev20250517.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
260
|
+
ai_edge_torch_nightly-0.5.0.dev20250517.dist-info/RECORD,,
|
File without changes
|
File without changes
|