ai-edge-torch-nightly 0.3.0.dev20241129__py3-none-any.whl → 0.3.0.dev20241204__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_convert/test/test_convert.py +48 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +6 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +6 -6
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +6 -6
- ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
- ai_edge_torch/generative/examples/{gemma/convert_gemma2_multi_prefills.py → moonshine/convert_moonshine_to_tflite.py} +11 -29
- ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +9 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +6 -6
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +6 -6
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +8 -6
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +6 -6
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +8 -6
- ai_edge_torch/generative/test/test_quantize.py +5 -0
- ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
- ai_edge_torch/odml_torch/export.py +45 -7
- ai_edge_torch/odml_torch/export_utils.py +2 -13
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +1 -3
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +1 -3
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +174 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +16 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241129.dist-info → ai_edge_torch_nightly-0.3.0.dev20241204.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241129.dist-info → ai_edge_torch_nightly-0.3.0.dev20241204.dist-info}/RECORD +28 -24
- {ai_edge_torch_nightly-0.3.0.dev20241129.dist-info → ai_edge_torch_nightly-0.3.0.dev20241204.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241129.dist-info → ai_edge_torch_nightly-0.3.0.dev20241204.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241129.dist-info → ai_edge_torch_nightly-0.3.0.dev20241204.dist-info}/top_level.txt +0 -0
@@ -21,10 +21,12 @@ from typing import Tuple
|
|
21
21
|
import ai_edge_torch
|
22
22
|
from ai_edge_torch import config
|
23
23
|
from ai_edge_torch._convert import conversion_utils
|
24
|
+
from ai_edge_torch.quantize import pt2e_quantizer
|
24
25
|
from ai_edge_torch.testing import model_coverage
|
25
26
|
import numpy as np
|
26
27
|
import torch
|
27
28
|
from torch import nn
|
29
|
+
from torch.ao.quantization import quantize_pt2e
|
28
30
|
import torchvision
|
29
31
|
|
30
32
|
from absl.testing import absltest as googletest
|
@@ -506,6 +508,52 @@ class TestConvert(googletest.TestCase):
|
|
506
508
|
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
|
507
509
|
)
|
508
510
|
|
511
|
+
def test_convert_resnet18_pt2e_per_layer(self):
|
512
|
+
# Step 1: export resnet18
|
513
|
+
args = (torch.randn(1, 3, 224, 224),)
|
514
|
+
m = torchvision.models.resnet18().eval()
|
515
|
+
m = torch._export.capture_pre_autograd_graph(m, args)
|
516
|
+
|
517
|
+
# Step 2: Insert observers or fake quantize modules
|
518
|
+
quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
|
519
|
+
pt2e_quantizer.get_symmetric_quantization_config(is_per_channel=False)
|
520
|
+
)
|
521
|
+
m = quantize_pt2e.prepare_pt2e(m, quantizer)
|
522
|
+
|
523
|
+
# Step 3: Quantize the model
|
524
|
+
m = quantize_pt2e.convert_pt2e(m, fold_quantize=False)
|
525
|
+
|
526
|
+
# pylint: disable=broad-except
|
527
|
+
try:
|
528
|
+
ai_edge_torch.convert(m, args)
|
529
|
+
except Exception as err:
|
530
|
+
self.fail(f"PT2E conversion failed: {err}")
|
531
|
+
# pylint: enable=broad-except
|
532
|
+
|
533
|
+
def test_convert_resnet18_pt2e_per_channel(self):
|
534
|
+
# Step 1: export resnet18
|
535
|
+
args = (torch.randn(1, 3, 224, 224),)
|
536
|
+
m = torchvision.models.resnet18().eval()
|
537
|
+
m = torch._export.capture_pre_autograd_graph(m, args)
|
538
|
+
|
539
|
+
# Step 2: Insert observers or fake quantize modules
|
540
|
+
quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
|
541
|
+
pt2e_quantizer.get_symmetric_quantization_config(is_per_channel=True)
|
542
|
+
)
|
543
|
+
m = quantize_pt2e.prepare_pt2e(m, quantizer)
|
544
|
+
# Step 3: Run through example inputs, otherwise per-channel
|
545
|
+
# quant may have scalar scale/zero_point
|
546
|
+
m(*args)
|
547
|
+
# Step 4: Quantize the model
|
548
|
+
m = quantize_pt2e.convert_pt2e(m, fold_quantize=False)
|
549
|
+
|
550
|
+
# pylint: disable=broad-except
|
551
|
+
try:
|
552
|
+
ai_edge_torch.convert(m, args)
|
553
|
+
except Exception as err:
|
554
|
+
self.fail(f"PT2E conversion failed: {err}")
|
555
|
+
# pylint: enable=broad-except
|
556
|
+
|
509
557
|
|
510
558
|
if __name__ == "__main__":
|
511
559
|
googletest.main()
|
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
|
37
|
-
'
|
38
|
-
1024,
|
39
|
-
'
|
36
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
+
'prefill_seq_lens',
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -55,11 +55,11 @@ def main(_):
|
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
57
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename = f'gemma_{quant_suffix}
|
58
|
+
output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
59
|
converter.convert_to_tflite(
|
60
60
|
pytorch_model,
|
61
61
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=
|
62
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
63
|
quantize=_QUANTIZE.value,
|
64
64
|
)
|
65
65
|
|
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
|
37
|
-
'
|
38
|
-
1024,
|
39
|
-
'
|
36
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
+
'prefill_seq_lens',
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -55,11 +55,11 @@ def main(_):
|
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
57
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename = f'gemma2_{quant_suffix}
|
58
|
+
output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
59
|
converter.convert_to_tflite(
|
60
60
|
pytorch_model,
|
61
61
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=
|
62
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
63
|
quantize=_QUANTIZE.value,
|
64
64
|
)
|
65
65
|
|
@@ -39,10 +39,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
39
39
|
'/tmp/',
|
40
40
|
'The tflite file path to export.',
|
41
41
|
)
|
42
|
-
|
43
|
-
'
|
44
|
-
1024,
|
45
|
-
'
|
42
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
43
|
+
'prefill_seq_lens',
|
44
|
+
(8, 64, 128, 256, 512, 1024),
|
45
|
+
'List of the maximum sizes of prefill input tensors.',
|
46
46
|
)
|
47
47
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
48
|
'kv_cache_max_len',
|
@@ -66,11 +66,11 @@ def main(_):
|
|
66
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
67
67
|
)
|
68
68
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
69
|
-
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}
|
69
|
+
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
70
70
|
converter.convert_to_tflite(
|
71
71
|
pytorch_model,
|
72
72
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
73
|
-
prefill_seq_len=
|
73
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
74
74
|
quantize=_QUANTIZE.value,
|
75
75
|
)
|
76
76
|
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -13,19 +13,21 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""Example
|
16
|
+
"""Example of converting a Moonshine model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
21
|
from absl import app
|
22
22
|
from absl import flags
|
23
|
-
|
23
|
+
import ai_edge_torch
|
24
|
+
from ai_edge_torch.generative.examples.moonshine import moonshine
|
24
25
|
from ai_edge_torch.generative.utilities import converter
|
26
|
+
import torch
|
25
27
|
|
26
28
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
29
|
'checkpoint_path',
|
28
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
30
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/moonshine'),
|
29
31
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
32
|
)
|
31
33
|
_TFLITE_PATH = flags.DEFINE_string(
|
@@ -33,35 +35,15 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
35
|
'/tmp/',
|
34
36
|
'The tflite file path to export.',
|
35
37
|
)
|
36
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
-
'prefill_seq_lens',
|
38
|
-
(8, 64, 128, 256, 512, 1024),
|
39
|
-
'List of the maximum sizes of prefill input tensors.',
|
40
|
-
)
|
41
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
-
'kv_cache_max_len',
|
43
|
-
1280,
|
44
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
-
)
|
46
|
-
_QUANTIZE = flags.DEFINE_bool(
|
47
|
-
'quantize',
|
48
|
-
True,
|
49
|
-
'Whether the model should be quantized.',
|
50
|
-
)
|
51
38
|
|
52
39
|
|
53
40
|
def main(_):
|
54
|
-
|
55
|
-
|
56
|
-
)
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
pytorch_model,
|
61
|
-
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
|
-
quantize=_QUANTIZE.value,
|
64
|
-
)
|
41
|
+
p_model = moonshine.build_preprocessor(_CHECKPOINT_PATH.value)
|
42
|
+
output_filename = f'moonshine_preprocessor.tflite'
|
43
|
+
_input = torch.randn((1, 1, 159414), dtype=torch.float)
|
44
|
+
edge_model = ai_edge_torch.convert(p_model, (_input,), quant_config=None)
|
45
|
+
tflite_path = os.path.join(_TFLITE_PATH.value, output_filename)
|
46
|
+
edge_model.export(tflite_path)
|
65
47
|
|
66
48
|
|
67
49
|
if __name__ == '__main__':
|
@@ -0,0 +1,103 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building the Moonshine model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
from typing import Optional, Tuple
|
21
|
+
from absl import app
|
22
|
+
from ai_edge_torch.generative.layers import attention
|
23
|
+
from ai_edge_torch.generative.layers import builder
|
24
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
27
|
+
import ai_edge_torch.generative.layers.normalization as normalization
|
28
|
+
import ai_edge_torch.generative.utilities.moonshine_loader as loading_utils
|
29
|
+
import h5py
|
30
|
+
import torch
|
31
|
+
from torch import nn
|
32
|
+
import torch.nn as nn
|
33
|
+
|
34
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
35
|
+
conv1D_0="layers/sequential/layers/conv1d/vars",
|
36
|
+
conv1D_1="layers/sequential/layers/conv1d_1/vars",
|
37
|
+
conv1D_2="layers/sequential/layers/conv1d_2/vars",
|
38
|
+
group_norm="layers/sequential/layers/group_normalization/vars",
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
class AudioPreprocessor(nn.Module):
|
43
|
+
|
44
|
+
def __init__(self, dim):
|
45
|
+
super(AudioPreprocessor, self).__init__()
|
46
|
+
self.conv1 = nn.Conv1d(
|
47
|
+
in_channels=1, out_channels=dim, kernel_size=127, stride=64, bias=False
|
48
|
+
)
|
49
|
+
self.tanh = nn.Tanh()
|
50
|
+
self.group_norm = normalization.GroupNorm(group_num=1, dim=dim, eps=1e-5)
|
51
|
+
self.conv2 = nn.Conv1d(
|
52
|
+
in_channels=dim,
|
53
|
+
out_channels=2 * dim,
|
54
|
+
kernel_size=7,
|
55
|
+
stride=3,
|
56
|
+
padding=0, # Equivalent to padding="valid"
|
57
|
+
)
|
58
|
+
self.gelu1 = nn.GELU()
|
59
|
+
self.conv3 = nn.Conv1d(
|
60
|
+
in_channels=2 * dim,
|
61
|
+
out_channels=dim,
|
62
|
+
kernel_size=3,
|
63
|
+
stride=2,
|
64
|
+
padding=0, # Equivalent to padding="valid"
|
65
|
+
)
|
66
|
+
self.gelu2 = nn.GELU()
|
67
|
+
|
68
|
+
def forward(self, inputs):
|
69
|
+
x = self.conv1(inputs)
|
70
|
+
x = self.tanh(x)
|
71
|
+
x = self.group_norm(x)
|
72
|
+
x = self.conv2(x)
|
73
|
+
x = self.gelu1(x)
|
74
|
+
x = self.conv3(x)
|
75
|
+
x = self.gelu2(x)
|
76
|
+
return x
|
77
|
+
|
78
|
+
|
79
|
+
def build_preprocessor(checkpoint_path: str, **kwargs) -> nn.Module:
|
80
|
+
ap = AudioPreprocessor(dim=416)
|
81
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
82
|
+
loader.load(ap, strict=True)
|
83
|
+
return ap
|
84
|
+
|
85
|
+
|
86
|
+
def main(_):
|
87
|
+
# TODO(b/375421767) Remove golden checks once full model is implemented.
|
88
|
+
HF_PATH = os.path.join(pathlib.Path.home(), "Downloads/llm_data/moonshine")
|
89
|
+
|
90
|
+
test_data_path = pathlib.Path(__file__).parent.resolve()
|
91
|
+
INPUT_PATH = test_data_path / "data" / "pp_input.pt")
|
92
|
+
GOLDEN_PATH = test_data_path / "data" / "pp_output.pt")
|
93
|
+
|
94
|
+
ap = build_preprocessor(HF_PATH)
|
95
|
+
ap.eval()
|
96
|
+
inputs = torch.load(INPUT_PATH).reshape((1, 1, 159414))
|
97
|
+
out = ap(inputs)
|
98
|
+
golden = torch.load(GOLDEN_PATH).transpose(1, 2)
|
99
|
+
assert torch.allclose(out, golden, atol=1e-4, rtol=1e-4)
|
100
|
+
|
101
|
+
|
102
|
+
if __name__ == "__main__":
|
103
|
+
app.run(main)
|
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
|
37
|
-
'
|
38
|
-
1024,
|
39
|
-
'
|
36
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
+
'prefill_seq_lens',
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -55,11 +55,14 @@ def main(_):
|
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
57
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename =
|
58
|
+
output_filename = (
|
59
|
+
f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
|
+
)
|
61
|
+
|
59
62
|
converter.convert_to_tflite(
|
60
63
|
pytorch_model,
|
61
64
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=
|
65
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
66
|
quantize=_QUANTIZE.value,
|
64
67
|
)
|
65
68
|
|
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
|
37
|
-
'
|
38
|
-
1024,
|
39
|
-
'
|
36
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
+
'prefill_seq_lens',
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -55,11 +55,11 @@ def main(_):
|
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
57
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename = f'phi3_{quant_suffix}
|
58
|
+
output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
59
|
converter.convert_to_tflite(
|
60
60
|
pytorch_model,
|
61
61
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=
|
62
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
63
|
quantize=_QUANTIZE.value,
|
64
64
|
)
|
65
65
|
|
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
|
37
|
-
'
|
38
|
-
1024,
|
39
|
-
'
|
36
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
+
'prefill_seq_lens',
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -55,11 +55,11 @@ def main(_):
|
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
57
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename = f'phi2_{quant_suffix}
|
58
|
+
output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
59
|
converter.convert_to_tflite(
|
60
60
|
pytorch_model,
|
61
61
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=
|
62
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
63
|
quantize=_QUANTIZE.value,
|
64
64
|
)
|
65
65
|
|
@@ -39,10 +39,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
39
39
|
'/tmp/',
|
40
40
|
'The tflite file path to export.',
|
41
41
|
)
|
42
|
-
|
43
|
-
'
|
44
|
-
1024,
|
45
|
-
'
|
42
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
43
|
+
'prefill_seq_lens',
|
44
|
+
(8, 64, 128, 256, 512, 1024),
|
45
|
+
'List of the maximum sizes of prefill input tensors.',
|
46
46
|
)
|
47
47
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
48
|
'kv_cache_max_len',
|
@@ -68,11 +68,13 @@ def main(_):
|
|
68
68
|
)
|
69
69
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
70
70
|
model_size = _MODEL_SIZE.value.replace('.', '_')
|
71
|
-
output_filename =
|
71
|
+
output_filename = (
|
72
|
+
f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
73
|
+
)
|
72
74
|
converter.convert_to_tflite(
|
73
75
|
pytorch_model,
|
74
76
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
75
|
-
prefill_seq_len=
|
77
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
76
78
|
quantize=_QUANTIZE.value,
|
77
79
|
)
|
78
80
|
|
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
|
37
|
-
'
|
38
|
-
1024,
|
39
|
-
'
|
36
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
+
'prefill_seq_lens',
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -55,11 +55,11 @@ def main(_):
|
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
57
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename = f'smollm_{quant_suffix}
|
58
|
+
output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
59
|
converter.convert_to_tflite(
|
60
60
|
pytorch_model,
|
61
61
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=
|
62
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
63
|
quantize=_QUANTIZE.value,
|
64
64
|
)
|
65
65
|
|
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
|
37
|
-
'
|
38
|
-
1024,
|
39
|
-
'
|
36
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
+
'prefill_seq_lens',
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -55,11 +55,13 @@ def main(_):
|
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
57
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename =
|
58
|
+
output_filename = (
|
59
|
+
f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
|
+
)
|
59
61
|
converter.convert_to_tflite(
|
60
62
|
pytorch_model,
|
61
63
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
-
prefill_seq_len=
|
64
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
65
|
quantize=_QUANTIZE.value,
|
64
66
|
)
|
65
67
|
|
@@ -91,6 +91,11 @@ class TestVerifyRecipes(parameterized.TestCase):
|
|
91
91
|
class TestQuantizeConvert(parameterized.TestCase):
|
92
92
|
"""Test conversion with quantization."""
|
93
93
|
|
94
|
+
def setUp(self):
|
95
|
+
super().setUp()
|
96
|
+
torch.manual_seed(0)
|
97
|
+
torch._dynamo.reset()
|
98
|
+
|
94
99
|
def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
95
100
|
return quant_config.QuantConfig(
|
96
101
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
@@ -0,0 +1,154 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
# Common utility functions for data loading etc.
|
16
|
+
from dataclasses import dataclass
|
17
|
+
import glob
|
18
|
+
import os
|
19
|
+
from typing import Callable, Dict
|
20
|
+
|
21
|
+
import h5py
|
22
|
+
import torch
|
23
|
+
|
24
|
+
|
25
|
+
def transpose_if_needed(t):
|
26
|
+
"""We assume the file is from Keras, i.e. channel last format."""
|
27
|
+
if len(t.shape) > 2:
|
28
|
+
return t.permute(2, 1, 0)
|
29
|
+
return t
|
30
|
+
|
31
|
+
|
32
|
+
def load_h5_statedict(full_path: str):
|
33
|
+
"""Loads the HDF5 DataSets into a single dctionary.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
full_path (string): the HDF5 filename or directory that contains the HDF5
|
37
|
+
files.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
A state dictionary contating loaded tensors.
|
41
|
+
|
42
|
+
Raises:
|
43
|
+
ValueError: If no tensors are loaded from the provided directory or file.
|
44
|
+
"""
|
45
|
+
pattern = (
|
46
|
+
os.path.join(full_path, "*.h5") if os.path.isdir(full_path) else full_path
|
47
|
+
)
|
48
|
+
files = []
|
49
|
+
for file in glob.glob(pattern):
|
50
|
+
files.append(file)
|
51
|
+
|
52
|
+
tensors = {}
|
53
|
+
|
54
|
+
def collect_datasets(name, obj):
|
55
|
+
if isinstance(obj, h5py.Dataset):
|
56
|
+
tensors[name] = transpose_if_needed(torch.from_numpy(obj[:]))
|
57
|
+
|
58
|
+
for file in files:
|
59
|
+
with h5py.File(file) as f:
|
60
|
+
f.visititems(collect_datasets)
|
61
|
+
|
62
|
+
if not tensors:
|
63
|
+
raise ValueError("Failed to load HDF5 file.")
|
64
|
+
return tensors
|
65
|
+
|
66
|
+
|
67
|
+
class ModelLoader:
|
68
|
+
"""Utility class for loading and converting checkpoints to ODML transformer layer format."""
|
69
|
+
|
70
|
+
@dataclass
|
71
|
+
class TensorNames:
|
72
|
+
conv1D_0: str = None
|
73
|
+
conv1D_1: str = None
|
74
|
+
conv1D_2: str = None
|
75
|
+
group_norm: str = None
|
76
|
+
|
77
|
+
def __init__(self, file_name: str, names: TensorNames) -> None:
|
78
|
+
"""ModelLoader constructor.
|
79
|
+
|
80
|
+
Can be used to load multiple models of the same type.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
file_name (str): Path to the checkpoint. Can be a directory or an exact
|
84
|
+
file.
|
85
|
+
names (TensorNames): An instance of `TensorNames` to determine mappings.
|
86
|
+
"""
|
87
|
+
self._file_name = file_name
|
88
|
+
self._names = names
|
89
|
+
self._loader = load_h5_statedict
|
90
|
+
|
91
|
+
def load(
|
92
|
+
self,
|
93
|
+
model: torch.nn.Module,
|
94
|
+
strict: bool = True,
|
95
|
+
):
|
96
|
+
"""Load the model from the checkpoint
|
97
|
+
|
98
|
+
Args:
|
99
|
+
model (torch.nn.Module): The pytorch model that needs to be loaded.
|
100
|
+
strict (bool, optional): Whether the converted keys are strictly
|
101
|
+
matched. Defaults to True.
|
102
|
+
|
103
|
+
Raises:
|
104
|
+
ValueError: If conversion results in unmapped tensors and strict mode is
|
105
|
+
enabled.
|
106
|
+
"""
|
107
|
+
state = self._loader(self._file_name)
|
108
|
+
|
109
|
+
if isinstance(self._names, ModelLoader.TensorNames):
|
110
|
+
converted_state = self._do_load(model, state, self._names)
|
111
|
+
else:
|
112
|
+
raise ValueError(f"Unkown type for names: {type(self._names)}")
|
113
|
+
|
114
|
+
if strict and state:
|
115
|
+
raise ValueError(
|
116
|
+
"Failed to map all tensor. Remaining tensor are:"
|
117
|
+
f" {list(state.keys())}"
|
118
|
+
)
|
119
|
+
model.load_state_dict(converted_state, strict=strict)
|
120
|
+
|
121
|
+
def _do_load(self, model, state, names, additional_prefix=""):
|
122
|
+
"""Load the model from the checkpoint
|
123
|
+
|
124
|
+
Args:
|
125
|
+
model (torch.nn.Module): The pytorch model that needs to be loaded.
|
126
|
+
state (Dict[str, torch.Tensor]): The pytorch state dictionary
|
127
|
+
names (TensorNames]): The TensorNames for the model we are loading.
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
Dict[str, torch.Tensor]: Map of name to tensor for loading.
|
131
|
+
"""
|
132
|
+
converted_state = dict()
|
133
|
+
if names.conv1D_0 is not None:
|
134
|
+
converted_state["conv1.weight"] = state.pop(f"{names.conv1D_0}/0")
|
135
|
+
if f"{names.conv1D_0}/1" in state:
|
136
|
+
converted_state["conv1.bias"] = state.pop(f"{names.conv1D_0}/1")
|
137
|
+
|
138
|
+
if names.conv1D_1 is not None:
|
139
|
+
converted_state["conv2.weight"] = state.pop(f"{names.conv1D_1}/0")
|
140
|
+
if f"{names.conv1D_1}/1" in state:
|
141
|
+
converted_state["conv2.bias"] = state.pop(f"{names.conv1D_1}/1")
|
142
|
+
|
143
|
+
if names.conv1D_2 is not None:
|
144
|
+
converted_state["conv3.weight"] = state.pop(f"{names.conv1D_2}/0")
|
145
|
+
if f"{names.conv1D_2}/1" in state:
|
146
|
+
converted_state["conv3.bias"] = state.pop(f"{names.conv1D_2}/1")
|
147
|
+
|
148
|
+
if names.group_norm is not None:
|
149
|
+
group_norm_name = names.group_norm
|
150
|
+
converted_state[f"group_norm.weight"] = state.pop(f"{group_norm_name}/0")
|
151
|
+
if f"{group_norm_name}/1" in state:
|
152
|
+
converted_state["group_norm.bias"] = state.pop(f"{group_norm_name}/1")
|
153
|
+
|
154
|
+
return converted_state
|
@@ -35,9 +35,7 @@ from . import lowerings
|
|
35
35
|
LoweringContext = lowerings.context.LoweringContext
|
36
36
|
|
37
37
|
|
38
|
-
def _build_flat_inputs(
|
39
|
-
ctx: ir.Context, exported_program: torch.export.ExportedProgram
|
40
|
-
):
|
38
|
+
def _build_flat_inputs(exported_program: torch.export.ExportedProgram):
|
41
39
|
"""Build flattened inputs and metadata from exported program's signature."""
|
42
40
|
placeholder_nodes = [
|
43
41
|
n for n in exported_program.graph.nodes if n.op == "placeholder"
|
@@ -49,9 +47,11 @@ def _build_flat_inputs(
|
|
49
47
|
ir_inputs = []
|
50
48
|
tensor_metas = []
|
51
49
|
for node, arg in zip(placeholder_nodes, export_flat_args):
|
52
|
-
tensor_meta = node.meta.get("tensor_meta")
|
50
|
+
tensor_meta = node.meta.get("tensor_meta") or node.meta.get("val")
|
53
51
|
if tensor_meta is None:
|
54
|
-
raise RuntimeError(
|
52
|
+
raise RuntimeError(
|
53
|
+
f"{type(arg)} (for {node.name}) does not have tensor meta"
|
54
|
+
)
|
55
55
|
|
56
56
|
tensor_metas.append(tensor_meta)
|
57
57
|
# Assume all dynamic dimensions are unbounded.
|
@@ -63,7 +63,7 @@ def _build_flat_inputs(
|
|
63
63
|
ir_inputs.append(
|
64
64
|
ir.RankedTensorType.get(
|
65
65
|
shape,
|
66
|
-
export_utils.torch_dtype_to_ir_element_type(
|
66
|
+
export_utils.torch_dtype_to_ir_element_type(tensor_meta.dtype),
|
67
67
|
)
|
68
68
|
)
|
69
69
|
return tuple(ir_inputs), tuple(export_flat_args), tuple(tensor_metas)
|
@@ -258,6 +258,43 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
|
|
258
258
|
rewrite_arange(node)
|
259
259
|
|
260
260
|
|
261
|
+
# TODO(b/331481564) Make this a ai_edge_torch FX pass.
|
262
|
+
def _convert_q_dq_per_channel_args_to_list(
|
263
|
+
exported_program: torch.export.ExportedProgram,
|
264
|
+
):
|
265
|
+
"""Resolve tensor inputs to Q/DQ ops as static number list for lowering.
|
266
|
+
|
267
|
+
This pass makes the ExportedProgram in a non-executable state. This pass must
|
268
|
+
be run after all run_decompositions calls.
|
269
|
+
"""
|
270
|
+
placeholder_nodes = [
|
271
|
+
n for n in exported_program.graph.nodes if n.op == "placeholder"
|
272
|
+
]
|
273
|
+
export_flat_args = _torch_future.graph_module_flat_inputs(
|
274
|
+
exported_program, *exported_program.example_inputs
|
275
|
+
)
|
276
|
+
|
277
|
+
placeholder_tensor = {
|
278
|
+
n: tensor for n, tensor in zip(placeholder_nodes, export_flat_args)
|
279
|
+
}
|
280
|
+
|
281
|
+
graph_module = exported_program.graph_module
|
282
|
+
for node in graph_module.graph.nodes:
|
283
|
+
if node.target in (
|
284
|
+
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
285
|
+
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
286
|
+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
287
|
+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
288
|
+
):
|
289
|
+
input, scale_node, zero_point_node = node.args[:3]
|
290
|
+
scale = placeholder_tensor[scale_node]
|
291
|
+
zero_point = placeholder_tensor[zero_point_node]
|
292
|
+
|
293
|
+
scale = scale.detach().numpy().tolist()
|
294
|
+
zero_point = zero_point.detach().numpy().tolist()
|
295
|
+
node.args = (input, scale, zero_point, *node.args[3:])
|
296
|
+
|
297
|
+
|
261
298
|
def exported_program_to_mlir(
|
262
299
|
exported_program: torch.export.ExportedProgram,
|
263
300
|
) -> MlirLowered:
|
@@ -270,6 +307,7 @@ def exported_program_to_mlir(
|
|
270
307
|
exported_program = _torch_future.safe_run_decompositions(
|
271
308
|
exported_program, lowerings.decompositions()
|
272
309
|
)
|
310
|
+
_convert_q_dq_per_channel_args_to_list(exported_program)
|
273
311
|
|
274
312
|
with export_utils.create_ir_context() as context, ir.Location.unknown():
|
275
313
|
|
@@ -277,7 +315,7 @@ def exported_program_to_mlir(
|
|
277
315
|
lctx = LoweringContext(context, module)
|
278
316
|
interpreter = LoweringInterpreter(exported_program.graph_module, lctx)
|
279
317
|
ir_flat_inputs, export_flat_args, tensor_metas = _build_flat_inputs(
|
280
|
-
|
318
|
+
exported_program
|
281
319
|
)
|
282
320
|
|
283
321
|
# HACK: OSS MLIR pybinding could mysteriously transform func.func under
|
@@ -14,9 +14,9 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Utilities for ODML Torch export."""
|
16
16
|
|
17
|
-
import functools
|
18
17
|
import re
|
19
18
|
from typing import Sequence, cast
|
19
|
+
from ai_edge_torch.odml_torch.lowerings import utils as lowering_utils
|
20
20
|
import jax._src.interpreters.mlir
|
21
21
|
from jax._src.lib.mlir import ir
|
22
22
|
from jax._src.lib.mlir.dialects import func
|
@@ -47,7 +47,6 @@ def create_ir_context():
|
|
47
47
|
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
|
48
48
|
context = jax._src.interpreters.mlir.make_ir_context()
|
49
49
|
context.allow_unregistered_dialects = True
|
50
|
-
|
51
50
|
return context
|
52
51
|
|
53
52
|
|
@@ -135,17 +134,7 @@ def build_ir_attr(val):
|
|
135
134
|
return ir.StringAttr.get(str(val))
|
136
135
|
|
137
136
|
|
138
|
-
|
139
|
-
ty_get = {
|
140
|
-
torch.double: ir.F64Type.get,
|
141
|
-
torch.float32: ir.F32Type.get,
|
142
|
-
torch.half: ir.F16Type.get,
|
143
|
-
torch.long: functools.partial(ir.IntegerType.get_signless, 64),
|
144
|
-
torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
|
145
|
-
torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
|
146
|
-
torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
|
147
|
-
}.get(dtype)
|
148
|
-
return ty_get(ctx)
|
137
|
+
torch_dtype_to_ir_element_type = lowering_utils.torch_dtype_to_ir_element_type
|
149
138
|
|
150
139
|
|
151
140
|
def ir_element_type_to_torch_dtype(ty):
|
@@ -163,9 +163,7 @@ def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
|
|
163
163
|
if aval is None:
|
164
164
|
return result
|
165
165
|
|
166
|
-
target_elty = export_utils.torch_dtype_to_ir_element_type(
|
167
|
-
lctx.ir_context, aval.dtype
|
168
|
-
)
|
166
|
+
target_elty = export_utils.torch_dtype_to_ir_element_type(aval.dtype)
|
169
167
|
if result.type.element_type == target_elty:
|
170
168
|
return result
|
171
169
|
return stablehlo.convert(
|
@@ -227,9 +227,7 @@ def _aten_cat(lctx: LoweringContext, tensors, dim=0):
|
|
227
227
|
if not non_empty_tensors:
|
228
228
|
return utils.splat(
|
229
229
|
0,
|
230
|
-
export_utils.torch_dtype_to_ir_element_type(
|
231
|
-
lctx.ir_context, out_aval.dtype
|
232
|
-
),
|
230
|
+
export_utils.torch_dtype_to_ir_element_type(out_aval.dtype),
|
233
231
|
out_aval.shape,
|
234
232
|
)
|
235
233
|
|
@@ -0,0 +1,174 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Lowerings for PT2E torch.ops.quantized_decomposed ops."""
|
16
|
+
from typing import Union, cast
|
17
|
+
|
18
|
+
from ai_edge_torch.odml_torch.lowerings import context
|
19
|
+
from ai_edge_torch.odml_torch.lowerings import utils
|
20
|
+
from jax._src.lib.mlir import ir
|
21
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
22
|
+
import torch
|
23
|
+
import torch.ao.quantization.fx._decomposed
|
24
|
+
import torch.utils._pytree as pytree
|
25
|
+
|
26
|
+
from . import registry
|
27
|
+
|
28
|
+
lower = registry.lower
|
29
|
+
LoweringContext = context.LoweringContext
|
30
|
+
|
31
|
+
|
32
|
+
def _uniform_quantized_type(
|
33
|
+
stored_type: str | ir.Type,
|
34
|
+
expressed_type: str | ir.Type,
|
35
|
+
*,
|
36
|
+
scale=float | list[float] | tuple[float],
|
37
|
+
zero_point=float | list[float] | tuple[float],
|
38
|
+
storage_type_min: int | None = None,
|
39
|
+
storage_type_max: int | None = None,
|
40
|
+
channel_axis: int | None = None,
|
41
|
+
channel_axis_size: int | None = None,
|
42
|
+
):
|
43
|
+
"""Polyfill for quant.UniformQuantizedType."""
|
44
|
+
if storage_type_min and storage_type_max:
|
45
|
+
storage_min_max = f"<{storage_type_min}:{storage_type_max}>"
|
46
|
+
else:
|
47
|
+
storage_min_max = ""
|
48
|
+
|
49
|
+
if channel_axis is not None:
|
50
|
+
# Per-channel quantization
|
51
|
+
# https://mlir.llvm.org/docs/Dialects/QuantDialect/#per-channel-quantization
|
52
|
+
assert isinstance(scale, (list, tuple))
|
53
|
+
assert isinstance(zero_point, (list, tuple))
|
54
|
+
|
55
|
+
if len(scale) == 1:
|
56
|
+
scale *= channel_axis_size
|
57
|
+
if len(zero_point) == 1:
|
58
|
+
zero_point *= channel_axis_size
|
59
|
+
|
60
|
+
assert len(scale) == len(zero_point) == channel_axis_size
|
61
|
+
scale_zp_strs = []
|
62
|
+
for s, zp in zip(scale, zero_point):
|
63
|
+
scale_zp_strs.append(f"{s}:{zp}")
|
64
|
+
scale_zp = "{" + ",".join(scale_zp_strs) + "}"
|
65
|
+
return ir.Type.parse(
|
66
|
+
f"!quant.uniform<{stored_type}{storage_min_max}:{expressed_type}:{channel_axis},{scale_zp}>"
|
67
|
+
)
|
68
|
+
else:
|
69
|
+
# Per-layer quantization
|
70
|
+
# https://mlir.llvm.org/docs/Dialects/QuantDialect/#per-layer-quantization
|
71
|
+
scale = pytree.tree_flatten([scale])[0][-1]
|
72
|
+
zero_point = pytree.tree_flatten([zero_point])[0][-1]
|
73
|
+
scale_zp = f"{scale}:{zero_point}"
|
74
|
+
return ir.Type.parse(
|
75
|
+
f"!quant.uniform<{stored_type}{storage_min_max}:{expressed_type},{scale_zp}>"
|
76
|
+
)
|
77
|
+
|
78
|
+
|
79
|
+
# Quant dialect is not registered in the Python MLIR pybinding used by
|
80
|
+
# odml-torch. Therefore, stablehlo.uniform_quantize/uniform_dequantize ops and
|
81
|
+
# quant types are represented in stablehlo.custom_call to pass MLIR verification
|
82
|
+
# and VHLO serialization before converter.
|
83
|
+
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
|
84
|
+
|
85
|
+
|
86
|
+
# Schema:
|
87
|
+
# - quantized_decomposed::quantize_per_tensor(Tensor input, float scale,
|
88
|
+
# int zero_point, int quant_min, int quant_max,
|
89
|
+
# ScalarType dtype) -> Tensor
|
90
|
+
# - quantized_decomposed::quantize_per_tensor.tensor(Tensor input,
|
91
|
+
# Tensor scale, Tensor zero_point, int quant_min, int quant_max,
|
92
|
+
# ScalarType dtype) -> Tensor
|
93
|
+
#
|
94
|
+
# Scale and zero_point in tensors are automatically converted to list before
|
95
|
+
# lowering.
|
96
|
+
@lower(torch.ops.quantized_decomposed.quantize_per_tensor)
|
97
|
+
def _quantize_per_tensor(
|
98
|
+
lctx: LoweringContext,
|
99
|
+
input: ir.Value,
|
100
|
+
scale: Union[float, list[float]],
|
101
|
+
zero_point: Union[float, list[float]],
|
102
|
+
quant_min: int,
|
103
|
+
quant_max: int,
|
104
|
+
dtype: torch.dtype,
|
105
|
+
):
|
106
|
+
input_ty = cast(ir.RankedTensorType, input.type)
|
107
|
+
qty = _uniform_quantized_type(
|
108
|
+
utils.torch_dtype_to_ir_element_type(dtype),
|
109
|
+
input_ty.element_type,
|
110
|
+
scale=scale,
|
111
|
+
zero_point=zero_point,
|
112
|
+
storage_type_min=quant_min,
|
113
|
+
storage_type_max=quant_max,
|
114
|
+
)
|
115
|
+
return stablehlo.custom_call(
|
116
|
+
call_target_name="odml_torch.uniform_quantize",
|
117
|
+
inputs=[input],
|
118
|
+
result=[input_ty],
|
119
|
+
backend_config=ir.StringAttr.get(
|
120
|
+
str(ir.RankedTensorType.get(input_ty.shape, qty))
|
121
|
+
),
|
122
|
+
)
|
123
|
+
|
124
|
+
|
125
|
+
# Schema:
|
126
|
+
# - quantized_decomposed::quantize_per_channel(Tensor input, Tensor scales,
|
127
|
+
# Tensor zero_points, int axis, int quant_min, int quant_max,
|
128
|
+
# ScalarType dtype) -> Tensor
|
129
|
+
#
|
130
|
+
# Scale and zero_point in tensors are automatically converted to list before
|
131
|
+
# lowering.
|
132
|
+
@lower(torch.ops.quantized_decomposed.quantize_per_channel)
|
133
|
+
def _quantize_per_channel(
|
134
|
+
lctx: LoweringContext,
|
135
|
+
input: ir.Value,
|
136
|
+
scale: list[float],
|
137
|
+
zero_point: list[float],
|
138
|
+
axis: int,
|
139
|
+
quant_min: int,
|
140
|
+
quant_max: int,
|
141
|
+
dtype: torch.dtype,
|
142
|
+
):
|
143
|
+
input_ty = cast(ir.RankedTensorType, input.type)
|
144
|
+
qty = _uniform_quantized_type(
|
145
|
+
utils.torch_dtype_to_ir_element_type(dtype),
|
146
|
+
input_ty.element_type,
|
147
|
+
scale=scale,
|
148
|
+
zero_point=zero_point,
|
149
|
+
channel_axis=axis,
|
150
|
+
channel_axis_size=input_ty.shape[axis],
|
151
|
+
storage_type_min=quant_min,
|
152
|
+
storage_type_max=quant_max,
|
153
|
+
)
|
154
|
+
return stablehlo.custom_call(
|
155
|
+
call_target_name="odml_torch.uniform_quantize",
|
156
|
+
inputs=[input],
|
157
|
+
result=[input_ty],
|
158
|
+
backend_config=ir.StringAttr.get(
|
159
|
+
str(ir.RankedTensorType.get(input_ty.shape, qty))
|
160
|
+
),
|
161
|
+
)
|
162
|
+
|
163
|
+
|
164
|
+
@lower(torch.ops.quantized_decomposed.dequantize_per_tensor)
|
165
|
+
@lower(torch.ops.quantized_decomposed.dequantize_per_channel)
|
166
|
+
def _dequantize(lctx: LoweringContext, input: ir.Value, *args, **kwargs):
|
167
|
+
result_meta = lctx.node.meta.get("tensor_meta")
|
168
|
+
result_elty = utils.torch_dtype_to_ir_element_type(result_meta.dtype)
|
169
|
+
|
170
|
+
return stablehlo.custom_call(
|
171
|
+
call_target_name="odml_torch.uniform_dequantize",
|
172
|
+
inputs=[input],
|
173
|
+
result=[ir.RankedTensorType.get(result_meta.shape, result_elty)],
|
174
|
+
)
|
@@ -14,6 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Utilities for building MLIR lowerings."""
|
16
16
|
|
17
|
+
import functools
|
17
18
|
import numbers
|
18
19
|
from typing import Any
|
19
20
|
from typing import Optional
|
@@ -21,6 +22,21 @@ from typing import Optional
|
|
21
22
|
from jax._src.lib.mlir import ir
|
22
23
|
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
23
24
|
import numpy as np
|
25
|
+
import torch
|
26
|
+
|
27
|
+
|
28
|
+
def torch_dtype_to_ir_element_type(dtype):
|
29
|
+
ty_get = {
|
30
|
+
torch.double: ir.F64Type.get,
|
31
|
+
torch.float32: ir.F32Type.get,
|
32
|
+
torch.half: ir.F16Type.get,
|
33
|
+
torch.long: functools.partial(ir.IntegerType.get_signless, 64),
|
34
|
+
torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
|
35
|
+
torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
|
36
|
+
torch.int8: functools.partial(ir.IntegerType.get_signless, 8),
|
37
|
+
torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
|
38
|
+
}[dtype]
|
39
|
+
return ty_get()
|
24
40
|
|
25
41
|
|
26
42
|
def splat(val, ty, shape=tuple(), *, loc: Optional[Any] = None):
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241204
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=IfBDOY7eb9sSynpfr3Qsw88QCzs0DDWN_jB1B6zi5ss,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -26,7 +26,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
26
26
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
27
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
28
28
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
29
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
29
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=v6AhfWqRBuHT7uBDueTbntaQtDSMMrvQOqlIDXNUaMA,17250
|
30
30
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
31
31
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
32
32
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
@@ -44,20 +44,22 @@ ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=bkq2Zk
|
|
44
44
|
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=-n79r6yFnCACpms5eMkXNpyQsCn2PYVRdB-jOoIqn14,2227
|
45
45
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
|
46
46
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=
|
48
|
-
ai_edge_torch/generative/examples/gemma/
|
49
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
|
47
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=mrG96_WEGD4NQ4uFEKrHRMAQvVVliOcj1zbI3drGDjI,2199
|
48
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=I_tvwCYmtf08D1HqDxYx7dpvj2q5_eaYnuI_3rI6Dlw,2201
|
50
49
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
|
51
50
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
|
52
51
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
53
52
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
54
53
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
55
54
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
56
|
-
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=
|
55
|
+
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=Brb83sbqBfStUiIZFhfWnYtN7LcNmkKyFn96cZK4sGo,2426
|
57
56
|
ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaXEJ3RLLwkHii7to7UhVo,6854
|
58
57
|
ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
|
58
|
+
ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
59
|
+
ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
|
60
|
+
ai_edge_torch/generative/examples/moonshine/moonshine.py,sha256=nZ2b8u4TmsB5sgdClgAuH8E78bcTv9RCnF9666HqP2M,3394
|
59
61
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
60
|
-
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256
|
62
|
+
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=-qDBu3bjUq0jx73SPDMsPIBP0BT1nA_0UgtFKeSuM18,2213
|
61
63
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
|
62
64
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
63
65
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -69,18 +71,18 @@ ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM
|
|
69
71
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
70
72
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
|
71
73
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
72
|
-
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=
|
73
|
-
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=
|
74
|
+
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=ruY-LLwpqBqVZ5z9h_sewYj04ukWRG4j804tUAyDdnA,2186
|
75
|
+
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=UdMk1SSkcWpv8gosUylx3JSCxdOJBjZNhuQQtT4-Ono,2184
|
74
76
|
ai_edge_torch/generative/examples/phi/phi2.py,sha256=nbivDwZREd-sypy_ittO59-yaAdPvHv1YEV6Fo5buCo,3341
|
75
77
|
ai_edge_torch/generative/examples/phi/phi3.py,sha256=GkHOaYfsFEbHvfZCaLlb3Us_h19ezqPDUakoz_DiG9A,7123
|
76
78
|
ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
|
77
79
|
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
|
78
80
|
ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
79
|
-
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=
|
81
|
+
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=1M3DTkf536TCLYcQB1lu-3TxQ6mV03dFhTdbk0p8i84,2523
|
80
82
|
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=oYm9hhALUQ4uOn-PO1bF7fCIGP8EWRNK4zClkx2RQs8,4070
|
81
83
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
82
84
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
83
|
-
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=
|
85
|
+
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=56CzCjyp9xh_2ZtXKN9tlEv6GayeSc4giTIZsi2Q59E,2194
|
84
86
|
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=M5qAcSUE5gxOSfq24a8lZku9kgvmlFCyIBar3kF2XEk,2570
|
85
87
|
ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
|
86
88
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -107,7 +109,7 @@ ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNH
|
|
107
109
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
|
108
110
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=rRodLr-hEqAs_-8x06O8qO-hJ_cqr2AfhJZ9DCptvwo,4332
|
109
111
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
110
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=
|
112
|
+
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=WmEshoN9HgNLbV7NTjxdqWz9Olcim6r_vo4R9eYE98I,2228
|
111
113
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=10X8HwPx4akzclnIMOBNItKQemhRbvxBbTo7nwZtWjM,2650
|
112
114
|
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
|
113
115
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
@@ -139,13 +141,14 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FP
|
|
139
141
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
140
142
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=aZFaheg2sq7rEccch1TZM6W4BSfpJZjrM9Gyp4hVGYs,6351
|
141
143
|
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=xWV9O2wuRHc4VNBWuWipiuqXa3AJhiV1nmjewAZHHWM,11177
|
142
|
-
ai_edge_torch/generative/test/test_quantize.py,sha256=
|
144
|
+
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
143
145
|
ai_edge_torch/generative/test/utils.py,sha256=eQ-hjd1eXuHJF3SJK6_CrjgOZVzmG_4VEdH7Z1gH_lA,1897
|
144
146
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
145
147
|
ai_edge_torch/generative/utilities/converter.py,sha256=S14STbyxV6A9HKy1BdUo49f2jS6Ij0RL9mVAFUMWYV8,5291
|
146
148
|
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
147
149
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
148
150
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
|
151
|
+
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
149
152
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
150
153
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
151
154
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
@@ -167,8 +170,8 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1
|
|
167
170
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
168
171
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
|
169
172
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
170
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
171
|
-
ai_edge_torch/odml_torch/export_utils.py,sha256=
|
173
|
+
ai_edge_torch/odml_torch/export.py,sha256=dgnNGBVkHBz0brlWALX2hGXpQ4YzCKdwbkF4oAfEu4I,13062
|
174
|
+
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
172
175
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
|
173
176
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
174
177
|
ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
|
@@ -177,17 +180,18 @@ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu
|
|
177
180
|
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
|
178
181
|
ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
|
179
182
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
|
180
|
-
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=
|
183
|
+
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
181
184
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
182
|
-
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=
|
183
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
185
|
+
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=1lMKPoStK3SUA8yYTPZBRhESN33BghGXnfqOOg4oeVk,995
|
186
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=ufvnaAh6rM_yfoc8ybI3VErHEVBv5W_p4iOe9slfwKM,9948
|
184
187
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
185
188
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
|
186
189
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=4UyNyaR2W-vCOvj-P5lywQ1_RfLIxVE7J_GONI6CQvI,10718
|
187
190
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
|
191
|
+
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=rFmzqcdjYrwhcxH8j9zCFStPy21HFF7hkUV_GQ8FPAk,6056
|
188
192
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
189
193
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=itTt8MLbq2LoHTzRidCF2TTbh0TP7L836u99qCjP3FA,2953
|
190
|
-
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=
|
194
|
+
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
|
191
195
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
192
196
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
193
197
|
ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
|
@@ -196,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
196
200
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
197
201
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
198
202
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
199
|
-
ai_edge_torch_nightly-0.3.0.
|
200
|
-
ai_edge_torch_nightly-0.3.0.
|
201
|
-
ai_edge_torch_nightly-0.3.0.
|
202
|
-
ai_edge_torch_nightly-0.3.0.
|
203
|
-
ai_edge_torch_nightly-0.3.0.
|
203
|
+
ai_edge_torch_nightly-0.3.0.dev20241204.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
204
|
+
ai_edge_torch_nightly-0.3.0.dev20241204.dist-info/METADATA,sha256=jcbPyL8PaZxa79FSq_FwD0sfu-2R_nTRMDJiPNmJXFM,1897
|
205
|
+
ai_edge_torch_nightly-0.3.0.dev20241204.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20241204.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20241204.dist-info/RECORD,,
|
File without changes
|
File without changes
|