ai-edge-torch-nightly 0.3.0.dev20240913__py3-none-any.whl → 0.3.0.dev20240915__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (50) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
  4. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
  5. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
  6. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
  7. ai_edge_torch/config.py +4 -1
  8. ai_edge_torch/fx_pass_base.py +101 -0
  9. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +4 -4
  10. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -4
  11. ai_edge_torch/generative/examples/gemma/gemma.py +2 -2
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -2
  13. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +86 -0
  14. ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
  15. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +4 -4
  16. ai_edge_torch/generative/examples/phi/phi2.py +2 -2
  17. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  18. ai_edge_torch/generative/examples/{smallm → smollm}/convert_to_tflite.py +12 -12
  19. ai_edge_torch/generative/examples/{smallm/smallm.py → smollm/smollm.py} +24 -15
  20. ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
  21. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
  22. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
  23. ai_edge_torch/generative/examples/t5/t5.py +8 -8
  24. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  25. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -4
  26. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
  27. ai_edge_torch/generative/fx_passes/__init__.py +4 -4
  28. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
  29. ai_edge_torch/generative/layers/attention.py +7 -0
  30. ai_edge_torch/generative/layers/builder.py +33 -11
  31. ai_edge_torch/generative/layers/feed_forward.py +26 -8
  32. ai_edge_torch/generative/layers/kv_cache.py +4 -4
  33. ai_edge_torch/generative/layers/model_config.py +24 -15
  34. ai_edge_torch/generative/quantize/example.py +2 -2
  35. ai_edge_torch/generative/test/test_model_conversion.py +28 -51
  36. ai_edge_torch/generative/test/test_model_conversion_large.py +43 -78
  37. ai_edge_torch/generative/test/test_quantize.py +5 -5
  38. ai_edge_torch/generative/utilities/loader.py +13 -0
  39. ai_edge_torch/odml_torch/export.py +40 -0
  40. ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
  41. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  42. ai_edge_torch/version.py +1 -1
  43. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/METADATA +1 -1
  44. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/RECORD +48 -46
  45. ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
  46. ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
  47. /ai_edge_torch/generative/examples/{smallm → openelm}/__init__.py +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/LICENSE +0 -0
  49. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/WHEEL +0 -0
  50. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,7 @@ class ActivationType(enum.Enum):
30
30
  GELU_QUICK = enum.auto()
31
31
  GE_GLU = enum.auto()
32
32
  RELU = enum.auto()
33
+ SILU_GLU = enum.auto()
33
34
 
34
35
 
35
36
  @enum.unique
@@ -58,6 +59,18 @@ class AttentionType(enum.Enum):
58
59
  LOCAL_SLIDING = enum.auto()
59
60
 
60
61
 
62
+ @dataclass
63
+ class NormalizationConfig:
64
+ """Normalizater parameters."""
65
+
66
+ type: NormalizationType = NormalizationType.NONE
67
+ enable_hlfb: bool = False
68
+ epsilon: float = 1e-5
69
+ zero_centered: bool = False
70
+ # Number of groups used in group normalization.
71
+ group_num: Optional[float] = None
72
+
73
+
61
74
  @dataclass
62
75
  class AttentionConfig:
63
76
  """Attention model's parameters."""
@@ -81,6 +94,14 @@ class AttentionConfig:
81
94
  # Whether to use bias with attention output projection.
82
95
  output_proj_use_bias: bool = False
83
96
  enable_kv_cache: bool = True
97
+ # The normalization applied to query projection's output.
98
+ query_norm_config: NormalizationConfig = field(
99
+ default_factory=NormalizationConfig
100
+ )
101
+ # The normalization applied to key projection's output.
102
+ key_norm_config: NormalizationConfig = field(
103
+ default_factory=NormalizationConfig
104
+ )
84
105
  relative_attention_num_buckets: int = 0
85
106
  relative_attention_max_distance: int = 0
86
107
  # Softcap on the output logits.
@@ -94,21 +115,9 @@ class AttentionConfig:
94
115
  @dataclass
95
116
  class ActivationConfig:
96
117
  type: ActivationType = ActivationType.LINEAR
97
- # Dimension of input and output, used in GeGLU.
98
- dim_in: Optional[int] = None
99
- dim_out: Optional[int] = None
100
-
101
-
102
- @dataclass
103
- class NormalizationConfig:
104
- """Normalizater parameters."""
105
-
106
- type: NormalizationType = NormalizationType.NONE
107
- enable_hlfb: bool = False
108
- epsilon: float = 1e-5
109
- zero_centered: bool = False
110
- # Number of groups used in group normalization.
111
- group_num: Optional[float] = None
118
+ # Whether to GLU gate is the front part instead of the back part of input
119
+ # when ActivationType is `GE_GLU` or `SILU_GLU`.
120
+ gate_is_front: bool = False
112
121
 
113
122
 
114
123
  @dataclass
@@ -25,9 +25,9 @@ def main():
25
25
  config = gemma.get_fake_model_config()
26
26
  model = gemma.Gemma(config)
27
27
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
28
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
28
+ tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
29
29
  tokens[0, :4] = idx
30
- input_pos = torch.arange(0, 10)
30
+ input_pos = torch.arange(0, 10, dtype=torch.int)
31
31
 
32
32
  # Create a quantization recipe to be applied to the model
33
33
  quant_config = quant_recipes.full_int8_dynamic_recipe()
@@ -42,15 +42,9 @@ class TestModelConversion(googletest.TestCase):
42
42
  )
43
43
  )
44
44
 
45
- @googletest.skipIf(
46
- ai_edge_config.Config.use_torch_xla,
47
- reason="tests with custom ops are not supported on oss",
48
- )
49
- def test_toy_model_with_kv_cache(self):
50
- config = toy_model_with_kv_cache.get_model_config()
51
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
52
- tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
53
- [10], dtype=torch.int64
45
+ def _test_model_with_kv_cache(self, config, pytorch_model):
46
+ tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
47
+ [10], dtype=torch.int
54
48
  )
55
49
  kv = kv_cache.KVCache.from_model_config(config)
56
50
 
@@ -83,58 +77,32 @@ class TestModelConversion(googletest.TestCase):
83
77
  ai_edge_config.Config.use_torch_xla,
84
78
  reason="tests with custom ops are not supported on oss",
85
79
  )
86
- def test_toy_model_with_kv_cache_with_hlfb(self):
80
+ def test_toy_model_with_kv_cache(self):
87
81
  config = toy_model_with_kv_cache.get_model_config()
88
- config.enable_hlfb = True
89
82
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
90
- tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
91
- [10], dtype=torch.int64
92
- )
93
- kv = kv_cache.KVCache.from_model_config(config)
94
-
95
- edge_model = ai_edge_torch.convert(
96
- pytorch_model,
97
- sample_kwargs={
98
- "tokens": tokens,
99
- "input_pos": input_pos,
100
- "kv_cache": kv,
101
- },
102
- )
103
- edge_model.set_interpreter_builder(
104
- self._interpreter_builder(edge_model.tflite_model())
105
- )
106
-
107
- self.assertTrue(
108
- test_utils.compare_tflite_torch(
109
- edge_model,
110
- pytorch_model,
111
- tokens,
112
- input_pos,
113
- kv,
114
- signature_name="serving_default",
115
- atol=1e-5,
116
- rtol=1e-5,
117
- )
118
- )
83
+ self._test_model_with_kv_cache(config, pytorch_model)
119
84
 
120
85
  @googletest.skipIf(
121
86
  ai_edge_config.Config.use_torch_xla,
122
87
  reason="tests with custom ops are not supported on oss",
123
88
  )
124
- def test_tiny_llama_multisig(self):
125
- config = tiny_llama.get_fake_model_config()
126
- pytorch_model = tiny_llama.TinyLlama(config).eval()
89
+ def test_toy_model_with_kv_cache_with_hlfb(self):
90
+ config = toy_model_with_kv_cache.get_model_config()
91
+ config.enable_hlfb = True
92
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
93
+ self._test_model_with_kv_cache(config, pytorch_model)
127
94
 
95
+ def _test_multisig_model(self, config, pytorch_model, atol, rtol):
128
96
  # prefill
129
97
  seq_len = 10
130
- prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.long, device="cpu")
98
+ prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.int, device="cpu")
131
99
  prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
132
100
  prefill_tokens[0, : len(prompt_token)] = prompt_token
133
- prefill_input_pos = torch.arange(0, seq_len)
101
+ prefill_input_pos = torch.arange(0, seq_len, dtype=torch.int)
134
102
 
135
103
  # decode
136
- decode_token = torch.tensor([[1]], dtype=torch.long)
137
- decode_input_pos = torch.tensor([5], dtype=torch.int64)
104
+ decode_token = torch.tensor([[1]], dtype=torch.int)
105
+ decode_input_pos = torch.tensor([5], dtype=torch.int)
138
106
 
139
107
  kv = kv_cache.KVCache.from_model_config(config)
140
108
 
@@ -171,8 +139,8 @@ class TestModelConversion(googletest.TestCase):
171
139
  prefill_input_pos,
172
140
  kv,
173
141
  signature_name="prefill",
174
- atol=1e-5,
175
- rtol=1e-5,
142
+ atol=atol,
143
+ rtol=atol,
176
144
  )
177
145
  )
178
146
 
@@ -184,11 +152,20 @@ class TestModelConversion(googletest.TestCase):
184
152
  decode_input_pos,
185
153
  kv,
186
154
  signature_name="decode",
187
- atol=1e-5,
188
- rtol=1e-5,
155
+ atol=atol,
156
+ rtol=atol,
189
157
  )
190
158
  )
191
159
 
160
+ @googletest.skipIf(
161
+ ai_edge_config.Config.use_torch_xla,
162
+ reason="tests with custom ops are not supported on oss",
163
+ )
164
+ def test_tiny_llama_multisig(self):
165
+ config = tiny_llama.get_fake_model_config()
166
+ pytorch_model = tiny_llama.TinyLlama(config).eval()
167
+ self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
168
+
192
169
 
193
170
  if __name__ == "__main__":
194
171
  googletest.main()
@@ -19,7 +19,9 @@ import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
20
  from ai_edge_torch.generative.examples.gemma import gemma
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.openelm import openelm
22
23
  from ai_edge_torch.generative.examples.phi import phi2
24
+ from ai_edge_torch.generative.examples.smollm import smollm
23
25
  from ai_edge_torch.generative.layers import kv_cache
24
26
  from ai_edge_torch.generative.test import utils as test_utils
25
27
  import numpy as np
@@ -43,28 +45,22 @@ class TestModelConversion(googletest.TestCase):
43
45
  )
44
46
  )
45
47
 
46
- @googletest.skipIf(
47
- ai_edge_config.Config.use_torch_xla,
48
- reason="tests with custom ops are not supported on oss",
49
- )
50
- def test_gemma(self):
51
- config = gemma.get_fake_model_config()
52
- model = gemma.Gemma(config)
53
-
48
+ def _test_model(self, config, model, signature_name, atol, rtol):
54
49
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
55
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
50
+ tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
56
51
  tokens[0, :4] = idx
57
- input_pos = torch.arange(0, 10)
52
+ input_pos = torch.arange(0, 10, dtype=torch.int)
58
53
  kv = kv_cache.KVCache.from_model_config(config)
59
54
 
60
- edge_model = ai_edge_torch.convert(
55
+ edge_model = ai_edge_torch.signature(
56
+ signature_name,
61
57
  model,
62
58
  sample_kwargs={
63
59
  "tokens": tokens,
64
60
  "input_pos": input_pos,
65
61
  "kv_cache": kv,
66
62
  },
67
- )
63
+ ).convert()
68
64
  edge_model.set_interpreter_builder(
69
65
  self._interpreter_builder(edge_model.tflite_model())
70
66
  )
@@ -76,9 +72,9 @@ class TestModelConversion(googletest.TestCase):
76
72
  tokens,
77
73
  input_pos,
78
74
  kv,
79
- signature_name="serving_default",
80
- atol=1e-2,
81
- rtol=1e-5,
75
+ signature_name=signature_name,
76
+ atol=atol,
77
+ rtol=rtol,
82
78
  )
83
79
  )
84
80
 
@@ -86,42 +82,21 @@ class TestModelConversion(googletest.TestCase):
86
82
  ai_edge_config.Config.use_torch_xla,
87
83
  reason="tests with custom ops are not supported on oss",
88
84
  )
89
- def test_gemma2(self):
90
- config = gemma2.get_fake_model_config()
91
- model = gemma2.Gemma2(config)
92
- model.eval()
93
-
94
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
95
- prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
96
- prefill_tokens[0, :4] = idx
97
- prefill_input_pos = torch.arange(0, 10)
98
- kv = kv_cache.KVCache.from_model_config(config)
99
-
100
- edge_model = ai_edge_torch.signature(
101
- "prefill",
102
- model,
103
- sample_kwargs={
104
- "tokens": prefill_tokens,
105
- "input_pos": prefill_input_pos,
106
- "kv_cache": kv,
107
- },
108
- ).convert()
109
- edge_model.set_interpreter_builder(
110
- self._interpreter_builder(edge_model.tflite_model())
85
+ def test_gemma(self):
86
+ config = gemma.get_fake_model_config()
87
+ pytorch_model = gemma.Gemma(config).eval()
88
+ self._test_model(
89
+ config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
111
90
  )
112
91
 
113
- self.assertTrue(
114
- test_utils.compare_tflite_torch(
115
- edge_model,
116
- model,
117
- prefill_tokens,
118
- prefill_input_pos,
119
- kv,
120
- signature_name="prefill",
121
- atol=1e-1,
122
- rtol=1e-3,
123
- )
124
- )
92
+ @googletest.skipIf(
93
+ ai_edge_config.Config.use_torch_xla,
94
+ reason="tests with custom ops are not supported on oss",
95
+ )
96
+ def test_gemma2(self):
97
+ config = gemma2.get_fake_model_config()
98
+ pytorch_model = gemma2.Gemma2(config).eval()
99
+ self._test_model(config, pytorch_model, "prefill", atol=1e-1, rtol=1e-3)
125
100
 
126
101
  @googletest.skipIf(
127
102
  ai_edge_config.Config.use_torch_xla,
@@ -130,37 +105,27 @@ class TestModelConversion(googletest.TestCase):
130
105
  def test_phi2(self):
131
106
  config = phi2.get_fake_model_config()
132
107
  pytorch_model = phi2.Phi2(config).eval()
133
-
134
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
135
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
136
- tokens[0, :4] = idx
137
- input_pos = torch.arange(0, 10)
138
- kv = kv_cache.KVCache.from_model_config(config)
139
-
140
- edge_model = ai_edge_torch.convert(
141
- pytorch_model,
142
- sample_kwargs={
143
- "tokens": tokens,
144
- "input_pos": input_pos,
145
- "kv_cache": kv,
146
- },
147
- )
148
- edge_model.set_interpreter_builder(
149
- self._interpreter_builder(edge_model.tflite_model())
108
+ self._test_model(
109
+ config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
150
110
  )
151
111
 
152
- self.assertTrue(
153
- test_utils.compare_tflite_torch(
154
- edge_model,
155
- pytorch_model,
156
- tokens,
157
- input_pos,
158
- kv,
159
- signature_name="serving_default",
160
- atol=1e-3,
161
- rtol=1e-3,
162
- )
163
- )
112
+ @googletest.skipIf(
113
+ ai_edge_config.Config.use_torch_xla,
114
+ reason="tests with custom ops are not supported on oss",
115
+ )
116
+ def test_smollm(self):
117
+ config = smollm.get_fake_model_config()
118
+ pytorch_model = smollm.SmolLM(config).eval()
119
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
120
+
121
+ @googletest.skipIf(
122
+ ai_edge_config.Config.use_torch_xla,
123
+ reason="tests with custom ops are not supported on oss",
124
+ )
125
+ def test_openelm(self):
126
+ config = openelm.get_fake_model_config()
127
+ pytorch_model = openelm.OpenELM(config).eval()
128
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
164
129
 
165
130
 
166
131
  if __name__ == "__main__":
@@ -115,8 +115,8 @@ class TestQuantizeConvert(parameterized.TestCase):
115
115
  def test_quantize_convert_toy_sizes(self, quant_config):
116
116
  config = toy_model.get_model_config()
117
117
  pytorch_model = toy_model.ToySingleLayerModel(config)
118
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
119
- input_pos = torch.arange(0, 100)
118
+ idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
119
+ input_pos = torch.arange(0, 100, dtype=torch.int)
120
120
 
121
121
  quantized_model = ai_edge_torch.convert(
122
122
  pytorch_model, (idx, input_pos), quant_config=quant_config
@@ -131,8 +131,8 @@ class TestQuantizeConvert(parameterized.TestCase):
131
131
  def test_quantize_convert_toy_weight_sharing(self):
132
132
  config = toy_model.get_model_config()
133
133
  pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
134
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
135
- input_pos = torch.arange(0, 100)
134
+ idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
135
+ input_pos = torch.arange(0, 100, dtype=torch.int)
136
136
 
137
137
  quant_config = quant_recipes.full_int8_dynamic_recipe()
138
138
  quantized_model = ai_edge_torch.convert(
@@ -149,7 +149,7 @@ class TestQuantizeConvert(parameterized.TestCase):
149
149
  self.skipTest("b/338288901")
150
150
  config = toy_model_with_kv_cache.get_model_config()
151
151
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
152
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
152
+ idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
153
153
  [10], dtype=torch.int64
154
154
  )
155
155
 
@@ -101,6 +101,8 @@ class ModelLoader:
101
101
  attn_value_proj: str = None
102
102
  attn_fused_qkv_proj: str = None
103
103
  attn_output_proj: str = None
104
+ attn_query_norm: str = None
105
+ attn_key_norm: str = None
104
106
 
105
107
  ff_up_proj: str = None
106
108
  ff_down_proj: str = None
@@ -323,6 +325,17 @@ class ModelLoader:
323
325
  )
324
326
  )
325
327
 
328
+ if self._names.attn_query_norm is not None:
329
+ attn_query_norm_name = self._names.attn_query_norm.format(idx)
330
+ converted_state[f"{prefix}.atten_func.query_norm.weight"] = state.pop(
331
+ f"{attn_query_norm_name}.weight"
332
+ )
333
+ if self._names.attn_key_norm is not None:
334
+ attn_key_norm_name = self._names.attn_key_norm.format(idx)
335
+ converted_state[f"{prefix}.atten_func.key_norm.weight"] = state.pop(
336
+ f"{attn_key_norm_name}.weight"
337
+ )
338
+
326
339
  o_name = self._names.attn_output_proj.format(idx)
327
340
  converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
328
341
  state.pop(f"{o_name}.weight")
@@ -223,6 +223,41 @@ class MlirLowered:
223
223
  return tf_integration.mlir_to_flatbuffer(self)
224
224
 
225
225
 
226
+ # TODO(b/331481564) Make this a ai_edge_torch FX pass.
227
+ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
228
+ """Convert internal constant aten ops' output from int64 to int32.
229
+
230
+ Int32 generally has better performance and compatibility than int64 in
231
+ runtime. This pass converts aten op where the output(s) are int64 constant
232
+ tensors to return int32 constant tensors.
233
+
234
+ Args:
235
+ exported_program: The exported program to apply the pass.
236
+ """
237
+
238
+ def in_i32(x: int):
239
+ return -2147483648 <= x <= 2147483647
240
+
241
+ def rewrite_arange(node: torch.fx.Node):
242
+ tensor_meta = node.meta.get("tensor_meta", None)
243
+ if not tensor_meta:
244
+ return
245
+
246
+ start, end = node.args[:2]
247
+ if tensor_meta.dtype != torch.int64:
248
+ return
249
+ if not (in_i32(start) and in_i32(end)):
250
+ return
251
+ op = node.target
252
+ node.target = lambda *args, **kwargs: op(*args, **kwargs).type(torch.int32)
253
+
254
+ graph_module = exported_program.graph_module
255
+ for node in graph_module.graph.nodes:
256
+
257
+ if node.target == torch.ops.aten.arange.start_step:
258
+ rewrite_arange(node)
259
+
260
+
226
261
  def exported_program_to_mlir(
227
262
  exported_program: torch.export.ExportedProgram,
228
263
  ) -> MlirLowered:
@@ -231,6 +266,11 @@ def exported_program_to_mlir(
231
266
  lowerings.decompositions()
232
267
  )
233
268
 
269
+ _convert_i64_to_i32(exported_program)
270
+ exported_program = exported_program.run_decompositions(
271
+ lowerings.decompositions()
272
+ )
273
+
234
274
  with export_utils.create_ir_context() as context, ir.Location.unknown():
235
275
 
236
276
  module = ir.Module.create()
@@ -202,3 +202,47 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
202
202
  x, y = utils.broadcast_args_if_needed(x, y)
203
203
 
204
204
  return stablehlo.divide(x, y)
205
+
206
+
207
+ # Schema:
208
+ # - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
209
+ # start=None, SymInt? end=None, SymInt step=1) -> Tensor
210
+ # Torch Reference:
211
+ # - https://pytorch.org/docs/stable/generated/torch.slice_scatter.html
212
+ # - https://github.com/pytorch/pytorch/blob/18f9331e5deb4c02ae5c206e133a9b4add49bd97/aten/src/ATen/native/TensorShape.cpp#L4002
213
+ @lower(torch.ops.aten.slice_scatter)
214
+ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
215
+ start = start or 0
216
+ end = end or self.type.shape[dim]
217
+ if start < 0:
218
+ start = self.type.shape[dim] + start
219
+ if end < 0:
220
+ end = self.type.shape[dim] + end
221
+
222
+ end = start + step * math.ceil((end - start) / step) - (step - 1)
223
+
224
+ padding_low = start
225
+ padding_high = self.type.shape[dim] - end
226
+
227
+ rank = len(self.type.shape)
228
+ src = stablehlo.pad(
229
+ src,
230
+ utils.splat(0, src.type.element_type, []),
231
+ edge_padding_low=[padding_low if i == dim else 0 for i in range(rank)],
232
+ edge_padding_high=[padding_high if i == dim else 0 for i in range(rank)],
233
+ interior_padding=[step - 1 if i == dim else 0 for i in range(rank)],
234
+ )
235
+ pred = np.ones(self.type.shape, dtype=np.bool_)
236
+ pred[*[
237
+ slice(start, end, step) if i == dim else slice(None, None, None)
238
+ for i in range(rank)
239
+ ]] = False
240
+ pred = stablehlo.constant(
241
+ ir.DenseElementsAttr.get(
242
+ np.packbits(pred, bitorder="little"),
243
+ type=ir.IntegerType.get_signless(1),
244
+ shape=pred.shape,
245
+ )
246
+ )
247
+ out = stablehlo.select(pred, self, src)
248
+ return out
@@ -203,7 +203,6 @@ lower_by_torch_xla2(torch.ops.aten.sin)
203
203
  lower_by_torch_xla2(torch.ops.aten.sinh)
204
204
  lower_by_torch_xla2(torch.ops.aten.slice)
205
205
  lower_by_torch_xla2(torch.ops.aten.slice_copy)
206
- lower_by_torch_xla2(torch.ops.aten.slice_scatter)
207
206
  lower_by_torch_xla2(torch.ops.aten.sort)
208
207
  lower_by_torch_xla2(torch.ops.aten.split)
209
208
  lower_by_torch_xla2(torch.ops.aten.split_copy)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240913"
16
+ __version__ = "0.3.0.dev20240915"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240913
3
+ Version: 0.3.0.dev20240915
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI