ai-edge-torch-nightly 0.3.0.dev20250123__py3-none-any.whl → 0.3.0.dev20250125__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. ai_edge_torch/_config.py +9 -0
  2. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +11 -8
  3. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +22 -24
  4. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +3 -4
  5. ai_edge_torch/generative/examples/deepseek/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +80 -0
  7. ai_edge_torch/generative/examples/deepseek/deepseek.py +92 -0
  8. ai_edge_torch/generative/examples/deepseek/verify.py +70 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +3 -0
  10. ai_edge_torch/generative/layers/experimental/__init__.py +14 -0
  11. ai_edge_torch/generative/layers/experimental/attention.py +269 -0
  12. ai_edge_torch/generative/layers/experimental/kv_cache.py +314 -0
  13. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +97 -0
  14. ai_edge_torch/generative/layers/experimental/types.py +97 -0
  15. ai_edge_torch/generative/layers/kv_cache.py +2 -1
  16. ai_edge_torch/generative/layers/model_config.py +5 -1
  17. ai_edge_torch/generative/test/test_model_conversion_large.py +11 -2
  18. ai_edge_torch/generative/utilities/bmm_4d.py +76 -0
  19. ai_edge_torch/generative/utilities/converter.py +18 -2
  20. ai_edge_torch/generative/utilities/model_builder.py +6 -1
  21. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +1 -1
  22. ai_edge_torch/quantize/pt2e_quantizer_utils.py +22 -2
  23. ai_edge_torch/version.py +1 -1
  24. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/METADATA +1 -1
  25. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/RECORD +28 -18
  26. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/LICENSE +0 -0
  27. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/WHEEL +0 -0
  28. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/top_level.txt +0 -0
ai_edge_torch/_config.py CHANGED
@@ -65,5 +65,14 @@ class _Config:
65
65
  def enable_group_norm_composite(self, value: bool):
66
66
  os.environ["ENABLE_GROUP_NORM_COMPOSITE"] = "y" if value else "n"
67
67
 
68
+ @property
69
+ def layout_optimize_partitioner(self) -> str:
70
+ """The algorithm to use for layout optimization."""
71
+ return os.environ.get("AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER", "DEFAULT")
72
+
73
+ @layout_optimize_partitioner.setter
74
+ def layout_optimize_partitioner(self, value: str):
75
+ os.environ["AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER"] = str(value).upper()
76
+
68
77
 
69
78
  config = _Config()
@@ -201,8 +201,14 @@ def _aten_group_norm_checker(node):
201
201
  return NHWCable(can_be=can_be, must_be=must_be)
202
202
 
203
203
 
204
- @nhwcable_node_checkers.register(aten.native_group_norm)
204
+ @nhwcable_node_checkers.register(aten.native_group_norm.default)
205
205
  def _aten_native_group_norm_checker(node):
206
+ # aten.group_norm is removed from the decomp table, so aten.native_group_norm
207
+ # should never exist in the graph. However, torch 2.5.1 could ignore the
208
+ # decomp table updates, so still add this native_group_norm checker and
209
+ # rewriter to be safe.
210
+ # The checker and rewriter are the same as the ones for aten.group_norm.
211
+
206
212
  val = node.meta.get("val")
207
213
  if (
208
214
  not isinstance(val, (list, tuple))
@@ -210,13 +216,10 @@ def _aten_native_group_norm_checker(node):
210
216
  or not hasattr(val[0], "shape")
211
217
  ):
212
218
  return NHWCable(can_be=False, must_be=False)
213
- if len(node.args) >= 3 and (
214
- node.args[1] is not None or node.args[2] is not None
215
- ):
216
- # Disable NHWC rewriter due to precision issue with weight and bias.
217
- # TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
218
- return NHWCable(can_be=False, must_be=False)
219
- return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
219
+
220
+ can_be = len(val[0].shape) == 4
221
+ must_be = can_be and ai_edge_torch.config.enable_group_norm_composite
222
+ return NHWCable(can_be=can_be, must_be=must_be)
220
223
 
221
224
 
222
225
  # ==== Ops must be NCHW
@@ -391,34 +391,32 @@ def _aten_native_group_norm(node):
391
391
  eps: float,
392
392
  **kwargs,
393
393
  ):
394
- input_reshaped = torch.reshape(
395
- input,
396
- [
397
- batch_size,
398
- flattened_inner_size,
399
- num_groups,
400
- num_channels // num_groups,
401
- ],
402
- )
403
- reduction_dims = [1, 3]
404
-
405
- biased_var, mean = torch.var_mean(
406
- input_reshaped, dim=reduction_dims, unbiased=False, keepdim=True
394
+ is_composite_supported = (
395
+ ai_edge_torch.config.enable_group_norm_composite
396
+ and weight is not None
397
+ and bias is not None
407
398
  )
408
- rstd = torch.rsqrt(biased_var + eps)
409
-
410
- out = (input_reshaped - mean) * rstd
411
- out = torch.reshape(out, input.shape)
412
399
 
413
- if weight is not None:
414
- out = out * weight
415
- if bias is not None:
416
- out = out + bias
400
+ builder = None
401
+ if is_composite_supported:
402
+ builder = StableHLOCompositeBuilder(
403
+ name="odml.group_norm",
404
+ attr={
405
+ "num_groups": num_groups,
406
+ "epsilon": eps,
407
+ "reduction_axes": [3],
408
+ "channel_axis": 3,
409
+ },
410
+ )
411
+ input, weight, bias = builder.mark_inputs(input, weight, bias)
417
412
 
418
- mean = torch.squeeze(mean, reduction_dims)
419
- rstd = torch.squeeze(rstd, reduction_dims)
413
+ input = utils.tensor_to_nchw(input)
414
+ output = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
415
+ output = utils.tensor_to_nhwc(output)
420
416
 
421
- return out, mean, rstd
417
+ if builder is not None:
418
+ output = builder.mark_outputs(output)
419
+ return (output, None, None)
422
420
 
423
421
  node.target = native_group_norm
424
422
 
@@ -18,6 +18,7 @@ import operator
18
18
  import os
19
19
  from typing import Union
20
20
 
21
+ import ai_edge_torch
21
22
  from ai_edge_torch import fx_infra
22
23
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
23
24
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
@@ -261,10 +262,8 @@ class OptimizeLayoutTransposesPass(fx_infra.ExportedProgramPassBase):
261
262
  self.mark_const_nodes(exported_program)
262
263
 
263
264
  graph_module = exported_program.graph_module
264
- partitioner = os.environ.get(
265
- "AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER", None
266
- )
267
- if partitioner == "MINCUT":
265
+ partitioner = ai_edge_torch.config.layout_optimize_partitioner
266
+ if partitioner in ("MINCUT", "OPTIMAL"):
268
267
  graph_module = layout_partitioners.min_cut.partition(graph_module)
269
268
  elif partitioner == "GREEDY":
270
269
  graph_module = layout_partitioners.greedy.partition(graph_module)
@@ -0,0 +1,14 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,80 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting DeepSeek R1 distilled models to tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.deepseek import deepseek
24
+ from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
+
27
+ _CHECKPOINT_PATH = flags.DEFINE_string(
28
+ 'checkpoint_path',
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/deepseek'),
30
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
31
+ )
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
+ '/tmp/',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'deepseek',
40
+ 'The prefix of the output tflite model name.',
41
+ )
42
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
+ 'prefill_seq_lens',
44
+ (8, 64, 128, 256, 512, 1024),
45
+ 'List of the maximum sizes of prefill input tensors.',
46
+ )
47
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
+ 'kv_cache_max_len',
49
+ 1280,
50
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
51
+ )
52
+ _QUANTIZE = flags.DEFINE_bool(
53
+ 'quantize',
54
+ True,
55
+ 'Whether the model should be quantized.',
56
+ )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
62
+
63
+
64
+ def main(_):
65
+ pytorch_model = deepseek.build_model(
66
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
67
+ )
68
+ converter.convert_to_tflite(
69
+ pytorch_model,
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
72
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
73
+ quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
75
+ export_config=ExportConfig(),
76
+ )
77
+
78
+
79
+ if __name__ == '__main__':
80
+ app.run(main)
@@ -0,0 +1,92 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building DeepSeek R1 distilled models."""
17
+
18
+ import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
21
+
22
+ TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
23
+
24
+
25
+ class DeepSeekDistillQwen(model_builder.DecoderOnlyModel):
26
+ """A DeepSeek distilled model based on Qwen."""
27
+ pass
28
+
29
+
30
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
31
+ """Returns the model config for a Qwen 2.5 3B model.
32
+
33
+ Args:
34
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
35
+ is 1024.
36
+
37
+ Returns:
38
+ The model config for a SmolLM model.
39
+ """
40
+ attn_config = cfg.AttentionConfig(
41
+ num_heads=12,
42
+ head_dim=128,
43
+ num_query_groups=2,
44
+ rotary_base=10000,
45
+ rotary_percentage=1.0,
46
+ qkv_use_bias=True,
47
+ )
48
+ ff_config = cfg.FeedForwardConfig(
49
+ type=cfg.FeedForwardType.GATED,
50
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
51
+ intermediate_size=8960,
52
+ )
53
+ norm_config = cfg.NormalizationConfig(
54
+ type=cfg.NormalizationType.RMS_NORM,
55
+ epsilon=1e-06,
56
+ )
57
+ block_config = cfg.TransformerBlockConfig(
58
+ attn_config=attn_config,
59
+ ff_config=ff_config,
60
+ pre_attention_norm_config=norm_config,
61
+ post_attention_norm_config=norm_config,
62
+ )
63
+ config = cfg.ModelConfig(
64
+ vocab_size=151936,
65
+ num_layers=28,
66
+ max_seq_len=4096,
67
+ embedding_dim=1536,
68
+ kv_cache_max_len=kv_cache_max_len,
69
+ block_configs=block_config,
70
+ final_norm_config=norm_config,
71
+ lm_head_share_weight_with_embedding=False,
72
+ enable_hlfb=True,
73
+ )
74
+ return config
75
+
76
+
77
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
78
+ config = get_model_config(**kwargs)
79
+ config.vocab_size = 128
80
+ config.num_layers = 2
81
+ # DeepSeek-R1-Distill-Qwen has only one block config.
82
+ config.block_config(0).ff_config.intermediate_size = 64
83
+ return config
84
+
85
+
86
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
87
+ return model_builder.build_decoder_only_model(
88
+ checkpoint_path=checkpoint_path,
89
+ config=get_model_config(**kwargs),
90
+ tensor_names=TENSOR_NAMES,
91
+ model_class=DeepSeekDistillQwen,
92
+ )
@@ -0,0 +1,70 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored DeepSeek R1 distilled 1.5B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.deepseek import deepseek
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _PROMPTS = flags.DEFINE_multi_string(
30
+ "prompts",
31
+ "What is the meaning of life?",
32
+ "The input prompts to generate answers.",
33
+ )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
39
+
40
+
41
+ def main(_):
42
+ checkpoint = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
46
+ # Locate the cached dir.
47
+ cached_config_file = transformers.utils.cached_file(
48
+ checkpoint, transformers.utils.CONFIG_NAME
49
+ )
50
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
+ reauthored_model = deepseek.build_model(reauthored_checkpoint)
53
+
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
55
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56
+
57
+ verifier.verify_reauthored_model(
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
63
+ generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
65
+ atol=1e-04,
66
+ )
67
+
68
+
69
+ if __name__ == "__main__":
70
+ app.run(main)
@@ -85,6 +85,7 @@ def convert_stable_diffusion_to_tflite(
85
85
  clip.TENSOR_NAMES,
86
86
  )
87
87
  loader.load(clip_model, strict=False)
88
+ clip_model.eval()
88
89
 
89
90
  diffusion_model = diffusion.Diffusion(
90
91
  diffusion.get_model_config(batch_size=2, device_type=_DEVICE_TYPE.value)
@@ -93,6 +94,7 @@ def convert_stable_diffusion_to_tflite(
93
94
  diffusion_ckpt_path, diffusion.TENSOR_NAMES
94
95
  )
95
96
  diffusion_loader.load(diffusion_model, strict=False)
97
+ diffusion_model.eval()
96
98
 
97
99
  decoder_model = decoder.Decoder(
98
100
  decoder.get_model_config(device_type=_DEVICE_TYPE.value)
@@ -101,6 +103,7 @@ def convert_stable_diffusion_to_tflite(
101
103
  decoder_ckpt_path, decoder.TENSOR_NAMES
102
104
  )
103
105
  decoder_loader.load(decoder_model, strict=False)
106
+ decoder_model.eval()
104
107
 
105
108
  # TODO(yichunk): enable image encoder conversion
106
109
  # if encoder_ckpt_path is not None:
@@ -0,0 +1,14 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #