ai-edge-torch-nightly 0.3.0.dev20250125__py3-none-any.whl → 0.3.0.dev20250126__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -43,6 +43,22 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
43
43
  lm_head=None,
44
44
  )
45
45
 
46
+ ALT_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
47
+ ff_up_proj="model.layers.{}.mlp.up_proj",
48
+ ff_down_proj="model.layers.{}.mlp.down_proj",
49
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
50
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
51
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
52
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
53
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
54
+ pre_attn_norm="model.layers.{}.input_layernorm",
55
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
56
+ pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
57
+ post_ff_norm="model.layers.{}.post_feedforward_layernorm",
58
+ embedding="model.embed_tokens",
59
+ final_norm="model.norm",
60
+ )
61
+
46
62
 
47
63
  class Gemma2Block(attention.TransformerBlock):
48
64
 
@@ -281,9 +297,18 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
281
297
 
282
298
 
283
299
  def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
284
- return model_builder.build_decoder_only_model(
285
- checkpoint_path=checkpoint_path,
286
- config=get_model_config_2b(**kwargs),
287
- tensor_names=TENSOR_NAMES,
288
- model_class=Gemma2,
289
- )
300
+ try:
301
+ return model_builder.build_decoder_only_model(
302
+ checkpoint_path=checkpoint_path,
303
+ config=get_model_config_2b(**kwargs),
304
+ tensor_names=TENSOR_NAMES,
305
+ model_class=Gemma2,
306
+ )
307
+ except KeyError as ke:
308
+ # Also attempt to load with an alternative naming scheme.
309
+ return model_builder.build_decoder_only_model(
310
+ checkpoint_path=checkpoint_path,
311
+ config=get_model_config_2b(**kwargs),
312
+ tensor_names=ALT_TENSOR_NAMES,
313
+ model_class=Gemma2,
314
+ )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250125"
16
+ __version__ = "0.3.0.dev20250126"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250125
3
+ Version: 0.3.0.dev20250126
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=yuz53SwRvngiQ41D-VX7MPmVGe-Vi-UR3v12E-o3P4I,706
5
+ ai_edge_torch/version.py,sha256=AwM7aAWvx4ye9s336KLyhdsEbDQgng3z3xegxFNjSYo,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -57,7 +57,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
57
57
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
58
58
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
59
59
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
60
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=CMkkTd_vO_Ej1SnmXIB0xqjRoArELOkyJ9uqjilpQeI,10298
60
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=Zqd4l7KnfDqK-0fkGKx3prtiTyRIdxSeI187Dg-bNU4,11350
61
61
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
62
62
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
63
63
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
@@ -222,8 +222,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
222
222
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
223
223
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
224
224
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
225
- ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
226
- ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/METADATA,sha256=BkUH2iAinJYGmBLTMdeYSpihXAHY_mBOkeprZLPaDGk,1966
227
- ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
228
- ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
229
- ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/RECORD,,
225
+ ai_edge_torch_nightly-0.3.0.dev20250126.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
226
+ ai_edge_torch_nightly-0.3.0.dev20250126.dist-info/METADATA,sha256=ansrwGIKqgtoDIyw39y1VpYIkmVrqMwPqfhRTwp8N5A,1966
227
+ ai_edge_torch_nightly-0.3.0.dev20250126.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
228
+ ai_edge_torch_nightly-0.3.0.dev20250126.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
229
+ ai_edge_torch_nightly-0.3.0.dev20250126.dist-info/RECORD,,