ai-edge-torch-nightly 0.5.0.dev20250502__py3-none-any.whl → 0.5.0.dev20250503__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma3/decoder.py +6 -2
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +2 -40
- ai_edge_torch/generative/examples/paligemma/decoder.py +1 -1
- ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -1
- ai_edge_torch/generative/layers/model_config.py +6 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -1
- ai_edge_torch/generative/layers/scaled_dot_product_attention_test.py +87 -0
- ai_edge_torch/generative/utilities/converter.py +32 -10
- ai_edge_torch/generative/utilities/export_config.py +2 -23
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250502.dist-info → ai_edge_torch_nightly-0.5.0.dev20250503.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250502.dist-info → ai_edge_torch_nightly-0.5.0.dev20250503.dist-info}/RECORD +15 -14
- {ai_edge_torch_nightly-0.5.0.dev20250502.dist-info → ai_edge_torch_nightly-0.5.0.dev20250503.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250502.dist-info → ai_edge_torch_nightly-0.5.0.dev20250503.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250502.dist-info → ai_edge_torch_nightly-0.5.0.dev20250503.dist-info}/top_level.txt +0 -0
@@ -199,7 +199,11 @@ class Decoder(nn.Module):
|
|
199
199
|
sliding_mask = torch.where(
|
200
200
|
sliding_mask_bool,
|
201
201
|
torch.zeros_like(sliding_mask_bool, dtype=torch.float),
|
202
|
-
torch.full_like(
|
202
|
+
torch.full_like(
|
203
|
+
sliding_mask_bool,
|
204
|
+
self.config.get_causal_mask_value(),
|
205
|
+
dtype=torch.float,
|
206
|
+
),
|
203
207
|
)
|
204
208
|
|
205
209
|
return sliding_mask
|
@@ -215,7 +219,7 @@ class Decoder(nn.Module):
|
|
215
219
|
mask = torch.logical_and(mask, pixel_mask)
|
216
220
|
else:
|
217
221
|
mask = torch.logical_or(mask, pixel_mask)
|
218
|
-
mask = torch.where(mask, 0,
|
222
|
+
mask = torch.where(mask, 0, self.config.get_causal_mask_value())
|
219
223
|
return mask
|
220
224
|
|
221
225
|
def build_pixel_mask(self, image_indices: torch.Tensor):
|
@@ -17,15 +17,10 @@
|
|
17
17
|
|
18
18
|
from absl import app
|
19
19
|
from ai_edge_torch.generative.examples.hammer import hammer
|
20
|
-
from ai_edge_torch.generative.layers import kv_cache
|
21
20
|
from ai_edge_torch.generative.utilities import converter
|
22
|
-
from ai_edge_torch.generative.utilities import export_config
|
23
|
-
import torch
|
24
|
-
|
21
|
+
from ai_edge_torch.generative.utilities import export_config
|
25
22
|
|
26
23
|
flags = converter.define_conversion_flags('hammer')
|
27
|
-
ExportConfig = export_cfg.ExportConfig
|
28
|
-
|
29
24
|
|
30
25
|
_MODEL_SIZE = flags.DEFINE_enum(
|
31
26
|
'model_size',
|
@@ -40,35 +35,6 @@ _BUILDER = {
|
|
40
35
|
}
|
41
36
|
|
42
37
|
|
43
|
-
def _create_mask(mask_len, kv_cache_max_len):
|
44
|
-
mask = torch.full(
|
45
|
-
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
46
|
-
)
|
47
|
-
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
48
|
-
return mask
|
49
|
-
|
50
|
-
|
51
|
-
def _create_export_config(
|
52
|
-
prefill_seq_lens: list[int], kv_cache_max_len: int
|
53
|
-
) -> ExportConfig:
|
54
|
-
"""Creates the export config for the model."""
|
55
|
-
export_config = ExportConfig()
|
56
|
-
if isinstance(prefill_seq_lens, list):
|
57
|
-
prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
|
58
|
-
else:
|
59
|
-
prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
|
60
|
-
|
61
|
-
export_config.prefill_mask = prefill_mask
|
62
|
-
|
63
|
-
decode_mask = torch.full(
|
64
|
-
(1, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
65
|
-
)
|
66
|
-
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
67
|
-
export_config.decode_mask = decode_mask
|
68
|
-
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
|
69
|
-
return export_config
|
70
|
-
|
71
|
-
|
72
38
|
def main(_):
|
73
39
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
74
40
|
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
@@ -80,11 +46,7 @@ def main(_):
|
|
80
46
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
81
47
|
quantize=flags.FLAGS.quantize,
|
82
48
|
lora_ranks=flags.FLAGS.lora_ranks,
|
83
|
-
export_config=
|
84
|
-
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
85
|
-
)
|
86
|
-
if flags.FLAGS.transpose_kv_cache
|
87
|
-
else ExportConfig(),
|
49
|
+
export_config=export_config.get_from_flags(),
|
88
50
|
)
|
89
51
|
|
90
52
|
|
@@ -75,7 +75,7 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
75
75
|
if mask is None:
|
76
76
|
embeds_len = input_embeds.shape[1]
|
77
77
|
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
78
|
-
mask[:, embeds_len:] =
|
78
|
+
mask[:, embeds_len:] = attn_config.causal_mask_value
|
79
79
|
|
80
80
|
return self._forward_with_embeds(
|
81
81
|
input_embeds,
|
@@ -75,7 +75,7 @@ class Decoder2(gemma2.Gemma2):
|
|
75
75
|
# By default, don't mask image embeds with a diagonal causal mask.
|
76
76
|
embeds_len = input_embeds.shape[1]
|
77
77
|
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
78
|
-
mask[:, embeds_len:] =
|
78
|
+
mask[:, embeds_len:] = attn_config.causal_mask_value
|
79
79
|
|
80
80
|
return self._forward_with_embeds(
|
81
81
|
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
@@ -116,6 +116,8 @@ class AttentionConfig:
|
|
116
116
|
attn_type: Optional[AttentionType] = None
|
117
117
|
# The size of the sliding window used for local attention.
|
118
118
|
sliding_window_size: Optional[int] = None
|
119
|
+
# The default causal mask value used by attention layer.
|
120
|
+
causal_mask_value: float = float("-inf")
|
119
121
|
|
120
122
|
|
121
123
|
@dataclasses.dataclass
|
@@ -247,3 +249,7 @@ class ModelConfig:
|
|
247
249
|
f"Index {idx} is out of range for layer configs: {self.block_configs}"
|
248
250
|
)
|
249
251
|
return self.block_configs[idx]
|
252
|
+
|
253
|
+
@property
|
254
|
+
def get_causal_mask_value(self) -> float:
|
255
|
+
return self.block_config(0).attn_config.causal_mask_value
|
@@ -160,7 +160,7 @@ def scaled_dot_product_attention_transposed(
|
|
160
160
|
Args:
|
161
161
|
query: Query tensor, with shape [B, T, N, H].
|
162
162
|
key: Key tensor, with shape [B, T, KV_LEN, H].
|
163
|
-
value: Value tensor, with shape [B, T,
|
163
|
+
value: Value tensor, with shape [B, T, H, KV_LEN].
|
164
164
|
head_size (int): head dimension.
|
165
165
|
mask (torch.Tensor): the optional mask tensor.
|
166
166
|
scale (float): the optional scale factor.
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from ai_edge_torch import odml_torch
|
17
|
+
from ai_edge_torch.generative.layers import scaled_dot_product_attention
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from absl.testing import absltest as googletest
|
21
|
+
|
22
|
+
|
23
|
+
class ScaledDotProductAttentionTest(googletest.TestCase):
|
24
|
+
|
25
|
+
def test_scaled_dot_product_attention(self):
|
26
|
+
query = torch.randn(1, 16, 16, 128, dtype=torch.float32)
|
27
|
+
key = torch.randn(1, 16, 16, 128, dtype=torch.float32)
|
28
|
+
value = torch.randn(1, 16, 16, 128, dtype=torch.float32)
|
29
|
+
mask = torch.ones((1, 1, 1, 16), dtype=torch.float32)
|
30
|
+
output = scaled_dot_product_attention.scaled_dot_product_attention(
|
31
|
+
query, key, value, head_size=128, mask=mask, scale=1.0, softcap=10.0
|
32
|
+
)
|
33
|
+
self.assertEqual(output.shape, (1, 16, 16, 128))
|
34
|
+
|
35
|
+
def test_scaled_dot_product_attention_transposed(self):
|
36
|
+
query = torch.randn(1, 16, 16, 128, dtype=torch.float32)
|
37
|
+
key = torch.randn(1, 16, 16, 128, dtype=torch.float32)
|
38
|
+
value = torch.randn(1, 16, 128, 16, dtype=torch.float32)
|
39
|
+
mask = torch.ones((1, 1, 1, 16), dtype=torch.float32)
|
40
|
+
output = (
|
41
|
+
scaled_dot_product_attention.scaled_dot_product_attention_transposed(
|
42
|
+
query, key, value, head_size=128, mask=mask, scale=1.0, softcap=10.0
|
43
|
+
)
|
44
|
+
)
|
45
|
+
self.assertEqual(output.shape, (1, 16, 16, 128))
|
46
|
+
|
47
|
+
def test_scaled_dot_product_attention_with_hlfb(self):
|
48
|
+
query = torch.randn(1, 16, 16, 128, dtype=torch.float32)
|
49
|
+
key = torch.randn(1, 16, 16, 128, dtype=torch.float32)
|
50
|
+
value = torch.randn(1, 16, 16, 128, dtype=torch.float32)
|
51
|
+
mask = torch.ones((1, 1, 1, 16), dtype=torch.float32)
|
52
|
+
output = (
|
53
|
+
scaled_dot_product_attention.scaled_dot_product_attention_with_hlfb(
|
54
|
+
query, key, value, head_size=128, mask=mask, scale=1.0, softcap=10.0
|
55
|
+
)
|
56
|
+
)
|
57
|
+
self.assertEqual(output.shape, (1, 16, 16, 128))
|
58
|
+
|
59
|
+
def model_to_mlir(model, args):
|
60
|
+
ep = torch.export.export(model, args)
|
61
|
+
mlir = odml_torch.export.exported_program_to_mlir(ep)
|
62
|
+
return mlir.get_text()
|
63
|
+
|
64
|
+
class SDPAModule(torch.nn.Module):
|
65
|
+
|
66
|
+
def __init__(self):
|
67
|
+
super().__init__()
|
68
|
+
|
69
|
+
def forward(self, query, key, value, mask):
|
70
|
+
return (
|
71
|
+
scaled_dot_product_attention.scaled_dot_product_attention_with_hlfb(
|
72
|
+
query,
|
73
|
+
key,
|
74
|
+
value,
|
75
|
+
head_size=128,
|
76
|
+
mask=mask,
|
77
|
+
scale=1.0,
|
78
|
+
softcap=10.0,
|
79
|
+
)
|
80
|
+
)
|
81
|
+
|
82
|
+
ir_text = model_to_mlir(SDPAModule().eval(), (query, key, value, mask))
|
83
|
+
self.assertEqual(ir_text.count("stablehlo.custom_call @mark_tensor"), 5)
|
84
|
+
|
85
|
+
|
86
|
+
if __name__ == "__main__":
|
87
|
+
googletest.main()
|
@@ -95,6 +95,18 @@ def define_conversion_flags(model_name: str):
|
|
95
95
|
return flags
|
96
96
|
|
97
97
|
|
98
|
+
def _build_mask(mask_len, kv_cache_max_len, causal_mask_value) -> torch.Tensor:
|
99
|
+
if isinstance(mask_len, list):
|
100
|
+
return [
|
101
|
+
_build_mask(i, kv_cache_max_len, causal_mask_value) for i in mask_len
|
102
|
+
]
|
103
|
+
|
104
|
+
mask = torch.full(
|
105
|
+
(mask_len, kv_cache_max_len), causal_mask_value, dtype=torch.float32
|
106
|
+
)
|
107
|
+
return torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
108
|
+
|
109
|
+
|
98
110
|
def convert_to_tflite(
|
99
111
|
pytorch_model: torch.nn.Module,
|
100
112
|
output_path: str,
|
@@ -229,14 +241,15 @@ def _export_helper(
|
|
229
241
|
torch.arange(0, seq_len + pixel_seq_len, dtype=torch.int)
|
230
242
|
)
|
231
243
|
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
244
|
+
prefill_masks = None
|
245
|
+
if flags.FLAGS.mask_as_input:
|
246
|
+
prefill_masks = [
|
247
|
+
_build_mask(
|
248
|
+
flags.FLAGS.prefill_seq_lens,
|
249
|
+
flags.FLAGS.kv_cache_max_len,
|
250
|
+
config.get_causal_mask_value(),
|
251
|
+
)
|
252
|
+
]
|
240
253
|
|
241
254
|
if prefill_masks:
|
242
255
|
assert len(prefill_masks) == len(prefill_seq_lens)
|
@@ -299,8 +312,17 @@ def _export_helper(
|
|
299
312
|
'input_pos': decode_input_pos,
|
300
313
|
'kv_cache': decode_kv,
|
301
314
|
}
|
302
|
-
if
|
303
|
-
|
315
|
+
if flags.FLAGS.mask_as_input:
|
316
|
+
# Note that the decode mask is not a correct causal mask, but it is okay
|
317
|
+
# for the conversion purpose because only the shape matters in conversion.
|
318
|
+
# A correct causal mask of decode for a given token position of decode, it
|
319
|
+
# should be built like:
|
320
|
+
#
|
321
|
+
# torch.triu(mask, diagonal=decode_position).unsqueeze(0).unsqueeze(0)
|
322
|
+
#
|
323
|
+
sample_kwargs['mask'] = _build_mask(
|
324
|
+
1, flags.FLAGS.kv_cache_max_len, config.get_causal_mask_value()
|
325
|
+
)
|
304
326
|
if lora is not None:
|
305
327
|
sample_kwargs['lora'] = lora
|
306
328
|
|
@@ -33,6 +33,8 @@ class ExportConfig:
|
|
33
33
|
# When False, only decode signatures will produce output.
|
34
34
|
output_logits_on_prefill: bool = False
|
35
35
|
# Attention masks given as inputs to the model.
|
36
|
+
# Note that `prefill_mask`, `decode_mask`, and `kvcache_cls` are deprecated
|
37
|
+
# and will be removed in a future version.
|
36
38
|
prefill_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
|
37
39
|
decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
|
38
40
|
# The KV Cache layout for K and V buffers in attention.
|
@@ -43,33 +45,10 @@ class ExportConfig:
|
|
43
45
|
decode_batch_size: int = 1
|
44
46
|
|
45
47
|
|
46
|
-
def _build_mask(mask_len, kv_cache_max_len) -> torch.Tensor:
|
47
|
-
if isinstance(mask_len, list):
|
48
|
-
return [_build_mask(i, kv_cache_max_len) for i in mask_len]
|
49
|
-
|
50
|
-
mask = torch.full(
|
51
|
-
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
52
|
-
)
|
53
|
-
return torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
54
|
-
|
55
|
-
|
56
48
|
def get_from_flags() -> ExportConfig:
|
57
49
|
"""Builds an export config according to the commandline flags."""
|
58
50
|
export_config = ExportConfig()
|
59
51
|
|
60
|
-
if flags.FLAGS.mask_as_input:
|
61
|
-
export_config.prefill_mask = _build_mask(
|
62
|
-
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
63
|
-
)
|
64
|
-
# Note that the decode mask is not a correct causal mask, but it is okay
|
65
|
-
# for the conversion purpose because only the shape matters in conversion.
|
66
|
-
# A correct causal mask of decode for a given token position of decode, it
|
67
|
-
# should be built like:
|
68
|
-
#
|
69
|
-
# torch.triu(mask, diagonal=decode_position).unsqueeze(0).unsqueeze(0)
|
70
|
-
#
|
71
|
-
export_config.decode_mask = _build_mask(1, flags.FLAGS.kv_cache_max_len)
|
72
|
-
|
73
52
|
if flags.FLAGS.transpose_kv_cache:
|
74
53
|
export_config.kvcache_layout = kv_utils.KV_LAYOUT_TRANSPOSED
|
75
54
|
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250503
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=rjhPV_Qh8FDlHQTy8wAJvuXSNGcntZerhf-8FTEjuWI,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -66,13 +66,13 @@ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=jhiyinOqPt5ZZjEa
|
|
66
66
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=n7f2nF6Lin_tDvPs0JVldsuaBzo7pAwi5YAHAhlIxQg,6139
|
67
67
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
68
68
|
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=MjkQDVynaw9C5z9ODzKfb85xW5JfxHUWBJ_Aco05FHo,1760
|
69
|
-
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=
|
69
|
+
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=fzLpuJO5JseQLA38Li-i9Xdnh9I4zdBWQEOeNbUEfjI,15737
|
70
70
|
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
|
71
71
|
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
72
72
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
73
73
|
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=KnE9ME3mrpQkAxFlBOJLsqcQkjsdDL1ClNhJahX5K5I,8960
|
74
74
|
ai_edge_torch/generative/examples/hammer/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
75
|
-
ai_edge_torch/generative/examples/hammer/convert_to_tflite.py,sha256=
|
75
|
+
ai_edge_torch/generative/examples/hammer/convert_to_tflite.py,sha256=9r8LXyaoBXYIIhhe1WQgEIjaxALQPE1dO2N6qopyWCk,1753
|
76
76
|
ai_edge_torch/generative/examples/hammer/hammer.py,sha256=76INcjffvaNCQ02fzXcxJUW_6EKHs4sg3q1nDBbEpHE,3431
|
77
77
|
ai_edge_torch/generative/examples/hammer/verify.py,sha256=MkzAGkbPy4LKRhyCDm1cw-9jUt4VUxLPdwK_25fCGSE,2705
|
78
78
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -88,8 +88,8 @@ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2u
|
|
88
88
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=4W26ZtPF5Cb9mpHYuRM4b2QB_4W76zf4WV36KzexVjs,2446
|
89
89
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
90
90
|
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=7HHXkC-IIu7ieBvBI4RlXs_oITz7R8a6YVYQskAs_Uk,2023
|
91
|
-
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256
|
92
|
-
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=
|
91
|
+
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=-EYUZp55dfRY1E-N0Pr3b9i5c7Tt1XvYxvsRixguVS8,5527
|
92
|
+
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=WB8r-e_Crog1ItBq3Zse_nUG-foFyBcJsuEG26r_Ji8,6076
|
93
93
|
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=SvuR97sjkBtfkerH7Hu1UXB8kCFLpEATNbPfCbNAyfo,5614
|
94
94
|
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=CFIjOmrn4a4Udki7l3im0JR4zTC_NttnsIr9_qWjKTY,6110
|
95
95
|
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=zrCNz_QSQU6BbaFtx-J-MqxXWcNlsAlquaHpKodsyW4,5350
|
@@ -162,10 +162,11 @@ ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-m
|
|
162
162
|
ai_edge_torch/generative/layers/feed_forward_test.py,sha256=8ZGy79BBpsyS6yKKDEKrDt249G5Mz-8VKWW7_WHx0u4,1655
|
163
163
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=b-7shzDaKexmvQF7P3SiAmIz4ZofjYWv3m5u71GojsA,10460
|
164
164
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
165
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
165
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=dRZUMa71ADaEllu7TfXUWTMHRCcMgvkFMYMzmeJi4G8,8576
|
166
166
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
167
167
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
168
|
-
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=
|
168
|
+
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=2_AgwENsaOgaxgiSqgoj0V0JzQ09dFtP_nBhX-lJK2g,5648
|
169
|
+
ai_edge_torch/generative/layers/scaled_dot_product_attention_test.py,sha256=c6JBMQsq9XeMmR1XvGEIidNsoh-YIvichXo2LwVHgr4,3301
|
169
170
|
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=iw7D_46CFe9iRvU0UumbkIoqWQEhDroxm9ABcK-CLlM,3600
|
170
171
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
171
172
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
|
@@ -188,8 +189,8 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hP
|
|
188
189
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
189
190
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
190
191
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
191
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
192
|
-
ai_edge_torch/generative/utilities/export_config.py,sha256=
|
192
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=K1gZWPq5f3Z7f9USeJ_PphctO1dyYTNrWSJQ-cztgKA,11658
|
193
|
+
ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
|
193
194
|
ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
|
194
195
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
|
195
196
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
@@ -247,8 +248,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
247
248
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
248
249
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
249
250
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
250
|
-
ai_edge_torch_nightly-0.5.0.
|
251
|
-
ai_edge_torch_nightly-0.5.0.
|
252
|
-
ai_edge_torch_nightly-0.5.0.
|
253
|
-
ai_edge_torch_nightly-0.5.0.
|
254
|
-
ai_edge_torch_nightly-0.5.0.
|
251
|
+
ai_edge_torch_nightly-0.5.0.dev20250503.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
252
|
+
ai_edge_torch_nightly-0.5.0.dev20250503.dist-info/METADATA,sha256=RU3caJRTJFodq-s8HxE5j7uo74dScWYcYMMAtqJVsD4,2051
|
253
|
+
ai_edge_torch_nightly-0.5.0.dev20250503.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
254
|
+
ai_edge_torch_nightly-0.5.0.dev20250503.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
255
|
+
ai_edge_torch_nightly-0.5.0.dev20250503.dist-info/RECORD,,
|
File without changes
|
File without changes
|