ai-edge-torch-nightly 0.2.0.dev20240626__py3-none-any.whl → 0.2.0.dev20240701__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +1 -0
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +23 -5
- ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/convert/to_channel_last_io.py +85 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +2 -2
- ai_edge_torch/generative/examples/gemma/gemma.py +1 -1
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +2 -2
- ai_edge_torch/generative/examples/phi2/phi2.py +1 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -2
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +12 -12
- ai_edge_torch/generative/examples/t5/t5.py +2 -2
- ai_edge_torch/generative/examples/test_models/toy_model.py +1 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +2 -2
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +2 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -2
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +1 -1
- ai_edge_torch/generative/quantize/example.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +8 -8
- ai_edge_torch/generative/test/test_quantize.py +2 -2
- ai_edge_torch/hlfb/mark_pattern/pattern.py +22 -9
- {ai_edge_torch_nightly-0.2.0.dev20240626.dist-info → ai_edge_torch_nightly-0.2.0.dev20240701.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240626.dist-info → ai_edge_torch_nightly-0.2.0.dev20240701.dist-info}/RECORD +26 -24
- {ai_edge_torch_nightly-0.2.0.dev20240626.dist-info → ai_edge_torch_nightly-0.2.0.dev20240701.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240626.dist-info → ai_edge_torch_nightly-0.2.0.dev20240701.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240626.dist-info → ai_edge_torch_nightly-0.2.0.dev20240701.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py
CHANGED
|
@@ -17,10 +17,22 @@ import functools
|
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
20
|
-
from ai_edge_torch.convert.fx_passes import
|
|
21
|
-
from ai_edge_torch.convert.fx_passes import
|
|
20
|
+
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
|
|
21
|
+
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
|
|
22
22
|
from ai_edge_torch.hlfb import mark_pattern
|
|
23
23
|
|
|
24
|
+
# For torch nightly released after mid June 2024,
|
|
25
|
+
# torch.nn.functional.interpolate no longer gets exported into decomposed graph
|
|
26
|
+
# but single aten op torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
|
|
27
|
+
# This behavior would our pattern matching based composite builder.
|
|
28
|
+
# It requires the pattern and model graph to get decomposed first for backward compatibility.
|
|
29
|
+
_INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions(
|
|
30
|
+
[
|
|
31
|
+
torch.ops.aten.upsample_bilinear2d.vec,
|
|
32
|
+
torch.ops.aten.upsample_nearest2d.vec,
|
|
33
|
+
]
|
|
34
|
+
)
|
|
35
|
+
|
|
24
36
|
|
|
25
37
|
@functools.cache
|
|
26
38
|
def _get_upsample_bilinear2d_pattern():
|
|
@@ -30,6 +42,7 @@ def _get_upsample_bilinear2d_pattern():
|
|
|
30
42
|
x, scale_factor=2, mode="bilinear", align_corners=False
|
|
31
43
|
),
|
|
32
44
|
export_args=(torch.rand(1, 3, 100, 100),),
|
|
45
|
+
decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
|
33
46
|
)
|
|
34
47
|
|
|
35
48
|
@pattern.register_attr_builder
|
|
@@ -52,6 +65,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
|
|
|
52
65
|
x, scale_factor=2, mode="bilinear", align_corners=True
|
|
53
66
|
),
|
|
54
67
|
export_args=(torch.rand(1, 3, 100, 100),),
|
|
68
|
+
decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
|
55
69
|
)
|
|
56
70
|
|
|
57
71
|
@pattern.register_attr_builder
|
|
@@ -72,6 +86,7 @@ def _get_interpolate_nearest2d_pattern():
|
|
|
72
86
|
"tfl.resize_nearest_neighbor",
|
|
73
87
|
lambda x: torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest"),
|
|
74
88
|
export_args=(torch.rand(1, 3, 100, 100),),
|
|
89
|
+
decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
|
75
90
|
)
|
|
76
91
|
|
|
77
92
|
@pattern.register_attr_builder
|
|
@@ -86,7 +101,7 @@ def _get_interpolate_nearest2d_pattern():
|
|
|
86
101
|
return pattern
|
|
87
102
|
|
|
88
103
|
|
|
89
|
-
class BuildInterpolateCompositePass(
|
|
104
|
+
class BuildInterpolateCompositePass(ExportedProgramPassBase):
|
|
90
105
|
|
|
91
106
|
def __init__(self):
|
|
92
107
|
super().__init__()
|
|
@@ -96,10 +111,13 @@ class BuildInterpolateCompositePass(FxPassBase):
|
|
|
96
111
|
_get_interpolate_nearest2d_pattern(),
|
|
97
112
|
]
|
|
98
113
|
|
|
99
|
-
def call(self,
|
|
114
|
+
def call(self, exported_program: torch.export.ExportedProgram):
|
|
115
|
+
exported_program = exported_program.run_decompositions(_INTERPOLATE_DECOMPOSITIONS)
|
|
116
|
+
|
|
117
|
+
graph_module = exported_program.graph_module
|
|
100
118
|
for pattern in self._patterns:
|
|
101
119
|
graph_module = mark_pattern.mark_pattern(graph_module, pattern)
|
|
102
120
|
|
|
103
121
|
graph_module.graph.lint()
|
|
104
122
|
graph_module.recompile()
|
|
105
|
-
return
|
|
123
|
+
return ExportedProgramPassResult(exported_program, True)
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import unittest
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
import ai_edge_torch
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Identity(torch.nn.Module):
|
|
24
|
+
|
|
25
|
+
def forward(self, x):
|
|
26
|
+
return x
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TestToChannelLastIO(unittest.TestCase):
|
|
30
|
+
"""Tests to_channel_last_io API and module wrapper."""
|
|
31
|
+
|
|
32
|
+
def test_no_transformations(self):
|
|
33
|
+
x = torch.rand(1, 3, 10, 10)
|
|
34
|
+
y = ai_edge_torch.to_channel_last_io(Identity())(x)
|
|
35
|
+
self.assertEqual(y.shape, (1, 3, 10, 10))
|
|
36
|
+
|
|
37
|
+
def test_args(self):
|
|
38
|
+
x = torch.rand(1, 10, 10, 3)
|
|
39
|
+
y = ai_edge_torch.to_channel_last_io(Identity(), args=[0])(x)
|
|
40
|
+
self.assertEqual(y.shape, (1, 3, 10, 10))
|
|
41
|
+
|
|
42
|
+
def test_outputs(self):
|
|
43
|
+
x = torch.rand(1, 3, 10, 10)
|
|
44
|
+
y = ai_edge_torch.to_channel_last_io(Identity(), outputs=[0])(x)
|
|
45
|
+
self.assertEqual(y.shape, (1, 10, 10, 3))
|
|
46
|
+
|
|
47
|
+
def test_args_outputs(self):
|
|
48
|
+
x = torch.rand(1, 10, 10, 3)
|
|
49
|
+
y = ai_edge_torch.to_channel_last_io(Identity(), args=[0], outputs=[0])(x)
|
|
50
|
+
self.assertEqual(y.shape, (1, 10, 10, 3))
|
|
51
|
+
|
|
52
|
+
def test_args_5d(self):
|
|
53
|
+
x = torch.rand(1, 10, 10, 10, 3)
|
|
54
|
+
y = ai_edge_torch.to_channel_last_io(Identity(), args=[0])(x)
|
|
55
|
+
self.assertEqual(y.shape, (1, 3, 10, 10, 10))
|
|
56
|
+
|
|
57
|
+
def test_outputs_5d(self):
|
|
58
|
+
x = torch.rand(1, 3, 10, 10, 10)
|
|
59
|
+
y = ai_edge_torch.to_channel_last_io(Identity(), outputs=[0])(x)
|
|
60
|
+
self.assertEqual(y.shape, (1, 10, 10, 10, 3))
|
|
61
|
+
|
|
62
|
+
def test_chained_wrappers(self):
|
|
63
|
+
x = torch.rand(1, 10, 10, 3)
|
|
64
|
+
|
|
65
|
+
m = Identity()
|
|
66
|
+
m = ai_edge_torch.to_channel_last_io(m, args=[0])
|
|
67
|
+
m = ai_edge_torch.to_channel_last_io(m, outputs=[0])
|
|
68
|
+
|
|
69
|
+
y = m(x)
|
|
70
|
+
self.assertEqual(y.shape, (1, 10, 10, 3))
|
|
71
|
+
|
|
72
|
+
def test_list_args(self):
|
|
73
|
+
class Add(torch.nn.Module):
|
|
74
|
+
|
|
75
|
+
def forward(self, x, y):
|
|
76
|
+
return x + y
|
|
77
|
+
|
|
78
|
+
x = (torch.rand(1, 10, 10, 3), torch.rand(1, 10, 10, 3))
|
|
79
|
+
y = ai_edge_torch.to_channel_last_io(Add(), args=[0, 1])(*x)
|
|
80
|
+
self.assertEqual(y.shape, (1, 3, 10, 10))
|
|
81
|
+
|
|
82
|
+
def test_list_outputs(self):
|
|
83
|
+
class TwoIdentity(torch.nn.Module):
|
|
84
|
+
|
|
85
|
+
def forward(self, x):
|
|
86
|
+
return x, x
|
|
87
|
+
|
|
88
|
+
x = torch.rand(1, 3, 10, 10)
|
|
89
|
+
y = ai_edge_torch.to_channel_last_io(TwoIdentity(), outputs=[0])(x)
|
|
90
|
+
self.assertIsInstance(y, tuple)
|
|
91
|
+
self.assertEqual(y[0].shape, (1, 10, 10, 3))
|
|
92
|
+
self.assertEqual(y[1].shape, (1, 3, 10, 10))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
if __name__ == "__main__":
|
|
96
|
+
unittest.main()
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch import nn
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ChannelLastIOWrapper(nn.Module):
|
|
23
|
+
|
|
24
|
+
def __init__(self, wrapped, *, args=None, outputs=None):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.wrapped = wrapped
|
|
27
|
+
self._args = args or []
|
|
28
|
+
self._outputs = outputs or []
|
|
29
|
+
|
|
30
|
+
def _to_channel_last(self, x):
|
|
31
|
+
if not torch.is_tensor(x):
|
|
32
|
+
raise ValueError("Input must be a torch tensor")
|
|
33
|
+
if x.ndim < 3:
|
|
34
|
+
raise ValueError("Input must be a tensor with rank >= 3 in layout (N, C, ...)")
|
|
35
|
+
dims = [0, *range(2, x.ndim), 1]
|
|
36
|
+
return torch.permute(x, dims)
|
|
37
|
+
|
|
38
|
+
def _to_channel_first(self, x):
|
|
39
|
+
if not torch.is_tensor(x):
|
|
40
|
+
raise ValueError("Input must be a torch tensor.")
|
|
41
|
+
if x.ndim < 3:
|
|
42
|
+
raise ValueError("Input must be a tensor with rank >= 3 in layout (N, ..., C)")
|
|
43
|
+
dims = [0, x.ndim - 1, *range(1, x.ndim - 1)]
|
|
44
|
+
return torch.permute(x, dims)
|
|
45
|
+
|
|
46
|
+
def forward(self, *args, **kwargs):
|
|
47
|
+
args = list(args)
|
|
48
|
+
for i in self._args:
|
|
49
|
+
args[i] = self._to_channel_first(args[i])
|
|
50
|
+
|
|
51
|
+
outputs = self.wrapped(*args, **kwargs)
|
|
52
|
+
|
|
53
|
+
if not isinstance(outputs, (list, tuple)):
|
|
54
|
+
outputs_is_list = False
|
|
55
|
+
output_list = [outputs]
|
|
56
|
+
else:
|
|
57
|
+
outputs_is_list = True
|
|
58
|
+
output_list = list(outputs)
|
|
59
|
+
|
|
60
|
+
for i in self._outputs:
|
|
61
|
+
output_list[i] = self._to_channel_last(output_list[i])
|
|
62
|
+
|
|
63
|
+
if not outputs_is_list:
|
|
64
|
+
return output_list[0]
|
|
65
|
+
else:
|
|
66
|
+
return type(outputs)(output_list)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def to_channel_last_io(
|
|
70
|
+
module: nn.Module,
|
|
71
|
+
args: Optional[list[int]] = None,
|
|
72
|
+
outputs: Optional[list[int]] = None,
|
|
73
|
+
):
|
|
74
|
+
"""Wraps the module with channel first to channel last layout transformations.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
args (list[int]): Transform args with indices in the list from channel first
|
|
78
|
+
(N, C, ...) to channel last (N, ..., C).
|
|
79
|
+
outputs (list[int]): Transform outputs with indices in the list from channel
|
|
80
|
+
first (N, C, ...) to channel last (N, ..., C).
|
|
81
|
+
Returns:
|
|
82
|
+
The wrapped nn.Module with additional layout transposes after inputs and/or before
|
|
83
|
+
outputs.
|
|
84
|
+
"""
|
|
85
|
+
return ChannelLastIOWrapper(module, args=args, outputs=outputs)
|
|
@@ -45,9 +45,9 @@ def convert_gemma_to_tflite(
|
|
|
45
45
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
46
46
|
)
|
|
47
47
|
# Tensors used to trace the model graph during conversion.
|
|
48
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
|
48
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
|
49
49
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
|
50
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
|
50
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
|
51
51
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
52
52
|
|
|
53
53
|
quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
|
|
@@ -163,7 +163,7 @@ def define_and_run_2b() -> None:
|
|
|
163
163
|
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
|
|
164
164
|
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
|
165
165
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
166
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.
|
|
166
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
|
167
167
|
tokens[0, :4] = idx
|
|
168
168
|
input_pos = torch.arange(0, kv_cache_max_len)
|
|
169
169
|
print("running an inference")
|
|
@@ -43,9 +43,9 @@ def convert_phi2_to_tflite(
|
|
|
43
43
|
"""
|
|
44
44
|
pytorch_model = phi2.build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
|
45
45
|
# Tensors used to trace the model graph during conversion.
|
|
46
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
|
46
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
|
47
47
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
|
48
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
|
48
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
|
49
49
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
50
50
|
|
|
51
51
|
quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
|
|
@@ -153,7 +153,7 @@ def define_and_run() -> None:
|
|
|
153
153
|
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/phi2")
|
|
154
154
|
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
|
155
155
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
156
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.
|
|
156
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
|
157
157
|
tokens[0, :4] = idx
|
|
158
158
|
input_pos = torch.arange(0, kv_cache_max_len)
|
|
159
159
|
print("running an inference")
|
|
@@ -60,8 +60,8 @@ class CLIP(nn.Module):
|
|
|
60
60
|
)
|
|
61
61
|
|
|
62
62
|
@torch.inference_mode
|
|
63
|
-
def forward(self, tokens: torch.
|
|
64
|
-
tokens = tokens.type(torch.
|
|
63
|
+
def forward(self, tokens: torch.IntTensor) -> torch.FloatTensor:
|
|
64
|
+
tokens = tokens.type(torch.int)
|
|
65
65
|
|
|
66
66
|
state = self.tok_embedding(tokens) + self.tok_embedding_position
|
|
67
67
|
for layer in self.transformer_blocks:
|
|
@@ -61,7 +61,7 @@ def convert_stable_diffusion_to_tflite(
|
|
|
61
61
|
n_tokens = 77
|
|
62
62
|
timestamp = 0
|
|
63
63
|
len_prompt = 1
|
|
64
|
-
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.
|
|
64
|
+
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.int)
|
|
65
65
|
input_image = torch.full((1, 3, image_height, image_width), 0, dtype=torch.float32)
|
|
66
66
|
noise = torch.full(
|
|
67
67
|
(len_prompt, 4, image_height // 8, image_width // 8), 0, dtype=torch.float32
|
|
@@ -30,23 +30,23 @@ def convert_t5_to_tflite_singlesig(checkpoint_path: str):
|
|
|
30
30
|
|
|
31
31
|
# encoder
|
|
32
32
|
seq_len = 512
|
|
33
|
-
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
|
33
|
+
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
|
34
34
|
prompt_e_token = [1, 2, 3, 4, 5, 6]
|
|
35
35
|
prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
|
|
36
|
-
prompt_e_token, dtype=torch.
|
|
36
|
+
prompt_e_token, dtype=torch.int
|
|
37
37
|
)
|
|
38
38
|
prefill_e_input_pos = torch.arange(0, seq_len)
|
|
39
|
-
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
|
39
|
+
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
|
40
40
|
prompt_d_token = [1, 2, 3, 4, 5, 6]
|
|
41
41
|
prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
|
|
42
|
-
prompt_d_token, dtype=torch.
|
|
42
|
+
prompt_d_token, dtype=torch.int
|
|
43
43
|
)
|
|
44
44
|
prefill_d_input_pos = torch.arange(0, seq_len)
|
|
45
45
|
|
|
46
46
|
# decoder
|
|
47
|
-
decode_token = torch.tensor([[1]], dtype=torch.
|
|
47
|
+
decode_token = torch.tensor([[1]], dtype=torch.int)
|
|
48
48
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
49
|
-
decode_d_token = torch.tensor([[1]], dtype=torch.
|
|
49
|
+
decode_d_token = torch.tensor([[1]], dtype=torch.int)
|
|
50
50
|
decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
51
51
|
|
|
52
52
|
# Pad mask for self attention only on "real" tokens.
|
|
@@ -78,23 +78,23 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str):
|
|
|
78
78
|
|
|
79
79
|
# encoder
|
|
80
80
|
seq_len = 512
|
|
81
|
-
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
|
81
|
+
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
|
82
82
|
prompt_e_token = [1, 2, 3, 4, 5, 6]
|
|
83
83
|
prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
|
|
84
|
-
prompt_e_token, dtype=torch.
|
|
84
|
+
prompt_e_token, dtype=torch.int
|
|
85
85
|
)
|
|
86
86
|
prefill_e_input_pos = torch.arange(0, seq_len)
|
|
87
|
-
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
|
87
|
+
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
|
88
88
|
prompt_d_token = [1, 2, 3, 4, 5, 6]
|
|
89
89
|
prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
|
|
90
|
-
prompt_d_token, dtype=torch.
|
|
90
|
+
prompt_d_token, dtype=torch.int
|
|
91
91
|
)
|
|
92
92
|
prefill_d_input_pos = torch.arange(0, seq_len)
|
|
93
93
|
|
|
94
94
|
# decoder
|
|
95
|
-
decode_token = torch.tensor([[1]], dtype=torch.
|
|
95
|
+
decode_token = torch.tensor([[1]], dtype=torch.int)
|
|
96
96
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
97
|
-
decode_d_token = torch.tensor([[1]], dtype=torch.
|
|
97
|
+
decode_d_token = torch.tensor([[1]], dtype=torch.int)
|
|
98
98
|
decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
99
99
|
|
|
100
100
|
# Pad mask for self attention only on "real" tokens.
|
|
@@ -562,7 +562,7 @@ def define_and_run_t5(checkpoint_path: str) -> None:
|
|
|
562
562
|
model = build_t5_model(checkpoint_path)
|
|
563
563
|
|
|
564
564
|
idx = get_sample_encoder_input_ids()
|
|
565
|
-
tokens = torch.full((1, 512), 0, dtype=torch.
|
|
565
|
+
tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
|
|
566
566
|
tokens[0, :77] = idx
|
|
567
567
|
input_pos = torch.arange(0, 512)
|
|
568
568
|
|
|
@@ -586,7 +586,7 @@ def define_and_run_t5_split(checkpoint_path: str) -> None:
|
|
|
586
586
|
t5_decoder_model = build_t5_decoder_model(config, embedding_layer, checkpoint_path)
|
|
587
587
|
idx = get_sample_encoder_input_ids()
|
|
588
588
|
|
|
589
|
-
tokens = torch.full((1, 512), 0, dtype=torch.
|
|
589
|
+
tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
|
|
590
590
|
tokens[0, :77] = idx
|
|
591
591
|
input_pos = torch.arange(0, 512)
|
|
592
592
|
|
|
@@ -93,7 +93,7 @@ def define_and_run() -> None:
|
|
|
93
93
|
)
|
|
94
94
|
|
|
95
95
|
model = ToySingleLayerModel(config)
|
|
96
|
-
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
|
|
96
|
+
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN, dtype=torch.int), 0)
|
|
97
97
|
input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
|
|
98
98
|
print('running an inference')
|
|
99
99
|
print(
|
|
@@ -115,13 +115,13 @@ def get_model_config() -> cfg.ModelConfig:
|
|
|
115
115
|
|
|
116
116
|
|
|
117
117
|
def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
|
118
|
-
idx = torch.unsqueeze(torch.arange(0, 100), 0)
|
|
118
|
+
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
|
119
119
|
input_pos = torch.arange(0, 100)
|
|
120
120
|
return idx, input_pos
|
|
121
121
|
|
|
122
122
|
|
|
123
123
|
def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
|
124
|
-
idx = torch.tensor([[1]], dtype=torch.
|
|
124
|
+
idx = torch.tensor([[1]], dtype=torch.int)
|
|
125
125
|
input_pos = torch.tensor([10])
|
|
126
126
|
return idx, input_pos
|
|
127
127
|
|
|
@@ -103,13 +103,13 @@ def get_model_config() -> cfg.ModelConfig:
|
|
|
103
103
|
|
|
104
104
|
|
|
105
105
|
def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
|
106
|
-
idx = torch.unsqueeze(torch.arange(0, 100), 0)
|
|
106
|
+
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
|
107
107
|
input_pos = torch.arange(0, 100)
|
|
108
108
|
return idx, input_pos
|
|
109
109
|
|
|
110
110
|
|
|
111
111
|
def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
|
112
|
-
idx = torch.tensor([[1]], dtype=torch.
|
|
112
|
+
idx = torch.tensor([[1]], dtype=torch.int)
|
|
113
113
|
input_pos = torch.tensor([10], dtype=torch.int64)
|
|
114
114
|
return idx, input_pos
|
|
115
115
|
|
|
@@ -45,9 +45,9 @@ def convert_tiny_llama_to_tflite(
|
|
|
45
45
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
46
46
|
)
|
|
47
47
|
# Tensors used to trace the model graph during conversion.
|
|
48
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
|
48
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
|
49
49
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
|
50
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
|
50
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
|
51
51
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
52
52
|
|
|
53
53
|
quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
|
|
@@ -153,7 +153,7 @@ def define_and_run() -> None:
|
|
|
153
153
|
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/tiny_llama")
|
|
154
154
|
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
|
155
155
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
156
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.
|
|
156
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
|
157
157
|
tokens[0, :4] = idx
|
|
158
158
|
input_pos = torch.arange(0, kv_cache_max_len)
|
|
159
159
|
print("running an inference")
|
|
@@ -26,7 +26,7 @@ def main():
|
|
|
26
26
|
config = gemma.get_fake_model_config_2b_for_test()
|
|
27
27
|
model = gemma.Gemma(config)
|
|
28
28
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
29
|
-
tokens = torch.full((1, 10), 0, dtype=torch.
|
|
29
|
+
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
|
|
30
30
|
tokens[0, :4] = idx
|
|
31
31
|
input_pos = torch.arange(0, 10)
|
|
32
32
|
|
|
@@ -35,7 +35,7 @@ class TestModelConversion(unittest.TestCase):
|
|
|
35
35
|
def test_toy_model_with_kv_cache(self):
|
|
36
36
|
config = toy_model_with_kv_cache.get_model_config()
|
|
37
37
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
38
|
-
idx, input_pos = torch.tensor([[1]], dtype=torch.
|
|
38
|
+
idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
|
39
39
|
[10], dtype=torch.int64
|
|
40
40
|
)
|
|
41
41
|
|
|
@@ -59,7 +59,7 @@ class TestModelConversion(unittest.TestCase):
|
|
|
59
59
|
config = toy_model_with_kv_cache.get_model_config()
|
|
60
60
|
config.batch_size = 2
|
|
61
61
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
62
|
-
idx, input_pos = torch.tensor([[1], [2]], dtype=torch.
|
|
62
|
+
idx, input_pos = torch.tensor([[1], [2]], dtype=torch.int), torch.tensor(
|
|
63
63
|
[10], dtype=torch.int64
|
|
64
64
|
)
|
|
65
65
|
|
|
@@ -83,7 +83,7 @@ class TestModelConversion(unittest.TestCase):
|
|
|
83
83
|
config = toy_model_with_kv_cache.get_model_config()
|
|
84
84
|
config.enable_hlfb = True
|
|
85
85
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
86
|
-
idx, input_pos = torch.tensor([[1]], dtype=torch.
|
|
86
|
+
idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
|
87
87
|
[10], dtype=torch.int64
|
|
88
88
|
)
|
|
89
89
|
|
|
@@ -109,7 +109,7 @@ class TestModelConversion(unittest.TestCase):
|
|
|
109
109
|
pytorch_model = tiny_llama.TinyLLamma(config)
|
|
110
110
|
|
|
111
111
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
112
|
-
tokens = torch.full((1, 10), 0, dtype=torch.
|
|
112
|
+
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
|
|
113
113
|
tokens[0, :4] = idx
|
|
114
114
|
input_pos = torch.arange(0, 10)
|
|
115
115
|
|
|
@@ -135,13 +135,13 @@ class TestModelConversion(unittest.TestCase):
|
|
|
135
135
|
|
|
136
136
|
# prefill
|
|
137
137
|
seq_len = 10
|
|
138
|
-
prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
|
138
|
+
prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.int, device="cpu")
|
|
139
139
|
prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
|
|
140
140
|
prefill_tokens[0, : len(prompt_token)] = prompt_token
|
|
141
141
|
prefill_input_pos = torch.arange(0, seq_len)
|
|
142
142
|
|
|
143
143
|
# decode
|
|
144
|
-
decode_token = torch.tensor([[1]], dtype=torch.
|
|
144
|
+
decode_token = torch.tensor([[1]], dtype=torch.int)
|
|
145
145
|
decode_input_pos = torch.tensor([5], dtype=torch.int64)
|
|
146
146
|
|
|
147
147
|
edge_model = (
|
|
@@ -183,7 +183,7 @@ class TestModelConversion(unittest.TestCase):
|
|
|
183
183
|
model = gemma.Gemma(config)
|
|
184
184
|
|
|
185
185
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
186
|
-
tokens = torch.full((1, 10), 0, dtype=torch.
|
|
186
|
+
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
|
|
187
187
|
tokens[0, :4] = idx
|
|
188
188
|
input_pos = torch.arange(0, 10)
|
|
189
189
|
|
|
@@ -210,7 +210,7 @@ class TestModelConversion(unittest.TestCase):
|
|
|
210
210
|
pytorch_model = phi2.Phi2(config)
|
|
211
211
|
|
|
212
212
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
213
|
-
tokens = torch.full((1, 10), 0, dtype=torch.
|
|
213
|
+
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
|
|
214
214
|
tokens[0, :4] = idx
|
|
215
215
|
input_pos = torch.arange(0, 10)
|
|
216
216
|
|
|
@@ -119,7 +119,7 @@ class TestQuantizeConvert(unittest.TestCase):
|
|
|
119
119
|
self.skipTest("b/346896669")
|
|
120
120
|
config = toy_model_with_kv_cache.get_model_config()
|
|
121
121
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
122
|
-
idx, input_pos = torch.tensor([[1]], dtype=torch.
|
|
122
|
+
idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
|
123
123
|
[10], dtype=torch.int64
|
|
124
124
|
)
|
|
125
125
|
|
|
@@ -137,7 +137,7 @@ class TestQuantizeConvert(unittest.TestCase):
|
|
|
137
137
|
self.skipTest("b/338288901")
|
|
138
138
|
config = toy_model_with_kv_cache.get_model_config()
|
|
139
139
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
140
|
-
idx, input_pos = torch.tensor([[1]], dtype=torch.
|
|
140
|
+
idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
|
141
141
|
[10], dtype=torch.int64
|
|
142
142
|
)
|
|
143
143
|
|
|
@@ -100,13 +100,18 @@ class ScalarAttrLocation:
|
|
|
100
100
|
|
|
101
101
|
|
|
102
102
|
def _find_scalar_attr(
|
|
103
|
-
pattern_module: torch.nn.Module,
|
|
103
|
+
pattern_module: torch.nn.Module,
|
|
104
|
+
export_args: tuple[Any],
|
|
105
|
+
tracker: ScalarAttrTracker,
|
|
106
|
+
decomp_table=None,
|
|
104
107
|
) -> ScalarAttrLocation:
|
|
105
108
|
scalar_loc_intersections = None
|
|
106
109
|
for source, target in tracker._source_targets:
|
|
107
110
|
track_args = list(export_args)
|
|
108
111
|
track_args[tracker.pattern_arg_pos] = source
|
|
109
112
|
ep = torch.export.export(pattern_module, tuple(track_args))
|
|
113
|
+
if decomp_table is not None:
|
|
114
|
+
ep = ep.run_decompositions(decomp_table)
|
|
110
115
|
|
|
111
116
|
scalar_locs = set()
|
|
112
117
|
nodes = ep.graph_module.graph.nodes
|
|
@@ -145,6 +150,7 @@ class Pattern:
|
|
|
145
150
|
["Pattern", GraphModule, InternalMatch], Optional[dict[str, Any]]
|
|
146
151
|
] = None,
|
|
147
152
|
scalar_attr_trackers: list[ScalarAttrTracker] = None,
|
|
153
|
+
decomp_table: Optional[dict[torch._ops.OperatorBase, Callable]] = None,
|
|
148
154
|
):
|
|
149
155
|
"""The PyTorch computation pattern to match against a model.
|
|
150
156
|
|
|
@@ -165,6 +171,8 @@ class Pattern:
|
|
|
165
171
|
for scalar args in `export_args`, which are used to track
|
|
166
172
|
the attr occurrence(s) and retrieve their values from the
|
|
167
173
|
matched subgraph.
|
|
174
|
+
decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]):
|
|
175
|
+
The decomposition table to be run on the pattern's exported program.
|
|
168
176
|
"""
|
|
169
177
|
if not isinstance(module, torch.nn.Module):
|
|
170
178
|
|
|
@@ -180,23 +188,28 @@ class Pattern:
|
|
|
180
188
|
module = PatternModule(module).eval()
|
|
181
189
|
|
|
182
190
|
self.name = name
|
|
183
|
-
self.exported_program = torch.export.export(module, export_args)
|
|
184
|
-
self.graph_module = self.exported_program.graph_module
|
|
185
191
|
self.attr_builder = attr_builder
|
|
186
192
|
self._scalar_attr_trackers = scalar_attr_trackers if scalar_attr_trackers else []
|
|
187
193
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
self.
|
|
194
|
+
exported_program = torch.export.export(module, export_args)
|
|
195
|
+
if decomp_table is not None:
|
|
196
|
+
exported_program = exported_program.run_decompositions(decomp_table)
|
|
197
|
+
|
|
198
|
+
self.exported_program = exported_program
|
|
199
|
+
self.graph_module = self.exported_program.graph_module
|
|
193
200
|
|
|
194
201
|
self._scalar_attr_locations = []
|
|
195
202
|
for tracker in self._scalar_attr_trackers:
|
|
196
203
|
self._scalar_attr_locations.append(
|
|
197
|
-
_find_scalar_attr(module, export_args, tracker)
|
|
204
|
+
_find_scalar_attr(module, export_args, tracker, decomp_table=decomp_table)
|
|
198
205
|
)
|
|
199
206
|
|
|
207
|
+
# Sanitize graph_module for more precise pattern matching.
|
|
208
|
+
# The graph_module to match against this pattern should apply equivalent
|
|
209
|
+
# sanitization.
|
|
210
|
+
self.graph_module = passes.remove_clone_ops(self.graph_module)
|
|
211
|
+
self.graph_module = passes.remove_dangling_args(self.graph_module)
|
|
212
|
+
|
|
200
213
|
# Builds list of ordered input and output nodes.
|
|
201
214
|
self.graph_nodes_map = {}
|
|
202
215
|
for node in self.graph_module.graph.nodes:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.2.0.
|
|
3
|
+
Version: 0.2.0.dev20240701
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -1,13 +1,14 @@
|
|
|
1
|
-
ai_edge_torch/__init__.py,sha256=
|
|
1
|
+
ai_edge_torch/__init__.py,sha256=CNDboRP4zQBpz2hznNCQWcQCARvNXUm3DMa1Dw_XXFg,1067
|
|
2
2
|
ai_edge_torch/model.py,sha256=kmcgELjsYl8YzF8nUF6P7q4i8MWS-pLGpfsy-yTUXmE,4243
|
|
3
3
|
ai_edge_torch/convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
4
4
|
ai_edge_torch/convert/conversion.py,sha256=8K8jQuaCjlUWoj7jiimxp_zpN6mYThLOcQ858UDcYnE,4159
|
|
5
5
|
ai_edge_torch/convert/conversion_utils.py,sha256=9BqCL38DErv1vEVGtT3BIJVhdwZjw2EQ-_m5UpvVVYE,11294
|
|
6
6
|
ai_edge_torch/convert/converter.py,sha256=bjj5TV5_g4sGyuSh8ThEDydlNMqhkGSY4SzXK6vwhqI,6927
|
|
7
|
+
ai_edge_torch/convert/to_channel_last_io.py,sha256=zo5tY3yDhY_EPCkrL1XSXs2uRFS8B4_qu08dSjNsUGk,2778
|
|
7
8
|
ai_edge_torch/convert/fx_passes/__init__.py,sha256=EPs4PSIDLuRH5EBETi6deaOvaaf_Q4xD3_9NVcR7x8o,2810
|
|
8
9
|
ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
|
|
9
10
|
ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=2yqUwJJ2R233_X9FNMOP9oYRTTzH34TR_BIUj-wfnKw,7080
|
|
10
|
-
ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py,sha256=
|
|
11
|
+
ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py,sha256=6m_vcycd9f3OQgQLx2hhQjsKfOqdxE5EkjzqrxqyAQM,4168
|
|
11
12
|
ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=UX6dJsxCqSkftXXvNBV-i7Bjk6H7qTyqzUnE640Itfg,1673
|
|
12
13
|
ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
|
|
13
14
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=VA9bekxPVhLk4MYlIRXnOzrSnbCtUmGj7OQ_fJcKQtc,795
|
|
@@ -24,6 +25,7 @@ ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
|
|
|
24
25
|
ai_edge_torch/convert/test/test_convert.py,sha256=2qPmmGqnfV_o1gfsSdjGq3-JR1b323ligiy5MdAv9NA,8021
|
|
25
26
|
ai_edge_torch/convert/test/test_convert_composites.py,sha256=_Ojc-H6GOS5s8ek3_8eRBL_AiCs-k3srziPJ2R4Ulrg,7255
|
|
26
27
|
ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
|
|
28
|
+
ai_edge_torch/convert/test/test_to_channel_last_io.py,sha256=I8c4ZG3v1vo0yxQYzLK_BTId4AOL9vadHGDtfCUZ4UI,2930
|
|
27
29
|
ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
|
|
28
30
|
ai_edge_torch/debug/culprit.py,sha256=urtCKPXORPvn6oyDxDSCSjgvngUnjjcsUMwAOeIl15E,14236
|
|
29
31
|
ai_edge_torch/debug/utils.py,sha256=hjVmQVVl1dKxEF0D6KB4a3ouQ3wBkTsebOX2YsUObZM,1430
|
|
@@ -34,15 +36,15 @@ ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
|
|
|
34
36
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
35
37
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
36
38
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
37
|
-
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=
|
|
38
|
-
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=
|
|
39
|
+
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=yl36sMjVqDlf9I41DF9C5wx6ztMxYB5xukD1NltUS04,2536
|
|
40
|
+
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=caBQrJTK4tXFeGM-i2cNXd2Tb8GXi463MwFHc7N65WE,5934
|
|
39
41
|
ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
40
|
-
ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=
|
|
41
|
-
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=
|
|
42
|
+
ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=jIFnJY9BtAUJtVkFnms7byZp-jshhQIx59DmK0OjJ8M,2510
|
|
43
|
+
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=tYo-WWOyw6LBF5wBVL_CsuTxKhur0SroDe2sYkTPdvI,5561
|
|
42
44
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
43
45
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
|
|
44
|
-
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=
|
|
45
|
-
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=
|
|
46
|
+
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=cIOqRZ76Pb8ywuCa3LUQnKnBVmvcaAPqvA5bdHfgaWw,3718
|
|
47
|
+
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=AwGVIY_tNlea1c4Rz3lEAeAqBvb-GxGzleX2dle98DE,4181
|
|
46
48
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=RgxedILk7iNMb0mhE4VkCs6d7BnFzYhR3vspUkC0-1o,11425
|
|
47
49
|
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=sRevfsmCun7zbceJbOstLKNUsLwzQDsGm7Mi2JmlREg,26021
|
|
48
50
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=mgbxkeFDMkNIGmnbcFTIFPu8EWKokghiviYIOB2lE3Q,3437
|
|
@@ -55,16 +57,16 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py
|
|
|
55
57
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=iPYX9ZSaxwSak2KI44j6TEr_g4pdxS3xpka4u0trjbo,2788
|
|
56
58
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5iRfU5MO6GR6K3WrdddIU_9U7ZZGEEb7zGKVY1WFl-8,1340
|
|
57
59
|
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
58
|
-
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=
|
|
59
|
-
ai_edge_torch/generative/examples/t5/t5.py,sha256=
|
|
60
|
+
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=vqYip3JnjFMN5s0VnIzwhlBJl3up75WMb_VtRTfMOK0,4524
|
|
61
|
+
ai_edge_torch/generative/examples/t5/t5.py,sha256=TyBDb50NbtKaHyGLGOMJ8dR2_GPF43oe8WhRfb1SMZ4,21085
|
|
60
62
|
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rkMwi-NJGBXHm5S57Rsj1LbcoVdyRkS7GmIBuU6F_2E,8274
|
|
61
63
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
62
|
-
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=
|
|
63
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=
|
|
64
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=
|
|
64
|
+
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=UNj7IgyCLOlaOu2xJKHciHqcd-_NXLNwMphCWwuaCes,3830
|
|
65
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=u1M38wuev-YG-exMG0_HzAw3yaikPRhEZCqGaHO7Cw0,5559
|
|
66
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=xy7kT1f8V0jqAfglggoeKVqR5KviKdfQun0WBYJ0jS8,4869
|
|
65
67
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
66
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=
|
|
67
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
|
68
|
+
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=rAsrP0CnhcYYMCcEJAGoFMRMur0LPoAcYpjFlBiZd2s,2564
|
|
69
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=__9QjZ-7Xi8MX5M7iRvw_jpic0-Fau2IO0QatleUzL0,5650
|
|
68
70
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=aXvYiaHDvETIrh0Q9DDZA_ZBiazGk80DT6nt7lLtC1o,1172
|
|
69
71
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=IehLwFNwa0C9fnk1pmNmyfuAwwWbuwdyKy46BSqNVdI,1948
|
|
70
72
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
@@ -82,7 +84,7 @@ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=H45wsXA6iJi_Mjd66NiQrh7
|
|
|
82
84
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=NmJiZ2-e1wbv9jnvI3VCyUJlONV5ZAOz-RTc7ipAZ5U,1872
|
|
83
85
|
ai_edge_torch/generative/layers/unet/model_config.py,sha256=FrIO-CR8aRIV2i8aFqom_4S7WCEDLMyYwo6U0oFyn7A,9097
|
|
84
86
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
85
|
-
ai_edge_torch/generative/quantize/example.py,sha256=
|
|
87
|
+
ai_edge_torch/generative/quantize/example.py,sha256=zgBgMyZ8RlSIjRhbaeodLyt6sj_dYuM7oI6Zyx7xQIg,1542
|
|
86
88
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
|
87
89
|
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=Y8zahKw7b_h7ajPaJZVef4jG-MoqImRCpVSbFtV_i24,5139
|
|
88
90
|
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=-vd6Qp0BdXJVKg4f0_hhwbKOi3QPIAPVqyXnJ-ZnISQ,1915
|
|
@@ -92,8 +94,8 @@ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=47DE
|
|
|
92
94
|
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=qUB4f2DoB14dLkNPWf6TZodpT81mfAJeWM-lCAmkuHY,5735
|
|
93
95
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
94
96
|
ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
|
|
95
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
|
96
|
-
ai_edge_torch/generative/test/test_quantize.py,sha256=
|
|
97
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=tXES8gePl8BptETyUgpIznbIK1SvGEd8mKz1bT_a8Mw,7581
|
|
98
|
+
ai_edge_torch/generative/test/test_quantize.py,sha256=Emp-8oLHyGddELSCkncuSQt8ZJIhZ2-y0-ghR92s10g,5386
|
|
97
99
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
98
100
|
ai_edge_torch/generative/utilities/loader.py,sha256=Hs92478j1g4jQGvbdP1aWvOy907HjwqQZE-NFy6HELo,11326
|
|
99
101
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=7ChqrnthD7I-Be6vkRvYTRhbGQ3tqMbikLpjY5HpSzE,30890
|
|
@@ -101,7 +103,7 @@ ai_edge_torch/generative/utilities/t5_loader.py,sha256=h1FQzt4x8wiQMX4NzYNVIaJGL
|
|
|
101
103
|
ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
|
|
102
104
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
|
|
103
105
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=YV2YKBkh7y7j7sd7EA81vf_1hUKUvTRiy1pfqZustXc,1539
|
|
104
|
-
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=
|
|
106
|
+
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=BEMKgkX8IrsX70h2CxwA_tpsBm6qWWe2bOeOufMYNkA,9722
|
|
105
107
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
106
108
|
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=qYR3PRGS9W3OG-qX7UFqL69VxXuUSfyDBUJtCXtXcOE,4262
|
|
107
109
|
ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=aUAPKnH4_Jxpp3mLlD5rzdT1g_VBm7OrwwLJ9DeJlzQ,8190
|
|
@@ -112,8 +114,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
|
|
|
112
114
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
113
115
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
114
116
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
|
|
115
|
-
ai_edge_torch_nightly-0.2.0.
|
|
116
|
-
ai_edge_torch_nightly-0.2.0.
|
|
117
|
-
ai_edge_torch_nightly-0.2.0.
|
|
118
|
-
ai_edge_torch_nightly-0.2.0.
|
|
119
|
-
ai_edge_torch_nightly-0.2.0.
|
|
117
|
+
ai_edge_torch_nightly-0.2.0.dev20240701.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
118
|
+
ai_edge_torch_nightly-0.2.0.dev20240701.dist-info/METADATA,sha256=QItv3j92_LbO-UPcMqZHRhwTeSTnVQiy-cam2ciLaGM,1748
|
|
119
|
+
ai_edge_torch_nightly-0.2.0.dev20240701.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
120
|
+
ai_edge_torch_nightly-0.2.0.dev20240701.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
121
|
+
ai_edge_torch_nightly-0.2.0.dev20240701.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|