ai-edge-torch-nightly 0.3.0.dev20250125__py3-none-any.whl → 0.3.0.dev20250126__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/gemma2.py +31 -6
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250125.dist-info → ai_edge_torch_nightly-0.3.0.dev20250126.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250125.dist-info → ai_edge_torch_nightly-0.3.0.dev20250126.dist-info}/RECORD +7 -7
- {ai_edge_torch_nightly-0.3.0.dev20250125.dist-info → ai_edge_torch_nightly-0.3.0.dev20250126.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250125.dist-info → ai_edge_torch_nightly-0.3.0.dev20250126.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250125.dist-info → ai_edge_torch_nightly-0.3.0.dev20250126.dist-info}/top_level.txt +0 -0
@@ -43,6 +43,22 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
43
43
|
lm_head=None,
|
44
44
|
)
|
45
45
|
|
46
|
+
ALT_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
47
|
+
ff_up_proj="model.layers.{}.mlp.up_proj",
|
48
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
49
|
+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
50
|
+
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
51
|
+
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
52
|
+
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
53
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
54
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
55
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
56
|
+
pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
|
57
|
+
post_ff_norm="model.layers.{}.post_feedforward_layernorm",
|
58
|
+
embedding="model.embed_tokens",
|
59
|
+
final_norm="model.norm",
|
60
|
+
)
|
61
|
+
|
46
62
|
|
47
63
|
class Gemma2Block(attention.TransformerBlock):
|
48
64
|
|
@@ -281,9 +297,18 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
281
297
|
|
282
298
|
|
283
299
|
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
300
|
+
try:
|
301
|
+
return model_builder.build_decoder_only_model(
|
302
|
+
checkpoint_path=checkpoint_path,
|
303
|
+
config=get_model_config_2b(**kwargs),
|
304
|
+
tensor_names=TENSOR_NAMES,
|
305
|
+
model_class=Gemma2,
|
306
|
+
)
|
307
|
+
except KeyError as ke:
|
308
|
+
# Also attempt to load with an alternative naming scheme.
|
309
|
+
return model_builder.build_decoder_only_model(
|
310
|
+
checkpoint_path=checkpoint_path,
|
311
|
+
config=get_model_config_2b(**kwargs),
|
312
|
+
tensor_names=ALT_TENSOR_NAMES,
|
313
|
+
model_class=Gemma2,
|
314
|
+
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250126
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=AwM7aAWvx4ye9s336KLyhdsEbDQgng3z3xegxFNjSYo,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -57,7 +57,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
|
|
57
57
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
|
58
58
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
|
59
59
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
|
60
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
60
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=Zqd4l7KnfDqK-0fkGKx3prtiTyRIdxSeI187Dg-bNU4,11350
|
61
61
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
62
62
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
63
63
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
@@ -222,8 +222,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
222
222
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
223
223
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
224
224
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
225
|
-
ai_edge_torch_nightly-0.3.0.
|
226
|
-
ai_edge_torch_nightly-0.3.0.
|
227
|
-
ai_edge_torch_nightly-0.3.0.
|
228
|
-
ai_edge_torch_nightly-0.3.0.
|
229
|
-
ai_edge_torch_nightly-0.3.0.
|
225
|
+
ai_edge_torch_nightly-0.3.0.dev20250126.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
226
|
+
ai_edge_torch_nightly-0.3.0.dev20250126.dist-info/METADATA,sha256=ansrwGIKqgtoDIyw39y1VpYIkmVrqMwPqfhRTwp8N5A,1966
|
227
|
+
ai_edge_torch_nightly-0.3.0.dev20250126.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
228
|
+
ai_edge_torch_nightly-0.3.0.dev20250126.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
229
|
+
ai_edge_torch_nightly-0.3.0.dev20250126.dist-info/RECORD,,
|
File without changes
|
File without changes
|