ai-edge-torch-nightly 0.3.0.dev20241129__py3-none-any.whl → 0.3.0.dev20241204__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +48 -0
  2. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +6 -6
  3. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +6 -6
  4. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +6 -6
  5. ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/{gemma/convert_gemma2_multi_prefills.py → moonshine/convert_moonshine_to_tflite.py} +11 -29
  7. ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
  8. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +9 -6
  9. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +6 -6
  10. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +6 -6
  11. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +8 -6
  12. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +6 -6
  13. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +8 -6
  14. ai_edge_torch/generative/test/test_quantize.py +5 -0
  15. ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
  16. ai_edge_torch/odml_torch/export.py +45 -7
  17. ai_edge_torch/odml_torch/export_utils.py +2 -13
  18. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +1 -3
  19. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  20. ai_edge_torch/odml_torch/lowerings/_basic.py +1 -3
  21. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +174 -0
  22. ai_edge_torch/odml_torch/lowerings/utils.py +16 -0
  23. ai_edge_torch/version.py +1 -1
  24. {ai_edge_torch_nightly-0.3.0.dev20241129.dist-info → ai_edge_torch_nightly-0.3.0.dev20241204.dist-info}/METADATA +1 -1
  25. {ai_edge_torch_nightly-0.3.0.dev20241129.dist-info → ai_edge_torch_nightly-0.3.0.dev20241204.dist-info}/RECORD +28 -24
  26. {ai_edge_torch_nightly-0.3.0.dev20241129.dist-info → ai_edge_torch_nightly-0.3.0.dev20241204.dist-info}/LICENSE +0 -0
  27. {ai_edge_torch_nightly-0.3.0.dev20241129.dist-info → ai_edge_torch_nightly-0.3.0.dev20241204.dist-info}/WHEEL +0 -0
  28. {ai_edge_torch_nightly-0.3.0.dev20241129.dist-info → ai_edge_torch_nightly-0.3.0.dev20241204.dist-info}/top_level.txt +0 -0
@@ -21,10 +21,12 @@ from typing import Tuple
21
21
  import ai_edge_torch
22
22
  from ai_edge_torch import config
23
23
  from ai_edge_torch._convert import conversion_utils
24
+ from ai_edge_torch.quantize import pt2e_quantizer
24
25
  from ai_edge_torch.testing import model_coverage
25
26
  import numpy as np
26
27
  import torch
27
28
  from torch import nn
29
+ from torch.ao.quantization import quantize_pt2e
28
30
  import torchvision
29
31
 
30
32
  from absl.testing import absltest as googletest
@@ -506,6 +508,52 @@ class TestConvert(googletest.TestCase):
506
508
  model_coverage.compare_tflite_torch(edge_model, torch_module, args)
507
509
  )
508
510
 
511
+ def test_convert_resnet18_pt2e_per_layer(self):
512
+ # Step 1: export resnet18
513
+ args = (torch.randn(1, 3, 224, 224),)
514
+ m = torchvision.models.resnet18().eval()
515
+ m = torch._export.capture_pre_autograd_graph(m, args)
516
+
517
+ # Step 2: Insert observers or fake quantize modules
518
+ quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
519
+ pt2e_quantizer.get_symmetric_quantization_config(is_per_channel=False)
520
+ )
521
+ m = quantize_pt2e.prepare_pt2e(m, quantizer)
522
+
523
+ # Step 3: Quantize the model
524
+ m = quantize_pt2e.convert_pt2e(m, fold_quantize=False)
525
+
526
+ # pylint: disable=broad-except
527
+ try:
528
+ ai_edge_torch.convert(m, args)
529
+ except Exception as err:
530
+ self.fail(f"PT2E conversion failed: {err}")
531
+ # pylint: enable=broad-except
532
+
533
+ def test_convert_resnet18_pt2e_per_channel(self):
534
+ # Step 1: export resnet18
535
+ args = (torch.randn(1, 3, 224, 224),)
536
+ m = torchvision.models.resnet18().eval()
537
+ m = torch._export.capture_pre_autograd_graph(m, args)
538
+
539
+ # Step 2: Insert observers or fake quantize modules
540
+ quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
541
+ pt2e_quantizer.get_symmetric_quantization_config(is_per_channel=True)
542
+ )
543
+ m = quantize_pt2e.prepare_pt2e(m, quantizer)
544
+ # Step 3: Run through example inputs, otherwise per-channel
545
+ # quant may have scalar scale/zero_point
546
+ m(*args)
547
+ # Step 4: Quantize the model
548
+ m = quantize_pt2e.convert_pt2e(m, fold_quantize=False)
549
+
550
+ # pylint: disable=broad-except
551
+ try:
552
+ ai_edge_torch.convert(m, args)
553
+ except Exception as err:
554
+ self.fail(f"PT2E conversion failed: {err}")
555
+ # pylint: enable=broad-except
556
+
509
557
 
510
558
  if __name__ == "__main__":
511
559
  googletest.main()
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
36
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
+ 'prefill_seq_lens',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -55,11 +55,11 @@ def main(_):
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
57
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'gemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
58
+ output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
59
  converter.convert_to_tflite(
60
60
  pytorch_model,
61
61
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
62
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
63
  quantize=_QUANTIZE.value,
64
64
  )
65
65
 
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
36
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
+ 'prefill_seq_lens',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -55,11 +55,11 @@ def main(_):
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
57
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'gemma2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
58
+ output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
59
  converter.convert_to_tflite(
60
60
  pytorch_model,
61
61
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
62
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
63
  quantize=_QUANTIZE.value,
64
64
  )
65
65
 
@@ -39,10 +39,10 @@ _TFLITE_PATH = flags.DEFINE_string(
39
39
  '/tmp/',
40
40
  'The tflite file path to export.',
41
41
  )
42
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
43
- 'prefill_seq_len',
44
- 1024,
45
- 'The maximum size of prefill input tensor.',
42
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
+ 'prefill_seq_lens',
44
+ (8, 64, 128, 256, 512, 1024),
45
+ 'List of the maximum sizes of prefill input tensors.',
46
46
  )
47
47
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
48
  'kv_cache_max_len',
@@ -66,11 +66,11 @@ def main(_):
66
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
67
67
  )
68
68
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
69
- output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
69
+ output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
70
70
  converter.convert_to_tflite(
71
71
  pytorch_model,
72
72
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
73
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
73
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
74
74
  quantize=_QUANTIZE.value,
75
75
  )
76
76
 
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -13,19 +13,21 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example to convert a Gemma2 model to multiple prefill length tflite model."""
16
+ """Example of converting a Moonshine model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
19
  import pathlib
20
20
 
21
21
  from absl import app
22
22
  from absl import flags
23
- from ai_edge_torch.generative.examples.gemma import gemma2
23
+ import ai_edge_torch
24
+ from ai_edge_torch.generative.examples.moonshine import moonshine
24
25
  from ai_edge_torch.generative.utilities import converter
26
+ import torch
25
27
 
26
28
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
29
  'checkpoint_path',
28
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
30
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/moonshine'),
29
31
  'The path to the model checkpoint, or directory holding the checkpoint.',
30
32
  )
31
33
  _TFLITE_PATH = flags.DEFINE_string(
@@ -33,35 +35,15 @@ _TFLITE_PATH = flags.DEFINE_string(
33
35
  '/tmp/',
34
36
  'The tflite file path to export.',
35
37
  )
36
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
- 'prefill_seq_lens',
38
- (8, 64, 128, 256, 512, 1024),
39
- 'List of the maximum sizes of prefill input tensors.',
40
- )
41
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
- 'kv_cache_max_len',
43
- 1280,
44
- 'The maximum size of KV cache buffer, including both prefill and decode.',
45
- )
46
- _QUANTIZE = flags.DEFINE_bool(
47
- 'quantize',
48
- True,
49
- 'Whether the model should be quantized.',
50
- )
51
38
 
52
39
 
53
40
  def main(_):
54
- pytorch_model = gemma2.build_2b_model(
55
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
- )
57
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'gemma2_{quant_suffix}_multi-prefill-seq_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
- converter.convert_to_tflite(
60
- pytorch_model,
61
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
- quantize=_QUANTIZE.value,
64
- )
41
+ p_model = moonshine.build_preprocessor(_CHECKPOINT_PATH.value)
42
+ output_filename = f'moonshine_preprocessor.tflite'
43
+ _input = torch.randn((1, 1, 159414), dtype=torch.float)
44
+ edge_model = ai_edge_torch.convert(p_model, (_input,), quant_config=None)
45
+ tflite_path = os.path.join(_TFLITE_PATH.value, output_filename)
46
+ edge_model.export(tflite_path)
65
47
 
66
48
 
67
49
  if __name__ == '__main__':
@@ -0,0 +1,103 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building the Moonshine model."""
17
+
18
+ import os
19
+ import pathlib
20
+ from typing import Optional, Tuple
21
+ from absl import app
22
+ from ai_edge_torch.generative.layers import attention
23
+ from ai_edge_torch.generative.layers import builder
24
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
+ import ai_edge_torch.generative.layers.model_config as cfg
27
+ import ai_edge_torch.generative.layers.normalization as normalization
28
+ import ai_edge_torch.generative.utilities.moonshine_loader as loading_utils
29
+ import h5py
30
+ import torch
31
+ from torch import nn
32
+ import torch.nn as nn
33
+
34
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
35
+ conv1D_0="layers/sequential/layers/conv1d/vars",
36
+ conv1D_1="layers/sequential/layers/conv1d_1/vars",
37
+ conv1D_2="layers/sequential/layers/conv1d_2/vars",
38
+ group_norm="layers/sequential/layers/group_normalization/vars",
39
+ )
40
+
41
+
42
+ class AudioPreprocessor(nn.Module):
43
+
44
+ def __init__(self, dim):
45
+ super(AudioPreprocessor, self).__init__()
46
+ self.conv1 = nn.Conv1d(
47
+ in_channels=1, out_channels=dim, kernel_size=127, stride=64, bias=False
48
+ )
49
+ self.tanh = nn.Tanh()
50
+ self.group_norm = normalization.GroupNorm(group_num=1, dim=dim, eps=1e-5)
51
+ self.conv2 = nn.Conv1d(
52
+ in_channels=dim,
53
+ out_channels=2 * dim,
54
+ kernel_size=7,
55
+ stride=3,
56
+ padding=0, # Equivalent to padding="valid"
57
+ )
58
+ self.gelu1 = nn.GELU()
59
+ self.conv3 = nn.Conv1d(
60
+ in_channels=2 * dim,
61
+ out_channels=dim,
62
+ kernel_size=3,
63
+ stride=2,
64
+ padding=0, # Equivalent to padding="valid"
65
+ )
66
+ self.gelu2 = nn.GELU()
67
+
68
+ def forward(self, inputs):
69
+ x = self.conv1(inputs)
70
+ x = self.tanh(x)
71
+ x = self.group_norm(x)
72
+ x = self.conv2(x)
73
+ x = self.gelu1(x)
74
+ x = self.conv3(x)
75
+ x = self.gelu2(x)
76
+ return x
77
+
78
+
79
+ def build_preprocessor(checkpoint_path: str, **kwargs) -> nn.Module:
80
+ ap = AudioPreprocessor(dim=416)
81
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
82
+ loader.load(ap, strict=True)
83
+ return ap
84
+
85
+
86
+ def main(_):
87
+ # TODO(b/375421767) Remove golden checks once full model is implemented.
88
+ HF_PATH = os.path.join(pathlib.Path.home(), "Downloads/llm_data/moonshine")
89
+
90
+ test_data_path = pathlib.Path(__file__).parent.resolve()
91
+ INPUT_PATH = test_data_path / "data" / "pp_input.pt")
92
+ GOLDEN_PATH = test_data_path / "data" / "pp_output.pt")
93
+
94
+ ap = build_preprocessor(HF_PATH)
95
+ ap.eval()
96
+ inputs = torch.load(INPUT_PATH).reshape((1, 1, 159414))
97
+ out = ap(inputs)
98
+ golden = torch.load(GOLDEN_PATH).transpose(1, 2)
99
+ assert torch.allclose(out, golden, atol=1e-4, rtol=1e-4)
100
+
101
+
102
+ if __name__ == "__main__":
103
+ app.run(main)
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
36
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
+ 'prefill_seq_lens',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -55,11 +55,14 @@ def main(_):
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
57
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'openelm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
58
+ output_filename = (
59
+ f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
+ )
61
+
59
62
  converter.convert_to_tflite(
60
63
  pytorch_model,
61
64
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
65
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
66
  quantize=_QUANTIZE.value,
64
67
  )
65
68
 
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
36
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
+ 'prefill_seq_lens',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -55,11 +55,11 @@ def main(_):
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
57
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'phi3_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
58
+ output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
59
  converter.convert_to_tflite(
60
60
  pytorch_model,
61
61
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
62
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
63
  quantize=_QUANTIZE.value,
64
64
  )
65
65
 
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
36
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
+ 'prefill_seq_lens',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -55,11 +55,11 @@ def main(_):
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
57
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'phi2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
58
+ output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
59
  converter.convert_to_tflite(
60
60
  pytorch_model,
61
61
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
62
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
63
  quantize=_QUANTIZE.value,
64
64
  )
65
65
 
@@ -39,10 +39,10 @@ _TFLITE_PATH = flags.DEFINE_string(
39
39
  '/tmp/',
40
40
  'The tflite file path to export.',
41
41
  )
42
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
43
- 'prefill_seq_len',
44
- 1024,
45
- 'The maximum size of prefill input tensor.',
42
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
+ 'prefill_seq_lens',
44
+ (8, 64, 128, 256, 512, 1024),
45
+ 'List of the maximum sizes of prefill input tensors.',
46
46
  )
47
47
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
48
  'kv_cache_max_len',
@@ -68,11 +68,13 @@ def main(_):
68
68
  )
69
69
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
70
70
  model_size = _MODEL_SIZE.value.replace('.', '_')
71
- output_filename = f'qwen_{model_size}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
71
+ output_filename = (
72
+ f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
73
+ )
72
74
  converter.convert_to_tflite(
73
75
  pytorch_model,
74
76
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
75
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
77
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
76
78
  quantize=_QUANTIZE.value,
77
79
  )
78
80
 
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
36
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
+ 'prefill_seq_lens',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -55,11 +55,11 @@ def main(_):
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
57
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'smollm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
58
+ output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
59
  converter.convert_to_tflite(
60
60
  pytorch_model,
61
61
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
62
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
63
  quantize=_QUANTIZE.value,
64
64
  )
65
65
 
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
33
33
  '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
36
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
37
+ 'prefill_seq_lens',
38
+ (8, 64, 128, 256, 512, 1024),
39
+ 'List of the maximum sizes of prefill input tensors.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
@@ -55,11 +55,13 @@ def main(_):
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
57
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'tinyllama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
58
+ output_filename = (
59
+ f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
+ )
59
61
  converter.convert_to_tflite(
60
62
  pytorch_model,
61
63
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
64
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
65
  quantize=_QUANTIZE.value,
64
66
  )
65
67
 
@@ -91,6 +91,11 @@ class TestVerifyRecipes(parameterized.TestCase):
91
91
  class TestQuantizeConvert(parameterized.TestCase):
92
92
  """Test conversion with quantization."""
93
93
 
94
+ def setUp(self):
95
+ super().setUp()
96
+ torch.manual_seed(0)
97
+ torch._dynamo.reset()
98
+
94
99
  def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
95
100
  return quant_config.QuantConfig(
96
101
  generative_recipe=quant_recipe.GenerativeQuantRecipe(
@@ -0,0 +1,154 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Common utility functions for data loading etc.
16
+ from dataclasses import dataclass
17
+ import glob
18
+ import os
19
+ from typing import Callable, Dict
20
+
21
+ import h5py
22
+ import torch
23
+
24
+
25
+ def transpose_if_needed(t):
26
+ """We assume the file is from Keras, i.e. channel last format."""
27
+ if len(t.shape) > 2:
28
+ return t.permute(2, 1, 0)
29
+ return t
30
+
31
+
32
+ def load_h5_statedict(full_path: str):
33
+ """Loads the HDF5 DataSets into a single dctionary.
34
+
35
+ Args:
36
+ full_path (string): the HDF5 filename or directory that contains the HDF5
37
+ files.
38
+
39
+ Returns:
40
+ A state dictionary contating loaded tensors.
41
+
42
+ Raises:
43
+ ValueError: If no tensors are loaded from the provided directory or file.
44
+ """
45
+ pattern = (
46
+ os.path.join(full_path, "*.h5") if os.path.isdir(full_path) else full_path
47
+ )
48
+ files = []
49
+ for file in glob.glob(pattern):
50
+ files.append(file)
51
+
52
+ tensors = {}
53
+
54
+ def collect_datasets(name, obj):
55
+ if isinstance(obj, h5py.Dataset):
56
+ tensors[name] = transpose_if_needed(torch.from_numpy(obj[:]))
57
+
58
+ for file in files:
59
+ with h5py.File(file) as f:
60
+ f.visititems(collect_datasets)
61
+
62
+ if not tensors:
63
+ raise ValueError("Failed to load HDF5 file.")
64
+ return tensors
65
+
66
+
67
+ class ModelLoader:
68
+ """Utility class for loading and converting checkpoints to ODML transformer layer format."""
69
+
70
+ @dataclass
71
+ class TensorNames:
72
+ conv1D_0: str = None
73
+ conv1D_1: str = None
74
+ conv1D_2: str = None
75
+ group_norm: str = None
76
+
77
+ def __init__(self, file_name: str, names: TensorNames) -> None:
78
+ """ModelLoader constructor.
79
+
80
+ Can be used to load multiple models of the same type.
81
+
82
+ Args:
83
+ file_name (str): Path to the checkpoint. Can be a directory or an exact
84
+ file.
85
+ names (TensorNames): An instance of `TensorNames` to determine mappings.
86
+ """
87
+ self._file_name = file_name
88
+ self._names = names
89
+ self._loader = load_h5_statedict
90
+
91
+ def load(
92
+ self,
93
+ model: torch.nn.Module,
94
+ strict: bool = True,
95
+ ):
96
+ """Load the model from the checkpoint
97
+
98
+ Args:
99
+ model (torch.nn.Module): The pytorch model that needs to be loaded.
100
+ strict (bool, optional): Whether the converted keys are strictly
101
+ matched. Defaults to True.
102
+
103
+ Raises:
104
+ ValueError: If conversion results in unmapped tensors and strict mode is
105
+ enabled.
106
+ """
107
+ state = self._loader(self._file_name)
108
+
109
+ if isinstance(self._names, ModelLoader.TensorNames):
110
+ converted_state = self._do_load(model, state, self._names)
111
+ else:
112
+ raise ValueError(f"Unkown type for names: {type(self._names)}")
113
+
114
+ if strict and state:
115
+ raise ValueError(
116
+ "Failed to map all tensor. Remaining tensor are:"
117
+ f" {list(state.keys())}"
118
+ )
119
+ model.load_state_dict(converted_state, strict=strict)
120
+
121
+ def _do_load(self, model, state, names, additional_prefix=""):
122
+ """Load the model from the checkpoint
123
+
124
+ Args:
125
+ model (torch.nn.Module): The pytorch model that needs to be loaded.
126
+ state (Dict[str, torch.Tensor]): The pytorch state dictionary
127
+ names (TensorNames]): The TensorNames for the model we are loading.
128
+
129
+ Returns:
130
+ Dict[str, torch.Tensor]: Map of name to tensor for loading.
131
+ """
132
+ converted_state = dict()
133
+ if names.conv1D_0 is not None:
134
+ converted_state["conv1.weight"] = state.pop(f"{names.conv1D_0}/0")
135
+ if f"{names.conv1D_0}/1" in state:
136
+ converted_state["conv1.bias"] = state.pop(f"{names.conv1D_0}/1")
137
+
138
+ if names.conv1D_1 is not None:
139
+ converted_state["conv2.weight"] = state.pop(f"{names.conv1D_1}/0")
140
+ if f"{names.conv1D_1}/1" in state:
141
+ converted_state["conv2.bias"] = state.pop(f"{names.conv1D_1}/1")
142
+
143
+ if names.conv1D_2 is not None:
144
+ converted_state["conv3.weight"] = state.pop(f"{names.conv1D_2}/0")
145
+ if f"{names.conv1D_2}/1" in state:
146
+ converted_state["conv3.bias"] = state.pop(f"{names.conv1D_2}/1")
147
+
148
+ if names.group_norm is not None:
149
+ group_norm_name = names.group_norm
150
+ converted_state[f"group_norm.weight"] = state.pop(f"{group_norm_name}/0")
151
+ if f"{group_norm_name}/1" in state:
152
+ converted_state["group_norm.bias"] = state.pop(f"{group_norm_name}/1")
153
+
154
+ return converted_state
@@ -35,9 +35,7 @@ from . import lowerings
35
35
  LoweringContext = lowerings.context.LoweringContext
36
36
 
37
37
 
38
- def _build_flat_inputs(
39
- ctx: ir.Context, exported_program: torch.export.ExportedProgram
40
- ):
38
+ def _build_flat_inputs(exported_program: torch.export.ExportedProgram):
41
39
  """Build flattened inputs and metadata from exported program's signature."""
42
40
  placeholder_nodes = [
43
41
  n for n in exported_program.graph.nodes if n.op == "placeholder"
@@ -49,9 +47,11 @@ def _build_flat_inputs(
49
47
  ir_inputs = []
50
48
  tensor_metas = []
51
49
  for node, arg in zip(placeholder_nodes, export_flat_args):
52
- tensor_meta = node.meta.get("tensor_meta")
50
+ tensor_meta = node.meta.get("tensor_meta") or node.meta.get("val")
53
51
  if tensor_meta is None:
54
- raise RuntimeError(f"{type(arg)} (for {node.name}) is not a tensor")
52
+ raise RuntimeError(
53
+ f"{type(arg)} (for {node.name}) does not have tensor meta"
54
+ )
55
55
 
56
56
  tensor_metas.append(tensor_meta)
57
57
  # Assume all dynamic dimensions are unbounded.
@@ -63,7 +63,7 @@ def _build_flat_inputs(
63
63
  ir_inputs.append(
64
64
  ir.RankedTensorType.get(
65
65
  shape,
66
- export_utils.torch_dtype_to_ir_element_type(ctx, tensor_meta.dtype),
66
+ export_utils.torch_dtype_to_ir_element_type(tensor_meta.dtype),
67
67
  )
68
68
  )
69
69
  return tuple(ir_inputs), tuple(export_flat_args), tuple(tensor_metas)
@@ -258,6 +258,43 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
258
258
  rewrite_arange(node)
259
259
 
260
260
 
261
+ # TODO(b/331481564) Make this a ai_edge_torch FX pass.
262
+ def _convert_q_dq_per_channel_args_to_list(
263
+ exported_program: torch.export.ExportedProgram,
264
+ ):
265
+ """Resolve tensor inputs to Q/DQ ops as static number list for lowering.
266
+
267
+ This pass makes the ExportedProgram in a non-executable state. This pass must
268
+ be run after all run_decompositions calls.
269
+ """
270
+ placeholder_nodes = [
271
+ n for n in exported_program.graph.nodes if n.op == "placeholder"
272
+ ]
273
+ export_flat_args = _torch_future.graph_module_flat_inputs(
274
+ exported_program, *exported_program.example_inputs
275
+ )
276
+
277
+ placeholder_tensor = {
278
+ n: tensor for n, tensor in zip(placeholder_nodes, export_flat_args)
279
+ }
280
+
281
+ graph_module = exported_program.graph_module
282
+ for node in graph_module.graph.nodes:
283
+ if node.target in (
284
+ torch.ops.quantized_decomposed.quantize_per_channel.default,
285
+ torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
286
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
287
+ torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
288
+ ):
289
+ input, scale_node, zero_point_node = node.args[:3]
290
+ scale = placeholder_tensor[scale_node]
291
+ zero_point = placeholder_tensor[zero_point_node]
292
+
293
+ scale = scale.detach().numpy().tolist()
294
+ zero_point = zero_point.detach().numpy().tolist()
295
+ node.args = (input, scale, zero_point, *node.args[3:])
296
+
297
+
261
298
  def exported_program_to_mlir(
262
299
  exported_program: torch.export.ExportedProgram,
263
300
  ) -> MlirLowered:
@@ -270,6 +307,7 @@ def exported_program_to_mlir(
270
307
  exported_program = _torch_future.safe_run_decompositions(
271
308
  exported_program, lowerings.decompositions()
272
309
  )
310
+ _convert_q_dq_per_channel_args_to_list(exported_program)
273
311
 
274
312
  with export_utils.create_ir_context() as context, ir.Location.unknown():
275
313
 
@@ -277,7 +315,7 @@ def exported_program_to_mlir(
277
315
  lctx = LoweringContext(context, module)
278
316
  interpreter = LoweringInterpreter(exported_program.graph_module, lctx)
279
317
  ir_flat_inputs, export_flat_args, tensor_metas = _build_flat_inputs(
280
- context, exported_program
318
+ exported_program
281
319
  )
282
320
 
283
321
  # HACK: OSS MLIR pybinding could mysteriously transform func.func under
@@ -14,9 +14,9 @@
14
14
  # ==============================================================================
15
15
  """Utilities for ODML Torch export."""
16
16
 
17
- import functools
18
17
  import re
19
18
  from typing import Sequence, cast
19
+ from ai_edge_torch.odml_torch.lowerings import utils as lowering_utils
20
20
  import jax._src.interpreters.mlir
21
21
  from jax._src.lib.mlir import ir
22
22
  from jax._src.lib.mlir.dialects import func
@@ -47,7 +47,6 @@ def create_ir_context():
47
47
  # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
48
48
  context = jax._src.interpreters.mlir.make_ir_context()
49
49
  context.allow_unregistered_dialects = True
50
-
51
50
  return context
52
51
 
53
52
 
@@ -135,17 +134,7 @@ def build_ir_attr(val):
135
134
  return ir.StringAttr.get(str(val))
136
135
 
137
136
 
138
- def torch_dtype_to_ir_element_type(ctx, dtype):
139
- ty_get = {
140
- torch.double: ir.F64Type.get,
141
- torch.float32: ir.F32Type.get,
142
- torch.half: ir.F16Type.get,
143
- torch.long: functools.partial(ir.IntegerType.get_signless, 64),
144
- torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
145
- torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
146
- torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
147
- }.get(dtype)
148
- return ty_get(ctx)
137
+ torch_dtype_to_ir_element_type = lowering_utils.torch_dtype_to_ir_element_type
149
138
 
150
139
 
151
140
  def ir_element_type_to_torch_dtype(ty):
@@ -163,9 +163,7 @@ def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
163
163
  if aval is None:
164
164
  return result
165
165
 
166
- target_elty = export_utils.torch_dtype_to_ir_element_type(
167
- lctx.ir_context, aval.dtype
168
- )
166
+ target_elty = export_utils.torch_dtype_to_ir_element_type(aval.dtype)
169
167
  if result.type.element_type == target_elty:
170
168
  return result
171
169
  return stablehlo.convert(
@@ -17,6 +17,7 @@ from . import _batch_norm
17
17
  from . import _convolution
18
18
  from . import _jax_lowerings
19
19
  from . import _layer_norm
20
+ from . import _quantized_decomposed
20
21
  from . import context
21
22
  from . import registry
22
23
  from . import utils
@@ -227,9 +227,7 @@ def _aten_cat(lctx: LoweringContext, tensors, dim=0):
227
227
  if not non_empty_tensors:
228
228
  return utils.splat(
229
229
  0,
230
- export_utils.torch_dtype_to_ir_element_type(
231
- lctx.ir_context, out_aval.dtype
232
- ),
230
+ export_utils.torch_dtype_to_ir_element_type(out_aval.dtype),
233
231
  out_aval.shape,
234
232
  )
235
233
 
@@ -0,0 +1,174 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Lowerings for PT2E torch.ops.quantized_decomposed ops."""
16
+ from typing import Union, cast
17
+
18
+ from ai_edge_torch.odml_torch.lowerings import context
19
+ from ai_edge_torch.odml_torch.lowerings import utils
20
+ from jax._src.lib.mlir import ir
21
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
22
+ import torch
23
+ import torch.ao.quantization.fx._decomposed
24
+ import torch.utils._pytree as pytree
25
+
26
+ from . import registry
27
+
28
+ lower = registry.lower
29
+ LoweringContext = context.LoweringContext
30
+
31
+
32
+ def _uniform_quantized_type(
33
+ stored_type: str | ir.Type,
34
+ expressed_type: str | ir.Type,
35
+ *,
36
+ scale=float | list[float] | tuple[float],
37
+ zero_point=float | list[float] | tuple[float],
38
+ storage_type_min: int | None = None,
39
+ storage_type_max: int | None = None,
40
+ channel_axis: int | None = None,
41
+ channel_axis_size: int | None = None,
42
+ ):
43
+ """Polyfill for quant.UniformQuantizedType."""
44
+ if storage_type_min and storage_type_max:
45
+ storage_min_max = f"<{storage_type_min}:{storage_type_max}>"
46
+ else:
47
+ storage_min_max = ""
48
+
49
+ if channel_axis is not None:
50
+ # Per-channel quantization
51
+ # https://mlir.llvm.org/docs/Dialects/QuantDialect/#per-channel-quantization
52
+ assert isinstance(scale, (list, tuple))
53
+ assert isinstance(zero_point, (list, tuple))
54
+
55
+ if len(scale) == 1:
56
+ scale *= channel_axis_size
57
+ if len(zero_point) == 1:
58
+ zero_point *= channel_axis_size
59
+
60
+ assert len(scale) == len(zero_point) == channel_axis_size
61
+ scale_zp_strs = []
62
+ for s, zp in zip(scale, zero_point):
63
+ scale_zp_strs.append(f"{s}:{zp}")
64
+ scale_zp = "{" + ",".join(scale_zp_strs) + "}"
65
+ return ir.Type.parse(
66
+ f"!quant.uniform<{stored_type}{storage_min_max}:{expressed_type}:{channel_axis},{scale_zp}>"
67
+ )
68
+ else:
69
+ # Per-layer quantization
70
+ # https://mlir.llvm.org/docs/Dialects/QuantDialect/#per-layer-quantization
71
+ scale = pytree.tree_flatten([scale])[0][-1]
72
+ zero_point = pytree.tree_flatten([zero_point])[0][-1]
73
+ scale_zp = f"{scale}:{zero_point}"
74
+ return ir.Type.parse(
75
+ f"!quant.uniform<{stored_type}{storage_min_max}:{expressed_type},{scale_zp}>"
76
+ )
77
+
78
+
79
+ # Quant dialect is not registered in the Python MLIR pybinding used by
80
+ # odml-torch. Therefore, stablehlo.uniform_quantize/uniform_dequantize ops and
81
+ # quant types are represented in stablehlo.custom_call to pass MLIR verification
82
+ # and VHLO serialization before converter.
83
+ # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
84
+
85
+
86
+ # Schema:
87
+ # - quantized_decomposed::quantize_per_tensor(Tensor input, float scale,
88
+ # int zero_point, int quant_min, int quant_max,
89
+ # ScalarType dtype) -> Tensor
90
+ # - quantized_decomposed::quantize_per_tensor.tensor(Tensor input,
91
+ # Tensor scale, Tensor zero_point, int quant_min, int quant_max,
92
+ # ScalarType dtype) -> Tensor
93
+ #
94
+ # Scale and zero_point in tensors are automatically converted to list before
95
+ # lowering.
96
+ @lower(torch.ops.quantized_decomposed.quantize_per_tensor)
97
+ def _quantize_per_tensor(
98
+ lctx: LoweringContext,
99
+ input: ir.Value,
100
+ scale: Union[float, list[float]],
101
+ zero_point: Union[float, list[float]],
102
+ quant_min: int,
103
+ quant_max: int,
104
+ dtype: torch.dtype,
105
+ ):
106
+ input_ty = cast(ir.RankedTensorType, input.type)
107
+ qty = _uniform_quantized_type(
108
+ utils.torch_dtype_to_ir_element_type(dtype),
109
+ input_ty.element_type,
110
+ scale=scale,
111
+ zero_point=zero_point,
112
+ storage_type_min=quant_min,
113
+ storage_type_max=quant_max,
114
+ )
115
+ return stablehlo.custom_call(
116
+ call_target_name="odml_torch.uniform_quantize",
117
+ inputs=[input],
118
+ result=[input_ty],
119
+ backend_config=ir.StringAttr.get(
120
+ str(ir.RankedTensorType.get(input_ty.shape, qty))
121
+ ),
122
+ )
123
+
124
+
125
+ # Schema:
126
+ # - quantized_decomposed::quantize_per_channel(Tensor input, Tensor scales,
127
+ # Tensor zero_points, int axis, int quant_min, int quant_max,
128
+ # ScalarType dtype) -> Tensor
129
+ #
130
+ # Scale and zero_point in tensors are automatically converted to list before
131
+ # lowering.
132
+ @lower(torch.ops.quantized_decomposed.quantize_per_channel)
133
+ def _quantize_per_channel(
134
+ lctx: LoweringContext,
135
+ input: ir.Value,
136
+ scale: list[float],
137
+ zero_point: list[float],
138
+ axis: int,
139
+ quant_min: int,
140
+ quant_max: int,
141
+ dtype: torch.dtype,
142
+ ):
143
+ input_ty = cast(ir.RankedTensorType, input.type)
144
+ qty = _uniform_quantized_type(
145
+ utils.torch_dtype_to_ir_element_type(dtype),
146
+ input_ty.element_type,
147
+ scale=scale,
148
+ zero_point=zero_point,
149
+ channel_axis=axis,
150
+ channel_axis_size=input_ty.shape[axis],
151
+ storage_type_min=quant_min,
152
+ storage_type_max=quant_max,
153
+ )
154
+ return stablehlo.custom_call(
155
+ call_target_name="odml_torch.uniform_quantize",
156
+ inputs=[input],
157
+ result=[input_ty],
158
+ backend_config=ir.StringAttr.get(
159
+ str(ir.RankedTensorType.get(input_ty.shape, qty))
160
+ ),
161
+ )
162
+
163
+
164
+ @lower(torch.ops.quantized_decomposed.dequantize_per_tensor)
165
+ @lower(torch.ops.quantized_decomposed.dequantize_per_channel)
166
+ def _dequantize(lctx: LoweringContext, input: ir.Value, *args, **kwargs):
167
+ result_meta = lctx.node.meta.get("tensor_meta")
168
+ result_elty = utils.torch_dtype_to_ir_element_type(result_meta.dtype)
169
+
170
+ return stablehlo.custom_call(
171
+ call_target_name="odml_torch.uniform_dequantize",
172
+ inputs=[input],
173
+ result=[ir.RankedTensorType.get(result_meta.shape, result_elty)],
174
+ )
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Utilities for building MLIR lowerings."""
16
16
 
17
+ import functools
17
18
  import numbers
18
19
  from typing import Any
19
20
  from typing import Optional
@@ -21,6 +22,21 @@ from typing import Optional
21
22
  from jax._src.lib.mlir import ir
22
23
  from jax._src.lib.mlir.dialects import hlo as stablehlo
23
24
  import numpy as np
25
+ import torch
26
+
27
+
28
+ def torch_dtype_to_ir_element_type(dtype):
29
+ ty_get = {
30
+ torch.double: ir.F64Type.get,
31
+ torch.float32: ir.F32Type.get,
32
+ torch.half: ir.F16Type.get,
33
+ torch.long: functools.partial(ir.IntegerType.get_signless, 64),
34
+ torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
35
+ torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
36
+ torch.int8: functools.partial(ir.IntegerType.get_signless, 8),
37
+ torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
38
+ }[dtype]
39
+ return ty_get()
24
40
 
25
41
 
26
42
  def splat(val, ty, shape=tuple(), *, loc: Optional[Any] = None):
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241129"
16
+ __version__ = "0.3.0.dev20241204"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241129
3
+ Version: 0.3.0.dev20241204
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=_XjQel8A-E5ufJfKJYOHGc5_ZGO6hbkXorvp4wgs8SU,706
6
+ ai_edge_torch/version.py,sha256=IfBDOY7eb9sSynpfr3Qsw88QCzs0DDWN_jB1B6zi5ss,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -26,7 +26,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
26
26
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
27
27
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
28
28
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
29
- ai_edge_torch/_convert/test/test_convert.py,sha256=yXfeWDw9u_rTS3B6kvvFPo5E4XNT3zKTSLFSBSAI9Fc,15502
29
+ ai_edge_torch/_convert/test/test_convert.py,sha256=v6AhfWqRBuHT7uBDueTbntaQtDSMMrvQOqlIDXNUaMA,17250
30
30
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
31
31
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
32
32
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
@@ -44,20 +44,22 @@ ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=bkq2Zk
44
44
  ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=-n79r6yFnCACpms5eMkXNpyQsCn2PYVRdB-jOoIqn14,2227
45
45
  ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
46
46
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
48
- ai_edge_torch/generative/examples/gemma/convert_gemma2_multi_prefills.py,sha256=6d9wG5MnStEys34_gFXwKTMRXUBFLTW1jEzCoWkAtwM,2224
49
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
47
+ ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=mrG96_WEGD4NQ4uFEKrHRMAQvVVliOcj1zbI3drGDjI,2199
48
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=I_tvwCYmtf08D1HqDxYx7dpvj2q5_eaYnuI_3rI6Dlw,2201
50
49
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
51
50
  ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
52
51
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
53
52
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
54
53
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
55
54
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
56
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=P0-pByTM5tslE23ILgo7nd0nOGE25ciBRG5wKJj0bBk,2411
55
+ ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=Brb83sbqBfStUiIZFhfWnYtN7LcNmkKyFn96cZK4sGo,2426
57
56
  ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaXEJ3RLLwkHii7to7UhVo,6854
58
57
  ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
58
+ ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
59
+ ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
60
+ ai_edge_torch/generative/examples/moonshine/moonshine.py,sha256=nZ2b8u4TmsB5sgdClgAuH8E78bcTv9RCnF9666HqP2M,3394
59
61
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
60
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
62
+ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=-qDBu3bjUq0jx73SPDMsPIBP0BT1nA_0UgtFKeSuM18,2213
61
63
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
62
64
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
63
65
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -69,18 +71,18 @@ ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM
69
71
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
70
72
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
71
73
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
72
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
73
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
74
+ ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=ruY-LLwpqBqVZ5z9h_sewYj04ukWRG4j804tUAyDdnA,2186
75
+ ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=UdMk1SSkcWpv8gosUylx3JSCxdOJBjZNhuQQtT4-Ono,2184
74
76
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=nbivDwZREd-sypy_ittO59-yaAdPvHv1YEV6Fo5buCo,3341
75
77
  ai_edge_torch/generative/examples/phi/phi3.py,sha256=GkHOaYfsFEbHvfZCaLlb3Us_h19ezqPDUakoz_DiG9A,7123
76
78
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
77
79
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
78
80
  ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
79
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=QAAVoSKDVf2rHAChzumGloVCWIU0Oe5UYKgv3T192Iw,2496
81
+ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=1M3DTkf536TCLYcQB1lu-3TxQ6mV03dFhTdbk0p8i84,2523
80
82
  ai_edge_torch/generative/examples/qwen/qwen.py,sha256=oYm9hhALUQ4uOn-PO1bF7fCIGP8EWRNK4zClkx2RQs8,4070
81
83
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
82
84
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
83
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
85
+ ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=56CzCjyp9xh_2ZtXKN9tlEv6GayeSc4giTIZsi2Q59E,2194
84
86
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=M5qAcSUE5gxOSfq24a8lZku9kgvmlFCyIBar3kF2XEk,2570
85
87
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
86
88
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -107,7 +109,7 @@ ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNH
107
109
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
108
110
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=rRodLr-hEqAs_-8x06O8qO-hJ_cqr2AfhJZ9DCptvwo,4332
109
111
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
112
+ ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=WmEshoN9HgNLbV7NTjxdqWz9Olcim6r_vo4R9eYE98I,2228
111
113
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=10X8HwPx4akzclnIMOBNItKQemhRbvxBbTo7nwZtWjM,2650
112
114
  ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
113
115
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
@@ -139,13 +141,14 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FP
139
141
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
140
142
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=aZFaheg2sq7rEccch1TZM6W4BSfpJZjrM9Gyp4hVGYs,6351
141
143
  ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=xWV9O2wuRHc4VNBWuWipiuqXa3AJhiV1nmjewAZHHWM,11177
142
- ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
144
+ ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
143
145
  ai_edge_torch/generative/test/utils.py,sha256=eQ-hjd1eXuHJF3SJK6_CrjgOZVzmG_4VEdH7Z1gH_lA,1897
144
146
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
145
147
  ai_edge_torch/generative/utilities/converter.py,sha256=S14STbyxV6A9HKy1BdUo49f2jS6Ij0RL9mVAFUMWYV8,5291
146
148
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
147
149
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
148
150
  ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
151
+ ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
149
152
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
150
153
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
151
154
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
@@ -167,8 +170,8 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1
167
170
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
168
171
  ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
169
172
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
170
- ai_edge_torch/odml_torch/export.py,sha256=4xwrsDeOAgzoB9m7EeNsBj6dC5Ajtn5aKDRQkdHxa-o,11584
171
- ai_edge_torch/odml_torch/export_utils.py,sha256=q84U69ZQ82hLXw-xncJ8IW-K71Xux-NWlzZTs7hdZWA,5127
173
+ ai_edge_torch/odml_torch/export.py,sha256=dgnNGBVkHBz0brlWALX2hGXpQ4YzCKdwbkF4oAfEu4I,13062
174
+ ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
172
175
  ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
173
176
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
174
177
  ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
@@ -177,17 +180,18 @@ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu
177
180
  ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
178
181
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
179
182
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
180
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=oQo9nxH08NnEDeZaGoCUk1kRtoEOM_f0DUOyd9nfxjg,6673
183
+ ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
181
184
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
182
- ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
183
- ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=eH9eJqFO-BI9l4WdXfjsItODPRa18SAR_qSvJ6-7gxc,9987
185
+ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=1lMKPoStK3SUA8yYTPZBRhESN33BghGXnfqOOg4oeVk,995
186
+ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=ufvnaAh6rM_yfoc8ybI3VErHEVBv5W_p4iOe9slfwKM,9948
184
187
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
185
188
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
186
189
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=4UyNyaR2W-vCOvj-P5lywQ1_RfLIxVE7J_GONI6CQvI,10718
187
190
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
191
+ ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=rFmzqcdjYrwhcxH8j9zCFStPy21HFF7hkUV_GQ8FPAk,6056
188
192
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
189
193
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=itTt8MLbq2LoHTzRidCF2TTbh0TP7L836u99qCjP3FA,2953
190
- ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
194
+ ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
191
195
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
192
196
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
193
197
  ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
@@ -196,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
196
200
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
197
201
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
198
202
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
199
- ai_edge_torch_nightly-0.3.0.dev20241129.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
200
- ai_edge_torch_nightly-0.3.0.dev20241129.dist-info/METADATA,sha256=ag4k5qLDRZe_u8bdu-RTqFL0q70kPk3KlmCpbI-TTSI,1897
201
- ai_edge_torch_nightly-0.3.0.dev20241129.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
202
- ai_edge_torch_nightly-0.3.0.dev20241129.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
203
- ai_edge_torch_nightly-0.3.0.dev20241129.dist-info/RECORD,,
203
+ ai_edge_torch_nightly-0.3.0.dev20241204.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
204
+ ai_edge_torch_nightly-0.3.0.dev20241204.dist-info/METADATA,sha256=jcbPyL8PaZxa79FSq_FwD0sfu-2R_nTRMDJiPNmJXFM,1897
205
+ ai_edge_torch_nightly-0.3.0.dev20241204.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
206
+ ai_edge_torch_nightly-0.3.0.dev20241204.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
207
+ ai_edge_torch_nightly-0.3.0.dev20241204.dist-info/RECORD,,