ai-edge-torch-nightly 0.3.0.dev20240913__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
  4. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
  5. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
  6. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
  7. ai_edge_torch/config.py +4 -1
  8. ai_edge_torch/fx_pass_base.py +101 -0
  9. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +4 -4
  10. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -4
  11. ai_edge_torch/generative/examples/gemma/gemma.py +2 -2
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -2
  13. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +86 -0
  14. ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
  15. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +4 -4
  16. ai_edge_torch/generative/examples/phi/phi2.py +2 -2
  17. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  18. ai_edge_torch/generative/examples/{smallm → smollm}/convert_to_tflite.py +12 -12
  19. ai_edge_torch/generative/examples/{smallm/smallm.py → smollm/smollm.py} +24 -15
  20. ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
  21. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
  22. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
  23. ai_edge_torch/generative/examples/t5/t5.py +8 -8
  24. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  25. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -4
  26. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
  27. ai_edge_torch/generative/fx_passes/__init__.py +4 -4
  28. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
  29. ai_edge_torch/generative/layers/attention.py +7 -0
  30. ai_edge_torch/generative/layers/builder.py +33 -11
  31. ai_edge_torch/generative/layers/feed_forward.py +26 -8
  32. ai_edge_torch/generative/layers/kv_cache.py +4 -4
  33. ai_edge_torch/generative/layers/model_config.py +24 -15
  34. ai_edge_torch/generative/quantize/example.py +2 -2
  35. ai_edge_torch/generative/test/test_model_conversion.py +28 -51
  36. ai_edge_torch/generative/test/test_model_conversion_large.py +43 -78
  37. ai_edge_torch/generative/test/test_quantize.py +5 -5
  38. ai_edge_torch/generative/utilities/loader.py +13 -0
  39. ai_edge_torch/odml_torch/export.py +40 -0
  40. ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
  41. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  42. ai_edge_torch/version.py +1 -1
  43. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
  44. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +48 -46
  45. ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
  46. ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
  47. /ai_edge_torch/generative/examples/{smallm → openelm}/__init__.py +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
  49. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
  50. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,7 @@ class ActivationType(enum.Enum):
30
30
  GELU_QUICK = enum.auto()
31
31
  GE_GLU = enum.auto()
32
32
  RELU = enum.auto()
33
+ SILU_GLU = enum.auto()
33
34
 
34
35
 
35
36
  @enum.unique
@@ -58,6 +59,18 @@ class AttentionType(enum.Enum):
58
59
  LOCAL_SLIDING = enum.auto()
59
60
 
60
61
 
62
+ @dataclass
63
+ class NormalizationConfig:
64
+ """Normalizater parameters."""
65
+
66
+ type: NormalizationType = NormalizationType.NONE
67
+ enable_hlfb: bool = False
68
+ epsilon: float = 1e-5
69
+ zero_centered: bool = False
70
+ # Number of groups used in group normalization.
71
+ group_num: Optional[float] = None
72
+
73
+
61
74
  @dataclass
62
75
  class AttentionConfig:
63
76
  """Attention model's parameters."""
@@ -81,6 +94,14 @@ class AttentionConfig:
81
94
  # Whether to use bias with attention output projection.
82
95
  output_proj_use_bias: bool = False
83
96
  enable_kv_cache: bool = True
97
+ # The normalization applied to query projection's output.
98
+ query_norm_config: NormalizationConfig = field(
99
+ default_factory=NormalizationConfig
100
+ )
101
+ # The normalization applied to key projection's output.
102
+ key_norm_config: NormalizationConfig = field(
103
+ default_factory=NormalizationConfig
104
+ )
84
105
  relative_attention_num_buckets: int = 0
85
106
  relative_attention_max_distance: int = 0
86
107
  # Softcap on the output logits.
@@ -94,21 +115,9 @@ class AttentionConfig:
94
115
  @dataclass
95
116
  class ActivationConfig:
96
117
  type: ActivationType = ActivationType.LINEAR
97
- # Dimension of input and output, used in GeGLU.
98
- dim_in: Optional[int] = None
99
- dim_out: Optional[int] = None
100
-
101
-
102
- @dataclass
103
- class NormalizationConfig:
104
- """Normalizater parameters."""
105
-
106
- type: NormalizationType = NormalizationType.NONE
107
- enable_hlfb: bool = False
108
- epsilon: float = 1e-5
109
- zero_centered: bool = False
110
- # Number of groups used in group normalization.
111
- group_num: Optional[float] = None
118
+ # Whether to GLU gate is the front part instead of the back part of input
119
+ # when ActivationType is `GE_GLU` or `SILU_GLU`.
120
+ gate_is_front: bool = False
112
121
 
113
122
 
114
123
  @dataclass
@@ -25,9 +25,9 @@ def main():
25
25
  config = gemma.get_fake_model_config()
26
26
  model = gemma.Gemma(config)
27
27
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
28
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
28
+ tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
29
29
  tokens[0, :4] = idx
30
- input_pos = torch.arange(0, 10)
30
+ input_pos = torch.arange(0, 10, dtype=torch.int)
31
31
 
32
32
  # Create a quantization recipe to be applied to the model
33
33
  quant_config = quant_recipes.full_int8_dynamic_recipe()
@@ -42,15 +42,9 @@ class TestModelConversion(googletest.TestCase):
42
42
  )
43
43
  )
44
44
 
45
- @googletest.skipIf(
46
- ai_edge_config.Config.use_torch_xla,
47
- reason="tests with custom ops are not supported on oss",
48
- )
49
- def test_toy_model_with_kv_cache(self):
50
- config = toy_model_with_kv_cache.get_model_config()
51
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
52
- tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
53
- [10], dtype=torch.int64
45
+ def _test_model_with_kv_cache(self, config, pytorch_model):
46
+ tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
47
+ [10], dtype=torch.int
54
48
  )
55
49
  kv = kv_cache.KVCache.from_model_config(config)
56
50
 
@@ -83,58 +77,32 @@ class TestModelConversion(googletest.TestCase):
83
77
  ai_edge_config.Config.use_torch_xla,
84
78
  reason="tests with custom ops are not supported on oss",
85
79
  )
86
- def test_toy_model_with_kv_cache_with_hlfb(self):
80
+ def test_toy_model_with_kv_cache(self):
87
81
  config = toy_model_with_kv_cache.get_model_config()
88
- config.enable_hlfb = True
89
82
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
90
- tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
91
- [10], dtype=torch.int64
92
- )
93
- kv = kv_cache.KVCache.from_model_config(config)
94
-
95
- edge_model = ai_edge_torch.convert(
96
- pytorch_model,
97
- sample_kwargs={
98
- "tokens": tokens,
99
- "input_pos": input_pos,
100
- "kv_cache": kv,
101
- },
102
- )
103
- edge_model.set_interpreter_builder(
104
- self._interpreter_builder(edge_model.tflite_model())
105
- )
106
-
107
- self.assertTrue(
108
- test_utils.compare_tflite_torch(
109
- edge_model,
110
- pytorch_model,
111
- tokens,
112
- input_pos,
113
- kv,
114
- signature_name="serving_default",
115
- atol=1e-5,
116
- rtol=1e-5,
117
- )
118
- )
83
+ self._test_model_with_kv_cache(config, pytorch_model)
119
84
 
120
85
  @googletest.skipIf(
121
86
  ai_edge_config.Config.use_torch_xla,
122
87
  reason="tests with custom ops are not supported on oss",
123
88
  )
124
- def test_tiny_llama_multisig(self):
125
- config = tiny_llama.get_fake_model_config()
126
- pytorch_model = tiny_llama.TinyLlama(config).eval()
89
+ def test_toy_model_with_kv_cache_with_hlfb(self):
90
+ config = toy_model_with_kv_cache.get_model_config()
91
+ config.enable_hlfb = True
92
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
93
+ self._test_model_with_kv_cache(config, pytorch_model)
127
94
 
95
+ def _test_multisig_model(self, config, pytorch_model, atol, rtol):
128
96
  # prefill
129
97
  seq_len = 10
130
- prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.long, device="cpu")
98
+ prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.int, device="cpu")
131
99
  prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
132
100
  prefill_tokens[0, : len(prompt_token)] = prompt_token
133
- prefill_input_pos = torch.arange(0, seq_len)
101
+ prefill_input_pos = torch.arange(0, seq_len, dtype=torch.int)
134
102
 
135
103
  # decode
136
- decode_token = torch.tensor([[1]], dtype=torch.long)
137
- decode_input_pos = torch.tensor([5], dtype=torch.int64)
104
+ decode_token = torch.tensor([[1]], dtype=torch.int)
105
+ decode_input_pos = torch.tensor([5], dtype=torch.int)
138
106
 
139
107
  kv = kv_cache.KVCache.from_model_config(config)
140
108
 
@@ -171,8 +139,8 @@ class TestModelConversion(googletest.TestCase):
171
139
  prefill_input_pos,
172
140
  kv,
173
141
  signature_name="prefill",
174
- atol=1e-5,
175
- rtol=1e-5,
142
+ atol=atol,
143
+ rtol=atol,
176
144
  )
177
145
  )
178
146
 
@@ -184,11 +152,20 @@ class TestModelConversion(googletest.TestCase):
184
152
  decode_input_pos,
185
153
  kv,
186
154
  signature_name="decode",
187
- atol=1e-5,
188
- rtol=1e-5,
155
+ atol=atol,
156
+ rtol=atol,
189
157
  )
190
158
  )
191
159
 
160
+ @googletest.skipIf(
161
+ ai_edge_config.Config.use_torch_xla,
162
+ reason="tests with custom ops are not supported on oss",
163
+ )
164
+ def test_tiny_llama_multisig(self):
165
+ config = tiny_llama.get_fake_model_config()
166
+ pytorch_model = tiny_llama.TinyLlama(config).eval()
167
+ self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
168
+
192
169
 
193
170
  if __name__ == "__main__":
194
171
  googletest.main()
@@ -19,7 +19,9 @@ import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
20
  from ai_edge_torch.generative.examples.gemma import gemma
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.openelm import openelm
22
23
  from ai_edge_torch.generative.examples.phi import phi2
24
+ from ai_edge_torch.generative.examples.smollm import smollm
23
25
  from ai_edge_torch.generative.layers import kv_cache
24
26
  from ai_edge_torch.generative.test import utils as test_utils
25
27
  import numpy as np
@@ -43,28 +45,22 @@ class TestModelConversion(googletest.TestCase):
43
45
  )
44
46
  )
45
47
 
46
- @googletest.skipIf(
47
- ai_edge_config.Config.use_torch_xla,
48
- reason="tests with custom ops are not supported on oss",
49
- )
50
- def test_gemma(self):
51
- config = gemma.get_fake_model_config()
52
- model = gemma.Gemma(config)
53
-
48
+ def _test_model(self, config, model, signature_name, atol, rtol):
54
49
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
55
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
50
+ tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
56
51
  tokens[0, :4] = idx
57
- input_pos = torch.arange(0, 10)
52
+ input_pos = torch.arange(0, 10, dtype=torch.int)
58
53
  kv = kv_cache.KVCache.from_model_config(config)
59
54
 
60
- edge_model = ai_edge_torch.convert(
55
+ edge_model = ai_edge_torch.signature(
56
+ signature_name,
61
57
  model,
62
58
  sample_kwargs={
63
59
  "tokens": tokens,
64
60
  "input_pos": input_pos,
65
61
  "kv_cache": kv,
66
62
  },
67
- )
63
+ ).convert()
68
64
  edge_model.set_interpreter_builder(
69
65
  self._interpreter_builder(edge_model.tflite_model())
70
66
  )
@@ -76,9 +72,9 @@ class TestModelConversion(googletest.TestCase):
76
72
  tokens,
77
73
  input_pos,
78
74
  kv,
79
- signature_name="serving_default",
80
- atol=1e-2,
81
- rtol=1e-5,
75
+ signature_name=signature_name,
76
+ atol=atol,
77
+ rtol=rtol,
82
78
  )
83
79
  )
84
80
 
@@ -86,42 +82,21 @@ class TestModelConversion(googletest.TestCase):
86
82
  ai_edge_config.Config.use_torch_xla,
87
83
  reason="tests with custom ops are not supported on oss",
88
84
  )
89
- def test_gemma2(self):
90
- config = gemma2.get_fake_model_config()
91
- model = gemma2.Gemma2(config)
92
- model.eval()
93
-
94
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
95
- prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
96
- prefill_tokens[0, :4] = idx
97
- prefill_input_pos = torch.arange(0, 10)
98
- kv = kv_cache.KVCache.from_model_config(config)
99
-
100
- edge_model = ai_edge_torch.signature(
101
- "prefill",
102
- model,
103
- sample_kwargs={
104
- "tokens": prefill_tokens,
105
- "input_pos": prefill_input_pos,
106
- "kv_cache": kv,
107
- },
108
- ).convert()
109
- edge_model.set_interpreter_builder(
110
- self._interpreter_builder(edge_model.tflite_model())
85
+ def test_gemma(self):
86
+ config = gemma.get_fake_model_config()
87
+ pytorch_model = gemma.Gemma(config).eval()
88
+ self._test_model(
89
+ config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
111
90
  )
112
91
 
113
- self.assertTrue(
114
- test_utils.compare_tflite_torch(
115
- edge_model,
116
- model,
117
- prefill_tokens,
118
- prefill_input_pos,
119
- kv,
120
- signature_name="prefill",
121
- atol=1e-1,
122
- rtol=1e-3,
123
- )
124
- )
92
+ @googletest.skipIf(
93
+ ai_edge_config.Config.use_torch_xla,
94
+ reason="tests with custom ops are not supported on oss",
95
+ )
96
+ def test_gemma2(self):
97
+ config = gemma2.get_fake_model_config()
98
+ pytorch_model = gemma2.Gemma2(config).eval()
99
+ self._test_model(config, pytorch_model, "prefill", atol=1e-1, rtol=1e-3)
125
100
 
126
101
  @googletest.skipIf(
127
102
  ai_edge_config.Config.use_torch_xla,
@@ -130,37 +105,27 @@ class TestModelConversion(googletest.TestCase):
130
105
  def test_phi2(self):
131
106
  config = phi2.get_fake_model_config()
132
107
  pytorch_model = phi2.Phi2(config).eval()
133
-
134
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
135
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
136
- tokens[0, :4] = idx
137
- input_pos = torch.arange(0, 10)
138
- kv = kv_cache.KVCache.from_model_config(config)
139
-
140
- edge_model = ai_edge_torch.convert(
141
- pytorch_model,
142
- sample_kwargs={
143
- "tokens": tokens,
144
- "input_pos": input_pos,
145
- "kv_cache": kv,
146
- },
147
- )
148
- edge_model.set_interpreter_builder(
149
- self._interpreter_builder(edge_model.tflite_model())
108
+ self._test_model(
109
+ config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
150
110
  )
151
111
 
152
- self.assertTrue(
153
- test_utils.compare_tflite_torch(
154
- edge_model,
155
- pytorch_model,
156
- tokens,
157
- input_pos,
158
- kv,
159
- signature_name="serving_default",
160
- atol=1e-3,
161
- rtol=1e-3,
162
- )
163
- )
112
+ @googletest.skipIf(
113
+ ai_edge_config.Config.use_torch_xla,
114
+ reason="tests with custom ops are not supported on oss",
115
+ )
116
+ def test_smollm(self):
117
+ config = smollm.get_fake_model_config()
118
+ pytorch_model = smollm.SmolLM(config).eval()
119
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
120
+
121
+ @googletest.skipIf(
122
+ ai_edge_config.Config.use_torch_xla,
123
+ reason="tests with custom ops are not supported on oss",
124
+ )
125
+ def test_openelm(self):
126
+ config = openelm.get_fake_model_config()
127
+ pytorch_model = openelm.OpenELM(config).eval()
128
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
164
129
 
165
130
 
166
131
  if __name__ == "__main__":
@@ -115,8 +115,8 @@ class TestQuantizeConvert(parameterized.TestCase):
115
115
  def test_quantize_convert_toy_sizes(self, quant_config):
116
116
  config = toy_model.get_model_config()
117
117
  pytorch_model = toy_model.ToySingleLayerModel(config)
118
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
119
- input_pos = torch.arange(0, 100)
118
+ idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
119
+ input_pos = torch.arange(0, 100, dtype=torch.int)
120
120
 
121
121
  quantized_model = ai_edge_torch.convert(
122
122
  pytorch_model, (idx, input_pos), quant_config=quant_config
@@ -131,8 +131,8 @@ class TestQuantizeConvert(parameterized.TestCase):
131
131
  def test_quantize_convert_toy_weight_sharing(self):
132
132
  config = toy_model.get_model_config()
133
133
  pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
134
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
135
- input_pos = torch.arange(0, 100)
134
+ idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
135
+ input_pos = torch.arange(0, 100, dtype=torch.int)
136
136
 
137
137
  quant_config = quant_recipes.full_int8_dynamic_recipe()
138
138
  quantized_model = ai_edge_torch.convert(
@@ -149,7 +149,7 @@ class TestQuantizeConvert(parameterized.TestCase):
149
149
  self.skipTest("b/338288901")
150
150
  config = toy_model_with_kv_cache.get_model_config()
151
151
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
152
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
152
+ idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
153
153
  [10], dtype=torch.int64
154
154
  )
155
155
 
@@ -101,6 +101,8 @@ class ModelLoader:
101
101
  attn_value_proj: str = None
102
102
  attn_fused_qkv_proj: str = None
103
103
  attn_output_proj: str = None
104
+ attn_query_norm: str = None
105
+ attn_key_norm: str = None
104
106
 
105
107
  ff_up_proj: str = None
106
108
  ff_down_proj: str = None
@@ -323,6 +325,17 @@ class ModelLoader:
323
325
  )
324
326
  )
325
327
 
328
+ if self._names.attn_query_norm is not None:
329
+ attn_query_norm_name = self._names.attn_query_norm.format(idx)
330
+ converted_state[f"{prefix}.atten_func.query_norm.weight"] = state.pop(
331
+ f"{attn_query_norm_name}.weight"
332
+ )
333
+ if self._names.attn_key_norm is not None:
334
+ attn_key_norm_name = self._names.attn_key_norm.format(idx)
335
+ converted_state[f"{prefix}.atten_func.key_norm.weight"] = state.pop(
336
+ f"{attn_key_norm_name}.weight"
337
+ )
338
+
326
339
  o_name = self._names.attn_output_proj.format(idx)
327
340
  converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
328
341
  state.pop(f"{o_name}.weight")
@@ -223,6 +223,41 @@ class MlirLowered:
223
223
  return tf_integration.mlir_to_flatbuffer(self)
224
224
 
225
225
 
226
+ # TODO(b/331481564) Make this a ai_edge_torch FX pass.
227
+ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
228
+ """Convert internal constant aten ops' output from int64 to int32.
229
+
230
+ Int32 generally has better performance and compatibility than int64 in
231
+ runtime. This pass converts aten op where the output(s) are int64 constant
232
+ tensors to return int32 constant tensors.
233
+
234
+ Args:
235
+ exported_program: The exported program to apply the pass.
236
+ """
237
+
238
+ def in_i32(x: int):
239
+ return -2147483648 <= x <= 2147483647
240
+
241
+ def rewrite_arange(node: torch.fx.Node):
242
+ tensor_meta = node.meta.get("tensor_meta", None)
243
+ if not tensor_meta:
244
+ return
245
+
246
+ start, end = node.args[:2]
247
+ if tensor_meta.dtype != torch.int64:
248
+ return
249
+ if not (in_i32(start) and in_i32(end)):
250
+ return
251
+ op = node.target
252
+ node.target = lambda *args, **kwargs: op(*args, **kwargs).type(torch.int32)
253
+
254
+ graph_module = exported_program.graph_module
255
+ for node in graph_module.graph.nodes:
256
+
257
+ if node.target == torch.ops.aten.arange.start_step:
258
+ rewrite_arange(node)
259
+
260
+
226
261
  def exported_program_to_mlir(
227
262
  exported_program: torch.export.ExportedProgram,
228
263
  ) -> MlirLowered:
@@ -231,6 +266,11 @@ def exported_program_to_mlir(
231
266
  lowerings.decompositions()
232
267
  )
233
268
 
269
+ _convert_i64_to_i32(exported_program)
270
+ exported_program = exported_program.run_decompositions(
271
+ lowerings.decompositions()
272
+ )
273
+
234
274
  with export_utils.create_ir_context() as context, ir.Location.unknown():
235
275
 
236
276
  module = ir.Module.create()
@@ -202,3 +202,47 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
202
202
  x, y = utils.broadcast_args_if_needed(x, y)
203
203
 
204
204
  return stablehlo.divide(x, y)
205
+
206
+
207
+ # Schema:
208
+ # - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
209
+ # start=None, SymInt? end=None, SymInt step=1) -> Tensor
210
+ # Torch Reference:
211
+ # - https://pytorch.org/docs/stable/generated/torch.slice_scatter.html
212
+ # - https://github.com/pytorch/pytorch/blob/18f9331e5deb4c02ae5c206e133a9b4add49bd97/aten/src/ATen/native/TensorShape.cpp#L4002
213
+ @lower(torch.ops.aten.slice_scatter)
214
+ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
215
+ start = start or 0
216
+ end = end or self.type.shape[dim]
217
+ if start < 0:
218
+ start = self.type.shape[dim] + start
219
+ if end < 0:
220
+ end = self.type.shape[dim] + end
221
+
222
+ end = start + step * math.ceil((end - start) / step) - (step - 1)
223
+
224
+ padding_low = start
225
+ padding_high = self.type.shape[dim] - end
226
+
227
+ rank = len(self.type.shape)
228
+ src = stablehlo.pad(
229
+ src,
230
+ utils.splat(0, src.type.element_type, []),
231
+ edge_padding_low=[padding_low if i == dim else 0 for i in range(rank)],
232
+ edge_padding_high=[padding_high if i == dim else 0 for i in range(rank)],
233
+ interior_padding=[step - 1 if i == dim else 0 for i in range(rank)],
234
+ )
235
+ pred = np.ones(self.type.shape, dtype=np.bool_)
236
+ pred[*[
237
+ slice(start, end, step) if i == dim else slice(None, None, None)
238
+ for i in range(rank)
239
+ ]] = False
240
+ pred = stablehlo.constant(
241
+ ir.DenseElementsAttr.get(
242
+ np.packbits(pred, bitorder="little"),
243
+ type=ir.IntegerType.get_signless(1),
244
+ shape=pred.shape,
245
+ )
246
+ )
247
+ out = stablehlo.select(pred, self, src)
248
+ return out
@@ -203,7 +203,6 @@ lower_by_torch_xla2(torch.ops.aten.sin)
203
203
  lower_by_torch_xla2(torch.ops.aten.sinh)
204
204
  lower_by_torch_xla2(torch.ops.aten.slice)
205
205
  lower_by_torch_xla2(torch.ops.aten.slice_copy)
206
- lower_by_torch_xla2(torch.ops.aten.slice_scatter)
207
206
  lower_by_torch_xla2(torch.ops.aten.sort)
208
207
  lower_by_torch_xla2(torch.ops.aten.split)
209
208
  lower_by_torch_xla2(torch.ops.aten.split_copy)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240913"
16
+ __version__ = "0.3.0.dev20240914"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240913
3
+ Version: 0.3.0.dev20240914
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI