ai-edge-torch-nightly 0.3.0.dev20250116__py3-none-any.whl → 0.3.0.dev20250117__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -109,21 +109,21 @@ def convert_signatures(
109
109
 
110
110
  _warn_training_modules(signatures)
111
111
 
112
- def export(*args, **kwargs):
112
+ def export(**kwargs):
113
113
  nonlocal strict_export
114
114
  if strict_export == "auto":
115
115
  try:
116
- exported_program = torch.export.export(*args, **kwargs, strict=True)
116
+ exported_program = torch.export.export(**kwargs, strict=True)
117
117
  except Exception:
118
118
  logging.warning(
119
119
  "torch.export.export(..., strict=True) failed. Retrying with"
120
120
  " strict=False"
121
121
  )
122
- exported_program = torch.export.export(*args, **kwargs, strict=False)
122
+ exported_program = torch.export.export(**kwargs, strict=False)
123
123
  elif not strict_export:
124
- exported_program = torch.export.export(*args, **kwargs, strict=False)
124
+ exported_program = torch.export.export(**kwargs, strict=False)
125
125
  else:
126
- exported_program = torch.export.export(*args, **kwargs, strict=True)
126
+ exported_program = torch.export.export(**kwargs, strict=True)
127
127
 
128
128
  if hasattr(torch._decomp, "_decomp_table_to_post_autograd_aten"):
129
129
  # Available after torch 2.5.0: `_decomp_table_to_post_autograd_aten` is a
@@ -136,7 +136,12 @@ def convert_signatures(
136
136
  return exported_program
137
137
 
138
138
  exported_programs: torch.export.ExportedProgram = [
139
- export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
139
+ export(
140
+ mod=sig.module,
141
+ args=sig.args,
142
+ kwargs=sig.kwargs,
143
+ dynamic_shapes=sig.dynamic_shapes,
144
+ )
140
145
  for sig in signatures
141
146
  ]
142
147
 
@@ -25,9 +25,9 @@ import torch.utils._pytree as pytree
25
25
  class Signature:
26
26
  name: str
27
27
  module: torch.nn.Module
28
- sample_args: tuple[torch.Tensor]
28
+ sample_args: tuple[torch.Tensor, ...]
29
29
  sample_kwargs: dict[str, torch.Tensor]
30
- dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
30
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any, ...]]] = None
31
31
 
32
32
  @property
33
33
  def _normalized_sample_args_kwargs(self):
@@ -61,6 +61,16 @@ class Signature:
61
61
  return names
62
62
 
63
63
  @property
64
- def flat_args(self) -> tuple[Any]:
64
+ def flat_args(self) -> tuple[Any, ...]:
65
65
  args, kwargs = self._normalized_sample_args_kwargs
66
66
  return tuple([*args, *kwargs.values()])
67
+
68
+ @property
69
+ def args(self) -> tuple[Any, ...]:
70
+ args, _ = self._normalized_sample_args_kwargs
71
+ return args
72
+
73
+ @property
74
+ def kwargs(self) -> dict[str, Any]:
75
+ _, kwargs = self._normalized_sample_args_kwargs
76
+ return kwargs
@@ -144,12 +144,13 @@ class Gemma2(nn.Module):
144
144
  attn_config = self.config.block_config(0).attn_config
145
145
  n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
146
146
  rope = rotary_pos_emb.build_rope(input_pos, n_elem, attn_config.rotary_base)
147
- mask = [
148
- self.get_attention_mask(
149
- self.config.block_config(i).attn_config.attn_type, input_pos
150
- )
151
- for i in range(self.config.num_layers)
152
- ]
147
+ if mask is None:
148
+ mask = [
149
+ self.get_attention_mask(
150
+ self.config.block_config(i).attn_config.attn_type, input_pos
151
+ )
152
+ for i in range(self.config.num_layers)
153
+ ]
153
154
 
154
155
  return self._forward_with_embeds(
155
156
  input_embeds, rope, mask, input_pos, kv_cache, export_config
@@ -159,7 +160,7 @@ class Gemma2(nn.Module):
159
160
  self,
160
161
  input_embeds: torch.Tensor,
161
162
  rope: Tuple[torch.Tensor, torch.Tensor],
162
- mask: List[torch.Tensor],
163
+ mask: torch.Tensor | List[torch.Tensor],
163
164
  input_pos: torch.Tensor,
164
165
  kv_cache: kv_utils.KVCache,
165
166
  export_config: Optional[model_builder.ExportConfig] = None,
@@ -174,17 +175,10 @@ class Gemma2(nn.Module):
174
175
  input_embeds = input_embeds * self.config.embedding_scale
175
176
  x = input_embeds
176
177
  updated_kv_entries = []
177
- mask_input = mask is not None
178
178
  for i, block in enumerate(self.transformer_blocks):
179
- mask = (
180
- mask
181
- if mask_input
182
- else self.get_attention_mask(
183
- block.config.attn_config.attn_type, input_pos
184
- )
185
- )
179
+ mask_entry = mask[i] if isinstance(mask, list) else mask
186
180
  kv_entry = kv_cache.caches[i] if kv_cache else None
187
- x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
181
+ x, kv_entry = block(x, rope, mask_entry, input_pos, kv_entry)
188
182
  if kv_entry:
189
183
  updated_kv_entries.append(kv_entry)
190
184
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
@@ -86,7 +86,6 @@ class Decoder2(gemma2.Gemma2):
86
86
  embeds_len = input_embeds.shape[1]
87
87
  mask = torch.zeros(embeds_len, self.config.kv_cache_max)
88
88
  mask[:, embeds_len:] = float("-inf")
89
- mask = [mask] * self.config.num_layers
90
89
 
91
90
  return self._forward_with_embeds(
92
91
  input_embeds, rope, mask, input_pos, kv_cache, export_config
@@ -135,7 +135,9 @@ def get_image_encoder_config() -> cfg.ModelConfig:
135
135
  image_embedding=image_embedding_config,
136
136
  block_configs=block_config,
137
137
  final_norm_config=norm_config,
138
- enable_hlfb=True,
138
+ # TODO: b/377051577 - Once RemoveSDPACompositeZeroMaskPass is removed,
139
+ # enable_hlfb can be set to True. See b/383865404#comment3 for details.
140
+ # enable_hlfb=True,
139
141
  )
140
142
  return config
141
143
 
@@ -232,7 +232,7 @@ class TestModelConversion(googletest.TestCase):
232
232
  ai_edge_torch.config.in_oss,
233
233
  reason="tests with custom ops are not supported in oss",
234
234
  )
235
- def disabled_test_paligemma1(self):
235
+ def test_paligemma1(self):
236
236
  self._test_paligemma_model(
237
237
  decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
238
238
  )
@@ -241,7 +241,7 @@ class TestModelConversion(googletest.TestCase):
241
241
  ai_edge_torch.config.in_oss,
242
242
  reason="tests with custom ops are not supported in oss",
243
243
  )
244
- def disabled_test_paligemma2(self):
244
+ def test_paligemma2(self):
245
245
  self._test_paligemma_model(
246
246
  decoder2.Decoder2,
247
247
  decoder2.get_fake_decoder2_config,
@@ -95,9 +95,9 @@ def _get_states(
95
95
  signatures: list[signature_module.Signature],
96
96
  ):
97
97
  for exported_program, signature in zip(exported_programs, signatures):
98
- args, _ = exported_program.example_inputs
98
+ args, kwargs = exported_program.example_inputs
99
99
  # Calling this to get **all** the state including model buffers.
100
- _flat_input_args = exported_program._graph_module_flat_inputs(args, {})
100
+ _flat_input_args = exported_program._graph_module_flat_inputs(args, kwargs)
101
101
  for tensor, input_spec in zip(
102
102
  _flat_input_args, exported_program.graph_signature.input_specs
103
103
  ):
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250116"
16
+ __version__ = "0.3.0.dev20250117"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250116
3
+ Version: 0.3.0.dev20250117
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,12 +3,12 @@ ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=G-QSz8n-nFkUu7vA6xPORcsrqKhus5kptuw7NkK8Au4,706
6
+ ai_edge_torch/version.py,sha256=xfIDXOyS2Ghdmd-YYXjVjsHuMh4G95I_J1Du3sMIue4,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
- ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
8
+ ai_edge_torch/_convert/conversion.py,sha256=pSDY0CzZQP_jAMjSfQ1O7Ud_AF5ZDeDF-nE3nAu_hoo,5815
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
10
10
  ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
11
- ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
11
+ ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
12
12
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
13
13
  ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
14
14
  ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
@@ -47,7 +47,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
47
47
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
48
48
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
49
49
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
50
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=F1EG_7qPKFXtSgq5D6FYBASyvDX7N6wSQS0qdRWJzMQ,10392
50
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=CMkkTd_vO_Ej1SnmXIB0xqjRoArELOkyJ9uqjilpQeI,10298
51
51
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
52
52
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
53
53
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
@@ -65,8 +65,8 @@ ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3Tif
65
65
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
66
66
  ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=scLsguzzuHfKYDWUd2uZkKYVRzdAbQHLd-kPam8QwvM,3004
67
67
  ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=g0Fbtf9WigOzQij7W1ksUca4eZTwVdCO2RcuFO2GD3M,5439
68
- ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=Z-vTX8dvW_UeAzCuiPOXtTqJiDDW3V1NYhOlqydpYDw,6477
69
- ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
68
+ ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=craPUFxlBniBz9a0Jc7VjK01jROMg5a47xJiEA1brnw,6430
69
+ ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=u4hEHjvLaMu-UnRrISOFXKMEJIMSLa9CfpjjmSIrlSY,5731
70
70
  ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=CEMG9gh51ev1KXPew927a6nfampiXX9bL6m-25tNYN8,6340
71
71
  ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
72
72
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
@@ -145,7 +145,7 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
145
145
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
146
146
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
147
147
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
148
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bBcey-aD4L_TwKRrrM81bN2VQoJjPPC84Rv4o3WOc34,12491
148
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=yzMgXkiZxHUF_xz0UR3kD3x74ELsmJetbQnmv7-9gyQ,12473
149
149
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
150
150
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
151
151
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -166,7 +166,7 @@ ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6
166
166
  ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=-5UqJyk__1YbUNGuxi4b2sn0CED0W-G337AXwxPGdEs,5567
167
167
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
168
168
  ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
169
- ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
169
+ ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5zo3MlX1QYW4c,4513
170
170
  ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
171
171
  ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
172
172
  ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
@@ -206,8 +206,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
206
206
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
207
207
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
208
208
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
209
- ai_edge_torch_nightly-0.3.0.dev20250116.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
210
- ai_edge_torch_nightly-0.3.0.dev20250116.dist-info/METADATA,sha256=BBIWyQ45p6eKiiajHicFpac_J0nPnyAeviNewQm5_pc,1966
211
- ai_edge_torch_nightly-0.3.0.dev20250116.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
212
- ai_edge_torch_nightly-0.3.0.dev20250116.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
213
- ai_edge_torch_nightly-0.3.0.dev20250116.dist-info/RECORD,,
209
+ ai_edge_torch_nightly-0.3.0.dev20250117.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
210
+ ai_edge_torch_nightly-0.3.0.dev20250117.dist-info/METADATA,sha256=5XJm1sJgKeIZBNGARZY0DOmuJB04moEM7GsVarmGwwU,1966
211
+ ai_edge_torch_nightly-0.3.0.dev20250117.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
212
+ ai_edge_torch_nightly-0.3.0.dev20250117.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
213
+ ai_edge_torch_nightly-0.3.0.dev20250117.dist-info/RECORD,,