ai-edge-torch-nightly 0.3.0.dev20250125__py3-none-any.whl → 0.3.0.dev20250127__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -43,6 +43,22 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
43
43
  lm_head=None,
44
44
  )
45
45
 
46
+ ALT_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
47
+ ff_up_proj="model.layers.{}.mlp.up_proj",
48
+ ff_down_proj="model.layers.{}.mlp.down_proj",
49
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
50
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
51
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
52
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
53
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
54
+ pre_attn_norm="model.layers.{}.input_layernorm",
55
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
56
+ pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
57
+ post_ff_norm="model.layers.{}.post_feedforward_layernorm",
58
+ embedding="model.embed_tokens",
59
+ final_norm="model.norm",
60
+ )
61
+
46
62
 
47
63
  class Gemma2Block(attention.TransformerBlock):
48
64
 
@@ -281,9 +297,18 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
281
297
 
282
298
 
283
299
  def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
284
- return model_builder.build_decoder_only_model(
285
- checkpoint_path=checkpoint_path,
286
- config=get_model_config_2b(**kwargs),
287
- tensor_names=TENSOR_NAMES,
288
- model_class=Gemma2,
289
- )
300
+ try:
301
+ return model_builder.build_decoder_only_model(
302
+ checkpoint_path=checkpoint_path,
303
+ config=get_model_config_2b(**kwargs),
304
+ tensor_names=TENSOR_NAMES,
305
+ model_class=Gemma2,
306
+ )
307
+ except KeyError as ke:
308
+ # Also attempt to load with an alternative naming scheme.
309
+ return model_builder.build_decoder_only_model(
310
+ checkpoint_path=checkpoint_path,
311
+ config=get_model_config_2b(**kwargs),
312
+ tensor_names=ALT_TENSOR_NAMES,
313
+ model_class=Gemma2,
314
+ )
@@ -16,11 +16,17 @@ from ai_edge_torch import fx_infra
16
16
  from ai_edge_torch import lowertools
17
17
  import torch
18
18
 
19
+ fx_infra.decomp.remove_pre_convert_decomp(torch.ops.aten.zeros.default)
20
+ fx_infra.decomp.remove_pre_convert_decomp(torch.ops.aten.zeros_like.default)
21
+
19
22
 
20
23
  class RemoveSDPACompositeZeroMaskPass(fx_infra.ExportedProgramPassBase):
21
24
 
22
25
  def is_zero_tensor_node(self, node: torch.fx.Node):
23
- return node.target == torch.ops.aten.zeros.default
26
+ return node.target in (
27
+ torch.ops.aten.zeros.default,
28
+ torch.ops.aten.zeros_like.default,
29
+ )
24
30
 
25
31
  def call(self, exported_program: torch.export.ExportedProgram):
26
32
  graph = exported_program.graph_module.graph
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250125"
16
+ __version__ = "0.3.0.dev20250127"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250125
3
+ Version: 0.3.0.dev20250127
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=yuz53SwRvngiQ41D-VX7MPmVGe-Vi-UR3v12E-o3P4I,706
5
+ ai_edge_torch/version.py,sha256=t2Ud6LcvJZuaAhzUIBhYrINw1PrvLlnIB2kspVz2TXQ,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -57,7 +57,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
57
57
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
58
58
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
59
59
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
60
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=CMkkTd_vO_Ej1SnmXIB0xqjRoArELOkyJ9uqjilpQeI,10298
60
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=Zqd4l7KnfDqK-0fkGKx3prtiTyRIdxSeI187Dg-bNU4,11350
61
61
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
62
62
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
63
63
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
@@ -126,7 +126,7 @@ ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=VU0c5pg
126
126
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
127
127
  ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
128
128
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=4rFrppMRKlTwwZeX1ON_cdp4yUqoTOES161IZQkJF6c,1143
129
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=n5TbXdhBZi8jQe4j7-rox_MugMVvW8ReOhkTA3pfQkw,1919
129
+ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
130
130
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
131
131
  ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8kQ4U3YANfSiTJKn8,13776
132
132
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
@@ -222,8 +222,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
222
222
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
223
223
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
224
224
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
225
- ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
226
- ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/METADATA,sha256=BkUH2iAinJYGmBLTMdeYSpihXAHY_mBOkeprZLPaDGk,1966
227
- ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
228
- ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
229
- ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/RECORD,,
225
+ ai_edge_torch_nightly-0.3.0.dev20250127.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
226
+ ai_edge_torch_nightly-0.3.0.dev20250127.dist-info/METADATA,sha256=YKdvtJfZZLT0aGrGYkzmZezBNXpk47o2JIzuyWCdOF8,1966
227
+ ai_edge_torch_nightly-0.3.0.dev20250127.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
228
+ ai_edge_torch_nightly-0.3.0.dev20250127.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
229
+ ai_edge_torch_nightly-0.3.0.dev20250127.dist-info/RECORD,,