ai-edge-torch-nightly 0.3.0.dev20250123__py3-none-any.whl → 0.3.0.dev20250125__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (28) hide show
  1. ai_edge_torch/_config.py +9 -0
  2. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +11 -8
  3. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +22 -24
  4. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +3 -4
  5. ai_edge_torch/generative/examples/deepseek/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +80 -0
  7. ai_edge_torch/generative/examples/deepseek/deepseek.py +92 -0
  8. ai_edge_torch/generative/examples/deepseek/verify.py +70 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +3 -0
  10. ai_edge_torch/generative/layers/experimental/__init__.py +14 -0
  11. ai_edge_torch/generative/layers/experimental/attention.py +269 -0
  12. ai_edge_torch/generative/layers/experimental/kv_cache.py +314 -0
  13. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +97 -0
  14. ai_edge_torch/generative/layers/experimental/types.py +97 -0
  15. ai_edge_torch/generative/layers/kv_cache.py +2 -1
  16. ai_edge_torch/generative/layers/model_config.py +5 -1
  17. ai_edge_torch/generative/test/test_model_conversion_large.py +11 -2
  18. ai_edge_torch/generative/utilities/bmm_4d.py +76 -0
  19. ai_edge_torch/generative/utilities/converter.py +18 -2
  20. ai_edge_torch/generative/utilities/model_builder.py +6 -1
  21. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +1 -1
  22. ai_edge_torch/quantize/pt2e_quantizer_utils.py +22 -2
  23. ai_edge_torch/version.py +1 -1
  24. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/METADATA +1 -1
  25. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/RECORD +28 -18
  26. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/LICENSE +0 -0
  27. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/WHEEL +0 -0
  28. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/top_level.txt +0 -0
ai_edge_torch/_config.py CHANGED
@@ -65,5 +65,14 @@ class _Config:
65
65
  def enable_group_norm_composite(self, value: bool):
66
66
  os.environ["ENABLE_GROUP_NORM_COMPOSITE"] = "y" if value else "n"
67
67
 
68
+ @property
69
+ def layout_optimize_partitioner(self) -> str:
70
+ """The algorithm to use for layout optimization."""
71
+ return os.environ.get("AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER", "DEFAULT")
72
+
73
+ @layout_optimize_partitioner.setter
74
+ def layout_optimize_partitioner(self, value: str):
75
+ os.environ["AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER"] = str(value).upper()
76
+
68
77
 
69
78
  config = _Config()
@@ -201,8 +201,14 @@ def _aten_group_norm_checker(node):
201
201
  return NHWCable(can_be=can_be, must_be=must_be)
202
202
 
203
203
 
204
- @nhwcable_node_checkers.register(aten.native_group_norm)
204
+ @nhwcable_node_checkers.register(aten.native_group_norm.default)
205
205
  def _aten_native_group_norm_checker(node):
206
+ # aten.group_norm is removed from the decomp table, so aten.native_group_norm
207
+ # should never exist in the graph. However, torch 2.5.1 could ignore the
208
+ # decomp table updates, so still add this native_group_norm checker and
209
+ # rewriter to be safe.
210
+ # The checker and rewriter are the same as the ones for aten.group_norm.
211
+
206
212
  val = node.meta.get("val")
207
213
  if (
208
214
  not isinstance(val, (list, tuple))
@@ -210,13 +216,10 @@ def _aten_native_group_norm_checker(node):
210
216
  or not hasattr(val[0], "shape")
211
217
  ):
212
218
  return NHWCable(can_be=False, must_be=False)
213
- if len(node.args) >= 3 and (
214
- node.args[1] is not None or node.args[2] is not None
215
- ):
216
- # Disable NHWC rewriter due to precision issue with weight and bias.
217
- # TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
218
- return NHWCable(can_be=False, must_be=False)
219
- return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
219
+
220
+ can_be = len(val[0].shape) == 4
221
+ must_be = can_be and ai_edge_torch.config.enable_group_norm_composite
222
+ return NHWCable(can_be=can_be, must_be=must_be)
220
223
 
221
224
 
222
225
  # ==== Ops must be NCHW
@@ -391,34 +391,32 @@ def _aten_native_group_norm(node):
391
391
  eps: float,
392
392
  **kwargs,
393
393
  ):
394
- input_reshaped = torch.reshape(
395
- input,
396
- [
397
- batch_size,
398
- flattened_inner_size,
399
- num_groups,
400
- num_channels // num_groups,
401
- ],
402
- )
403
- reduction_dims = [1, 3]
404
-
405
- biased_var, mean = torch.var_mean(
406
- input_reshaped, dim=reduction_dims, unbiased=False, keepdim=True
394
+ is_composite_supported = (
395
+ ai_edge_torch.config.enable_group_norm_composite
396
+ and weight is not None
397
+ and bias is not None
407
398
  )
408
- rstd = torch.rsqrt(biased_var + eps)
409
-
410
- out = (input_reshaped - mean) * rstd
411
- out = torch.reshape(out, input.shape)
412
399
 
413
- if weight is not None:
414
- out = out * weight
415
- if bias is not None:
416
- out = out + bias
400
+ builder = None
401
+ if is_composite_supported:
402
+ builder = StableHLOCompositeBuilder(
403
+ name="odml.group_norm",
404
+ attr={
405
+ "num_groups": num_groups,
406
+ "epsilon": eps,
407
+ "reduction_axes": [3],
408
+ "channel_axis": 3,
409
+ },
410
+ )
411
+ input, weight, bias = builder.mark_inputs(input, weight, bias)
417
412
 
418
- mean = torch.squeeze(mean, reduction_dims)
419
- rstd = torch.squeeze(rstd, reduction_dims)
413
+ input = utils.tensor_to_nchw(input)
414
+ output = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
415
+ output = utils.tensor_to_nhwc(output)
420
416
 
421
- return out, mean, rstd
417
+ if builder is not None:
418
+ output = builder.mark_outputs(output)
419
+ return (output, None, None)
422
420
 
423
421
  node.target = native_group_norm
424
422
 
@@ -18,6 +18,7 @@ import operator
18
18
  import os
19
19
  from typing import Union
20
20
 
21
+ import ai_edge_torch
21
22
  from ai_edge_torch import fx_infra
22
23
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
23
24
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
@@ -261,10 +262,8 @@ class OptimizeLayoutTransposesPass(fx_infra.ExportedProgramPassBase):
261
262
  self.mark_const_nodes(exported_program)
262
263
 
263
264
  graph_module = exported_program.graph_module
264
- partitioner = os.environ.get(
265
- "AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER", None
266
- )
267
- if partitioner == "MINCUT":
265
+ partitioner = ai_edge_torch.config.layout_optimize_partitioner
266
+ if partitioner in ("MINCUT", "OPTIMAL"):
268
267
  graph_module = layout_partitioners.min_cut.partition(graph_module)
269
268
  elif partitioner == "GREEDY":
270
269
  graph_module = layout_partitioners.greedy.partition(graph_module)
@@ -0,0 +1,14 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,80 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting DeepSeek R1 distilled models to tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.deepseek import deepseek
24
+ from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
+
27
+ _CHECKPOINT_PATH = flags.DEFINE_string(
28
+ 'checkpoint_path',
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/deepseek'),
30
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
31
+ )
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
+ '/tmp/',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'deepseek',
40
+ 'The prefix of the output tflite model name.',
41
+ )
42
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
+ 'prefill_seq_lens',
44
+ (8, 64, 128, 256, 512, 1024),
45
+ 'List of the maximum sizes of prefill input tensors.',
46
+ )
47
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
+ 'kv_cache_max_len',
49
+ 1280,
50
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
51
+ )
52
+ _QUANTIZE = flags.DEFINE_bool(
53
+ 'quantize',
54
+ True,
55
+ 'Whether the model should be quantized.',
56
+ )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
62
+
63
+
64
+ def main(_):
65
+ pytorch_model = deepseek.build_model(
66
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
67
+ )
68
+ converter.convert_to_tflite(
69
+ pytorch_model,
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
72
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
73
+ quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
75
+ export_config=ExportConfig(),
76
+ )
77
+
78
+
79
+ if __name__ == '__main__':
80
+ app.run(main)
@@ -0,0 +1,92 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building DeepSeek R1 distilled models."""
17
+
18
+ import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
21
+
22
+ TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
23
+
24
+
25
+ class DeepSeekDistillQwen(model_builder.DecoderOnlyModel):
26
+ """A DeepSeek distilled model based on Qwen."""
27
+ pass
28
+
29
+
30
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
31
+ """Returns the model config for a Qwen 2.5 3B model.
32
+
33
+ Args:
34
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
35
+ is 1024.
36
+
37
+ Returns:
38
+ The model config for a SmolLM model.
39
+ """
40
+ attn_config = cfg.AttentionConfig(
41
+ num_heads=12,
42
+ head_dim=128,
43
+ num_query_groups=2,
44
+ rotary_base=10000,
45
+ rotary_percentage=1.0,
46
+ qkv_use_bias=True,
47
+ )
48
+ ff_config = cfg.FeedForwardConfig(
49
+ type=cfg.FeedForwardType.GATED,
50
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
51
+ intermediate_size=8960,
52
+ )
53
+ norm_config = cfg.NormalizationConfig(
54
+ type=cfg.NormalizationType.RMS_NORM,
55
+ epsilon=1e-06,
56
+ )
57
+ block_config = cfg.TransformerBlockConfig(
58
+ attn_config=attn_config,
59
+ ff_config=ff_config,
60
+ pre_attention_norm_config=norm_config,
61
+ post_attention_norm_config=norm_config,
62
+ )
63
+ config = cfg.ModelConfig(
64
+ vocab_size=151936,
65
+ num_layers=28,
66
+ max_seq_len=4096,
67
+ embedding_dim=1536,
68
+ kv_cache_max_len=kv_cache_max_len,
69
+ block_configs=block_config,
70
+ final_norm_config=norm_config,
71
+ lm_head_share_weight_with_embedding=False,
72
+ enable_hlfb=True,
73
+ )
74
+ return config
75
+
76
+
77
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
78
+ config = get_model_config(**kwargs)
79
+ config.vocab_size = 128
80
+ config.num_layers = 2
81
+ # DeepSeek-R1-Distill-Qwen has only one block config.
82
+ config.block_config(0).ff_config.intermediate_size = 64
83
+ return config
84
+
85
+
86
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
87
+ return model_builder.build_decoder_only_model(
88
+ checkpoint_path=checkpoint_path,
89
+ config=get_model_config(**kwargs),
90
+ tensor_names=TENSOR_NAMES,
91
+ model_class=DeepSeekDistillQwen,
92
+ )
@@ -0,0 +1,70 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored DeepSeek R1 distilled 1.5B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.deepseek import deepseek
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _PROMPTS = flags.DEFINE_multi_string(
30
+ "prompts",
31
+ "What is the meaning of life?",
32
+ "The input prompts to generate answers.",
33
+ )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
39
+
40
+
41
+ def main(_):
42
+ checkpoint = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
46
+ # Locate the cached dir.
47
+ cached_config_file = transformers.utils.cached_file(
48
+ checkpoint, transformers.utils.CONFIG_NAME
49
+ )
50
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
+ reauthored_model = deepseek.build_model(reauthored_checkpoint)
53
+
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
55
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56
+
57
+ verifier.verify_reauthored_model(
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
63
+ generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
65
+ atol=1e-04,
66
+ )
67
+
68
+
69
+ if __name__ == "__main__":
70
+ app.run(main)
@@ -85,6 +85,7 @@ def convert_stable_diffusion_to_tflite(
85
85
  clip.TENSOR_NAMES,
86
86
  )
87
87
  loader.load(clip_model, strict=False)
88
+ clip_model.eval()
88
89
 
89
90
  diffusion_model = diffusion.Diffusion(
90
91
  diffusion.get_model_config(batch_size=2, device_type=_DEVICE_TYPE.value)
@@ -93,6 +94,7 @@ def convert_stable_diffusion_to_tflite(
93
94
  diffusion_ckpt_path, diffusion.TENSOR_NAMES
94
95
  )
95
96
  diffusion_loader.load(diffusion_model, strict=False)
97
+ diffusion_model.eval()
96
98
 
97
99
  decoder_model = decoder.Decoder(
98
100
  decoder.get_model_config(device_type=_DEVICE_TYPE.value)
@@ -101,6 +103,7 @@ def convert_stable_diffusion_to_tflite(
101
103
  decoder_ckpt_path, decoder.TENSOR_NAMES
102
104
  )
103
105
  decoder_loader.load(decoder_model, strict=False)
106
+ decoder_model.eval()
104
107
 
105
108
  # TODO(yichunk): enable image encoder conversion
106
109
  # if encoder_ckpt_path is not None:
@@ -0,0 +1,14 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #