ai-edge-torch-nightly 0.3.0.dev20250125__py3-none-any.whl → 0.3.0.dev20250127__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/gemma2.py +31 -6
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +7 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250125.dist-info → ai_edge_torch_nightly-0.3.0.dev20250127.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250125.dist-info → ai_edge_torch_nightly-0.3.0.dev20250127.dist-info}/RECORD +8 -8
- {ai_edge_torch_nightly-0.3.0.dev20250125.dist-info → ai_edge_torch_nightly-0.3.0.dev20250127.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250125.dist-info → ai_edge_torch_nightly-0.3.0.dev20250127.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250125.dist-info → ai_edge_torch_nightly-0.3.0.dev20250127.dist-info}/top_level.txt +0 -0
@@ -43,6 +43,22 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
43
43
|
lm_head=None,
|
44
44
|
)
|
45
45
|
|
46
|
+
ALT_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
47
|
+
ff_up_proj="model.layers.{}.mlp.up_proj",
|
48
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
49
|
+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
50
|
+
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
51
|
+
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
52
|
+
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
53
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
54
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
55
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
56
|
+
pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
|
57
|
+
post_ff_norm="model.layers.{}.post_feedforward_layernorm",
|
58
|
+
embedding="model.embed_tokens",
|
59
|
+
final_norm="model.norm",
|
60
|
+
)
|
61
|
+
|
46
62
|
|
47
63
|
class Gemma2Block(attention.TransformerBlock):
|
48
64
|
|
@@ -281,9 +297,18 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
281
297
|
|
282
298
|
|
283
299
|
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
300
|
+
try:
|
301
|
+
return model_builder.build_decoder_only_model(
|
302
|
+
checkpoint_path=checkpoint_path,
|
303
|
+
config=get_model_config_2b(**kwargs),
|
304
|
+
tensor_names=TENSOR_NAMES,
|
305
|
+
model_class=Gemma2,
|
306
|
+
)
|
307
|
+
except KeyError as ke:
|
308
|
+
# Also attempt to load with an alternative naming scheme.
|
309
|
+
return model_builder.build_decoder_only_model(
|
310
|
+
checkpoint_path=checkpoint_path,
|
311
|
+
config=get_model_config_2b(**kwargs),
|
312
|
+
tensor_names=ALT_TENSOR_NAMES,
|
313
|
+
model_class=Gemma2,
|
314
|
+
)
|
@@ -16,11 +16,17 @@ from ai_edge_torch import fx_infra
|
|
16
16
|
from ai_edge_torch import lowertools
|
17
17
|
import torch
|
18
18
|
|
19
|
+
fx_infra.decomp.remove_pre_convert_decomp(torch.ops.aten.zeros.default)
|
20
|
+
fx_infra.decomp.remove_pre_convert_decomp(torch.ops.aten.zeros_like.default)
|
21
|
+
|
19
22
|
|
20
23
|
class RemoveSDPACompositeZeroMaskPass(fx_infra.ExportedProgramPassBase):
|
21
24
|
|
22
25
|
def is_zero_tensor_node(self, node: torch.fx.Node):
|
23
|
-
return node.target
|
26
|
+
return node.target in (
|
27
|
+
torch.ops.aten.zeros.default,
|
28
|
+
torch.ops.aten.zeros_like.default,
|
29
|
+
)
|
24
30
|
|
25
31
|
def call(self, exported_program: torch.export.ExportedProgram):
|
26
32
|
graph = exported_program.graph_module.graph
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250127
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=t2Ud6LcvJZuaAhzUIBhYrINw1PrvLlnIB2kspVz2TXQ,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -57,7 +57,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
|
|
57
57
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
|
58
58
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
|
59
59
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
|
60
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
60
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=Zqd4l7KnfDqK-0fkGKx3prtiTyRIdxSeI187Dg-bNU4,11350
|
61
61
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
62
62
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
63
63
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
@@ -126,7 +126,7 @@ ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=VU0c5pg
|
|
126
126
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
|
127
127
|
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
|
128
128
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=4rFrppMRKlTwwZeX1ON_cdp4yUqoTOES161IZQkJF6c,1143
|
129
|
-
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=
|
129
|
+
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
|
130
130
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
131
131
|
ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8kQ4U3YANfSiTJKn8,13776
|
132
132
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
@@ -222,8 +222,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
222
222
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
223
223
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
224
224
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
225
|
-
ai_edge_torch_nightly-0.3.0.
|
226
|
-
ai_edge_torch_nightly-0.3.0.
|
227
|
-
ai_edge_torch_nightly-0.3.0.
|
228
|
-
ai_edge_torch_nightly-0.3.0.
|
229
|
-
ai_edge_torch_nightly-0.3.0.
|
225
|
+
ai_edge_torch_nightly-0.3.0.dev20250127.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
226
|
+
ai_edge_torch_nightly-0.3.0.dev20250127.dist-info/METADATA,sha256=YKdvtJfZZLT0aGrGYkzmZezBNXpk47o2JIzuyWCdOF8,1966
|
227
|
+
ai_edge_torch_nightly-0.3.0.dev20250127.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
228
|
+
ai_edge_torch_nightly-0.3.0.dev20250127.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
229
|
+
ai_edge_torch_nightly-0.3.0.dev20250127.dist-info/RECORD,,
|
File without changes
|
File without changes
|