ai-edge-torch-nightly 0.5.0.dev20250502__py3-none-any.whl → 0.5.0.dev20250503__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -199,7 +199,11 @@ class Decoder(nn.Module):
199
199
  sliding_mask = torch.where(
200
200
  sliding_mask_bool,
201
201
  torch.zeros_like(sliding_mask_bool, dtype=torch.float),
202
- torch.full_like(sliding_mask_bool, float("-inf"), dtype=torch.float),
202
+ torch.full_like(
203
+ sliding_mask_bool,
204
+ self.config.get_causal_mask_value(),
205
+ dtype=torch.float,
206
+ ),
203
207
  )
204
208
 
205
209
  return sliding_mask
@@ -215,7 +219,7 @@ class Decoder(nn.Module):
215
219
  mask = torch.logical_and(mask, pixel_mask)
216
220
  else:
217
221
  mask = torch.logical_or(mask, pixel_mask)
218
- mask = torch.where(mask, 0, float("-inf"))
222
+ mask = torch.where(mask, 0, self.config.get_causal_mask_value())
219
223
  return mask
220
224
 
221
225
  def build_pixel_mask(self, image_indices: torch.Tensor):
@@ -17,15 +17,10 @@
17
17
 
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.hammer import hammer
20
- from ai_edge_torch.generative.layers import kv_cache
21
20
  from ai_edge_torch.generative.utilities import converter
22
- from ai_edge_torch.generative.utilities import export_config as export_cfg
23
- import torch
24
-
21
+ from ai_edge_torch.generative.utilities import export_config
25
22
 
26
23
  flags = converter.define_conversion_flags('hammer')
27
- ExportConfig = export_cfg.ExportConfig
28
-
29
24
 
30
25
  _MODEL_SIZE = flags.DEFINE_enum(
31
26
  'model_size',
@@ -40,35 +35,6 @@ _BUILDER = {
40
35
  }
41
36
 
42
37
 
43
- def _create_mask(mask_len, kv_cache_max_len):
44
- mask = torch.full(
45
- (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
46
- )
47
- mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
48
- return mask
49
-
50
-
51
- def _create_export_config(
52
- prefill_seq_lens: list[int], kv_cache_max_len: int
53
- ) -> ExportConfig:
54
- """Creates the export config for the model."""
55
- export_config = ExportConfig()
56
- if isinstance(prefill_seq_lens, list):
57
- prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
58
- else:
59
- prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
60
-
61
- export_config.prefill_mask = prefill_mask
62
-
63
- decode_mask = torch.full(
64
- (1, kv_cache_max_len), float('-inf'), dtype=torch.float32
65
- )
66
- decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
67
- export_config.decode_mask = decode_mask
68
- export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
69
- return export_config
70
-
71
-
72
38
  def main(_):
73
39
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
74
40
  flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
@@ -80,11 +46,7 @@ def main(_):
80
46
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
81
47
  quantize=flags.FLAGS.quantize,
82
48
  lora_ranks=flags.FLAGS.lora_ranks,
83
- export_config=_create_export_config(
84
- flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
85
- )
86
- if flags.FLAGS.transpose_kv_cache
87
- else ExportConfig(),
49
+ export_config=export_config.get_from_flags(),
88
50
  )
89
51
 
90
52
 
@@ -75,7 +75,7 @@ class Decoder(model_builder.DecoderOnlyModel):
75
75
  if mask is None:
76
76
  embeds_len = input_embeds.shape[1]
77
77
  mask = torch.zeros(embeds_len, self.config.kv_cache_max)
78
- mask[:, embeds_len:] = float("-inf")
78
+ mask[:, embeds_len:] = attn_config.causal_mask_value
79
79
 
80
80
  return self._forward_with_embeds(
81
81
  input_embeds,
@@ -75,7 +75,7 @@ class Decoder2(gemma2.Gemma2):
75
75
  # By default, don't mask image embeds with a diagonal causal mask.
76
76
  embeds_len = input_embeds.shape[1]
77
77
  mask = torch.zeros(embeds_len, self.config.kv_cache_max)
78
- mask[:, embeds_len:] = float("-inf")
78
+ mask[:, embeds_len:] = attn_config.causal_mask_value
79
79
 
80
80
  return self._forward_with_embeds(
81
81
  input_embeds, rope, mask, input_pos, kv_cache, export_config
@@ -116,6 +116,8 @@ class AttentionConfig:
116
116
  attn_type: Optional[AttentionType] = None
117
117
  # The size of the sliding window used for local attention.
118
118
  sliding_window_size: Optional[int] = None
119
+ # The default causal mask value used by attention layer.
120
+ causal_mask_value: float = float("-inf")
119
121
 
120
122
 
121
123
  @dataclasses.dataclass
@@ -247,3 +249,7 @@ class ModelConfig:
247
249
  f"Index {idx} is out of range for layer configs: {self.block_configs}"
248
250
  )
249
251
  return self.block_configs[idx]
252
+
253
+ @property
254
+ def get_causal_mask_value(self) -> float:
255
+ return self.block_config(0).attn_config.causal_mask_value
@@ -160,7 +160,7 @@ def scaled_dot_product_attention_transposed(
160
160
  Args:
161
161
  query: Query tensor, with shape [B, T, N, H].
162
162
  key: Key tensor, with shape [B, T, KV_LEN, H].
163
- value: Value tensor, with shape [B, T, KV_LEN, H].
163
+ value: Value tensor, with shape [B, T, H, KV_LEN].
164
164
  head_size (int): head dimension.
165
165
  mask (torch.Tensor): the optional mask tensor.
166
166
  scale (float): the optional scale factor.
@@ -0,0 +1,87 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from ai_edge_torch import odml_torch
17
+ from ai_edge_torch.generative.layers import scaled_dot_product_attention
18
+ import torch
19
+
20
+ from absl.testing import absltest as googletest
21
+
22
+
23
+ class ScaledDotProductAttentionTest(googletest.TestCase):
24
+
25
+ def test_scaled_dot_product_attention(self):
26
+ query = torch.randn(1, 16, 16, 128, dtype=torch.float32)
27
+ key = torch.randn(1, 16, 16, 128, dtype=torch.float32)
28
+ value = torch.randn(1, 16, 16, 128, dtype=torch.float32)
29
+ mask = torch.ones((1, 1, 1, 16), dtype=torch.float32)
30
+ output = scaled_dot_product_attention.scaled_dot_product_attention(
31
+ query, key, value, head_size=128, mask=mask, scale=1.0, softcap=10.0
32
+ )
33
+ self.assertEqual(output.shape, (1, 16, 16, 128))
34
+
35
+ def test_scaled_dot_product_attention_transposed(self):
36
+ query = torch.randn(1, 16, 16, 128, dtype=torch.float32)
37
+ key = torch.randn(1, 16, 16, 128, dtype=torch.float32)
38
+ value = torch.randn(1, 16, 128, 16, dtype=torch.float32)
39
+ mask = torch.ones((1, 1, 1, 16), dtype=torch.float32)
40
+ output = (
41
+ scaled_dot_product_attention.scaled_dot_product_attention_transposed(
42
+ query, key, value, head_size=128, mask=mask, scale=1.0, softcap=10.0
43
+ )
44
+ )
45
+ self.assertEqual(output.shape, (1, 16, 16, 128))
46
+
47
+ def test_scaled_dot_product_attention_with_hlfb(self):
48
+ query = torch.randn(1, 16, 16, 128, dtype=torch.float32)
49
+ key = torch.randn(1, 16, 16, 128, dtype=torch.float32)
50
+ value = torch.randn(1, 16, 16, 128, dtype=torch.float32)
51
+ mask = torch.ones((1, 1, 1, 16), dtype=torch.float32)
52
+ output = (
53
+ scaled_dot_product_attention.scaled_dot_product_attention_with_hlfb(
54
+ query, key, value, head_size=128, mask=mask, scale=1.0, softcap=10.0
55
+ )
56
+ )
57
+ self.assertEqual(output.shape, (1, 16, 16, 128))
58
+
59
+ def model_to_mlir(model, args):
60
+ ep = torch.export.export(model, args)
61
+ mlir = odml_torch.export.exported_program_to_mlir(ep)
62
+ return mlir.get_text()
63
+
64
+ class SDPAModule(torch.nn.Module):
65
+
66
+ def __init__(self):
67
+ super().__init__()
68
+
69
+ def forward(self, query, key, value, mask):
70
+ return (
71
+ scaled_dot_product_attention.scaled_dot_product_attention_with_hlfb(
72
+ query,
73
+ key,
74
+ value,
75
+ head_size=128,
76
+ mask=mask,
77
+ scale=1.0,
78
+ softcap=10.0,
79
+ )
80
+ )
81
+
82
+ ir_text = model_to_mlir(SDPAModule().eval(), (query, key, value, mask))
83
+ self.assertEqual(ir_text.count("stablehlo.custom_call @mark_tensor"), 5)
84
+
85
+
86
+ if __name__ == "__main__":
87
+ googletest.main()
@@ -95,6 +95,18 @@ def define_conversion_flags(model_name: str):
95
95
  return flags
96
96
 
97
97
 
98
+ def _build_mask(mask_len, kv_cache_max_len, causal_mask_value) -> torch.Tensor:
99
+ if isinstance(mask_len, list):
100
+ return [
101
+ _build_mask(i, kv_cache_max_len, causal_mask_value) for i in mask_len
102
+ ]
103
+
104
+ mask = torch.full(
105
+ (mask_len, kv_cache_max_len), causal_mask_value, dtype=torch.float32
106
+ )
107
+ return torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
108
+
109
+
98
110
  def convert_to_tflite(
99
111
  pytorch_model: torch.nn.Module,
100
112
  output_path: str,
@@ -229,14 +241,15 @@ def _export_helper(
229
241
  torch.arange(0, seq_len + pixel_seq_len, dtype=torch.int)
230
242
  )
231
243
 
232
- if export_config.prefill_mask is None:
233
- prefill_masks = None
234
- elif isinstance(export_config.prefill_mask, torch.Tensor):
235
- prefill_masks = [export_config.prefill_mask]
236
- elif isinstance(export_config.prefill_mask, list):
237
- prefill_masks = export_config.prefill_mask
238
- else:
239
- raise ValueError('Prefill masks unrecognized.')
244
+ prefill_masks = None
245
+ if flags.FLAGS.mask_as_input:
246
+ prefill_masks = [
247
+ _build_mask(
248
+ flags.FLAGS.prefill_seq_lens,
249
+ flags.FLAGS.kv_cache_max_len,
250
+ config.get_causal_mask_value(),
251
+ )
252
+ ]
240
253
 
241
254
  if prefill_masks:
242
255
  assert len(prefill_masks) == len(prefill_seq_lens)
@@ -299,8 +312,17 @@ def _export_helper(
299
312
  'input_pos': decode_input_pos,
300
313
  'kv_cache': decode_kv,
301
314
  }
302
- if export_config.decode_mask is not None:
303
- sample_kwargs['mask'] = export_config.decode_mask
315
+ if flags.FLAGS.mask_as_input:
316
+ # Note that the decode mask is not a correct causal mask, but it is okay
317
+ # for the conversion purpose because only the shape matters in conversion.
318
+ # A correct causal mask of decode for a given token position of decode, it
319
+ # should be built like:
320
+ #
321
+ # torch.triu(mask, diagonal=decode_position).unsqueeze(0).unsqueeze(0)
322
+ #
323
+ sample_kwargs['mask'] = _build_mask(
324
+ 1, flags.FLAGS.kv_cache_max_len, config.get_causal_mask_value()
325
+ )
304
326
  if lora is not None:
305
327
  sample_kwargs['lora'] = lora
306
328
 
@@ -33,6 +33,8 @@ class ExportConfig:
33
33
  # When False, only decode signatures will produce output.
34
34
  output_logits_on_prefill: bool = False
35
35
  # Attention masks given as inputs to the model.
36
+ # Note that `prefill_mask`, `decode_mask`, and `kvcache_cls` are deprecated
37
+ # and will be removed in a future version.
36
38
  prefill_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
37
39
  decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
38
40
  # The KV Cache layout for K and V buffers in attention.
@@ -43,33 +45,10 @@ class ExportConfig:
43
45
  decode_batch_size: int = 1
44
46
 
45
47
 
46
- def _build_mask(mask_len, kv_cache_max_len) -> torch.Tensor:
47
- if isinstance(mask_len, list):
48
- return [_build_mask(i, kv_cache_max_len) for i in mask_len]
49
-
50
- mask = torch.full(
51
- (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
52
- )
53
- return torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
54
-
55
-
56
48
  def get_from_flags() -> ExportConfig:
57
49
  """Builds an export config according to the commandline flags."""
58
50
  export_config = ExportConfig()
59
51
 
60
- if flags.FLAGS.mask_as_input:
61
- export_config.prefill_mask = _build_mask(
62
- flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
63
- )
64
- # Note that the decode mask is not a correct causal mask, but it is okay
65
- # for the conversion purpose because only the shape matters in conversion.
66
- # A correct causal mask of decode for a given token position of decode, it
67
- # should be built like:
68
- #
69
- # torch.triu(mask, diagonal=decode_position).unsqueeze(0).unsqueeze(0)
70
- #
71
- export_config.decode_mask = _build_mask(1, flags.FLAGS.kv_cache_max_len)
72
-
73
52
  if flags.FLAGS.transpose_kv_cache:
74
53
  export_config.kvcache_layout = kv_utils.KV_LAYOUT_TRANSPOSED
75
54
 
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250502"
16
+ __version__ = "0.5.0.dev20250503"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250502
3
+ Version: 0.5.0.dev20250503
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
5
- ai_edge_torch/version.py,sha256=xg9p8FPpu9FCdvOFGsZZCz9bNlzIBYx6sFmrLF_JO7s,706
5
+ ai_edge_torch/version.py,sha256=rjhPV_Qh8FDlHQTy8wAJvuXSNGcntZerhf-8FTEjuWI,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -66,13 +66,13 @@ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=jhiyinOqPt5ZZjEa
66
66
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=n7f2nF6Lin_tDvPs0JVldsuaBzo7pAwi5YAHAhlIxQg,6139
67
67
  ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
68
68
  ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=MjkQDVynaw9C5z9ODzKfb85xW5JfxHUWBJ_Aco05FHo,1760
69
- ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=eXWE5CSX0KeUMsPevgsYOfvyajl9F1RFF4DCWhHcYPA,15646
69
+ ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=fzLpuJO5JseQLA38Li-i9Xdnh9I4zdBWQEOeNbUEfjI,15737
70
70
  ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
71
71
  ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
72
72
  ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
73
73
  ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=KnE9ME3mrpQkAxFlBOJLsqcQkjsdDL1ClNhJahX5K5I,8960
74
74
  ai_edge_torch/generative/examples/hammer/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
75
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py,sha256=946mchDmvUhMsv1kzslp4LHtCIuHn4qjimHYQ-XnxMo,2962
75
+ ai_edge_torch/generative/examples/hammer/convert_to_tflite.py,sha256=9r8LXyaoBXYIIhhe1WQgEIjaxALQPE1dO2N6qopyWCk,1753
76
76
  ai_edge_torch/generative/examples/hammer/hammer.py,sha256=76INcjffvaNCQ02fzXcxJUW_6EKHs4sg3q1nDBbEpHE,3431
77
77
  ai_edge_torch/generative/examples/hammer/verify.py,sha256=MkzAGkbPy4LKRhyCDm1cw-9jUt4VUxLPdwK_25fCGSE,2705
78
78
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -88,8 +88,8 @@ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2u
88
88
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=4W26ZtPF5Cb9mpHYuRM4b2QB_4W76zf4WV36KzexVjs,2446
89
89
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
90
90
  ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=7HHXkC-IIu7ieBvBI4RlXs_oITz7R8a6YVYQskAs_Uk,2023
91
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=G1dwtWp_v77AI3uyIY-8g6qRP2tRH3CIKjJTeYNqFPU,5511
92
- ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=Z-SKdb0dd8uWT1d-FRwFx5-tJEqpdrQwiIZnFRhOtVo,6060
91
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=-EYUZp55dfRY1E-N0Pr3b9i5c7Tt1XvYxvsRixguVS8,5527
92
+ ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=WB8r-e_Crog1ItBq3Zse_nUG-foFyBcJsuEG26r_Ji8,6076
93
93
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=SvuR97sjkBtfkerH7Hu1UXB8kCFLpEATNbPfCbNAyfo,5614
94
94
  ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=CFIjOmrn4a4Udki7l3im0JR4zTC_NttnsIr9_qWjKTY,6110
95
95
  ai_edge_torch/generative/examples/paligemma/verify.py,sha256=zrCNz_QSQU6BbaFtx-J-MqxXWcNlsAlquaHpKodsyW4,5350
@@ -162,10 +162,11 @@ ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-m
162
162
  ai_edge_torch/generative/layers/feed_forward_test.py,sha256=8ZGy79BBpsyS6yKKDEKrDt249G5Mz-8VKWW7_WHx0u4,1655
163
163
  ai_edge_torch/generative/layers/kv_cache.py,sha256=b-7shzDaKexmvQF7P3SiAmIz4ZofjYWv3m5u71GojsA,10460
164
164
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
165
- ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
165
+ ai_edge_torch/generative/layers/model_config.py,sha256=dRZUMa71ADaEllu7TfXUWTMHRCcMgvkFMYMzmeJi4G8,8576
166
166
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
167
167
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
168
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=efqqGRZPJ55hKn1MQJ-cXfrJD85uS1v7W_juyGyts58,5648
168
+ ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=2_AgwENsaOgaxgiSqgoj0V0JzQ09dFtP_nBhX-lJK2g,5648
169
+ ai_edge_torch/generative/layers/scaled_dot_product_attention_test.py,sha256=c6JBMQsq9XeMmR1XvGEIidNsoh-YIvichXo2LwVHgr4,3301
169
170
  ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=iw7D_46CFe9iRvU0UumbkIoqWQEhDroxm9ABcK-CLlM,3600
170
171
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
171
172
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
@@ -188,8 +189,8 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hP
188
189
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
189
190
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
190
191
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
191
- ai_edge_torch/generative/utilities/converter.py,sha256=8A1MvU8SbJQkn2SIhF-73TXbI_i6nrloCdkpw83P2xQ,10953
192
- ai_edge_torch/generative/utilities/export_config.py,sha256=yGkfdN8Qrp8b_K8e5H0qaYmDrg0Dx_eb75JLhOnlygQ,2827
192
+ ai_edge_torch/generative/utilities/converter.py,sha256=K1gZWPq5f3Z7f9USeJ_PphctO1dyYTNrWSJQ-cztgKA,11658
193
+ ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
193
194
  ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
194
195
  ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
195
196
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
@@ -247,8 +248,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
247
248
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
248
249
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
249
250
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
250
- ai_edge_torch_nightly-0.5.0.dev20250502.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
251
- ai_edge_torch_nightly-0.5.0.dev20250502.dist-info/METADATA,sha256=31xPCLYZskB9H70ijwBN7fjINXcq9ljgnlAjieGcUDM,2051
252
- ai_edge_torch_nightly-0.5.0.dev20250502.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
253
- ai_edge_torch_nightly-0.5.0.dev20250502.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
254
- ai_edge_torch_nightly-0.5.0.dev20250502.dist-info/RECORD,,
251
+ ai_edge_torch_nightly-0.5.0.dev20250503.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
252
+ ai_edge_torch_nightly-0.5.0.dev20250503.dist-info/METADATA,sha256=RU3caJRTJFodq-s8HxE5j7uo74dScWYcYMMAtqJVsD4,2051
253
+ ai_edge_torch_nightly-0.5.0.dev20250503.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
254
+ ai_edge_torch_nightly-0.5.0.dev20250503.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
255
+ ai_edge_torch_nightly-0.5.0.dev20250503.dist-info/RECORD,,