ai-edge-torch-nightly 0.3.0.dev20250116__py3-none-any.whl → 0.3.0.dev20250117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -109,21 +109,21 @@ def convert_signatures(
109
109
 
110
110
  _warn_training_modules(signatures)
111
111
 
112
- def export(*args, **kwargs):
112
+ def export(**kwargs):
113
113
  nonlocal strict_export
114
114
  if strict_export == "auto":
115
115
  try:
116
- exported_program = torch.export.export(*args, **kwargs, strict=True)
116
+ exported_program = torch.export.export(**kwargs, strict=True)
117
117
  except Exception:
118
118
  logging.warning(
119
119
  "torch.export.export(..., strict=True) failed. Retrying with"
120
120
  " strict=False"
121
121
  )
122
- exported_program = torch.export.export(*args, **kwargs, strict=False)
122
+ exported_program = torch.export.export(**kwargs, strict=False)
123
123
  elif not strict_export:
124
- exported_program = torch.export.export(*args, **kwargs, strict=False)
124
+ exported_program = torch.export.export(**kwargs, strict=False)
125
125
  else:
126
- exported_program = torch.export.export(*args, **kwargs, strict=True)
126
+ exported_program = torch.export.export(**kwargs, strict=True)
127
127
 
128
128
  if hasattr(torch._decomp, "_decomp_table_to_post_autograd_aten"):
129
129
  # Available after torch 2.5.0: `_decomp_table_to_post_autograd_aten` is a
@@ -136,7 +136,12 @@ def convert_signatures(
136
136
  return exported_program
137
137
 
138
138
  exported_programs: torch.export.ExportedProgram = [
139
- export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
139
+ export(
140
+ mod=sig.module,
141
+ args=sig.args,
142
+ kwargs=sig.kwargs,
143
+ dynamic_shapes=sig.dynamic_shapes,
144
+ )
140
145
  for sig in signatures
141
146
  ]
142
147
 
@@ -25,9 +25,9 @@ import torch.utils._pytree as pytree
25
25
  class Signature:
26
26
  name: str
27
27
  module: torch.nn.Module
28
- sample_args: tuple[torch.Tensor]
28
+ sample_args: tuple[torch.Tensor, ...]
29
29
  sample_kwargs: dict[str, torch.Tensor]
30
- dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
30
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any, ...]]] = None
31
31
 
32
32
  @property
33
33
  def _normalized_sample_args_kwargs(self):
@@ -61,6 +61,16 @@ class Signature:
61
61
  return names
62
62
 
63
63
  @property
64
- def flat_args(self) -> tuple[Any]:
64
+ def flat_args(self) -> tuple[Any, ...]:
65
65
  args, kwargs = self._normalized_sample_args_kwargs
66
66
  return tuple([*args, *kwargs.values()])
67
+
68
+ @property
69
+ def args(self) -> tuple[Any, ...]:
70
+ args, _ = self._normalized_sample_args_kwargs
71
+ return args
72
+
73
+ @property
74
+ def kwargs(self) -> dict[str, Any]:
75
+ _, kwargs = self._normalized_sample_args_kwargs
76
+ return kwargs
@@ -144,12 +144,13 @@ class Gemma2(nn.Module):
144
144
  attn_config = self.config.block_config(0).attn_config
145
145
  n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
146
146
  rope = rotary_pos_emb.build_rope(input_pos, n_elem, attn_config.rotary_base)
147
- mask = [
148
- self.get_attention_mask(
149
- self.config.block_config(i).attn_config.attn_type, input_pos
150
- )
151
- for i in range(self.config.num_layers)
152
- ]
147
+ if mask is None:
148
+ mask = [
149
+ self.get_attention_mask(
150
+ self.config.block_config(i).attn_config.attn_type, input_pos
151
+ )
152
+ for i in range(self.config.num_layers)
153
+ ]
153
154
 
154
155
  return self._forward_with_embeds(
155
156
  input_embeds, rope, mask, input_pos, kv_cache, export_config
@@ -159,7 +160,7 @@ class Gemma2(nn.Module):
159
160
  self,
160
161
  input_embeds: torch.Tensor,
161
162
  rope: Tuple[torch.Tensor, torch.Tensor],
162
- mask: List[torch.Tensor],
163
+ mask: torch.Tensor | List[torch.Tensor],
163
164
  input_pos: torch.Tensor,
164
165
  kv_cache: kv_utils.KVCache,
165
166
  export_config: Optional[model_builder.ExportConfig] = None,
@@ -174,17 +175,10 @@ class Gemma2(nn.Module):
174
175
  input_embeds = input_embeds * self.config.embedding_scale
175
176
  x = input_embeds
176
177
  updated_kv_entries = []
177
- mask_input = mask is not None
178
178
  for i, block in enumerate(self.transformer_blocks):
179
- mask = (
180
- mask
181
- if mask_input
182
- else self.get_attention_mask(
183
- block.config.attn_config.attn_type, input_pos
184
- )
185
- )
179
+ mask_entry = mask[i] if isinstance(mask, list) else mask
186
180
  kv_entry = kv_cache.caches[i] if kv_cache else None
187
- x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
181
+ x, kv_entry = block(x, rope, mask_entry, input_pos, kv_entry)
188
182
  if kv_entry:
189
183
  updated_kv_entries.append(kv_entry)
190
184
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
@@ -86,7 +86,6 @@ class Decoder2(gemma2.Gemma2):
86
86
  embeds_len = input_embeds.shape[1]
87
87
  mask = torch.zeros(embeds_len, self.config.kv_cache_max)
88
88
  mask[:, embeds_len:] = float("-inf")
89
- mask = [mask] * self.config.num_layers
90
89
 
91
90
  return self._forward_with_embeds(
92
91
  input_embeds, rope, mask, input_pos, kv_cache, export_config
@@ -135,7 +135,9 @@ def get_image_encoder_config() -> cfg.ModelConfig:
135
135
  image_embedding=image_embedding_config,
136
136
  block_configs=block_config,
137
137
  final_norm_config=norm_config,
138
- enable_hlfb=True,
138
+ # TODO: b/377051577 - Once RemoveSDPACompositeZeroMaskPass is removed,
139
+ # enable_hlfb can be set to True. See b/383865404#comment3 for details.
140
+ # enable_hlfb=True,
139
141
  )
140
142
  return config
141
143
 
@@ -232,7 +232,7 @@ class TestModelConversion(googletest.TestCase):
232
232
  ai_edge_torch.config.in_oss,
233
233
  reason="tests with custom ops are not supported in oss",
234
234
  )
235
- def disabled_test_paligemma1(self):
235
+ def test_paligemma1(self):
236
236
  self._test_paligemma_model(
237
237
  decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
238
238
  )
@@ -241,7 +241,7 @@ class TestModelConversion(googletest.TestCase):
241
241
  ai_edge_torch.config.in_oss,
242
242
  reason="tests with custom ops are not supported in oss",
243
243
  )
244
- def disabled_test_paligemma2(self):
244
+ def test_paligemma2(self):
245
245
  self._test_paligemma_model(
246
246
  decoder2.Decoder2,
247
247
  decoder2.get_fake_decoder2_config,
@@ -95,9 +95,9 @@ def _get_states(
95
95
  signatures: list[signature_module.Signature],
96
96
  ):
97
97
  for exported_program, signature in zip(exported_programs, signatures):
98
- args, _ = exported_program.example_inputs
98
+ args, kwargs = exported_program.example_inputs
99
99
  # Calling this to get **all** the state including model buffers.
100
- _flat_input_args = exported_program._graph_module_flat_inputs(args, {})
100
+ _flat_input_args = exported_program._graph_module_flat_inputs(args, kwargs)
101
101
  for tensor, input_spec in zip(
102
102
  _flat_input_args, exported_program.graph_signature.input_specs
103
103
  ):
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250116"
16
+ __version__ = "0.3.0.dev20250117"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250116
3
+ Version: 0.3.0.dev20250117
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,12 +3,12 @@ ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=G-QSz8n-nFkUu7vA6xPORcsrqKhus5kptuw7NkK8Au4,706
6
+ ai_edge_torch/version.py,sha256=xfIDXOyS2Ghdmd-YYXjVjsHuMh4G95I_J1Du3sMIue4,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
- ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
8
+ ai_edge_torch/_convert/conversion.py,sha256=pSDY0CzZQP_jAMjSfQ1O7Ud_AF5ZDeDF-nE3nAu_hoo,5815
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
10
10
  ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
11
- ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
11
+ ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
12
12
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
13
13
  ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
14
14
  ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
@@ -47,7 +47,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
47
47
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
48
48
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
49
49
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
50
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=F1EG_7qPKFXtSgq5D6FYBASyvDX7N6wSQS0qdRWJzMQ,10392
50
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=CMkkTd_vO_Ej1SnmXIB0xqjRoArELOkyJ9uqjilpQeI,10298
51
51
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
52
52
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
53
53
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
@@ -65,8 +65,8 @@ ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3Tif
65
65
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
66
66
  ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=scLsguzzuHfKYDWUd2uZkKYVRzdAbQHLd-kPam8QwvM,3004
67
67
  ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=g0Fbtf9WigOzQij7W1ksUca4eZTwVdCO2RcuFO2GD3M,5439
68
- ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=Z-vTX8dvW_UeAzCuiPOXtTqJiDDW3V1NYhOlqydpYDw,6477
69
- ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
68
+ ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=craPUFxlBniBz9a0Jc7VjK01jROMg5a47xJiEA1brnw,6430
69
+ ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=u4hEHjvLaMu-UnRrISOFXKMEJIMSLa9CfpjjmSIrlSY,5731
70
70
  ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=CEMG9gh51ev1KXPew927a6nfampiXX9bL6m-25tNYN8,6340
71
71
  ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
72
72
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
@@ -145,7 +145,7 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
145
145
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
146
146
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
147
147
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
148
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bBcey-aD4L_TwKRrrM81bN2VQoJjPPC84Rv4o3WOc34,12491
148
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=yzMgXkiZxHUF_xz0UR3kD3x74ELsmJetbQnmv7-9gyQ,12473
149
149
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
150
150
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
151
151
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -166,7 +166,7 @@ ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6
166
166
  ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=-5UqJyk__1YbUNGuxi4b2sn0CED0W-G337AXwxPGdEs,5567
167
167
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
168
168
  ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
169
- ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
169
+ ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5zo3MlX1QYW4c,4513
170
170
  ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
171
171
  ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
172
172
  ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
@@ -206,8 +206,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
206
206
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
207
207
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
208
208
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
209
- ai_edge_torch_nightly-0.3.0.dev20250116.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
210
- ai_edge_torch_nightly-0.3.0.dev20250116.dist-info/METADATA,sha256=BBIWyQ45p6eKiiajHicFpac_J0nPnyAeviNewQm5_pc,1966
211
- ai_edge_torch_nightly-0.3.0.dev20250116.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
212
- ai_edge_torch_nightly-0.3.0.dev20250116.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
213
- ai_edge_torch_nightly-0.3.0.dev20250116.dist-info/RECORD,,
209
+ ai_edge_torch_nightly-0.3.0.dev20250117.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
210
+ ai_edge_torch_nightly-0.3.0.dev20250117.dist-info/METADATA,sha256=5XJm1sJgKeIZBNGARZY0DOmuJB04moEM7GsVarmGwwU,1966
211
+ ai_edge_torch_nightly-0.3.0.dev20250117.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
212
+ ai_edge_torch_nightly-0.3.0.dev20250117.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
213
+ ai_edge_torch_nightly-0.3.0.dev20250117.dist-info/RECORD,,