ai-edge-torch-nightly 0.3.0.dev20241129__py3-none-any.whl → 0.3.0.dev20241204__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/test/test_convert.py +48 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +6 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +6 -6
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +6 -6
- ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
- ai_edge_torch/generative/examples/{gemma/convert_gemma2_multi_prefills.py → moonshine/convert_moonshine_to_tflite.py} +11 -29
- ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +9 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +6 -6
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +6 -6
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +8 -6
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +6 -6
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +8 -6
- ai_edge_torch/generative/test/test_quantize.py +5 -0
- ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
- ai_edge_torch/odml_torch/export.py +45 -7
- ai_edge_torch/odml_torch/export_utils.py +2 -13
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +1 -3
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +1 -3
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +174 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +16 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241129.dist-info → ai_edge_torch_nightly-0.3.0.dev20241204.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241129.dist-info → ai_edge_torch_nightly-0.3.0.dev20241204.dist-info}/RECORD +28 -24
- {ai_edge_torch_nightly-0.3.0.dev20241129.dist-info → ai_edge_torch_nightly-0.3.0.dev20241204.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241129.dist-info → ai_edge_torch_nightly-0.3.0.dev20241204.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241129.dist-info → ai_edge_torch_nightly-0.3.0.dev20241204.dist-info}/top_level.txt +0 -0
@@ -21,10 +21,12 @@ from typing import Tuple
|
|
21
21
|
import ai_edge_torch
|
22
22
|
from ai_edge_torch import config
|
23
23
|
from ai_edge_torch._convert import conversion_utils
|
24
|
+
from ai_edge_torch.quantize import pt2e_quantizer
|
24
25
|
from ai_edge_torch.testing import model_coverage
|
25
26
|
import numpy as np
|
26
27
|
import torch
|
27
28
|
from torch import nn
|
29
|
+
from torch.ao.quantization import quantize_pt2e
|
28
30
|
import torchvision
|
29
31
|
|
30
32
|
from absl.testing import absltest as googletest
|
@@ -506,6 +508,52 @@ class TestConvert(googletest.TestCase):
|
|
506
508
|
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
|
507
509
|
)
|
508
510
|
|
511
|
+
def test_convert_resnet18_pt2e_per_layer(self):
|
512
|
+
# Step 1: export resnet18
|
513
|
+
args = (torch.randn(1, 3, 224, 224),)
|
514
|
+
m = torchvision.models.resnet18().eval()
|
515
|
+
m = torch._export.capture_pre_autograd_graph(m, args)
|
516
|
+
|
517
|
+
# Step 2: Insert observers or fake quantize modules
|
518
|
+
quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
|
519
|
+
pt2e_quantizer.get_symmetric_quantization_config(is_per_channel=False)
|
520
|
+
)
|
521
|
+
m = quantize_pt2e.prepare_pt2e(m, quantizer)
|
522
|
+
|
523
|
+
# Step 3: Quantize the model
|
524
|
+
m = quantize_pt2e.convert_pt2e(m, fold_quantize=False)
|
525
|
+
|
526
|
+
# pylint: disable=broad-except
|
527
|
+
try:
|
528
|
+
ai_edge_torch.convert(m, args)
|
529
|
+
except Exception as err:
|
530
|
+
self.fail(f"PT2E conversion failed: {err}")
|
531
|
+
# pylint: enable=broad-except
|
532
|
+
|
533
|
+
def test_convert_resnet18_pt2e_per_channel(self):
|
534
|
+
# Step 1: export resnet18
|
535
|
+
args = (torch.randn(1, 3, 224, 224),)
|
536
|
+
m = torchvision.models.resnet18().eval()
|
537
|
+
m = torch._export.capture_pre_autograd_graph(m, args)
|
538
|
+
|
539
|
+
# Step 2: Insert observers or fake quantize modules
|
540
|
+
quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
|
541
|
+
pt2e_quantizer.get_symmetric_quantization_config(is_per_channel=True)
|
542
|
+
)
|
543
|
+
m = quantize_pt2e.prepare_pt2e(m, quantizer)
|
544
|
+
# Step 3: Run through example inputs, otherwise per-channel
|
545
|
+
# quant may have scalar scale/zero_point
|
546
|
+
m(*args)
|
547
|
+
# Step 4: Quantize the model
|
548
|
+
m = quantize_pt2e.convert_pt2e(m, fold_quantize=False)
|
549
|
+
|
550
|
+
# pylint: disable=broad-except
|
551
|
+
try:
|
552
|
+
ai_edge_torch.convert(m, args)
|
553
|
+
except Exception as err:
|
554
|
+
self.fail(f"PT2E conversion failed: {err}")
|
555
|
+
# pylint: enable=broad-except
|
556
|
+
|
509
557
|
|
510
558
|
if __name__ == "__main__":
|
511
559
|
googletest.main()
|
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
|
37
|
-
'
|
38
|
-
1024,
|
39
|
-
'
|
36
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
+
'prefill_seq_lens',
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -55,11 +55,11 @@ def main(_):
|
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
57
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename = f'gemma_{quant_suffix}
|
58
|
+
output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
59
|
converter.convert_to_tflite(
|
60
60
|
pytorch_model,
|
61
61
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=
|
62
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
63
|
quantize=_QUANTIZE.value,
|
64
64
|
)
|
65
65
|
|
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
|
37
|
-
'
|
38
|
-
1024,
|
39
|
-
'
|
36
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
+
'prefill_seq_lens',
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -55,11 +55,11 @@ def main(_):
|
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
57
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename = f'gemma2_{quant_suffix}
|
58
|
+
output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
59
|
converter.convert_to_tflite(
|
60
60
|
pytorch_model,
|
61
61
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=
|
62
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
63
|
quantize=_QUANTIZE.value,
|
64
64
|
)
|
65
65
|
|
@@ -39,10 +39,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
39
39
|
'/tmp/',
|
40
40
|
'The tflite file path to export.',
|
41
41
|
)
|
42
|
-
|
43
|
-
'
|
44
|
-
1024,
|
45
|
-
'
|
42
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
43
|
+
'prefill_seq_lens',
|
44
|
+
(8, 64, 128, 256, 512, 1024),
|
45
|
+
'List of the maximum sizes of prefill input tensors.',
|
46
46
|
)
|
47
47
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
48
|
'kv_cache_max_len',
|
@@ -66,11 +66,11 @@ def main(_):
|
|
66
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
67
67
|
)
|
68
68
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
69
|
-
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}
|
69
|
+
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
70
70
|
converter.convert_to_tflite(
|
71
71
|
pytorch_model,
|
72
72
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
73
|
-
prefill_seq_len=
|
73
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
74
74
|
quantize=_QUANTIZE.value,
|
75
75
|
)
|
76
76
|
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -13,19 +13,21 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""Example
|
16
|
+
"""Example of converting a Moonshine model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
21
|
from absl import app
|
22
22
|
from absl import flags
|
23
|
-
|
23
|
+
import ai_edge_torch
|
24
|
+
from ai_edge_torch.generative.examples.moonshine import moonshine
|
24
25
|
from ai_edge_torch.generative.utilities import converter
|
26
|
+
import torch
|
25
27
|
|
26
28
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
29
|
'checkpoint_path',
|
28
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
30
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/moonshine'),
|
29
31
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
32
|
)
|
31
33
|
_TFLITE_PATH = flags.DEFINE_string(
|
@@ -33,35 +35,15 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
35
|
'/tmp/',
|
34
36
|
'The tflite file path to export.',
|
35
37
|
)
|
36
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
-
'prefill_seq_lens',
|
38
|
-
(8, 64, 128, 256, 512, 1024),
|
39
|
-
'List of the maximum sizes of prefill input tensors.',
|
40
|
-
)
|
41
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
-
'kv_cache_max_len',
|
43
|
-
1280,
|
44
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
-
)
|
46
|
-
_QUANTIZE = flags.DEFINE_bool(
|
47
|
-
'quantize',
|
48
|
-
True,
|
49
|
-
'Whether the model should be quantized.',
|
50
|
-
)
|
51
38
|
|
52
39
|
|
53
40
|
def main(_):
|
54
|
-
|
55
|
-
|
56
|
-
)
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
pytorch_model,
|
61
|
-
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
|
-
quantize=_QUANTIZE.value,
|
64
|
-
)
|
41
|
+
p_model = moonshine.build_preprocessor(_CHECKPOINT_PATH.value)
|
42
|
+
output_filename = f'moonshine_preprocessor.tflite'
|
43
|
+
_input = torch.randn((1, 1, 159414), dtype=torch.float)
|
44
|
+
edge_model = ai_edge_torch.convert(p_model, (_input,), quant_config=None)
|
45
|
+
tflite_path = os.path.join(_TFLITE_PATH.value, output_filename)
|
46
|
+
edge_model.export(tflite_path)
|
65
47
|
|
66
48
|
|
67
49
|
if __name__ == '__main__':
|
@@ -0,0 +1,103 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building the Moonshine model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
from typing import Optional, Tuple
|
21
|
+
from absl import app
|
22
|
+
from ai_edge_torch.generative.layers import attention
|
23
|
+
from ai_edge_torch.generative.layers import builder
|
24
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
27
|
+
import ai_edge_torch.generative.layers.normalization as normalization
|
28
|
+
import ai_edge_torch.generative.utilities.moonshine_loader as loading_utils
|
29
|
+
import h5py
|
30
|
+
import torch
|
31
|
+
from torch import nn
|
32
|
+
import torch.nn as nn
|
33
|
+
|
34
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
35
|
+
conv1D_0="layers/sequential/layers/conv1d/vars",
|
36
|
+
conv1D_1="layers/sequential/layers/conv1d_1/vars",
|
37
|
+
conv1D_2="layers/sequential/layers/conv1d_2/vars",
|
38
|
+
group_norm="layers/sequential/layers/group_normalization/vars",
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
class AudioPreprocessor(nn.Module):
|
43
|
+
|
44
|
+
def __init__(self, dim):
|
45
|
+
super(AudioPreprocessor, self).__init__()
|
46
|
+
self.conv1 = nn.Conv1d(
|
47
|
+
in_channels=1, out_channels=dim, kernel_size=127, stride=64, bias=False
|
48
|
+
)
|
49
|
+
self.tanh = nn.Tanh()
|
50
|
+
self.group_norm = normalization.GroupNorm(group_num=1, dim=dim, eps=1e-5)
|
51
|
+
self.conv2 = nn.Conv1d(
|
52
|
+
in_channels=dim,
|
53
|
+
out_channels=2 * dim,
|
54
|
+
kernel_size=7,
|
55
|
+
stride=3,
|
56
|
+
padding=0, # Equivalent to padding="valid"
|
57
|
+
)
|
58
|
+
self.gelu1 = nn.GELU()
|
59
|
+
self.conv3 = nn.Conv1d(
|
60
|
+
in_channels=2 * dim,
|
61
|
+
out_channels=dim,
|
62
|
+
kernel_size=3,
|
63
|
+
stride=2,
|
64
|
+
padding=0, # Equivalent to padding="valid"
|
65
|
+
)
|
66
|
+
self.gelu2 = nn.GELU()
|
67
|
+
|
68
|
+
def forward(self, inputs):
|
69
|
+
x = self.conv1(inputs)
|
70
|
+
x = self.tanh(x)
|
71
|
+
x = self.group_norm(x)
|
72
|
+
x = self.conv2(x)
|
73
|
+
x = self.gelu1(x)
|
74
|
+
x = self.conv3(x)
|
75
|
+
x = self.gelu2(x)
|
76
|
+
return x
|
77
|
+
|
78
|
+
|
79
|
+
def build_preprocessor(checkpoint_path: str, **kwargs) -> nn.Module:
|
80
|
+
ap = AudioPreprocessor(dim=416)
|
81
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
82
|
+
loader.load(ap, strict=True)
|
83
|
+
return ap
|
84
|
+
|
85
|
+
|
86
|
+
def main(_):
|
87
|
+
# TODO(b/375421767) Remove golden checks once full model is implemented.
|
88
|
+
HF_PATH = os.path.join(pathlib.Path.home(), "Downloads/llm_data/moonshine")
|
89
|
+
|
90
|
+
test_data_path = pathlib.Path(__file__).parent.resolve()
|
91
|
+
INPUT_PATH = test_data_path / "data" / "pp_input.pt")
|
92
|
+
GOLDEN_PATH = test_data_path / "data" / "pp_output.pt")
|
93
|
+
|
94
|
+
ap = build_preprocessor(HF_PATH)
|
95
|
+
ap.eval()
|
96
|
+
inputs = torch.load(INPUT_PATH).reshape((1, 1, 159414))
|
97
|
+
out = ap(inputs)
|
98
|
+
golden = torch.load(GOLDEN_PATH).transpose(1, 2)
|
99
|
+
assert torch.allclose(out, golden, atol=1e-4, rtol=1e-4)
|
100
|
+
|
101
|
+
|
102
|
+
if __name__ == "__main__":
|
103
|
+
app.run(main)
|
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
|
37
|
-
'
|
38
|
-
1024,
|
39
|
-
'
|
36
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
+
'prefill_seq_lens',
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -55,11 +55,14 @@ def main(_):
|
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
57
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename =
|
58
|
+
output_filename = (
|
59
|
+
f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
|
+
)
|
61
|
+
|
59
62
|
converter.convert_to_tflite(
|
60
63
|
pytorch_model,
|
61
64
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=
|
65
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
66
|
quantize=_QUANTIZE.value,
|
64
67
|
)
|
65
68
|
|
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
|
37
|
-
'
|
38
|
-
1024,
|
39
|
-
'
|
36
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
+
'prefill_seq_lens',
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -55,11 +55,11 @@ def main(_):
|
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
57
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename = f'phi3_{quant_suffix}
|
58
|
+
output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
59
|
converter.convert_to_tflite(
|
60
60
|
pytorch_model,
|
61
61
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=
|
62
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
63
|
quantize=_QUANTIZE.value,
|
64
64
|
)
|
65
65
|
|
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
|
37
|
-
'
|
38
|
-
1024,
|
39
|
-
'
|
36
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
+
'prefill_seq_lens',
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -55,11 +55,11 @@ def main(_):
|
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
57
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename = f'phi2_{quant_suffix}
|
58
|
+
output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
59
|
converter.convert_to_tflite(
|
60
60
|
pytorch_model,
|
61
61
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=
|
62
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
63
|
quantize=_QUANTIZE.value,
|
64
64
|
)
|
65
65
|
|
@@ -39,10 +39,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
39
39
|
'/tmp/',
|
40
40
|
'The tflite file path to export.',
|
41
41
|
)
|
42
|
-
|
43
|
-
'
|
44
|
-
1024,
|
45
|
-
'
|
42
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
43
|
+
'prefill_seq_lens',
|
44
|
+
(8, 64, 128, 256, 512, 1024),
|
45
|
+
'List of the maximum sizes of prefill input tensors.',
|
46
46
|
)
|
47
47
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
48
|
'kv_cache_max_len',
|
@@ -68,11 +68,13 @@ def main(_):
|
|
68
68
|
)
|
69
69
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
70
70
|
model_size = _MODEL_SIZE.value.replace('.', '_')
|
71
|
-
output_filename =
|
71
|
+
output_filename = (
|
72
|
+
f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
73
|
+
)
|
72
74
|
converter.convert_to_tflite(
|
73
75
|
pytorch_model,
|
74
76
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
75
|
-
prefill_seq_len=
|
77
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
76
78
|
quantize=_QUANTIZE.value,
|
77
79
|
)
|
78
80
|
|
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
|
37
|
-
'
|
38
|
-
1024,
|
39
|
-
'
|
36
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
+
'prefill_seq_lens',
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -55,11 +55,11 @@ def main(_):
|
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
57
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename = f'smollm_{quant_suffix}
|
58
|
+
output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
59
|
converter.convert_to_tflite(
|
60
60
|
pytorch_model,
|
61
61
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=
|
62
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
63
|
quantize=_QUANTIZE.value,
|
64
64
|
)
|
65
65
|
|
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
|
37
|
-
'
|
38
|
-
1024,
|
39
|
-
'
|
36
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
+
'prefill_seq_lens',
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -55,11 +55,13 @@ def main(_):
|
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
57
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename =
|
58
|
+
output_filename = (
|
59
|
+
f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
|
+
)
|
59
61
|
converter.convert_to_tflite(
|
60
62
|
pytorch_model,
|
61
63
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=
|
64
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
65
|
quantize=_QUANTIZE.value,
|
64
66
|
)
|
65
67
|
|
@@ -91,6 +91,11 @@ class TestVerifyRecipes(parameterized.TestCase):
|
|
91
91
|
class TestQuantizeConvert(parameterized.TestCase):
|
92
92
|
"""Test conversion with quantization."""
|
93
93
|
|
94
|
+
def setUp(self):
|
95
|
+
super().setUp()
|
96
|
+
torch.manual_seed(0)
|
97
|
+
torch._dynamo.reset()
|
98
|
+
|
94
99
|
def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
95
100
|
return quant_config.QuantConfig(
|
96
101
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
@@ -0,0 +1,154 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
# Common utility functions for data loading etc.
|
16
|
+
from dataclasses import dataclass
|
17
|
+
import glob
|
18
|
+
import os
|
19
|
+
from typing import Callable, Dict
|
20
|
+
|
21
|
+
import h5py
|
22
|
+
import torch
|
23
|
+
|
24
|
+
|
25
|
+
def transpose_if_needed(t):
|
26
|
+
"""We assume the file is from Keras, i.e. channel last format."""
|
27
|
+
if len(t.shape) > 2:
|
28
|
+
return t.permute(2, 1, 0)
|
29
|
+
return t
|
30
|
+
|
31
|
+
|
32
|
+
def load_h5_statedict(full_path: str):
|
33
|
+
"""Loads the HDF5 DataSets into a single dctionary.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
full_path (string): the HDF5 filename or directory that contains the HDF5
|
37
|
+
files.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
A state dictionary contating loaded tensors.
|
41
|
+
|
42
|
+
Raises:
|
43
|
+
ValueError: If no tensors are loaded from the provided directory or file.
|
44
|
+
"""
|
45
|
+
pattern = (
|
46
|
+
os.path.join(full_path, "*.h5") if os.path.isdir(full_path) else full_path
|
47
|
+
)
|
48
|
+
files = []
|
49
|
+
for file in glob.glob(pattern):
|
50
|
+
files.append(file)
|
51
|
+
|
52
|
+
tensors = {}
|
53
|
+
|
54
|
+
def collect_datasets(name, obj):
|
55
|
+
if isinstance(obj, h5py.Dataset):
|
56
|
+
tensors[name] = transpose_if_needed(torch.from_numpy(obj[:]))
|
57
|
+
|
58
|
+
for file in files:
|
59
|
+
with h5py.File(file) as f:
|
60
|
+
f.visititems(collect_datasets)
|
61
|
+
|
62
|
+
if not tensors:
|
63
|
+
raise ValueError("Failed to load HDF5 file.")
|
64
|
+
return tensors
|
65
|
+
|
66
|
+
|
67
|
+
class ModelLoader:
|
68
|
+
"""Utility class for loading and converting checkpoints to ODML transformer layer format."""
|
69
|
+
|
70
|
+
@dataclass
|
71
|
+
class TensorNames:
|
72
|
+
conv1D_0: str = None
|
73
|
+
conv1D_1: str = None
|
74
|
+
conv1D_2: str = None
|
75
|
+
group_norm: str = None
|
76
|
+
|
77
|
+
def __init__(self, file_name: str, names: TensorNames) -> None:
|
78
|
+
"""ModelLoader constructor.
|
79
|
+
|
80
|
+
Can be used to load multiple models of the same type.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
file_name (str): Path to the checkpoint. Can be a directory or an exact
|
84
|
+
file.
|
85
|
+
names (TensorNames): An instance of `TensorNames` to determine mappings.
|
86
|
+
"""
|
87
|
+
self._file_name = file_name
|
88
|
+
self._names = names
|
89
|
+
self._loader = load_h5_statedict
|
90
|
+
|
91
|
+
def load(
|
92
|
+
self,
|
93
|
+
model: torch.nn.Module,
|
94
|
+
strict: bool = True,
|
95
|
+
):
|
96
|
+
"""Load the model from the checkpoint
|
97
|
+
|
98
|
+
Args:
|
99
|
+
model (torch.nn.Module): The pytorch model that needs to be loaded.
|
100
|
+
strict (bool, optional): Whether the converted keys are strictly
|
101
|
+
matched. Defaults to True.
|
102
|
+
|
103
|
+
Raises:
|
104
|
+
ValueError: If conversion results in unmapped tensors and strict mode is
|
105
|
+
enabled.
|
106
|
+
"""
|
107
|
+
state = self._loader(self._file_name)
|
108
|
+
|
109
|
+
if isinstance(self._names, ModelLoader.TensorNames):
|
110
|
+
converted_state = self._do_load(model, state, self._names)
|
111
|
+
else:
|
112
|
+
raise ValueError(f"Unkown type for names: {type(self._names)}")
|
113
|
+
|
114
|
+
if strict and state:
|
115
|
+
raise ValueError(
|
116
|
+
"Failed to map all tensor. Remaining tensor are:"
|
117
|
+
f" {list(state.keys())}"
|
118
|
+
)
|
119
|
+
model.load_state_dict(converted_state, strict=strict)
|
120
|
+
|
121
|
+
def _do_load(self, model, state, names, additional_prefix=""):
|
122
|
+
"""Load the model from the checkpoint
|
123
|
+
|
124
|
+
Args:
|
125
|
+
model (torch.nn.Module): The pytorch model that needs to be loaded.
|
126
|
+
state (Dict[str, torch.Tensor]): The pytorch state dictionary
|
127
|
+
names (TensorNames]): The TensorNames for the model we are loading.
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
Dict[str, torch.Tensor]: Map of name to tensor for loading.
|
131
|
+
"""
|
132
|
+
converted_state = dict()
|
133
|
+
if names.conv1D_0 is not None:
|
134
|
+
converted_state["conv1.weight"] = state.pop(f"{names.conv1D_0}/0")
|
135
|
+
if f"{names.conv1D_0}/1" in state:
|
136
|
+
converted_state["conv1.bias"] = state.pop(f"{names.conv1D_0}/1")
|
137
|
+
|
138
|
+
if names.conv1D_1 is not None:
|
139
|
+
converted_state["conv2.weight"] = state.pop(f"{names.conv1D_1}/0")
|
140
|
+
if f"{names.conv1D_1}/1" in state:
|
141
|
+
converted_state["conv2.bias"] = state.pop(f"{names.conv1D_1}/1")
|
142
|
+
|
143
|
+
if names.conv1D_2 is not None:
|
144
|
+
converted_state["conv3.weight"] = state.pop(f"{names.conv1D_2}/0")
|
145
|
+
if f"{names.conv1D_2}/1" in state:
|
146
|
+
converted_state["conv3.bias"] = state.pop(f"{names.conv1D_2}/1")
|
147
|
+
|
148
|
+
if names.group_norm is not None:
|
149
|
+
group_norm_name = names.group_norm
|
150
|
+
converted_state[f"group_norm.weight"] = state.pop(f"{group_norm_name}/0")
|
151
|
+
if f"{group_norm_name}/1" in state:
|
152
|
+
converted_state["group_norm.bias"] = state.pop(f"{group_norm_name}/1")
|
153
|
+
|
154
|
+
return converted_state
|
@@ -35,9 +35,7 @@ from . import lowerings
|
|
35
35
|
LoweringContext = lowerings.context.LoweringContext
|
36
36
|
|
37
37
|
|
38
|
-
def _build_flat_inputs(
|
39
|
-
ctx: ir.Context, exported_program: torch.export.ExportedProgram
|
40
|
-
):
|
38
|
+
def _build_flat_inputs(exported_program: torch.export.ExportedProgram):
|
41
39
|
"""Build flattened inputs and metadata from exported program's signature."""
|
42
40
|
placeholder_nodes = [
|
43
41
|
n for n in exported_program.graph.nodes if n.op == "placeholder"
|
@@ -49,9 +47,11 @@ def _build_flat_inputs(
|
|
49
47
|
ir_inputs = []
|
50
48
|
tensor_metas = []
|
51
49
|
for node, arg in zip(placeholder_nodes, export_flat_args):
|
52
|
-
tensor_meta = node.meta.get("tensor_meta")
|
50
|
+
tensor_meta = node.meta.get("tensor_meta") or node.meta.get("val")
|
53
51
|
if tensor_meta is None:
|
54
|
-
raise RuntimeError(
|
52
|
+
raise RuntimeError(
|
53
|
+
f"{type(arg)} (for {node.name}) does not have tensor meta"
|
54
|
+
)
|
55
55
|
|
56
56
|
tensor_metas.append(tensor_meta)
|
57
57
|
# Assume all dynamic dimensions are unbounded.
|
@@ -63,7 +63,7 @@ def _build_flat_inputs(
|
|
63
63
|
ir_inputs.append(
|
64
64
|
ir.RankedTensorType.get(
|
65
65
|
shape,
|
66
|
-
export_utils.torch_dtype_to_ir_element_type(
|
66
|
+
export_utils.torch_dtype_to_ir_element_type(tensor_meta.dtype),
|
67
67
|
)
|
68
68
|
)
|
69
69
|
return tuple(ir_inputs), tuple(export_flat_args), tuple(tensor_metas)
|
@@ -258,6 +258,43 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
|
|
258
258
|
rewrite_arange(node)
|
259
259
|
|
260
260
|
|
261
|
+
# TODO(b/331481564) Make this a ai_edge_torch FX pass.
|
262
|
+
def _convert_q_dq_per_channel_args_to_list(
|
263
|
+
exported_program: torch.export.ExportedProgram,
|
264
|
+
):
|
265
|
+
"""Resolve tensor inputs to Q/DQ ops as static number list for lowering.
|
266
|
+
|
267
|
+
This pass makes the ExportedProgram in a non-executable state. This pass must
|
268
|
+
be run after all run_decompositions calls.
|
269
|
+
"""
|
270
|
+
placeholder_nodes = [
|
271
|
+
n for n in exported_program.graph.nodes if n.op == "placeholder"
|
272
|
+
]
|
273
|
+
export_flat_args = _torch_future.graph_module_flat_inputs(
|
274
|
+
exported_program, *exported_program.example_inputs
|
275
|
+
)
|
276
|
+
|
277
|
+
placeholder_tensor = {
|
278
|
+
n: tensor for n, tensor in zip(placeholder_nodes, export_flat_args)
|
279
|
+
}
|
280
|
+
|
281
|
+
graph_module = exported_program.graph_module
|
282
|
+
for node in graph_module.graph.nodes:
|
283
|
+
if node.target in (
|
284
|
+
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
285
|
+
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
286
|
+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
287
|
+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
288
|
+
):
|
289
|
+
input, scale_node, zero_point_node = node.args[:3]
|
290
|
+
scale = placeholder_tensor[scale_node]
|
291
|
+
zero_point = placeholder_tensor[zero_point_node]
|
292
|
+
|
293
|
+
scale = scale.detach().numpy().tolist()
|
294
|
+
zero_point = zero_point.detach().numpy().tolist()
|
295
|
+
node.args = (input, scale, zero_point, *node.args[3:])
|
296
|
+
|
297
|
+
|
261
298
|
def exported_program_to_mlir(
|
262
299
|
exported_program: torch.export.ExportedProgram,
|
263
300
|
) -> MlirLowered:
|
@@ -270,6 +307,7 @@ def exported_program_to_mlir(
|
|
270
307
|
exported_program = _torch_future.safe_run_decompositions(
|
271
308
|
exported_program, lowerings.decompositions()
|
272
309
|
)
|
310
|
+
_convert_q_dq_per_channel_args_to_list(exported_program)
|
273
311
|
|
274
312
|
with export_utils.create_ir_context() as context, ir.Location.unknown():
|
275
313
|
|
@@ -277,7 +315,7 @@ def exported_program_to_mlir(
|
|
277
315
|
lctx = LoweringContext(context, module)
|
278
316
|
interpreter = LoweringInterpreter(exported_program.graph_module, lctx)
|
279
317
|
ir_flat_inputs, export_flat_args, tensor_metas = _build_flat_inputs(
|
280
|
-
|
318
|
+
exported_program
|
281
319
|
)
|
282
320
|
|
283
321
|
# HACK: OSS MLIR pybinding could mysteriously transform func.func under
|
@@ -14,9 +14,9 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Utilities for ODML Torch export."""
|
16
16
|
|
17
|
-
import functools
|
18
17
|
import re
|
19
18
|
from typing import Sequence, cast
|
19
|
+
from ai_edge_torch.odml_torch.lowerings import utils as lowering_utils
|
20
20
|
import jax._src.interpreters.mlir
|
21
21
|
from jax._src.lib.mlir import ir
|
22
22
|
from jax._src.lib.mlir.dialects import func
|
@@ -47,7 +47,6 @@ def create_ir_context():
|
|
47
47
|
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
|
48
48
|
context = jax._src.interpreters.mlir.make_ir_context()
|
49
49
|
context.allow_unregistered_dialects = True
|
50
|
-
|
51
50
|
return context
|
52
51
|
|
53
52
|
|
@@ -135,17 +134,7 @@ def build_ir_attr(val):
|
|
135
134
|
return ir.StringAttr.get(str(val))
|
136
135
|
|
137
136
|
|
138
|
-
|
139
|
-
ty_get = {
|
140
|
-
torch.double: ir.F64Type.get,
|
141
|
-
torch.float32: ir.F32Type.get,
|
142
|
-
torch.half: ir.F16Type.get,
|
143
|
-
torch.long: functools.partial(ir.IntegerType.get_signless, 64),
|
144
|
-
torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
|
145
|
-
torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
|
146
|
-
torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
|
147
|
-
}.get(dtype)
|
148
|
-
return ty_get(ctx)
|
137
|
+
torch_dtype_to_ir_element_type = lowering_utils.torch_dtype_to_ir_element_type
|
149
138
|
|
150
139
|
|
151
140
|
def ir_element_type_to_torch_dtype(ty):
|
@@ -163,9 +163,7 @@ def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
|
|
163
163
|
if aval is None:
|
164
164
|
return result
|
165
165
|
|
166
|
-
target_elty = export_utils.torch_dtype_to_ir_element_type(
|
167
|
-
lctx.ir_context, aval.dtype
|
168
|
-
)
|
166
|
+
target_elty = export_utils.torch_dtype_to_ir_element_type(aval.dtype)
|
169
167
|
if result.type.element_type == target_elty:
|
170
168
|
return result
|
171
169
|
return stablehlo.convert(
|
@@ -227,9 +227,7 @@ def _aten_cat(lctx: LoweringContext, tensors, dim=0):
|
|
227
227
|
if not non_empty_tensors:
|
228
228
|
return utils.splat(
|
229
229
|
0,
|
230
|
-
export_utils.torch_dtype_to_ir_element_type(
|
231
|
-
lctx.ir_context, out_aval.dtype
|
232
|
-
),
|
230
|
+
export_utils.torch_dtype_to_ir_element_type(out_aval.dtype),
|
233
231
|
out_aval.shape,
|
234
232
|
)
|
235
233
|
|
@@ -0,0 +1,174 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Lowerings for PT2E torch.ops.quantized_decomposed ops."""
|
16
|
+
from typing import Union, cast
|
17
|
+
|
18
|
+
from ai_edge_torch.odml_torch.lowerings import context
|
19
|
+
from ai_edge_torch.odml_torch.lowerings import utils
|
20
|
+
from jax._src.lib.mlir import ir
|
21
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
22
|
+
import torch
|
23
|
+
import torch.ao.quantization.fx._decomposed
|
24
|
+
import torch.utils._pytree as pytree
|
25
|
+
|
26
|
+
from . import registry
|
27
|
+
|
28
|
+
lower = registry.lower
|
29
|
+
LoweringContext = context.LoweringContext
|
30
|
+
|
31
|
+
|
32
|
+
def _uniform_quantized_type(
|
33
|
+
stored_type: str | ir.Type,
|
34
|
+
expressed_type: str | ir.Type,
|
35
|
+
*,
|
36
|
+
scale=float | list[float] | tuple[float],
|
37
|
+
zero_point=float | list[float] | tuple[float],
|
38
|
+
storage_type_min: int | None = None,
|
39
|
+
storage_type_max: int | None = None,
|
40
|
+
channel_axis: int | None = None,
|
41
|
+
channel_axis_size: int | None = None,
|
42
|
+
):
|
43
|
+
"""Polyfill for quant.UniformQuantizedType."""
|
44
|
+
if storage_type_min and storage_type_max:
|
45
|
+
storage_min_max = f"<{storage_type_min}:{storage_type_max}>"
|
46
|
+
else:
|
47
|
+
storage_min_max = ""
|
48
|
+
|
49
|
+
if channel_axis is not None:
|
50
|
+
# Per-channel quantization
|
51
|
+
# https://mlir.llvm.org/docs/Dialects/QuantDialect/#per-channel-quantization
|
52
|
+
assert isinstance(scale, (list, tuple))
|
53
|
+
assert isinstance(zero_point, (list, tuple))
|
54
|
+
|
55
|
+
if len(scale) == 1:
|
56
|
+
scale *= channel_axis_size
|
57
|
+
if len(zero_point) == 1:
|
58
|
+
zero_point *= channel_axis_size
|
59
|
+
|
60
|
+
assert len(scale) == len(zero_point) == channel_axis_size
|
61
|
+
scale_zp_strs = []
|
62
|
+
for s, zp in zip(scale, zero_point):
|
63
|
+
scale_zp_strs.append(f"{s}:{zp}")
|
64
|
+
scale_zp = "{" + ",".join(scale_zp_strs) + "}"
|
65
|
+
return ir.Type.parse(
|
66
|
+
f"!quant.uniform<{stored_type}{storage_min_max}:{expressed_type}:{channel_axis},{scale_zp}>"
|
67
|
+
)
|
68
|
+
else:
|
69
|
+
# Per-layer quantization
|
70
|
+
# https://mlir.llvm.org/docs/Dialects/QuantDialect/#per-layer-quantization
|
71
|
+
scale = pytree.tree_flatten([scale])[0][-1]
|
72
|
+
zero_point = pytree.tree_flatten([zero_point])[0][-1]
|
73
|
+
scale_zp = f"{scale}:{zero_point}"
|
74
|
+
return ir.Type.parse(
|
75
|
+
f"!quant.uniform<{stored_type}{storage_min_max}:{expressed_type},{scale_zp}>"
|
76
|
+
)
|
77
|
+
|
78
|
+
|
79
|
+
# Quant dialect is not registered in the Python MLIR pybinding used by
|
80
|
+
# odml-torch. Therefore, stablehlo.uniform_quantize/uniform_dequantize ops and
|
81
|
+
# quant types are represented in stablehlo.custom_call to pass MLIR verification
|
82
|
+
# and VHLO serialization before converter.
|
83
|
+
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
|
84
|
+
|
85
|
+
|
86
|
+
# Schema:
|
87
|
+
# - quantized_decomposed::quantize_per_tensor(Tensor input, float scale,
|
88
|
+
# int zero_point, int quant_min, int quant_max,
|
89
|
+
# ScalarType dtype) -> Tensor
|
90
|
+
# - quantized_decomposed::quantize_per_tensor.tensor(Tensor input,
|
91
|
+
# Tensor scale, Tensor zero_point, int quant_min, int quant_max,
|
92
|
+
# ScalarType dtype) -> Tensor
|
93
|
+
#
|
94
|
+
# Scale and zero_point in tensors are automatically converted to list before
|
95
|
+
# lowering.
|
96
|
+
@lower(torch.ops.quantized_decomposed.quantize_per_tensor)
|
97
|
+
def _quantize_per_tensor(
|
98
|
+
lctx: LoweringContext,
|
99
|
+
input: ir.Value,
|
100
|
+
scale: Union[float, list[float]],
|
101
|
+
zero_point: Union[float, list[float]],
|
102
|
+
quant_min: int,
|
103
|
+
quant_max: int,
|
104
|
+
dtype: torch.dtype,
|
105
|
+
):
|
106
|
+
input_ty = cast(ir.RankedTensorType, input.type)
|
107
|
+
qty = _uniform_quantized_type(
|
108
|
+
utils.torch_dtype_to_ir_element_type(dtype),
|
109
|
+
input_ty.element_type,
|
110
|
+
scale=scale,
|
111
|
+
zero_point=zero_point,
|
112
|
+
storage_type_min=quant_min,
|
113
|
+
storage_type_max=quant_max,
|
114
|
+
)
|
115
|
+
return stablehlo.custom_call(
|
116
|
+
call_target_name="odml_torch.uniform_quantize",
|
117
|
+
inputs=[input],
|
118
|
+
result=[input_ty],
|
119
|
+
backend_config=ir.StringAttr.get(
|
120
|
+
str(ir.RankedTensorType.get(input_ty.shape, qty))
|
121
|
+
),
|
122
|
+
)
|
123
|
+
|
124
|
+
|
125
|
+
# Schema:
|
126
|
+
# - quantized_decomposed::quantize_per_channel(Tensor input, Tensor scales,
|
127
|
+
# Tensor zero_points, int axis, int quant_min, int quant_max,
|
128
|
+
# ScalarType dtype) -> Tensor
|
129
|
+
#
|
130
|
+
# Scale and zero_point in tensors are automatically converted to list before
|
131
|
+
# lowering.
|
132
|
+
@lower(torch.ops.quantized_decomposed.quantize_per_channel)
|
133
|
+
def _quantize_per_channel(
|
134
|
+
lctx: LoweringContext,
|
135
|
+
input: ir.Value,
|
136
|
+
scale: list[float],
|
137
|
+
zero_point: list[float],
|
138
|
+
axis: int,
|
139
|
+
quant_min: int,
|
140
|
+
quant_max: int,
|
141
|
+
dtype: torch.dtype,
|
142
|
+
):
|
143
|
+
input_ty = cast(ir.RankedTensorType, input.type)
|
144
|
+
qty = _uniform_quantized_type(
|
145
|
+
utils.torch_dtype_to_ir_element_type(dtype),
|
146
|
+
input_ty.element_type,
|
147
|
+
scale=scale,
|
148
|
+
zero_point=zero_point,
|
149
|
+
channel_axis=axis,
|
150
|
+
channel_axis_size=input_ty.shape[axis],
|
151
|
+
storage_type_min=quant_min,
|
152
|
+
storage_type_max=quant_max,
|
153
|
+
)
|
154
|
+
return stablehlo.custom_call(
|
155
|
+
call_target_name="odml_torch.uniform_quantize",
|
156
|
+
inputs=[input],
|
157
|
+
result=[input_ty],
|
158
|
+
backend_config=ir.StringAttr.get(
|
159
|
+
str(ir.RankedTensorType.get(input_ty.shape, qty))
|
160
|
+
),
|
161
|
+
)
|
162
|
+
|
163
|
+
|
164
|
+
@lower(torch.ops.quantized_decomposed.dequantize_per_tensor)
|
165
|
+
@lower(torch.ops.quantized_decomposed.dequantize_per_channel)
|
166
|
+
def _dequantize(lctx: LoweringContext, input: ir.Value, *args, **kwargs):
|
167
|
+
result_meta = lctx.node.meta.get("tensor_meta")
|
168
|
+
result_elty = utils.torch_dtype_to_ir_element_type(result_meta.dtype)
|
169
|
+
|
170
|
+
return stablehlo.custom_call(
|
171
|
+
call_target_name="odml_torch.uniform_dequantize",
|
172
|
+
inputs=[input],
|
173
|
+
result=[ir.RankedTensorType.get(result_meta.shape, result_elty)],
|
174
|
+
)
|
@@ -14,6 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Utilities for building MLIR lowerings."""
|
16
16
|
|
17
|
+
import functools
|
17
18
|
import numbers
|
18
19
|
from typing import Any
|
19
20
|
from typing import Optional
|
@@ -21,6 +22,21 @@ from typing import Optional
|
|
21
22
|
from jax._src.lib.mlir import ir
|
22
23
|
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
23
24
|
import numpy as np
|
25
|
+
import torch
|
26
|
+
|
27
|
+
|
28
|
+
def torch_dtype_to_ir_element_type(dtype):
|
29
|
+
ty_get = {
|
30
|
+
torch.double: ir.F64Type.get,
|
31
|
+
torch.float32: ir.F32Type.get,
|
32
|
+
torch.half: ir.F16Type.get,
|
33
|
+
torch.long: functools.partial(ir.IntegerType.get_signless, 64),
|
34
|
+
torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
|
35
|
+
torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
|
36
|
+
torch.int8: functools.partial(ir.IntegerType.get_signless, 8),
|
37
|
+
torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
|
38
|
+
}[dtype]
|
39
|
+
return ty_get()
|
24
40
|
|
25
41
|
|
26
42
|
def splat(val, ty, shape=tuple(), *, loc: Optional[Any] = None):
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241204
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=IfBDOY7eb9sSynpfr3Qsw88QCzs0DDWN_jB1B6zi5ss,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -26,7 +26,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
26
26
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
27
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
28
28
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
29
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
29
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=v6AhfWqRBuHT7uBDueTbntaQtDSMMrvQOqlIDXNUaMA,17250
|
30
30
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
31
31
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
32
32
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
@@ -44,20 +44,22 @@ ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=bkq2Zk
|
|
44
44
|
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=-n79r6yFnCACpms5eMkXNpyQsCn2PYVRdB-jOoIqn14,2227
|
45
45
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
|
46
46
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=
|
48
|
-
ai_edge_torch/generative/examples/gemma/
|
49
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
|
47
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=mrG96_WEGD4NQ4uFEKrHRMAQvVVliOcj1zbI3drGDjI,2199
|
48
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=I_tvwCYmtf08D1HqDxYx7dpvj2q5_eaYnuI_3rI6Dlw,2201
|
50
49
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
|
51
50
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
|
52
51
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
53
52
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
54
53
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
55
54
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
56
|
-
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=
|
55
|
+
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=Brb83sbqBfStUiIZFhfWnYtN7LcNmkKyFn96cZK4sGo,2426
|
57
56
|
ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaXEJ3RLLwkHii7to7UhVo,6854
|
58
57
|
ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
|
58
|
+
ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
59
|
+
ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
|
60
|
+
ai_edge_torch/generative/examples/moonshine/moonshine.py,sha256=nZ2b8u4TmsB5sgdClgAuH8E78bcTv9RCnF9666HqP2M,3394
|
59
61
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
60
|
-
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256
|
62
|
+
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=-qDBu3bjUq0jx73SPDMsPIBP0BT1nA_0UgtFKeSuM18,2213
|
61
63
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
|
62
64
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
63
65
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -69,18 +71,18 @@ ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM
|
|
69
71
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
70
72
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
|
71
73
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
72
|
-
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=
|
73
|
-
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=
|
74
|
+
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=ruY-LLwpqBqVZ5z9h_sewYj04ukWRG4j804tUAyDdnA,2186
|
75
|
+
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=UdMk1SSkcWpv8gosUylx3JSCxdOJBjZNhuQQtT4-Ono,2184
|
74
76
|
ai_edge_torch/generative/examples/phi/phi2.py,sha256=nbivDwZREd-sypy_ittO59-yaAdPvHv1YEV6Fo5buCo,3341
|
75
77
|
ai_edge_torch/generative/examples/phi/phi3.py,sha256=GkHOaYfsFEbHvfZCaLlb3Us_h19ezqPDUakoz_DiG9A,7123
|
76
78
|
ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
|
77
79
|
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
|
78
80
|
ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
79
|
-
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=
|
81
|
+
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=1M3DTkf536TCLYcQB1lu-3TxQ6mV03dFhTdbk0p8i84,2523
|
80
82
|
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=oYm9hhALUQ4uOn-PO1bF7fCIGP8EWRNK4zClkx2RQs8,4070
|
81
83
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
82
84
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
83
|
-
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=
|
85
|
+
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=56CzCjyp9xh_2ZtXKN9tlEv6GayeSc4giTIZsi2Q59E,2194
|
84
86
|
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=M5qAcSUE5gxOSfq24a8lZku9kgvmlFCyIBar3kF2XEk,2570
|
85
87
|
ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
|
86
88
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -107,7 +109,7 @@ ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNH
|
|
107
109
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
|
108
110
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=rRodLr-hEqAs_-8x06O8qO-hJ_cqr2AfhJZ9DCptvwo,4332
|
109
111
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
110
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=
|
112
|
+
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=WmEshoN9HgNLbV7NTjxdqWz9Olcim6r_vo4R9eYE98I,2228
|
111
113
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=10X8HwPx4akzclnIMOBNItKQemhRbvxBbTo7nwZtWjM,2650
|
112
114
|
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
|
113
115
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
@@ -139,13 +141,14 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FP
|
|
139
141
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
140
142
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=aZFaheg2sq7rEccch1TZM6W4BSfpJZjrM9Gyp4hVGYs,6351
|
141
143
|
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=xWV9O2wuRHc4VNBWuWipiuqXa3AJhiV1nmjewAZHHWM,11177
|
142
|
-
ai_edge_torch/generative/test/test_quantize.py,sha256=
|
144
|
+
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
143
145
|
ai_edge_torch/generative/test/utils.py,sha256=eQ-hjd1eXuHJF3SJK6_CrjgOZVzmG_4VEdH7Z1gH_lA,1897
|
144
146
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
145
147
|
ai_edge_torch/generative/utilities/converter.py,sha256=S14STbyxV6A9HKy1BdUo49f2jS6Ij0RL9mVAFUMWYV8,5291
|
146
148
|
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
147
149
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
148
150
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
|
151
|
+
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
149
152
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
150
153
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
151
154
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
@@ -167,8 +170,8 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1
|
|
167
170
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
168
171
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
|
169
172
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
170
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
171
|
-
ai_edge_torch/odml_torch/export_utils.py,sha256=
|
173
|
+
ai_edge_torch/odml_torch/export.py,sha256=dgnNGBVkHBz0brlWALX2hGXpQ4YzCKdwbkF4oAfEu4I,13062
|
174
|
+
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
172
175
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
|
173
176
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
174
177
|
ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
|
@@ -177,17 +180,18 @@ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu
|
|
177
180
|
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
|
178
181
|
ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
|
179
182
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
|
180
|
-
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=
|
183
|
+
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
181
184
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
182
|
-
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=
|
183
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
185
|
+
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=1lMKPoStK3SUA8yYTPZBRhESN33BghGXnfqOOg4oeVk,995
|
186
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=ufvnaAh6rM_yfoc8ybI3VErHEVBv5W_p4iOe9slfwKM,9948
|
184
187
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
185
188
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
|
186
189
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=4UyNyaR2W-vCOvj-P5lywQ1_RfLIxVE7J_GONI6CQvI,10718
|
187
190
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
|
191
|
+
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=rFmzqcdjYrwhcxH8j9zCFStPy21HFF7hkUV_GQ8FPAk,6056
|
188
192
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
189
193
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=itTt8MLbq2LoHTzRidCF2TTbh0TP7L836u99qCjP3FA,2953
|
190
|
-
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=
|
194
|
+
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
|
191
195
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
192
196
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
193
197
|
ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
|
@@ -196,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
196
200
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
197
201
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
198
202
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
199
|
-
ai_edge_torch_nightly-0.3.0.
|
200
|
-
ai_edge_torch_nightly-0.3.0.
|
201
|
-
ai_edge_torch_nightly-0.3.0.
|
202
|
-
ai_edge_torch_nightly-0.3.0.
|
203
|
-
ai_edge_torch_nightly-0.3.0.
|
203
|
+
ai_edge_torch_nightly-0.3.0.dev20241204.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
204
|
+
ai_edge_torch_nightly-0.3.0.dev20241204.dist-info/METADATA,sha256=jcbPyL8PaZxa79FSq_FwD0sfu-2R_nTRMDJiPNmJXFM,1897
|
205
|
+
ai_edge_torch_nightly-0.3.0.dev20241204.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20241204.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20241204.dist-info/RECORD,,
|
File without changes
|
File without changes
|