ai-edge-torch-nightly 0.3.0.dev20250116__py3-none-any.whl → 0.3.0.dev20250117__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_convert/conversion.py +11 -6
- ai_edge_torch/_convert/signature.py +13 -3
- ai_edge_torch/generative/examples/gemma/gemma2.py +10 -16
- ai_edge_torch/generative/examples/paligemma/decoder2.py +0 -1
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +3 -1
- ai_edge_torch/generative/test/test_model_conversion_large.py +2 -2
- ai_edge_torch/lowertools/common_utils.py +2 -2
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250116.dist-info → ai_edge_torch_nightly-0.3.0.dev20250117.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250116.dist-info → ai_edge_torch_nightly-0.3.0.dev20250117.dist-info}/RECORD +13 -13
- {ai_edge_torch_nightly-0.3.0.dev20250116.dist-info → ai_edge_torch_nightly-0.3.0.dev20250117.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250116.dist-info → ai_edge_torch_nightly-0.3.0.dev20250117.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250116.dist-info → ai_edge_torch_nightly-0.3.0.dev20250117.dist-info}/top_level.txt +0 -0
@@ -109,21 +109,21 @@ def convert_signatures(
|
|
109
109
|
|
110
110
|
_warn_training_modules(signatures)
|
111
111
|
|
112
|
-
def export(
|
112
|
+
def export(**kwargs):
|
113
113
|
nonlocal strict_export
|
114
114
|
if strict_export == "auto":
|
115
115
|
try:
|
116
|
-
exported_program = torch.export.export(
|
116
|
+
exported_program = torch.export.export(**kwargs, strict=True)
|
117
117
|
except Exception:
|
118
118
|
logging.warning(
|
119
119
|
"torch.export.export(..., strict=True) failed. Retrying with"
|
120
120
|
" strict=False"
|
121
121
|
)
|
122
|
-
exported_program = torch.export.export(
|
122
|
+
exported_program = torch.export.export(**kwargs, strict=False)
|
123
123
|
elif not strict_export:
|
124
|
-
exported_program = torch.export.export(
|
124
|
+
exported_program = torch.export.export(**kwargs, strict=False)
|
125
125
|
else:
|
126
|
-
exported_program = torch.export.export(
|
126
|
+
exported_program = torch.export.export(**kwargs, strict=True)
|
127
127
|
|
128
128
|
if hasattr(torch._decomp, "_decomp_table_to_post_autograd_aten"):
|
129
129
|
# Available after torch 2.5.0: `_decomp_table_to_post_autograd_aten` is a
|
@@ -136,7 +136,12 @@ def convert_signatures(
|
|
136
136
|
return exported_program
|
137
137
|
|
138
138
|
exported_programs: torch.export.ExportedProgram = [
|
139
|
-
export(
|
139
|
+
export(
|
140
|
+
mod=sig.module,
|
141
|
+
args=sig.args,
|
142
|
+
kwargs=sig.kwargs,
|
143
|
+
dynamic_shapes=sig.dynamic_shapes,
|
144
|
+
)
|
140
145
|
for sig in signatures
|
141
146
|
]
|
142
147
|
|
@@ -25,9 +25,9 @@ import torch.utils._pytree as pytree
|
|
25
25
|
class Signature:
|
26
26
|
name: str
|
27
27
|
module: torch.nn.Module
|
28
|
-
sample_args: tuple[torch.Tensor]
|
28
|
+
sample_args: tuple[torch.Tensor, ...]
|
29
29
|
sample_kwargs: dict[str, torch.Tensor]
|
30
|
-
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
|
30
|
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any, ...]]] = None
|
31
31
|
|
32
32
|
@property
|
33
33
|
def _normalized_sample_args_kwargs(self):
|
@@ -61,6 +61,16 @@ class Signature:
|
|
61
61
|
return names
|
62
62
|
|
63
63
|
@property
|
64
|
-
def flat_args(self) -> tuple[Any]:
|
64
|
+
def flat_args(self) -> tuple[Any, ...]:
|
65
65
|
args, kwargs = self._normalized_sample_args_kwargs
|
66
66
|
return tuple([*args, *kwargs.values()])
|
67
|
+
|
68
|
+
@property
|
69
|
+
def args(self) -> tuple[Any, ...]:
|
70
|
+
args, _ = self._normalized_sample_args_kwargs
|
71
|
+
return args
|
72
|
+
|
73
|
+
@property
|
74
|
+
def kwargs(self) -> dict[str, Any]:
|
75
|
+
_, kwargs = self._normalized_sample_args_kwargs
|
76
|
+
return kwargs
|
@@ -144,12 +144,13 @@ class Gemma2(nn.Module):
|
|
144
144
|
attn_config = self.config.block_config(0).attn_config
|
145
145
|
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
146
146
|
rope = rotary_pos_emb.build_rope(input_pos, n_elem, attn_config.rotary_base)
|
147
|
-
mask
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
147
|
+
if mask is None:
|
148
|
+
mask = [
|
149
|
+
self.get_attention_mask(
|
150
|
+
self.config.block_config(i).attn_config.attn_type, input_pos
|
151
|
+
)
|
152
|
+
for i in range(self.config.num_layers)
|
153
|
+
]
|
153
154
|
|
154
155
|
return self._forward_with_embeds(
|
155
156
|
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
@@ -159,7 +160,7 @@ class Gemma2(nn.Module):
|
|
159
160
|
self,
|
160
161
|
input_embeds: torch.Tensor,
|
161
162
|
rope: Tuple[torch.Tensor, torch.Tensor],
|
162
|
-
mask: List[torch.Tensor],
|
163
|
+
mask: torch.Tensor | List[torch.Tensor],
|
163
164
|
input_pos: torch.Tensor,
|
164
165
|
kv_cache: kv_utils.KVCache,
|
165
166
|
export_config: Optional[model_builder.ExportConfig] = None,
|
@@ -174,17 +175,10 @@ class Gemma2(nn.Module):
|
|
174
175
|
input_embeds = input_embeds * self.config.embedding_scale
|
175
176
|
x = input_embeds
|
176
177
|
updated_kv_entries = []
|
177
|
-
mask_input = mask is not None
|
178
178
|
for i, block in enumerate(self.transformer_blocks):
|
179
|
-
|
180
|
-
mask
|
181
|
-
if mask_input
|
182
|
-
else self.get_attention_mask(
|
183
|
-
block.config.attn_config.attn_type, input_pos
|
184
|
-
)
|
185
|
-
)
|
179
|
+
mask_entry = mask[i] if isinstance(mask, list) else mask
|
186
180
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
187
|
-
x, kv_entry = block(x, rope,
|
181
|
+
x, kv_entry = block(x, rope, mask_entry, input_pos, kv_entry)
|
188
182
|
if kv_entry:
|
189
183
|
updated_kv_entries.append(kv_entry)
|
190
184
|
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
@@ -86,7 +86,6 @@ class Decoder2(gemma2.Gemma2):
|
|
86
86
|
embeds_len = input_embeds.shape[1]
|
87
87
|
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
88
88
|
mask[:, embeds_len:] = float("-inf")
|
89
|
-
mask = [mask] * self.config.num_layers
|
90
89
|
|
91
90
|
return self._forward_with_embeds(
|
92
91
|
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
@@ -135,7 +135,9 @@ def get_image_encoder_config() -> cfg.ModelConfig:
|
|
135
135
|
image_embedding=image_embedding_config,
|
136
136
|
block_configs=block_config,
|
137
137
|
final_norm_config=norm_config,
|
138
|
-
|
138
|
+
# TODO: b/377051577 - Once RemoveSDPACompositeZeroMaskPass is removed,
|
139
|
+
# enable_hlfb can be set to True. See b/383865404#comment3 for details.
|
140
|
+
# enable_hlfb=True,
|
139
141
|
)
|
140
142
|
return config
|
141
143
|
|
@@ -232,7 +232,7 @@ class TestModelConversion(googletest.TestCase):
|
|
232
232
|
ai_edge_torch.config.in_oss,
|
233
233
|
reason="tests with custom ops are not supported in oss",
|
234
234
|
)
|
235
|
-
def
|
235
|
+
def test_paligemma1(self):
|
236
236
|
self._test_paligemma_model(
|
237
237
|
decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
|
238
238
|
)
|
@@ -241,7 +241,7 @@ class TestModelConversion(googletest.TestCase):
|
|
241
241
|
ai_edge_torch.config.in_oss,
|
242
242
|
reason="tests with custom ops are not supported in oss",
|
243
243
|
)
|
244
|
-
def
|
244
|
+
def test_paligemma2(self):
|
245
245
|
self._test_paligemma_model(
|
246
246
|
decoder2.Decoder2,
|
247
247
|
decoder2.get_fake_decoder2_config,
|
@@ -95,9 +95,9 @@ def _get_states(
|
|
95
95
|
signatures: list[signature_module.Signature],
|
96
96
|
):
|
97
97
|
for exported_program, signature in zip(exported_programs, signatures):
|
98
|
-
args,
|
98
|
+
args, kwargs = exported_program.example_inputs
|
99
99
|
# Calling this to get **all** the state including model buffers.
|
100
|
-
_flat_input_args = exported_program._graph_module_flat_inputs(args,
|
100
|
+
_flat_input_args = exported_program._graph_module_flat_inputs(args, kwargs)
|
101
101
|
for tensor, input_spec in zip(
|
102
102
|
_flat_input_args, exported_program.graph_signature.input_specs
|
103
103
|
):
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250117
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,12 +3,12 @@ ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=xfIDXOyS2Ghdmd-YYXjVjsHuMh4G95I_J1Du3sMIue4,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
8
|
+
ai_edge_torch/_convert/conversion.py,sha256=pSDY0CzZQP_jAMjSfQ1O7Ud_AF5ZDeDF-nE3nAu_hoo,5815
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
10
10
|
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
11
|
-
ai_edge_torch/_convert/signature.py,sha256
|
11
|
+
ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
|
12
12
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
13
13
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
|
14
14
|
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
|
@@ -47,7 +47,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
|
|
47
47
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
|
48
48
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
|
49
49
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
|
50
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
50
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=CMkkTd_vO_Ej1SnmXIB0xqjRoArELOkyJ9uqjilpQeI,10298
|
51
51
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
52
52
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
53
53
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
@@ -65,8 +65,8 @@ ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3Tif
|
|
65
65
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
66
66
|
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=scLsguzzuHfKYDWUd2uZkKYVRzdAbQHLd-kPam8QwvM,3004
|
67
67
|
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=g0Fbtf9WigOzQij7W1ksUca4eZTwVdCO2RcuFO2GD3M,5439
|
68
|
-
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=
|
69
|
-
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=
|
68
|
+
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=craPUFxlBniBz9a0Jc7VjK01jROMg5a47xJiEA1brnw,6430
|
69
|
+
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=u4hEHjvLaMu-UnRrISOFXKMEJIMSLa9CfpjjmSIrlSY,5731
|
70
70
|
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=CEMG9gh51ev1KXPew927a6nfampiXX9bL6m-25tNYN8,6340
|
71
71
|
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
|
72
72
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
@@ -145,7 +145,7 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
|
|
145
145
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
146
146
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
147
147
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
|
148
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
148
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=yzMgXkiZxHUF_xz0UR3kD3x74ELsmJetbQnmv7-9gyQ,12473
|
149
149
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
150
150
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
151
151
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
@@ -166,7 +166,7 @@ ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6
|
|
166
166
|
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=-5UqJyk__1YbUNGuxi4b2sn0CED0W-G337AXwxPGdEs,5567
|
167
167
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
168
168
|
ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
|
169
|
-
ai_edge_torch/lowertools/common_utils.py,sha256=
|
169
|
+
ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5zo3MlX1QYW4c,4513
|
170
170
|
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
|
171
171
|
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
172
172
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
|
@@ -206,8 +206,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
206
206
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
207
207
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
208
208
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
209
|
-
ai_edge_torch_nightly-0.3.0.
|
210
|
-
ai_edge_torch_nightly-0.3.0.
|
211
|
-
ai_edge_torch_nightly-0.3.0.
|
212
|
-
ai_edge_torch_nightly-0.3.0.
|
213
|
-
ai_edge_torch_nightly-0.3.0.
|
209
|
+
ai_edge_torch_nightly-0.3.0.dev20250117.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
210
|
+
ai_edge_torch_nightly-0.3.0.dev20250117.dist-info/METADATA,sha256=5XJm1sJgKeIZBNGARZY0DOmuJB04moEM7GsVarmGwwU,1966
|
211
|
+
ai_edge_torch_nightly-0.3.0.dev20250117.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
212
|
+
ai_edge_torch_nightly-0.3.0.dev20250117.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
213
|
+
ai_edge_torch_nightly-0.3.0.dev20250117.dist-info/RECORD,,
|
File without changes
|
File without changes
|