ai-edge-torch-nightly 0.3.0.dev20250125__py3-none-any.whl → 0.3.0.dev20250126__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/gemma2.py +31 -6
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250125.dist-info → ai_edge_torch_nightly-0.3.0.dev20250126.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250125.dist-info → ai_edge_torch_nightly-0.3.0.dev20250126.dist-info}/RECORD +7 -7
- {ai_edge_torch_nightly-0.3.0.dev20250125.dist-info → ai_edge_torch_nightly-0.3.0.dev20250126.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250125.dist-info → ai_edge_torch_nightly-0.3.0.dev20250126.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250125.dist-info → ai_edge_torch_nightly-0.3.0.dev20250126.dist-info}/top_level.txt +0 -0
@@ -43,6 +43,22 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
43
43
|
lm_head=None,
|
44
44
|
)
|
45
45
|
|
46
|
+
ALT_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
47
|
+
ff_up_proj="model.layers.{}.mlp.up_proj",
|
48
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
49
|
+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
50
|
+
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
51
|
+
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
52
|
+
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
53
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
54
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
55
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
56
|
+
pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
|
57
|
+
post_ff_norm="model.layers.{}.post_feedforward_layernorm",
|
58
|
+
embedding="model.embed_tokens",
|
59
|
+
final_norm="model.norm",
|
60
|
+
)
|
61
|
+
|
46
62
|
|
47
63
|
class Gemma2Block(attention.TransformerBlock):
|
48
64
|
|
@@ -281,9 +297,18 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
281
297
|
|
282
298
|
|
283
299
|
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
300
|
+
try:
|
301
|
+
return model_builder.build_decoder_only_model(
|
302
|
+
checkpoint_path=checkpoint_path,
|
303
|
+
config=get_model_config_2b(**kwargs),
|
304
|
+
tensor_names=TENSOR_NAMES,
|
305
|
+
model_class=Gemma2,
|
306
|
+
)
|
307
|
+
except KeyError as ke:
|
308
|
+
# Also attempt to load with an alternative naming scheme.
|
309
|
+
return model_builder.build_decoder_only_model(
|
310
|
+
checkpoint_path=checkpoint_path,
|
311
|
+
config=get_model_config_2b(**kwargs),
|
312
|
+
tensor_names=ALT_TENSOR_NAMES,
|
313
|
+
model_class=Gemma2,
|
314
|
+
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250126
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=AwM7aAWvx4ye9s336KLyhdsEbDQgng3z3xegxFNjSYo,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -57,7 +57,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
|
|
57
57
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
|
58
58
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
|
59
59
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
|
60
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
60
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=Zqd4l7KnfDqK-0fkGKx3prtiTyRIdxSeI187Dg-bNU4,11350
|
61
61
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
62
62
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
63
63
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
@@ -222,8 +222,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
222
222
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
223
223
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
224
224
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
225
|
-
ai_edge_torch_nightly-0.3.0.
|
226
|
-
ai_edge_torch_nightly-0.3.0.
|
227
|
-
ai_edge_torch_nightly-0.3.0.
|
228
|
-
ai_edge_torch_nightly-0.3.0.
|
229
|
-
ai_edge_torch_nightly-0.3.0.
|
225
|
+
ai_edge_torch_nightly-0.3.0.dev20250126.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
226
|
+
ai_edge_torch_nightly-0.3.0.dev20250126.dist-info/METADATA,sha256=ansrwGIKqgtoDIyw39y1VpYIkmVrqMwPqfhRTwp8N5A,1966
|
227
|
+
ai_edge_torch_nightly-0.3.0.dev20250126.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
228
|
+
ai_edge_torch_nightly-0.3.0.dev20250126.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
229
|
+
ai_edge_torch_nightly-0.3.0.dev20250126.dist-info/RECORD,,
|
File without changes
|
File without changes
|