ai-edge-torch-nightly 0.2.0.dev20240626__py3-none-any.whl → 0.2.0.dev20240701__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (26) hide show
  1. ai_edge_torch/__init__.py +1 -0
  2. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +23 -5
  3. ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
  4. ai_edge_torch/convert/to_channel_last_io.py +85 -0
  5. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +2 -2
  6. ai_edge_torch/generative/examples/gemma/gemma.py +1 -1
  7. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +2 -2
  8. ai_edge_torch/generative/examples/phi2/phi2.py +1 -1
  9. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -2
  10. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
  11. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +12 -12
  12. ai_edge_torch/generative/examples/t5/t5.py +2 -2
  13. ai_edge_torch/generative/examples/test_models/toy_model.py +1 -1
  14. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +2 -2
  15. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +2 -2
  16. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -2
  17. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +1 -1
  18. ai_edge_torch/generative/quantize/example.py +1 -1
  19. ai_edge_torch/generative/test/test_model_conversion.py +8 -8
  20. ai_edge_torch/generative/test/test_quantize.py +2 -2
  21. ai_edge_torch/hlfb/mark_pattern/pattern.py +22 -9
  22. {ai_edge_torch_nightly-0.2.0.dev20240626.dist-info → ai_edge_torch_nightly-0.2.0.dev20240701.dist-info}/METADATA +1 -1
  23. {ai_edge_torch_nightly-0.2.0.dev20240626.dist-info → ai_edge_torch_nightly-0.2.0.dev20240701.dist-info}/RECORD +26 -24
  24. {ai_edge_torch_nightly-0.2.0.dev20240626.dist-info → ai_edge_torch_nightly-0.2.0.dev20240701.dist-info}/LICENSE +0 -0
  25. {ai_edge_torch_nightly-0.2.0.dev20240626.dist-info → ai_edge_torch_nightly-0.2.0.dev20240701.dist-info}/WHEEL +0 -0
  26. {ai_edge_torch_nightly-0.2.0.dev20240626.dist-info → ai_edge_torch_nightly-0.2.0.dev20240701.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py CHANGED
@@ -15,6 +15,7 @@
15
15
 
16
16
  from .convert.converter import convert
17
17
  from .convert.converter import signature
18
+ from .convert.to_channel_last_io import to_channel_last_io
18
19
  from .model import Model
19
20
 
20
21
 
@@ -17,10 +17,22 @@ import functools
17
17
 
18
18
  import torch
19
19
 
20
- from ai_edge_torch.convert.fx_passes import FxPassBase
21
- from ai_edge_torch.convert.fx_passes import FxPassResult
20
+ from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
21
+ from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
22
22
  from ai_edge_torch.hlfb import mark_pattern
23
23
 
24
+ # For torch nightly released after mid June 2024,
25
+ # torch.nn.functional.interpolate no longer gets exported into decomposed graph
26
+ # but single aten op torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
27
+ # This behavior would our pattern matching based composite builder.
28
+ # It requires the pattern and model graph to get decomposed first for backward compatibility.
29
+ _INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions(
30
+ [
31
+ torch.ops.aten.upsample_bilinear2d.vec,
32
+ torch.ops.aten.upsample_nearest2d.vec,
33
+ ]
34
+ )
35
+
24
36
 
25
37
  @functools.cache
26
38
  def _get_upsample_bilinear2d_pattern():
@@ -30,6 +42,7 @@ def _get_upsample_bilinear2d_pattern():
30
42
  x, scale_factor=2, mode="bilinear", align_corners=False
31
43
  ),
32
44
  export_args=(torch.rand(1, 3, 100, 100),),
45
+ decomp_table=_INTERPOLATE_DECOMPOSITIONS,
33
46
  )
34
47
 
35
48
  @pattern.register_attr_builder
@@ -52,6 +65,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
52
65
  x, scale_factor=2, mode="bilinear", align_corners=True
53
66
  ),
54
67
  export_args=(torch.rand(1, 3, 100, 100),),
68
+ decomp_table=_INTERPOLATE_DECOMPOSITIONS,
55
69
  )
56
70
 
57
71
  @pattern.register_attr_builder
@@ -72,6 +86,7 @@ def _get_interpolate_nearest2d_pattern():
72
86
  "tfl.resize_nearest_neighbor",
73
87
  lambda x: torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest"),
74
88
  export_args=(torch.rand(1, 3, 100, 100),),
89
+ decomp_table=_INTERPOLATE_DECOMPOSITIONS,
75
90
  )
76
91
 
77
92
  @pattern.register_attr_builder
@@ -86,7 +101,7 @@ def _get_interpolate_nearest2d_pattern():
86
101
  return pattern
87
102
 
88
103
 
89
- class BuildInterpolateCompositePass(FxPassBase):
104
+ class BuildInterpolateCompositePass(ExportedProgramPassBase):
90
105
 
91
106
  def __init__(self):
92
107
  super().__init__()
@@ -96,10 +111,13 @@ class BuildInterpolateCompositePass(FxPassBase):
96
111
  _get_interpolate_nearest2d_pattern(),
97
112
  ]
98
113
 
99
- def call(self, graph_module: torch.fx.GraphModule):
114
+ def call(self, exported_program: torch.export.ExportedProgram):
115
+ exported_program = exported_program.run_decompositions(_INTERPOLATE_DECOMPOSITIONS)
116
+
117
+ graph_module = exported_program.graph_module
100
118
  for pattern in self._patterns:
101
119
  graph_module = mark_pattern.mark_pattern(graph_module, pattern)
102
120
 
103
121
  graph_module.graph.lint()
104
122
  graph_module.recompile()
105
- return FxPassResult(graph_module, True)
123
+ return ExportedProgramPassResult(exported_program, True)
@@ -0,0 +1,96 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import torch
19
+
20
+ import ai_edge_torch
21
+
22
+
23
+ class Identity(torch.nn.Module):
24
+
25
+ def forward(self, x):
26
+ return x
27
+
28
+
29
+ class TestToChannelLastIO(unittest.TestCase):
30
+ """Tests to_channel_last_io API and module wrapper."""
31
+
32
+ def test_no_transformations(self):
33
+ x = torch.rand(1, 3, 10, 10)
34
+ y = ai_edge_torch.to_channel_last_io(Identity())(x)
35
+ self.assertEqual(y.shape, (1, 3, 10, 10))
36
+
37
+ def test_args(self):
38
+ x = torch.rand(1, 10, 10, 3)
39
+ y = ai_edge_torch.to_channel_last_io(Identity(), args=[0])(x)
40
+ self.assertEqual(y.shape, (1, 3, 10, 10))
41
+
42
+ def test_outputs(self):
43
+ x = torch.rand(1, 3, 10, 10)
44
+ y = ai_edge_torch.to_channel_last_io(Identity(), outputs=[0])(x)
45
+ self.assertEqual(y.shape, (1, 10, 10, 3))
46
+
47
+ def test_args_outputs(self):
48
+ x = torch.rand(1, 10, 10, 3)
49
+ y = ai_edge_torch.to_channel_last_io(Identity(), args=[0], outputs=[0])(x)
50
+ self.assertEqual(y.shape, (1, 10, 10, 3))
51
+
52
+ def test_args_5d(self):
53
+ x = torch.rand(1, 10, 10, 10, 3)
54
+ y = ai_edge_torch.to_channel_last_io(Identity(), args=[0])(x)
55
+ self.assertEqual(y.shape, (1, 3, 10, 10, 10))
56
+
57
+ def test_outputs_5d(self):
58
+ x = torch.rand(1, 3, 10, 10, 10)
59
+ y = ai_edge_torch.to_channel_last_io(Identity(), outputs=[0])(x)
60
+ self.assertEqual(y.shape, (1, 10, 10, 10, 3))
61
+
62
+ def test_chained_wrappers(self):
63
+ x = torch.rand(1, 10, 10, 3)
64
+
65
+ m = Identity()
66
+ m = ai_edge_torch.to_channel_last_io(m, args=[0])
67
+ m = ai_edge_torch.to_channel_last_io(m, outputs=[0])
68
+
69
+ y = m(x)
70
+ self.assertEqual(y.shape, (1, 10, 10, 3))
71
+
72
+ def test_list_args(self):
73
+ class Add(torch.nn.Module):
74
+
75
+ def forward(self, x, y):
76
+ return x + y
77
+
78
+ x = (torch.rand(1, 10, 10, 3), torch.rand(1, 10, 10, 3))
79
+ y = ai_edge_torch.to_channel_last_io(Add(), args=[0, 1])(*x)
80
+ self.assertEqual(y.shape, (1, 3, 10, 10))
81
+
82
+ def test_list_outputs(self):
83
+ class TwoIdentity(torch.nn.Module):
84
+
85
+ def forward(self, x):
86
+ return x, x
87
+
88
+ x = torch.rand(1, 3, 10, 10)
89
+ y = ai_edge_torch.to_channel_last_io(TwoIdentity(), outputs=[0])(x)
90
+ self.assertIsInstance(y, tuple)
91
+ self.assertEqual(y[0].shape, (1, 10, 10, 3))
92
+ self.assertEqual(y[1].shape, (1, 3, 10, 10))
93
+
94
+
95
+ if __name__ == "__main__":
96
+ unittest.main()
@@ -0,0 +1,85 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Optional
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+
22
+ class ChannelLastIOWrapper(nn.Module):
23
+
24
+ def __init__(self, wrapped, *, args=None, outputs=None):
25
+ super().__init__()
26
+ self.wrapped = wrapped
27
+ self._args = args or []
28
+ self._outputs = outputs or []
29
+
30
+ def _to_channel_last(self, x):
31
+ if not torch.is_tensor(x):
32
+ raise ValueError("Input must be a torch tensor")
33
+ if x.ndim < 3:
34
+ raise ValueError("Input must be a tensor with rank >= 3 in layout (N, C, ...)")
35
+ dims = [0, *range(2, x.ndim), 1]
36
+ return torch.permute(x, dims)
37
+
38
+ def _to_channel_first(self, x):
39
+ if not torch.is_tensor(x):
40
+ raise ValueError("Input must be a torch tensor.")
41
+ if x.ndim < 3:
42
+ raise ValueError("Input must be a tensor with rank >= 3 in layout (N, ..., C)")
43
+ dims = [0, x.ndim - 1, *range(1, x.ndim - 1)]
44
+ return torch.permute(x, dims)
45
+
46
+ def forward(self, *args, **kwargs):
47
+ args = list(args)
48
+ for i in self._args:
49
+ args[i] = self._to_channel_first(args[i])
50
+
51
+ outputs = self.wrapped(*args, **kwargs)
52
+
53
+ if not isinstance(outputs, (list, tuple)):
54
+ outputs_is_list = False
55
+ output_list = [outputs]
56
+ else:
57
+ outputs_is_list = True
58
+ output_list = list(outputs)
59
+
60
+ for i in self._outputs:
61
+ output_list[i] = self._to_channel_last(output_list[i])
62
+
63
+ if not outputs_is_list:
64
+ return output_list[0]
65
+ else:
66
+ return type(outputs)(output_list)
67
+
68
+
69
+ def to_channel_last_io(
70
+ module: nn.Module,
71
+ args: Optional[list[int]] = None,
72
+ outputs: Optional[list[int]] = None,
73
+ ):
74
+ """Wraps the module with channel first to channel last layout transformations.
75
+
76
+ Args:
77
+ args (list[int]): Transform args with indices in the list from channel first
78
+ (N, C, ...) to channel last (N, ..., C).
79
+ outputs (list[int]): Transform outputs with indices in the list from channel
80
+ first (N, C, ...) to channel last (N, ..., C).
81
+ Returns:
82
+ The wrapped nn.Module with additional layout transposes after inputs and/or before
83
+ outputs.
84
+ """
85
+ return ChannelLastIOWrapper(module, args=args, outputs=outputs)
@@ -45,9 +45,9 @@ def convert_gemma_to_tflite(
45
45
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
46
46
  )
47
47
  # Tensors used to trace the model graph during conversion.
48
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
48
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
49
49
  prefill_input_pos = torch.arange(0, prefill_seq_len)
50
- decode_token = torch.tensor([[0]], dtype=torch.long)
50
+ decode_token = torch.tensor([[0]], dtype=torch.int)
51
51
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
52
52
 
53
53
  quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
@@ -163,7 +163,7 @@ def define_and_run_2b() -> None:
163
163
  checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
164
164
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
165
165
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
166
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
166
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
167
167
  tokens[0, :4] = idx
168
168
  input_pos = torch.arange(0, kv_cache_max_len)
169
169
  print("running an inference")
@@ -43,9 +43,9 @@ def convert_phi2_to_tflite(
43
43
  """
44
44
  pytorch_model = phi2.build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
45
45
  # Tensors used to trace the model graph during conversion.
46
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
46
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
47
47
  prefill_input_pos = torch.arange(0, prefill_seq_len)
48
- decode_token = torch.tensor([[0]], dtype=torch.long)
48
+ decode_token = torch.tensor([[0]], dtype=torch.int)
49
49
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
50
 
51
51
  quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
@@ -153,7 +153,7 @@ def define_and_run() -> None:
153
153
  checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/phi2")
154
154
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
155
155
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
156
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
156
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
157
157
  tokens[0, :4] = idx
158
158
  input_pos = torch.arange(0, kv_cache_max_len)
159
159
  print("running an inference")
@@ -60,8 +60,8 @@ class CLIP(nn.Module):
60
60
  )
61
61
 
62
62
  @torch.inference_mode
63
- def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
64
- tokens = tokens.type(torch.long)
63
+ def forward(self, tokens: torch.IntTensor) -> torch.FloatTensor:
64
+ tokens = tokens.type(torch.int)
65
65
 
66
66
  state = self.tok_embedding(tokens) + self.tok_embedding_position
67
67
  for layer in self.transformer_blocks:
@@ -61,7 +61,7 @@ def convert_stable_diffusion_to_tflite(
61
61
  n_tokens = 77
62
62
  timestamp = 0
63
63
  len_prompt = 1
64
- prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long)
64
+ prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.int)
65
65
  input_image = torch.full((1, 3, image_height, image_width), 0, dtype=torch.float32)
66
66
  noise = torch.full(
67
67
  (len_prompt, 4, image_height // 8, image_width // 8), 0, dtype=torch.float32
@@ -30,23 +30,23 @@ def convert_t5_to_tflite_singlesig(checkpoint_path: str):
30
30
 
31
31
  # encoder
32
32
  seq_len = 512
33
- prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
33
+ prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
34
34
  prompt_e_token = [1, 2, 3, 4, 5, 6]
35
35
  prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
36
- prompt_e_token, dtype=torch.long
36
+ prompt_e_token, dtype=torch.int
37
37
  )
38
38
  prefill_e_input_pos = torch.arange(0, seq_len)
39
- prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
39
+ prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
40
40
  prompt_d_token = [1, 2, 3, 4, 5, 6]
41
41
  prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
42
- prompt_d_token, dtype=torch.long
42
+ prompt_d_token, dtype=torch.int
43
43
  )
44
44
  prefill_d_input_pos = torch.arange(0, seq_len)
45
45
 
46
46
  # decoder
47
- decode_token = torch.tensor([[1]], dtype=torch.long)
47
+ decode_token = torch.tensor([[1]], dtype=torch.int)
48
48
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
49
- decode_d_token = torch.tensor([[1]], dtype=torch.long)
49
+ decode_d_token = torch.tensor([[1]], dtype=torch.int)
50
50
  decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
51
51
 
52
52
  # Pad mask for self attention only on "real" tokens.
@@ -78,23 +78,23 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str):
78
78
 
79
79
  # encoder
80
80
  seq_len = 512
81
- prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
81
+ prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
82
82
  prompt_e_token = [1, 2, 3, 4, 5, 6]
83
83
  prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
84
- prompt_e_token, dtype=torch.long
84
+ prompt_e_token, dtype=torch.int
85
85
  )
86
86
  prefill_e_input_pos = torch.arange(0, seq_len)
87
- prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
87
+ prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
88
88
  prompt_d_token = [1, 2, 3, 4, 5, 6]
89
89
  prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
90
- prompt_d_token, dtype=torch.long
90
+ prompt_d_token, dtype=torch.int
91
91
  )
92
92
  prefill_d_input_pos = torch.arange(0, seq_len)
93
93
 
94
94
  # decoder
95
- decode_token = torch.tensor([[1]], dtype=torch.long)
95
+ decode_token = torch.tensor([[1]], dtype=torch.int)
96
96
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
97
- decode_d_token = torch.tensor([[1]], dtype=torch.long)
97
+ decode_d_token = torch.tensor([[1]], dtype=torch.int)
98
98
  decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
99
99
 
100
100
  # Pad mask for self attention only on "real" tokens.
@@ -562,7 +562,7 @@ def define_and_run_t5(checkpoint_path: str) -> None:
562
562
  model = build_t5_model(checkpoint_path)
563
563
 
564
564
  idx = get_sample_encoder_input_ids()
565
- tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
565
+ tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
566
566
  tokens[0, :77] = idx
567
567
  input_pos = torch.arange(0, 512)
568
568
 
@@ -586,7 +586,7 @@ def define_and_run_t5_split(checkpoint_path: str) -> None:
586
586
  t5_decoder_model = build_t5_decoder_model(config, embedding_layer, checkpoint_path)
587
587
  idx = get_sample_encoder_input_ids()
588
588
 
589
- tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
589
+ tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
590
590
  tokens[0, :77] = idx
591
591
  input_pos = torch.arange(0, 512)
592
592
 
@@ -93,7 +93,7 @@ def define_and_run() -> None:
93
93
  )
94
94
 
95
95
  model = ToySingleLayerModel(config)
96
- idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
96
+ idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN, dtype=torch.int), 0)
97
97
  input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
98
98
  print('running an inference')
99
99
  print(
@@ -115,13 +115,13 @@ def get_model_config() -> cfg.ModelConfig:
115
115
 
116
116
 
117
117
  def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
118
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
118
+ idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
119
119
  input_pos = torch.arange(0, 100)
120
120
  return idx, input_pos
121
121
 
122
122
 
123
123
  def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
124
- idx = torch.tensor([[1]], dtype=torch.long)
124
+ idx = torch.tensor([[1]], dtype=torch.int)
125
125
  input_pos = torch.tensor([10])
126
126
  return idx, input_pos
127
127
 
@@ -103,13 +103,13 @@ def get_model_config() -> cfg.ModelConfig:
103
103
 
104
104
 
105
105
  def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
106
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
106
+ idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
107
107
  input_pos = torch.arange(0, 100)
108
108
  return idx, input_pos
109
109
 
110
110
 
111
111
  def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
112
- idx = torch.tensor([[1]], dtype=torch.long)
112
+ idx = torch.tensor([[1]], dtype=torch.int)
113
113
  input_pos = torch.tensor([10], dtype=torch.int64)
114
114
  return idx, input_pos
115
115
 
@@ -45,9 +45,9 @@ def convert_tiny_llama_to_tflite(
45
45
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
46
46
  )
47
47
  # Tensors used to trace the model graph during conversion.
48
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
48
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
49
49
  prefill_input_pos = torch.arange(0, prefill_seq_len)
50
- decode_token = torch.tensor([[0]], dtype=torch.long)
50
+ decode_token = torch.tensor([[0]], dtype=torch.int)
51
51
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
52
52
 
53
53
  quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
@@ -153,7 +153,7 @@ def define_and_run() -> None:
153
153
  checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/tiny_llama")
154
154
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
155
155
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
156
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
156
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
157
157
  tokens[0, :4] = idx
158
158
  input_pos = torch.arange(0, kv_cache_max_len)
159
159
  print("running an inference")
@@ -26,7 +26,7 @@ def main():
26
26
  config = gemma.get_fake_model_config_2b_for_test()
27
27
  model = gemma.Gemma(config)
28
28
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
29
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
29
+ tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
30
30
  tokens[0, :4] = idx
31
31
  input_pos = torch.arange(0, 10)
32
32
 
@@ -35,7 +35,7 @@ class TestModelConversion(unittest.TestCase):
35
35
  def test_toy_model_with_kv_cache(self):
36
36
  config = toy_model_with_kv_cache.get_model_config()
37
37
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
38
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
38
+ idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
39
39
  [10], dtype=torch.int64
40
40
  )
41
41
 
@@ -59,7 +59,7 @@ class TestModelConversion(unittest.TestCase):
59
59
  config = toy_model_with_kv_cache.get_model_config()
60
60
  config.batch_size = 2
61
61
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
62
- idx, input_pos = torch.tensor([[1], [2]], dtype=torch.long), torch.tensor(
62
+ idx, input_pos = torch.tensor([[1], [2]], dtype=torch.int), torch.tensor(
63
63
  [10], dtype=torch.int64
64
64
  )
65
65
 
@@ -83,7 +83,7 @@ class TestModelConversion(unittest.TestCase):
83
83
  config = toy_model_with_kv_cache.get_model_config()
84
84
  config.enable_hlfb = True
85
85
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
86
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
86
+ idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
87
87
  [10], dtype=torch.int64
88
88
  )
89
89
 
@@ -109,7 +109,7 @@ class TestModelConversion(unittest.TestCase):
109
109
  pytorch_model = tiny_llama.TinyLLamma(config)
110
110
 
111
111
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
112
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
112
+ tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
113
113
  tokens[0, :4] = idx
114
114
  input_pos = torch.arange(0, 10)
115
115
 
@@ -135,13 +135,13 @@ class TestModelConversion(unittest.TestCase):
135
135
 
136
136
  # prefill
137
137
  seq_len = 10
138
- prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.long, device="cpu")
138
+ prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.int, device="cpu")
139
139
  prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
140
140
  prefill_tokens[0, : len(prompt_token)] = prompt_token
141
141
  prefill_input_pos = torch.arange(0, seq_len)
142
142
 
143
143
  # decode
144
- decode_token = torch.tensor([[1]], dtype=torch.long)
144
+ decode_token = torch.tensor([[1]], dtype=torch.int)
145
145
  decode_input_pos = torch.tensor([5], dtype=torch.int64)
146
146
 
147
147
  edge_model = (
@@ -183,7 +183,7 @@ class TestModelConversion(unittest.TestCase):
183
183
  model = gemma.Gemma(config)
184
184
 
185
185
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
186
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
186
+ tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
187
187
  tokens[0, :4] = idx
188
188
  input_pos = torch.arange(0, 10)
189
189
 
@@ -210,7 +210,7 @@ class TestModelConversion(unittest.TestCase):
210
210
  pytorch_model = phi2.Phi2(config)
211
211
 
212
212
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
213
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
213
+ tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
214
214
  tokens[0, :4] = idx
215
215
  input_pos = torch.arange(0, 10)
216
216
 
@@ -119,7 +119,7 @@ class TestQuantizeConvert(unittest.TestCase):
119
119
  self.skipTest("b/346896669")
120
120
  config = toy_model_with_kv_cache.get_model_config()
121
121
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
122
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
122
+ idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
123
123
  [10], dtype=torch.int64
124
124
  )
125
125
 
@@ -137,7 +137,7 @@ class TestQuantizeConvert(unittest.TestCase):
137
137
  self.skipTest("b/338288901")
138
138
  config = toy_model_with_kv_cache.get_model_config()
139
139
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
140
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
140
+ idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
141
141
  [10], dtype=torch.int64
142
142
  )
143
143
 
@@ -100,13 +100,18 @@ class ScalarAttrLocation:
100
100
 
101
101
 
102
102
  def _find_scalar_attr(
103
- pattern_module: torch.nn.Module, export_args: tuple[Any], tracker: ScalarAttrTracker
103
+ pattern_module: torch.nn.Module,
104
+ export_args: tuple[Any],
105
+ tracker: ScalarAttrTracker,
106
+ decomp_table=None,
104
107
  ) -> ScalarAttrLocation:
105
108
  scalar_loc_intersections = None
106
109
  for source, target in tracker._source_targets:
107
110
  track_args = list(export_args)
108
111
  track_args[tracker.pattern_arg_pos] = source
109
112
  ep = torch.export.export(pattern_module, tuple(track_args))
113
+ if decomp_table is not None:
114
+ ep = ep.run_decompositions(decomp_table)
110
115
 
111
116
  scalar_locs = set()
112
117
  nodes = ep.graph_module.graph.nodes
@@ -145,6 +150,7 @@ class Pattern:
145
150
  ["Pattern", GraphModule, InternalMatch], Optional[dict[str, Any]]
146
151
  ] = None,
147
152
  scalar_attr_trackers: list[ScalarAttrTracker] = None,
153
+ decomp_table: Optional[dict[torch._ops.OperatorBase, Callable]] = None,
148
154
  ):
149
155
  """The PyTorch computation pattern to match against a model.
150
156
 
@@ -165,6 +171,8 @@ class Pattern:
165
171
  for scalar args in `export_args`, which are used to track
166
172
  the attr occurrence(s) and retrieve their values from the
167
173
  matched subgraph.
174
+ decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]):
175
+ The decomposition table to be run on the pattern's exported program.
168
176
  """
169
177
  if not isinstance(module, torch.nn.Module):
170
178
 
@@ -180,23 +188,28 @@ class Pattern:
180
188
  module = PatternModule(module).eval()
181
189
 
182
190
  self.name = name
183
- self.exported_program = torch.export.export(module, export_args)
184
- self.graph_module = self.exported_program.graph_module
185
191
  self.attr_builder = attr_builder
186
192
  self._scalar_attr_trackers = scalar_attr_trackers if scalar_attr_trackers else []
187
193
 
188
- # Sanitize graph_module for more precise pattern matching.
189
- # The graph_module to match against this pattern should apply equivalent
190
- # sanitization.
191
- self.graph_module = passes.remove_clone_ops(self.graph_module)
192
- self.graph_module = passes.remove_dangling_args(self.graph_module)
194
+ exported_program = torch.export.export(module, export_args)
195
+ if decomp_table is not None:
196
+ exported_program = exported_program.run_decompositions(decomp_table)
197
+
198
+ self.exported_program = exported_program
199
+ self.graph_module = self.exported_program.graph_module
193
200
 
194
201
  self._scalar_attr_locations = []
195
202
  for tracker in self._scalar_attr_trackers:
196
203
  self._scalar_attr_locations.append(
197
- _find_scalar_attr(module, export_args, tracker)
204
+ _find_scalar_attr(module, export_args, tracker, decomp_table=decomp_table)
198
205
  )
199
206
 
207
+ # Sanitize graph_module for more precise pattern matching.
208
+ # The graph_module to match against this pattern should apply equivalent
209
+ # sanitization.
210
+ self.graph_module = passes.remove_clone_ops(self.graph_module)
211
+ self.graph_module = passes.remove_dangling_args(self.graph_module)
212
+
200
213
  # Builds list of ordered input and output nodes.
201
214
  self.graph_nodes_map = {}
202
215
  for node in self.graph_module.graph.nodes:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240626
3
+ Version: 0.2.0.dev20240701
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,13 +1,14 @@
1
- ai_edge_torch/__init__.py,sha256=FPMmuFU3pyMREtjB_san1fy_0PFtAsgA0VZfOYvDrb4,1008
1
+ ai_edge_torch/__init__.py,sha256=CNDboRP4zQBpz2hznNCQWcQCARvNXUm3DMa1Dw_XXFg,1067
2
2
  ai_edge_torch/model.py,sha256=kmcgELjsYl8YzF8nUF6P7q4i8MWS-pLGpfsy-yTUXmE,4243
3
3
  ai_edge_torch/convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
4
4
  ai_edge_torch/convert/conversion.py,sha256=8K8jQuaCjlUWoj7jiimxp_zpN6mYThLOcQ858UDcYnE,4159
5
5
  ai_edge_torch/convert/conversion_utils.py,sha256=9BqCL38DErv1vEVGtT3BIJVhdwZjw2EQ-_m5UpvVVYE,11294
6
6
  ai_edge_torch/convert/converter.py,sha256=bjj5TV5_g4sGyuSh8ThEDydlNMqhkGSY4SzXK6vwhqI,6927
7
+ ai_edge_torch/convert/to_channel_last_io.py,sha256=zo5tY3yDhY_EPCkrL1XSXs2uRFS8B4_qu08dSjNsUGk,2778
7
8
  ai_edge_torch/convert/fx_passes/__init__.py,sha256=EPs4PSIDLuRH5EBETi6deaOvaaf_Q4xD3_9NVcR7x8o,2810
8
9
  ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
9
10
  ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=2yqUwJJ2R233_X9FNMOP9oYRTTzH34TR_BIUj-wfnKw,7080
10
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py,sha256=jB27GlDC8x36nn35aiq97uKERiq4KXSUZ7tv7yc0Gl4,3223
11
+ ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py,sha256=6m_vcycd9f3OQgQLx2hhQjsKfOqdxE5EkjzqrxqyAQM,4168
11
12
  ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=UX6dJsxCqSkftXXvNBV-i7Bjk6H7qTyqzUnE640Itfg,1673
12
13
  ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
13
14
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=VA9bekxPVhLk4MYlIRXnOzrSnbCtUmGj7OQ_fJcKQtc,795
@@ -24,6 +25,7 @@ ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
24
25
  ai_edge_torch/convert/test/test_convert.py,sha256=2qPmmGqnfV_o1gfsSdjGq3-JR1b323ligiy5MdAv9NA,8021
25
26
  ai_edge_torch/convert/test/test_convert_composites.py,sha256=_Ojc-H6GOS5s8ek3_8eRBL_AiCs-k3srziPJ2R4Ulrg,7255
26
27
  ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
28
+ ai_edge_torch/convert/test/test_to_channel_last_io.py,sha256=I8c4ZG3v1vo0yxQYzLK_BTId4AOL9vadHGDtfCUZ4UI,2930
27
29
  ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
28
30
  ai_edge_torch/debug/culprit.py,sha256=urtCKPXORPvn6oyDxDSCSjgvngUnjjcsUMwAOeIl15E,14236
29
31
  ai_edge_torch/debug/utils.py,sha256=hjVmQVVl1dKxEF0D6KB4a3ouQ3wBkTsebOX2YsUObZM,1430
@@ -34,15 +36,15 @@ ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
34
36
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
35
37
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
36
38
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
37
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=dZv3r24uHsTMokEdnl3nf7LpmV0q7FLnVtCuHn5AuUs,2538
38
- ai_edge_torch/generative/examples/gemma/gemma.py,sha256=1lZfXGHmbII4rFu0U2B9NzlJCRhphxtmQtkCHQ39_uw,5935
39
+ ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=yl36sMjVqDlf9I41DF9C5wx6ztMxYB5xukD1NltUS04,2536
40
+ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=caBQrJTK4tXFeGM-i2cNXd2Tb8GXi463MwFHc7N65WE,5934
39
41
  ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
40
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=6nOuwx9q3AUlYcjXRRXSr_3M2JKqdJ-vUf-uE3VFYHE,2512
41
- ai_edge_torch/generative/examples/phi2/phi2.py,sha256=PMhKC6JCAMYSj2F3UmWHWK4rTcXD-B6PuehaoDccRqk,5562
42
+ ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=jIFnJY9BtAUJtVkFnms7byZp-jshhQIx59DmK0OjJ8M,2510
43
+ ai_edge_torch/generative/examples/phi2/phi2.py,sha256=tYo-WWOyw6LBF5wBVL_CsuTxKhur0SroDe2sYkTPdvI,5561
42
44
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
43
45
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
44
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=qU1wVEcn_biwCuDguZljhlLGzpLIqgqC31Dh_lXquQc,3720
45
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=wVEjsKd5JCIiYf5GF19rOXs2NHscZh0D69mxaS4f0Sk,4182
46
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=cIOqRZ76Pb8ywuCa3LUQnKnBVmvcaAPqvA5bdHfgaWw,3718
47
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=AwGVIY_tNlea1c4Rz3lEAeAqBvb-GxGzleX2dle98DE,4181
46
48
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=RgxedILk7iNMb0mhE4VkCs6d7BnFzYhR3vspUkC0-1o,11425
47
49
  ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=sRevfsmCun7zbceJbOstLKNUsLwzQDsGm7Mi2JmlREg,26021
48
50
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=mgbxkeFDMkNIGmnbcFTIFPu8EWKokghiviYIOB2lE3Q,3437
@@ -55,16 +57,16 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py
55
57
  ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=iPYX9ZSaxwSak2KI44j6TEr_g4pdxS3xpka4u0trjbo,2788
56
58
  ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5iRfU5MO6GR6K3WrdddIU_9U7ZZGEEb7zGKVY1WFl-8,1340
57
59
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
58
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
59
- ai_edge_torch/generative/examples/t5/t5.py,sha256=L6YrVzUEzP-Imb8W28LdukFGrx1aWSzz1kyYK_9RFZM,21087
60
+ ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=vqYip3JnjFMN5s0VnIzwhlBJl3up75WMb_VtRTfMOK0,4524
61
+ ai_edge_torch/generative/examples/t5/t5.py,sha256=TyBDb50NbtKaHyGLGOMJ8dR2_GPF43oe8WhRfb1SMZ4,21085
60
62
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rkMwi-NJGBXHm5S57Rsj1LbcoVdyRkS7GmIBuU6F_2E,8274
61
63
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
62
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=CUXsQ_IU96NaCg9jyfeKI0Zz2iWDkJUsPJyPR1Pgz7I,3813
63
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=zwCmCnhr-vhBwHqv9i7xMasdBGVNqAGxZvWsncsJn58,5543
64
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=lfYUiem_Pbn3vGgPx84BeI8n7rN3-1fImwCLm8Eo2U8,4853
64
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=UNj7IgyCLOlaOu2xJKHciHqcd-_NXLNwMphCWwuaCes,3830
65
+ ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=u1M38wuev-YG-exMG0_HzAw3yaikPRhEZCqGaHO7Cw0,5559
66
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=xy7kT1f8V0jqAfglggoeKVqR5KviKdfQun0WBYJ0jS8,4869
65
67
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
66
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=E4I5OlC4zyl5cxiiu7uTED-zcwYRu210lP1zuT3xLBE,2566
67
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=IFRLPG9wz_aLl_zV_6CETCjSM03ukA6bZqqyDLVACuw,5651
68
+ ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=rAsrP0CnhcYYMCcEJAGoFMRMur0LPoAcYpjFlBiZd2s,2564
69
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=__9QjZ-7Xi8MX5M7iRvw_jpic0-Fau2IO0QatleUzL0,5650
68
70
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=aXvYiaHDvETIrh0Q9DDZA_ZBiazGk80DT6nt7lLtC1o,1172
69
71
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=IehLwFNwa0C9fnk1pmNmyfuAwwWbuwdyKy46BSqNVdI,1948
70
72
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -82,7 +84,7 @@ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=H45wsXA6iJi_Mjd66NiQrh7
82
84
  ai_edge_torch/generative/layers/unet/builder.py,sha256=NmJiZ2-e1wbv9jnvI3VCyUJlONV5ZAOz-RTc7ipAZ5U,1872
83
85
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=FrIO-CR8aRIV2i8aFqom_4S7WCEDLMyYwo6U0oFyn7A,9097
84
86
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
85
- ai_edge_torch/generative/quantize/example.py,sha256=t-YwyKSPAG-OZC1DfH-0vfie2RHHpTSQjxUY-tmhu5g,1543
87
+ ai_edge_torch/generative/quantize/example.py,sha256=zgBgMyZ8RlSIjRhbaeodLyt6sj_dYuM7oI6Zyx7xQIg,1542
86
88
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
87
89
  ai_edge_torch/generative/quantize/quant_recipe.py,sha256=Y8zahKw7b_h7ajPaJZVef4jG-MoqImRCpVSbFtV_i24,5139
88
90
  ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=-vd6Qp0BdXJVKg4f0_hhwbKOi3QPIAPVqyXnJ-ZnISQ,1915
@@ -92,8 +94,8 @@ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=47DE
92
94
  ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=qUB4f2DoB14dLkNPWf6TZodpT81mfAJeWM-lCAmkuHY,5735
93
95
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
94
96
  ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
95
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=LsPTrLC1I4JW2GowTS3V9Eu257vLHr2Yj5f_qaFUX84,7589
96
- ai_edge_torch/generative/test/test_quantize.py,sha256=IjCbCPWzIgXk3s7y7SJsg2usIxhOqs3PuhFvEYR4Sdw,5388
97
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=tXES8gePl8BptETyUgpIznbIK1SvGEd8mKz1bT_a8Mw,7581
98
+ ai_edge_torch/generative/test/test_quantize.py,sha256=Emp-8oLHyGddELSCkncuSQt8ZJIhZ2-y0-ghR92s10g,5386
97
99
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
98
100
  ai_edge_torch/generative/utilities/loader.py,sha256=Hs92478j1g4jQGvbdP1aWvOy907HjwqQZE-NFy6HELo,11326
99
101
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=7ChqrnthD7I-Be6vkRvYTRhbGQ3tqMbikLpjY5HpSzE,30890
@@ -101,7 +103,7 @@ ai_edge_torch/generative/utilities/t5_loader.py,sha256=h1FQzt4x8wiQMX4NzYNVIaJGL
101
103
  ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
102
104
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
103
105
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=YV2YKBkh7y7j7sd7EA81vf_1hUKUvTRiy1pfqZustXc,1539
104
- ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=H4047w-xwx27rYPKNqmeOSQ9M1Adkpd7drp81YdV7Hw,9206
106
+ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=BEMKgkX8IrsX70h2CxwA_tpsBm6qWWe2bOeOufMYNkA,9722
105
107
  ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
106
108
  ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=qYR3PRGS9W3OG-qX7UFqL69VxXuUSfyDBUJtCXtXcOE,4262
107
109
  ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=aUAPKnH4_Jxpp3mLlD5rzdT1g_VBm7OrwwLJ9DeJlzQ,8190
@@ -112,8 +114,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
112
114
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
113
115
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
114
116
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
115
- ai_edge_torch_nightly-0.2.0.dev20240626.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
116
- ai_edge_torch_nightly-0.2.0.dev20240626.dist-info/METADATA,sha256=taaGTe-WdFG7HWUB43xnjYmG_iFcRwaIaZCK91-QP7M,1748
117
- ai_edge_torch_nightly-0.2.0.dev20240626.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
118
- ai_edge_torch_nightly-0.2.0.dev20240626.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
119
- ai_edge_torch_nightly-0.2.0.dev20240626.dist-info/RECORD,,
117
+ ai_edge_torch_nightly-0.2.0.dev20240701.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
118
+ ai_edge_torch_nightly-0.2.0.dev20240701.dist-info/METADATA,sha256=QItv3j92_LbO-UPcMqZHRhwTeSTnVQiy-cam2ciLaGM,1748
119
+ ai_edge_torch_nightly-0.2.0.dev20240701.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
120
+ ai_edge_torch_nightly-0.2.0.dev20240701.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
121
+ ai_edge_torch_nightly-0.2.0.dev20240701.dist-info/RECORD,,