ai-edge-torch-nightly 0.3.0.dev20241201__py3-none-any.whl → 0.3.0.dev20241205__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (28) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +48 -0
  2. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +6 -6
  3. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +6 -6
  4. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +6 -6
  5. ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/{gemma/convert_gemma2_multi_prefills.py → moonshine/convert_moonshine_to_tflite.py} +11 -29
  7. ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
  8. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +9 -6
  9. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +6 -6
  10. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +6 -6
  11. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +8 -6
  12. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +6 -6
  13. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +8 -6
  14. ai_edge_torch/generative/test/test_quantize.py +5 -0
  15. ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
  16. ai_edge_torch/odml_torch/export.py +45 -7
  17. ai_edge_torch/odml_torch/export_utils.py +2 -13
  18. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +1 -3
  19. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  20. ai_edge_torch/odml_torch/lowerings/_basic.py +1 -3
  21. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +174 -0
  22. ai_edge_torch/odml_torch/lowerings/utils.py +16 -0
  23. ai_edge_torch/version.py +1 -1
  24. {ai_edge_torch_nightly-0.3.0.dev20241201.dist-info → ai_edge_torch_nightly-0.3.0.dev20241205.dist-info}/METADATA +1 -1
  25. {ai_edge_torch_nightly-0.3.0.dev20241201.dist-info → ai_edge_torch_nightly-0.3.0.dev20241205.dist-info}/RECORD +28 -24
  26. {ai_edge_torch_nightly-0.3.0.dev20241201.dist-info → ai_edge_torch_nightly-0.3.0.dev20241205.dist-info}/LICENSE +0 -0
  27. {ai_edge_torch_nightly-0.3.0.dev20241201.dist-info → ai_edge_torch_nightly-0.3.0.dev20241205.dist-info}/WHEEL +0 -0
  28. {ai_edge_torch_nightly-0.3.0.dev20241201.dist-info → ai_edge_torch_nightly-0.3.0.dev20241205.dist-info}/top_level.txt +0 -0
@@ -21,10 +21,12 @@ from typing import Tuple
21
21
  import ai_edge_torch
22
22
  from ai_edge_torch import config
23
23
  from ai_edge_torch._convert import conversion_utils
24
+ from ai_edge_torch.quantize import pt2e_quantizer
24
25
  from ai_edge_torch.testing import model_coverage
25
26
  import numpy as np
26
27
  import torch
27
28
  from torch import nn
29
+ from torch.ao.quantization import quantize_pt2e
28
30
  import torchvision
29
31
 
30
32
  from absl.testing import absltest as googletest
@@ -506,6 +508,52 @@ class TestConvert(googletest.TestCase):
506
508
  model_coverage.compare_tflite_torch(edge_model, torch_module, args)
507
509
  )
508
510
 
511
+ def test_convert_resnet18_pt2e_per_layer(self):
512
+ # Step 1: export resnet18
513
+ args = (torch.randn(1, 3, 224, 224),)
514
+ m = torchvision.models.resnet18().eval()
515
+ m = torch._export.capture_pre_autograd_graph(m, args)
516
+
517
+ # Step 2: Insert observers or fake quantize modules
518
+ quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
519
+ pt2e_quantizer.get_symmetric_quantization_config(is_per_channel=False)
520
+ )
521
+ m = quantize_pt2e.prepare_pt2e(m, quantizer)
522
+
523
+ # Step 3: Quantize the model
524
+ m = quantize_pt2e.convert_pt2e(m, fold_quantize=False)
525
+
526
+ # pylint: disable=broad-except
527
+ try:
528
+ ai_edge_torch.convert(m, args)
529
+ except Exception as err:
530
+ self.fail(f"PT2E conversion failed: {err}")
531
+ # pylint: enable=broad-except
532
+
533
+ def test_convert_resnet18_pt2e_per_channel(self):
534
+ # Step 1: export resnet18
535
+ args = (torch.randn(1, 3, 224, 224),)
536
+ m = torchvision.models.resnet18().eval()
537
+ m = torch._export.capture_pre_autograd_graph(m, args)
538
+
539
+ # Step 2: Insert observers or fake quantize modules
540
+ quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
541
+ pt2e_quantizer.get_symmetric_quantization_config(is_per_channel=True)
542
+ )
543
+ m = quantize_pt2e.prepare_pt2e(m, quantizer)
544
+ # Step 3: Run through example inputs, otherwise per-channel
545
+ # quant may have scalar scale/zero_point
546
+ m(*args)
547
+ # Step 4: Quantize the model
548
+ m = quantize_pt2e.convert_pt2e(m, fold_quantize=False)
549
+
550
+ # pylint: disable=broad-except
551
+ try:
552
+ ai_edge_torch.convert(m, args)
553
+ except Exception as err:
554
+ self.fail(f"PT2E conversion failed: {err}")
555
+ # pylint: enable=broad-except
556
+
509
557
 
510
558
  if __name__ == "__main__":
511
559
  googletest.main()
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
36
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
+ 'prefill_seq_lens',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -55,11 +55,11 @@ def main(_):
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
57
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'gemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
58
+ output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
59
  converter.convert_to_tflite(
60
60
  pytorch_model,
61
61
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
62
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
63
  quantize=_QUANTIZE.value,
64
64
  )
65
65
 
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
36
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
+ 'prefill_seq_lens',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -55,11 +55,11 @@ def main(_):
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
57
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'gemma2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
58
+ output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
59
  converter.convert_to_tflite(
60
60
  pytorch_model,
61
61
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
62
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
63
  quantize=_QUANTIZE.value,
64
64
  )
65
65
 
@@ -39,10 +39,10 @@ _TFLITE_PATH = flags.DEFINE_string(
39
39
  '/tmp/',
40
40
  'The tflite file path to export.',
41
41
  )
42
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
43
- 'prefill_seq_len',
44
- 1024,
45
- 'The maximum size of prefill input tensor.',
42
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
+ 'prefill_seq_lens',
44
+ (8, 64, 128, 256, 512, 1024),
45
+ 'List of the maximum sizes of prefill input tensors.',
46
46
  )
47
47
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
48
  'kv_cache_max_len',
@@ -66,11 +66,11 @@ def main(_):
66
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
67
67
  )
68
68
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
69
- output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
69
+ output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
70
70
  converter.convert_to_tflite(
71
71
  pytorch_model,
72
72
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
73
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
73
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
74
74
  quantize=_QUANTIZE.value,
75
75
  )
76
76
 
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -13,19 +13,21 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example to convert a Gemma2 model to multiple prefill length tflite model."""
16
+ """Example of converting a Moonshine model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
19
  import pathlib
20
20
 
21
21
  from absl import app
22
22
  from absl import flags
23
- from ai_edge_torch.generative.examples.gemma import gemma2
23
+ import ai_edge_torch
24
+ from ai_edge_torch.generative.examples.moonshine import moonshine
24
25
  from ai_edge_torch.generative.utilities import converter
26
+ import torch
25
27
 
26
28
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
29
  'checkpoint_path',
28
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
30
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/moonshine'),
29
31
  'The path to the model checkpoint, or directory holding the checkpoint.',
30
32
  )
31
33
  _TFLITE_PATH = flags.DEFINE_string(
@@ -33,35 +35,15 @@ _TFLITE_PATH = flags.DEFINE_string(
33
35
  '/tmp/',
34
36
  'The tflite file path to export.',
35
37
  )
36
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
- 'prefill_seq_lens',
38
- (8, 64, 128, 256, 512, 1024),
39
- 'List of the maximum sizes of prefill input tensors.',
40
- )
41
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
- 'kv_cache_max_len',
43
- 1280,
44
- 'The maximum size of KV cache buffer, including both prefill and decode.',
45
- )
46
- _QUANTIZE = flags.DEFINE_bool(
47
- 'quantize',
48
- True,
49
- 'Whether the model should be quantized.',
50
- )
51
38
 
52
39
 
53
40
  def main(_):
54
- pytorch_model = gemma2.build_2b_model(
55
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
- )
57
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'gemma2_{quant_suffix}_multi-prefill-seq_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
- converter.convert_to_tflite(
60
- pytorch_model,
61
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
- quantize=_QUANTIZE.value,
64
- )
41
+ p_model = moonshine.build_preprocessor(_CHECKPOINT_PATH.value)
42
+ output_filename = f'moonshine_preprocessor.tflite'
43
+ _input = torch.randn((1, 1, 159414), dtype=torch.float)
44
+ edge_model = ai_edge_torch.convert(p_model, (_input,), quant_config=None)
45
+ tflite_path = os.path.join(_TFLITE_PATH.value, output_filename)
46
+ edge_model.export(tflite_path)
65
47
 
66
48
 
67
49
  if __name__ == '__main__':
@@ -0,0 +1,103 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building the Moonshine model."""
17
+
18
+ import os
19
+ import pathlib
20
+ from typing import Optional, Tuple
21
+ from absl import app
22
+ from ai_edge_torch.generative.layers import attention
23
+ from ai_edge_torch.generative.layers import builder
24
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
+ import ai_edge_torch.generative.layers.model_config as cfg
27
+ import ai_edge_torch.generative.layers.normalization as normalization
28
+ import ai_edge_torch.generative.utilities.moonshine_loader as loading_utils
29
+ import h5py
30
+ import torch
31
+ from torch import nn
32
+ import torch.nn as nn
33
+
34
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
35
+ conv1D_0="layers/sequential/layers/conv1d/vars",
36
+ conv1D_1="layers/sequential/layers/conv1d_1/vars",
37
+ conv1D_2="layers/sequential/layers/conv1d_2/vars",
38
+ group_norm="layers/sequential/layers/group_normalization/vars",
39
+ )
40
+
41
+
42
+ class AudioPreprocessor(nn.Module):
43
+
44
+ def __init__(self, dim):
45
+ super(AudioPreprocessor, self).__init__()
46
+ self.conv1 = nn.Conv1d(
47
+ in_channels=1, out_channels=dim, kernel_size=127, stride=64, bias=False
48
+ )
49
+ self.tanh = nn.Tanh()
50
+ self.group_norm = normalization.GroupNorm(group_num=1, dim=dim, eps=1e-5)
51
+ self.conv2 = nn.Conv1d(
52
+ in_channels=dim,
53
+ out_channels=2 * dim,
54
+ kernel_size=7,
55
+ stride=3,
56
+ padding=0, # Equivalent to padding="valid"
57
+ )
58
+ self.gelu1 = nn.GELU()
59
+ self.conv3 = nn.Conv1d(
60
+ in_channels=2 * dim,
61
+ out_channels=dim,
62
+ kernel_size=3,
63
+ stride=2,
64
+ padding=0, # Equivalent to padding="valid"
65
+ )
66
+ self.gelu2 = nn.GELU()
67
+
68
+ def forward(self, inputs):
69
+ x = self.conv1(inputs)
70
+ x = self.tanh(x)
71
+ x = self.group_norm(x)
72
+ x = self.conv2(x)
73
+ x = self.gelu1(x)
74
+ x = self.conv3(x)
75
+ x = self.gelu2(x)
76
+ return x
77
+
78
+
79
+ def build_preprocessor(checkpoint_path: str, **kwargs) -> nn.Module:
80
+ ap = AudioPreprocessor(dim=416)
81
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
82
+ loader.load(ap, strict=True)
83
+ return ap
84
+
85
+
86
+ def main(_):
87
+ # TODO(b/375421767) Remove golden checks once full model is implemented.
88
+ HF_PATH = os.path.join(pathlib.Path.home(), "Downloads/llm_data/moonshine")
89
+
90
+ test_data_path = pathlib.Path(__file__).parent.resolve()
91
+ INPUT_PATH = test_data_path / "data" / "pp_input.pt")
92
+ GOLDEN_PATH = test_data_path / "data" / "pp_output.pt")
93
+
94
+ ap = build_preprocessor(HF_PATH)
95
+ ap.eval()
96
+ inputs = torch.load(INPUT_PATH).reshape((1, 1, 159414))
97
+ out = ap(inputs)
98
+ golden = torch.load(GOLDEN_PATH).transpose(1, 2)
99
+ assert torch.allclose(out, golden, atol=1e-4, rtol=1e-4)
100
+
101
+
102
+ if __name__ == "__main__":
103
+ app.run(main)
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
36
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
+ 'prefill_seq_lens',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -55,11 +55,14 @@ def main(_):
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
57
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'openelm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
58
+ output_filename = (
59
+ f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
+ )
61
+
59
62
  converter.convert_to_tflite(
60
63
  pytorch_model,
61
64
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
65
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
66
  quantize=_QUANTIZE.value,
64
67
  )
65
68
 
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
36
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
+ 'prefill_seq_lens',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -55,11 +55,11 @@ def main(_):
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
57
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'phi3_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
58
+ output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
59
  converter.convert_to_tflite(
60
60
  pytorch_model,
61
61
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
62
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
63
  quantize=_QUANTIZE.value,
64
64
  )
65
65
 
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
36
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
+ 'prefill_seq_lens',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -55,11 +55,11 @@ def main(_):
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
57
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'phi2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
58
+ output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
59
  converter.convert_to_tflite(
60
60
  pytorch_model,
61
61
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
62
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
63
  quantize=_QUANTIZE.value,
64
64
  )
65
65
 
@@ -39,10 +39,10 @@ _TFLITE_PATH = flags.DEFINE_string(
39
39
  '/tmp/',
40
40
  'The tflite file path to export.',
41
41
  )
42
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
43
- 'prefill_seq_len',
44
- 1024,
45
- 'The maximum size of prefill input tensor.',
42
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
+ 'prefill_seq_lens',
44
+ (8, 64, 128, 256, 512, 1024),
45
+ 'List of the maximum sizes of prefill input tensors.',
46
46
  )
47
47
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
48
  'kv_cache_max_len',
@@ -68,11 +68,13 @@ def main(_):
68
68
  )
69
69
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
70
70
  model_size = _MODEL_SIZE.value.replace('.', '_')
71
- output_filename = f'qwen_{model_size}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
71
+ output_filename = (
72
+ f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
73
+ )
72
74
  converter.convert_to_tflite(
73
75
  pytorch_model,
74
76
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
75
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
77
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
76
78
  quantize=_QUANTIZE.value,
77
79
  )
78
80
 
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
36
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
+ 'prefill_seq_lens',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -55,11 +55,11 @@ def main(_):
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
57
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'smollm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
58
+ output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
59
  converter.convert_to_tflite(
60
60
  pytorch_model,
61
61
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
62
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
63
  quantize=_QUANTIZE.value,
64
64
  )
65
65
 
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
36
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
+ 'prefill_seq_lens',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -55,11 +55,13 @@ def main(_):
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
57
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'tinyllama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
58
+ output_filename = (
59
+ f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
+ )
59
61
  converter.convert_to_tflite(
60
62
  pytorch_model,
61
63
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
64
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
65
  quantize=_QUANTIZE.value,
64
66
  )
65
67
 
@@ -91,6 +91,11 @@ class TestVerifyRecipes(parameterized.TestCase):
91
91
  class TestQuantizeConvert(parameterized.TestCase):
92
92
  """Test conversion with quantization."""
93
93
 
94
+ def setUp(self):
95
+ super().setUp()
96
+ torch.manual_seed(0)
97
+ torch._dynamo.reset()
98
+
94
99
  def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
95
100
  return quant_config.QuantConfig(
96
101
  generative_recipe=quant_recipe.GenerativeQuantRecipe(
@@ -0,0 +1,154 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Common utility functions for data loading etc.
16
+ from dataclasses import dataclass
17
+ import glob
18
+ import os
19
+ from typing import Callable, Dict
20
+
21
+ import h5py
22
+ import torch
23
+
24
+
25
+ def transpose_if_needed(t):
26
+ """We assume the file is from Keras, i.e. channel last format."""
27
+ if len(t.shape) > 2:
28
+ return t.permute(2, 1, 0)
29
+ return t
30
+
31
+
32
+ def load_h5_statedict(full_path: str):
33
+ """Loads the HDF5 DataSets into a single dctionary.
34
+
35
+ Args:
36
+ full_path (string): the HDF5 filename or directory that contains the HDF5
37
+ files.
38
+
39
+ Returns:
40
+ A state dictionary contating loaded tensors.
41
+
42
+ Raises:
43
+ ValueError: If no tensors are loaded from the provided directory or file.
44
+ """
45
+ pattern = (
46
+ os.path.join(full_path, "*.h5") if os.path.isdir(full_path) else full_path
47
+ )
48
+ files = []
49
+ for file in glob.glob(pattern):
50
+ files.append(file)
51
+
52
+ tensors = {}
53
+
54
+ def collect_datasets(name, obj):
55
+ if isinstance(obj, h5py.Dataset):
56
+ tensors[name] = transpose_if_needed(torch.from_numpy(obj[:]))
57
+
58
+ for file in files:
59
+ with h5py.File(file) as f:
60
+ f.visititems(collect_datasets)
61
+
62
+ if not tensors:
63
+ raise ValueError("Failed to load HDF5 file.")
64
+ return tensors
65
+
66
+
67
+ class ModelLoader:
68
+ """Utility class for loading and converting checkpoints to ODML transformer layer format."""
69
+
70
+ @dataclass
71
+ class TensorNames:
72
+ conv1D_0: str = None
73
+ conv1D_1: str = None
74
+ conv1D_2: str = None
75
+ group_norm: str = None
76
+
77
+ def __init__(self, file_name: str, names: TensorNames) -> None:
78
+ """ModelLoader constructor.
79
+
80
+ Can be used to load multiple models of the same type.
81
+
82
+ Args:
83
+ file_name (str): Path to the checkpoint. Can be a directory or an exact
84
+ file.
85
+ names (TensorNames): An instance of `TensorNames` to determine mappings.
86
+ """
87
+ self._file_name = file_name
88
+ self._names = names
89
+ self._loader = load_h5_statedict
90
+
91
+ def load(
92
+ self,
93
+ model: torch.nn.Module,
94
+ strict: bool = True,
95
+ ):
96
+ """Load the model from the checkpoint
97
+
98
+ Args:
99
+ model (torch.nn.Module): The pytorch model that needs to be loaded.
100
+ strict (bool, optional): Whether the converted keys are strictly
101
+ matched. Defaults to True.
102
+
103
+ Raises:
104
+ ValueError: If conversion results in unmapped tensors and strict mode is
105
+ enabled.
106
+ """
107
+ state = self._loader(self._file_name)
108
+
109
+ if isinstance(self._names, ModelLoader.TensorNames):
110
+ converted_state = self._do_load(model, state, self._names)
111
+ else:
112
+ raise ValueError(f"Unkown type for names: {type(self._names)}")
113
+
114
+ if strict and state:
115
+ raise ValueError(
116
+ "Failed to map all tensor. Remaining tensor are:"
117
+ f" {list(state.keys())}"
118
+ )
119
+ model.load_state_dict(converted_state, strict=strict)
120
+
121
+ def _do_load(self, model, state, names, additional_prefix=""):
122
+ """Load the model from the checkpoint
123
+
124
+ Args:
125
+ model (torch.nn.Module): The pytorch model that needs to be loaded.
126
+ state (Dict[str, torch.Tensor]): The pytorch state dictionary
127
+ names (TensorNames]): The TensorNames for the model we are loading.
128
+
129
+ Returns:
130
+ Dict[str, torch.Tensor]: Map of name to tensor for loading.
131
+ """
132
+ converted_state = dict()
133
+ if names.conv1D_0 is not None:
134
+ converted_state["conv1.weight"] = state.pop(f"{names.conv1D_0}/0")
135
+ if f"{names.conv1D_0}/1" in state:
136
+ converted_state["conv1.bias"] = state.pop(f"{names.conv1D_0}/1")
137
+
138
+ if names.conv1D_1 is not None:
139
+ converted_state["conv2.weight"] = state.pop(f"{names.conv1D_1}/0")
140
+ if f"{names.conv1D_1}/1" in state:
141
+ converted_state["conv2.bias"] = state.pop(f"{names.conv1D_1}/1")
142
+
143
+ if names.conv1D_2 is not None:
144
+ converted_state["conv3.weight"] = state.pop(f"{names.conv1D_2}/0")
145
+ if f"{names.conv1D_2}/1" in state:
146
+ converted_state["conv3.bias"] = state.pop(f"{names.conv1D_2}/1")
147
+
148
+ if names.group_norm is not None:
149
+ group_norm_name = names.group_norm
150
+ converted_state[f"group_norm.weight"] = state.pop(f"{group_norm_name}/0")
151
+ if f"{group_norm_name}/1" in state:
152
+ converted_state["group_norm.bias"] = state.pop(f"{group_norm_name}/1")
153
+
154
+ return converted_state
@@ -35,9 +35,7 @@ from . import lowerings
35
35
  LoweringContext = lowerings.context.LoweringContext
36
36
 
37
37
 
38
- def _build_flat_inputs(
39
- ctx: ir.Context, exported_program: torch.export.ExportedProgram
40
- ):
38
+ def _build_flat_inputs(exported_program: torch.export.ExportedProgram):
41
39
  """Build flattened inputs and metadata from exported program's signature."""
42
40
  placeholder_nodes = [
43
41
  n for n in exported_program.graph.nodes if n.op == "placeholder"
@@ -49,9 +47,11 @@ def _build_flat_inputs(
49
47
  ir_inputs = []
50
48
  tensor_metas = []
51
49
  for node, arg in zip(placeholder_nodes, export_flat_args):
52
- tensor_meta = node.meta.get("tensor_meta")
50
+ tensor_meta = node.meta.get("tensor_meta") or node.meta.get("val")
53
51
  if tensor_meta is None:
54
- raise RuntimeError(f"{type(arg)} (for {node.name}) is not a tensor")
52
+ raise RuntimeError(
53
+ f"{type(arg)} (for {node.name}) does not have tensor meta"
54
+ )
55
55
 
56
56
  tensor_metas.append(tensor_meta)
57
57
  # Assume all dynamic dimensions are unbounded.
@@ -63,7 +63,7 @@ def _build_flat_inputs(
63
63
  ir_inputs.append(
64
64
  ir.RankedTensorType.get(
65
65
  shape,
66
- export_utils.torch_dtype_to_ir_element_type(ctx, tensor_meta.dtype),
66
+ export_utils.torch_dtype_to_ir_element_type(tensor_meta.dtype),
67
67
  )
68
68
  )
69
69
  return tuple(ir_inputs), tuple(export_flat_args), tuple(tensor_metas)
@@ -258,6 +258,43 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
258
258
  rewrite_arange(node)
259
259
 
260
260
 
261
+ # TODO(b/331481564) Make this a ai_edge_torch FX pass.
262
+ def _convert_q_dq_per_channel_args_to_list(
263
+ exported_program: torch.export.ExportedProgram,
264
+ ):
265
+ """Resolve tensor inputs to Q/DQ ops as static number list for lowering.
266
+
267
+ This pass makes the ExportedProgram in a non-executable state. This pass must
268
+ be run after all run_decompositions calls.
269
+ """
270
+ placeholder_nodes = [
271
+ n for n in exported_program.graph.nodes if n.op == "placeholder"
272
+ ]
273
+ export_flat_args = _torch_future.graph_module_flat_inputs(
274
+ exported_program, *exported_program.example_inputs
275
+ )
276
+
277
+ placeholder_tensor = {
278
+ n: tensor for n, tensor in zip(placeholder_nodes, export_flat_args)
279
+ }
280
+
281
+ graph_module = exported_program.graph_module
282
+ for node in graph_module.graph.nodes:
283
+ if node.target in (
284
+ torch.ops.quantized_decomposed.quantize_per_channel.default,
285
+ torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
286
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
287
+ torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
288
+ ):
289
+ input, scale_node, zero_point_node = node.args[:3]
290
+ scale = placeholder_tensor[scale_node]
291
+ zero_point = placeholder_tensor[zero_point_node]
292
+
293
+ scale = scale.detach().numpy().tolist()
294
+ zero_point = zero_point.detach().numpy().tolist()
295
+ node.args = (input, scale, zero_point, *node.args[3:])
296
+
297
+
261
298
  def exported_program_to_mlir(
262
299
  exported_program: torch.export.ExportedProgram,
263
300
  ) -> MlirLowered:
@@ -270,6 +307,7 @@ def exported_program_to_mlir(
270
307
  exported_program = _torch_future.safe_run_decompositions(
271
308
  exported_program, lowerings.decompositions()
272
309
  )
310
+ _convert_q_dq_per_channel_args_to_list(exported_program)
273
311
 
274
312
  with export_utils.create_ir_context() as context, ir.Location.unknown():
275
313
 
@@ -277,7 +315,7 @@ def exported_program_to_mlir(
277
315
  lctx = LoweringContext(context, module)
278
316
  interpreter = LoweringInterpreter(exported_program.graph_module, lctx)
279
317
  ir_flat_inputs, export_flat_args, tensor_metas = _build_flat_inputs(
280
- context, exported_program
318
+ exported_program
281
319
  )
282
320
 
283
321
  # HACK: OSS MLIR pybinding could mysteriously transform func.func under
@@ -14,9 +14,9 @@
14
14
  # ==============================================================================
15
15
  """Utilities for ODML Torch export."""
16
16
 
17
- import functools
18
17
  import re
19
18
  from typing import Sequence, cast
19
+ from ai_edge_torch.odml_torch.lowerings import utils as lowering_utils
20
20
  import jax._src.interpreters.mlir
21
21
  from jax._src.lib.mlir import ir
22
22
  from jax._src.lib.mlir.dialects import func
@@ -47,7 +47,6 @@ def create_ir_context():
47
47
  # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
48
48
  context = jax._src.interpreters.mlir.make_ir_context()
49
49
  context.allow_unregistered_dialects = True
50
-
51
50
  return context
52
51
 
53
52
 
@@ -135,17 +134,7 @@ def build_ir_attr(val):
135
134
  return ir.StringAttr.get(str(val))
136
135
 
137
136
 
138
- def torch_dtype_to_ir_element_type(ctx, dtype):
139
- ty_get = {
140
- torch.double: ir.F64Type.get,
141
- torch.float32: ir.F32Type.get,
142
- torch.half: ir.F16Type.get,
143
- torch.long: functools.partial(ir.IntegerType.get_signless, 64),
144
- torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
145
- torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
146
- torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
147
- }.get(dtype)
148
- return ty_get(ctx)
137
+ torch_dtype_to_ir_element_type = lowering_utils.torch_dtype_to_ir_element_type
149
138
 
150
139
 
151
140
  def ir_element_type_to_torch_dtype(ty):
@@ -163,9 +163,7 @@ def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
163
163
  if aval is None:
164
164
  return result
165
165
 
166
- target_elty = export_utils.torch_dtype_to_ir_element_type(
167
- lctx.ir_context, aval.dtype
168
- )
166
+ target_elty = export_utils.torch_dtype_to_ir_element_type(aval.dtype)
169
167
  if result.type.element_type == target_elty:
170
168
  return result
171
169
  return stablehlo.convert(
@@ -17,6 +17,7 @@ from . import _batch_norm
17
17
  from . import _convolution
18
18
  from . import _jax_lowerings
19
19
  from . import _layer_norm
20
+ from . import _quantized_decomposed
20
21
  from . import context
21
22
  from . import registry
22
23
  from . import utils
@@ -227,9 +227,7 @@ def _aten_cat(lctx: LoweringContext, tensors, dim=0):
227
227
  if not non_empty_tensors:
228
228
  return utils.splat(
229
229
  0,
230
- export_utils.torch_dtype_to_ir_element_type(
231
- lctx.ir_context, out_aval.dtype
232
- ),
230
+ export_utils.torch_dtype_to_ir_element_type(out_aval.dtype),
233
231
  out_aval.shape,
234
232
  )
235
233
 
@@ -0,0 +1,174 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Lowerings for PT2E torch.ops.quantized_decomposed ops."""
16
+ from typing import Union, cast
17
+
18
+ from ai_edge_torch.odml_torch.lowerings import context
19
+ from ai_edge_torch.odml_torch.lowerings import utils
20
+ from jax._src.lib.mlir import ir
21
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
22
+ import torch
23
+ import torch.ao.quantization.fx._decomposed
24
+ import torch.utils._pytree as pytree
25
+
26
+ from . import registry
27
+
28
+ lower = registry.lower
29
+ LoweringContext = context.LoweringContext
30
+
31
+
32
+ def _uniform_quantized_type(
33
+ stored_type: str | ir.Type,
34
+ expressed_type: str | ir.Type,
35
+ *,
36
+ scale=float | list[float] | tuple[float],
37
+ zero_point=float | list[float] | tuple[float],
38
+ storage_type_min: int | None = None,
39
+ storage_type_max: int | None = None,
40
+ channel_axis: int | None = None,
41
+ channel_axis_size: int | None = None,
42
+ ):
43
+ """Polyfill for quant.UniformQuantizedType."""
44
+ if storage_type_min and storage_type_max:
45
+ storage_min_max = f"<{storage_type_min}:{storage_type_max}>"
46
+ else:
47
+ storage_min_max = ""
48
+
49
+ if channel_axis is not None:
50
+ # Per-channel quantization
51
+ # https://mlir.llvm.org/docs/Dialects/QuantDialect/#per-channel-quantization
52
+ assert isinstance(scale, (list, tuple))
53
+ assert isinstance(zero_point, (list, tuple))
54
+
55
+ if len(scale) == 1:
56
+ scale *= channel_axis_size
57
+ if len(zero_point) == 1:
58
+ zero_point *= channel_axis_size
59
+
60
+ assert len(scale) == len(zero_point) == channel_axis_size
61
+ scale_zp_strs = []
62
+ for s, zp in zip(scale, zero_point):
63
+ scale_zp_strs.append(f"{s}:{zp}")
64
+ scale_zp = "{" + ",".join(scale_zp_strs) + "}"
65
+ return ir.Type.parse(
66
+ f"!quant.uniform<{stored_type}{storage_min_max}:{expressed_type}:{channel_axis},{scale_zp}>"
67
+ )
68
+ else:
69
+ # Per-layer quantization
70
+ # https://mlir.llvm.org/docs/Dialects/QuantDialect/#per-layer-quantization
71
+ scale = pytree.tree_flatten([scale])[0][-1]
72
+ zero_point = pytree.tree_flatten([zero_point])[0][-1]
73
+ scale_zp = f"{scale}:{zero_point}"
74
+ return ir.Type.parse(
75
+ f"!quant.uniform<{stored_type}{storage_min_max}:{expressed_type},{scale_zp}>"
76
+ )
77
+
78
+
79
+ # Quant dialect is not registered in the Python MLIR pybinding used by
80
+ # odml-torch. Therefore, stablehlo.uniform_quantize/uniform_dequantize ops and
81
+ # quant types are represented in stablehlo.custom_call to pass MLIR verification
82
+ # and VHLO serialization before converter.
83
+ # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
84
+
85
+
86
+ # Schema:
87
+ # - quantized_decomposed::quantize_per_tensor(Tensor input, float scale,
88
+ # int zero_point, int quant_min, int quant_max,
89
+ # ScalarType dtype) -> Tensor
90
+ # - quantized_decomposed::quantize_per_tensor.tensor(Tensor input,
91
+ # Tensor scale, Tensor zero_point, int quant_min, int quant_max,
92
+ # ScalarType dtype) -> Tensor
93
+ #
94
+ # Scale and zero_point in tensors are automatically converted to list before
95
+ # lowering.
96
+ @lower(torch.ops.quantized_decomposed.quantize_per_tensor)
97
+ def _quantize_per_tensor(
98
+ lctx: LoweringContext,
99
+ input: ir.Value,
100
+ scale: Union[float, list[float]],
101
+ zero_point: Union[float, list[float]],
102
+ quant_min: int,
103
+ quant_max: int,
104
+ dtype: torch.dtype,
105
+ ):
106
+ input_ty = cast(ir.RankedTensorType, input.type)
107
+ qty = _uniform_quantized_type(
108
+ utils.torch_dtype_to_ir_element_type(dtype),
109
+ input_ty.element_type,
110
+ scale=scale,
111
+ zero_point=zero_point,
112
+ storage_type_min=quant_min,
113
+ storage_type_max=quant_max,
114
+ )
115
+ return stablehlo.custom_call(
116
+ call_target_name="odml_torch.uniform_quantize",
117
+ inputs=[input],
118
+ result=[input_ty],
119
+ backend_config=ir.StringAttr.get(
120
+ str(ir.RankedTensorType.get(input_ty.shape, qty))
121
+ ),
122
+ )
123
+
124
+
125
+ # Schema:
126
+ # - quantized_decomposed::quantize_per_channel(Tensor input, Tensor scales,
127
+ # Tensor zero_points, int axis, int quant_min, int quant_max,
128
+ # ScalarType dtype) -> Tensor
129
+ #
130
+ # Scale and zero_point in tensors are automatically converted to list before
131
+ # lowering.
132
+ @lower(torch.ops.quantized_decomposed.quantize_per_channel)
133
+ def _quantize_per_channel(
134
+ lctx: LoweringContext,
135
+ input: ir.Value,
136
+ scale: list[float],
137
+ zero_point: list[float],
138
+ axis: int,
139
+ quant_min: int,
140
+ quant_max: int,
141
+ dtype: torch.dtype,
142
+ ):
143
+ input_ty = cast(ir.RankedTensorType, input.type)
144
+ qty = _uniform_quantized_type(
145
+ utils.torch_dtype_to_ir_element_type(dtype),
146
+ input_ty.element_type,
147
+ scale=scale,
148
+ zero_point=zero_point,
149
+ channel_axis=axis,
150
+ channel_axis_size=input_ty.shape[axis],
151
+ storage_type_min=quant_min,
152
+ storage_type_max=quant_max,
153
+ )
154
+ return stablehlo.custom_call(
155
+ call_target_name="odml_torch.uniform_quantize",
156
+ inputs=[input],
157
+ result=[input_ty],
158
+ backend_config=ir.StringAttr.get(
159
+ str(ir.RankedTensorType.get(input_ty.shape, qty))
160
+ ),
161
+ )
162
+
163
+
164
+ @lower(torch.ops.quantized_decomposed.dequantize_per_tensor)
165
+ @lower(torch.ops.quantized_decomposed.dequantize_per_channel)
166
+ def _dequantize(lctx: LoweringContext, input: ir.Value, *args, **kwargs):
167
+ result_meta = lctx.node.meta.get("tensor_meta")
168
+ result_elty = utils.torch_dtype_to_ir_element_type(result_meta.dtype)
169
+
170
+ return stablehlo.custom_call(
171
+ call_target_name="odml_torch.uniform_dequantize",
172
+ inputs=[input],
173
+ result=[ir.RankedTensorType.get(result_meta.shape, result_elty)],
174
+ )
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Utilities for building MLIR lowerings."""
16
16
 
17
+ import functools
17
18
  import numbers
18
19
  from typing import Any
19
20
  from typing import Optional
@@ -21,6 +22,21 @@ from typing import Optional
21
22
  from jax._src.lib.mlir import ir
22
23
  from jax._src.lib.mlir.dialects import hlo as stablehlo
23
24
  import numpy as np
25
+ import torch
26
+
27
+
28
+ def torch_dtype_to_ir_element_type(dtype):
29
+ ty_get = {
30
+ torch.double: ir.F64Type.get,
31
+ torch.float32: ir.F32Type.get,
32
+ torch.half: ir.F16Type.get,
33
+ torch.long: functools.partial(ir.IntegerType.get_signless, 64),
34
+ torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
35
+ torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
36
+ torch.int8: functools.partial(ir.IntegerType.get_signless, 8),
37
+ torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
38
+ }[dtype]
39
+ return ty_get()
24
40
 
25
41
 
26
42
  def splat(val, ty, shape=tuple(), *, loc: Optional[Any] = None):
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241201"
16
+ __version__ = "0.3.0.dev20241205"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241201
3
+ Version: 0.3.0.dev20241205
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=OUqcy-x2l3EVJNsWANXG1NaPkhKDz4-EkU_yVTe0f1Y,706
6
+ ai_edge_torch/version.py,sha256=UKNQIv9LGNIpDQkZXBrHuhDFIYET3G8pLZ5njXu6KJc,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -26,7 +26,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
26
26
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
27
27
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
28
28
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
29
- ai_edge_torch/_convert/test/test_convert.py,sha256=yXfeWDw9u_rTS3B6kvvFPo5E4XNT3zKTSLFSBSAI9Fc,15502
29
+ ai_edge_torch/_convert/test/test_convert.py,sha256=v6AhfWqRBuHT7uBDueTbntaQtDSMMrvQOqlIDXNUaMA,17250
30
30
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
31
31
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
32
32
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
@@ -44,20 +44,22 @@ ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=bkq2Zk
44
44
  ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=-n79r6yFnCACpms5eMkXNpyQsCn2PYVRdB-jOoIqn14,2227
45
45
  ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
46
46
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
48
- ai_edge_torch/generative/examples/gemma/convert_gemma2_multi_prefills.py,sha256=6d9wG5MnStEys34_gFXwKTMRXUBFLTW1jEzCoWkAtwM,2224
49
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
47
+ ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=mrG96_WEGD4NQ4uFEKrHRMAQvVVliOcj1zbI3drGDjI,2199
48
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=I_tvwCYmtf08D1HqDxYx7dpvj2q5_eaYnuI_3rI6Dlw,2201
50
49
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
51
50
  ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
52
51
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
53
52
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
54
53
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
55
54
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
56
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=P0-pByTM5tslE23ILgo7nd0nOGE25ciBRG5wKJj0bBk,2411
55
+ ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=Brb83sbqBfStUiIZFhfWnYtN7LcNmkKyFn96cZK4sGo,2426
57
56
  ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaXEJ3RLLwkHii7to7UhVo,6854
58
57
  ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
58
+ ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
59
+ ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
60
+ ai_edge_torch/generative/examples/moonshine/moonshine.py,sha256=nZ2b8u4TmsB5sgdClgAuH8E78bcTv9RCnF9666HqP2M,3394
59
61
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
60
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
62
+ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=-qDBu3bjUq0jx73SPDMsPIBP0BT1nA_0UgtFKeSuM18,2213
61
63
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
62
64
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
63
65
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -69,18 +71,18 @@ ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM
69
71
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
70
72
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
71
73
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
72
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
73
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
74
+ ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=ruY-LLwpqBqVZ5z9h_sewYj04ukWRG4j804tUAyDdnA,2186
75
+ ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=UdMk1SSkcWpv8gosUylx3JSCxdOJBjZNhuQQtT4-Ono,2184
74
76
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=nbivDwZREd-sypy_ittO59-yaAdPvHv1YEV6Fo5buCo,3341
75
77
  ai_edge_torch/generative/examples/phi/phi3.py,sha256=GkHOaYfsFEbHvfZCaLlb3Us_h19ezqPDUakoz_DiG9A,7123
76
78
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
77
79
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
78
80
  ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
79
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=QAAVoSKDVf2rHAChzumGloVCWIU0Oe5UYKgv3T192Iw,2496
81
+ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=1M3DTkf536TCLYcQB1lu-3TxQ6mV03dFhTdbk0p8i84,2523
80
82
  ai_edge_torch/generative/examples/qwen/qwen.py,sha256=oYm9hhALUQ4uOn-PO1bF7fCIGP8EWRNK4zClkx2RQs8,4070
81
83
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
82
84
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
83
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
85
+ ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=56CzCjyp9xh_2ZtXKN9tlEv6GayeSc4giTIZsi2Q59E,2194
84
86
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=M5qAcSUE5gxOSfq24a8lZku9kgvmlFCyIBar3kF2XEk,2570
85
87
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
86
88
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -107,7 +109,7 @@ ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNH
107
109
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
108
110
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=rRodLr-hEqAs_-8x06O8qO-hJ_cqr2AfhJZ9DCptvwo,4332
109
111
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
112
+ ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=WmEshoN9HgNLbV7NTjxdqWz9Olcim6r_vo4R9eYE98I,2228
111
113
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=10X8HwPx4akzclnIMOBNItKQemhRbvxBbTo7nwZtWjM,2650
112
114
  ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
113
115
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
@@ -139,13 +141,14 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FP
139
141
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
140
142
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=aZFaheg2sq7rEccch1TZM6W4BSfpJZjrM9Gyp4hVGYs,6351
141
143
  ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=xWV9O2wuRHc4VNBWuWipiuqXa3AJhiV1nmjewAZHHWM,11177
142
- ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
144
+ ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
143
145
  ai_edge_torch/generative/test/utils.py,sha256=eQ-hjd1eXuHJF3SJK6_CrjgOZVzmG_4VEdH7Z1gH_lA,1897
144
146
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
145
147
  ai_edge_torch/generative/utilities/converter.py,sha256=S14STbyxV6A9HKy1BdUo49f2jS6Ij0RL9mVAFUMWYV8,5291
146
148
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
147
149
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
148
150
  ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
151
+ ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
149
152
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
150
153
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
151
154
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
@@ -167,8 +170,8 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1
167
170
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
168
171
  ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
169
172
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
170
- ai_edge_torch/odml_torch/export.py,sha256=4xwrsDeOAgzoB9m7EeNsBj6dC5Ajtn5aKDRQkdHxa-o,11584
171
- ai_edge_torch/odml_torch/export_utils.py,sha256=q84U69ZQ82hLXw-xncJ8IW-K71Xux-NWlzZTs7hdZWA,5127
173
+ ai_edge_torch/odml_torch/export.py,sha256=dgnNGBVkHBz0brlWALX2hGXpQ4YzCKdwbkF4oAfEu4I,13062
174
+ ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
172
175
  ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
173
176
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
174
177
  ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
@@ -177,17 +180,18 @@ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu
177
180
  ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
178
181
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
179
182
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
180
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=oQo9nxH08NnEDeZaGoCUk1kRtoEOM_f0DUOyd9nfxjg,6673
183
+ ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
181
184
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
182
- ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
183
- ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=eH9eJqFO-BI9l4WdXfjsItODPRa18SAR_qSvJ6-7gxc,9987
185
+ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=1lMKPoStK3SUA8yYTPZBRhESN33BghGXnfqOOg4oeVk,995
186
+ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=ufvnaAh6rM_yfoc8ybI3VErHEVBv5W_p4iOe9slfwKM,9948
184
187
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
185
188
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
186
189
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=4UyNyaR2W-vCOvj-P5lywQ1_RfLIxVE7J_GONI6CQvI,10718
187
190
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
191
+ ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=rFmzqcdjYrwhcxH8j9zCFStPy21HFF7hkUV_GQ8FPAk,6056
188
192
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
189
193
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=itTt8MLbq2LoHTzRidCF2TTbh0TP7L836u99qCjP3FA,2953
190
- ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
194
+ ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
191
195
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
192
196
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
193
197
  ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
@@ -196,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
196
200
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
197
201
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
198
202
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
199
- ai_edge_torch_nightly-0.3.0.dev20241201.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
200
- ai_edge_torch_nightly-0.3.0.dev20241201.dist-info/METADATA,sha256=Imp9XnPMYxNMskFOdV5J8IzWfU4Ox84qo5-ghCYKDJU,1897
201
- ai_edge_torch_nightly-0.3.0.dev20241201.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
202
- ai_edge_torch_nightly-0.3.0.dev20241201.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
203
- ai_edge_torch_nightly-0.3.0.dev20241201.dist-info/RECORD,,
203
+ ai_edge_torch_nightly-0.3.0.dev20241205.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
204
+ ai_edge_torch_nightly-0.3.0.dev20241205.dist-info/METADATA,sha256=q0YQggf3bWL7q67R2IpsvyUlncZRjjJRfsqL8yLNJ_Y,1897
205
+ ai_edge_torch_nightly-0.3.0.dev20241205.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
206
+ ai_edge_torch_nightly-0.3.0.dev20241205.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
207
+ ai_edge_torch_nightly-0.3.0.dev20241205.dist-info/RECORD,,