ai-edge-torch-nightly 0.3.0.dev20240913__py3-none-any.whl → 0.3.0.dev20240915__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
  4. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
  5. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
  6. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
  7. ai_edge_torch/config.py +4 -1
  8. ai_edge_torch/fx_pass_base.py +101 -0
  9. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +4 -4
  10. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -4
  11. ai_edge_torch/generative/examples/gemma/gemma.py +2 -2
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -2
  13. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +86 -0
  14. ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
  15. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +4 -4
  16. ai_edge_torch/generative/examples/phi/phi2.py +2 -2
  17. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  18. ai_edge_torch/generative/examples/{smallm → smollm}/convert_to_tflite.py +12 -12
  19. ai_edge_torch/generative/examples/{smallm/smallm.py → smollm/smollm.py} +24 -15
  20. ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
  21. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
  22. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
  23. ai_edge_torch/generative/examples/t5/t5.py +8 -8
  24. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  25. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -4
  26. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
  27. ai_edge_torch/generative/fx_passes/__init__.py +4 -4
  28. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
  29. ai_edge_torch/generative/layers/attention.py +7 -0
  30. ai_edge_torch/generative/layers/builder.py +33 -11
  31. ai_edge_torch/generative/layers/feed_forward.py +26 -8
  32. ai_edge_torch/generative/layers/kv_cache.py +4 -4
  33. ai_edge_torch/generative/layers/model_config.py +24 -15
  34. ai_edge_torch/generative/quantize/example.py +2 -2
  35. ai_edge_torch/generative/test/test_model_conversion.py +28 -51
  36. ai_edge_torch/generative/test/test_model_conversion_large.py +43 -78
  37. ai_edge_torch/generative/test/test_quantize.py +5 -5
  38. ai_edge_torch/generative/utilities/loader.py +13 -0
  39. ai_edge_torch/odml_torch/export.py +40 -0
  40. ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
  41. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  42. ai_edge_torch/version.py +1 -1
  43. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/METADATA +1 -1
  44. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/RECORD +48 -46
  45. ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
  46. ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
  47. /ai_edge_torch/generative/examples/{smallm → openelm}/__init__.py +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/LICENSE +0 -0
  49. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/WHEEL +0 -0
  50. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ import logging
17
17
  import os
18
18
  from typing import Any, Optional
19
19
 
20
+ from ai_edge_torch import fx_pass_base
20
21
  from ai_edge_torch import lowertools
21
22
  from ai_edge_torch import model
22
23
  from ai_edge_torch._convert import fx_passes
@@ -34,7 +35,7 @@ def _run_convert_passes(
34
35
  exported_program = generative_fx_passes.run_generative_passes(
35
36
  exported_program
36
37
  )
37
- return fx_passes.run_passes(
38
+ return fx_pass_base.run_passes(
38
39
  exported_program,
39
40
  [
40
41
  fx_passes.BuildInterpolateCompositePass(),
@@ -15,44 +15,8 @@
15
15
 
16
16
  from typing import Sequence, Union
17
17
 
18
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
19
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
20
- from ai_edge_torch._convert.fx_passes._pass_base import FxPassBase
21
- from ai_edge_torch._convert.fx_passes._pass_base import FxPassResult
22
- from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA
23
- from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA
24
- from ai_edge_torch._convert.fx_passes.canonicalize_pass import CanonicalizePass
25
- from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
26
- from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
27
- from torch.export import ExportedProgram
28
- from torch.fx.passes.infra.pass_manager import pass_result_wrapper
29
- import torch.utils._pytree as pytree
30
-
31
-
32
- # TODO(cnchan): make a PassManager class.
33
- def run_passes(
34
- exported_program: ExportedProgram,
35
- passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
36
- ) -> ExportedProgram:
37
- passes, _ = pytree.tree_flatten(passes)
38
- for pass_ in passes:
39
- if not isinstance(pass_, ExportedProgramPassBase):
40
- pass_ = pass_result_wrapper(pass_)
41
- if isinstance(pass_, ExportedProgramPassBase):
42
- exported_program = pass_(exported_program).exported_program
43
- else:
44
- gm = exported_program.graph_module
45
- gm, modified = pass_(gm)
46
- if modified and gm is not exported_program.graph_module:
47
- exported_program = ExportedProgram(
48
- root=gm,
49
- graph=gm.graph,
50
- graph_signature=exported_program.graph_signature,
51
- state_dict=exported_program.state_dict,
52
- range_constraints=exported_program.range_constraints,
53
- module_call_graph=exported_program.module_call_graph,
54
- example_inputs=exported_program.example_inputs,
55
- verifier=exported_program.verifier,
56
- constants=exported_program.constants,
57
- )
58
- return exported_program
18
+ from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
19
+ from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
20
+ from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
21
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
22
+ from ai_edge_torch.fx_pass_base import CanonicalizePass
@@ -13,11 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from functools import reduce
17
16
  from typing import Any, Callable
17
+ from ai_edge_torch import fx_pass_base
18
18
  from ai_edge_torch import lowertools
19
19
  import torch
20
- from torch.fx.passes.infra import pass_base
21
20
  import torch.utils._pytree as pytree
22
21
 
23
22
  _composite_builders: dict[
@@ -277,7 +276,7 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
277
276
  node.target = embedding
278
277
 
279
278
 
280
- class BuildAtenCompositePass(pass_base.PassBase):
279
+ class BuildAtenCompositePass(fx_pass_base.PassBase):
281
280
 
282
281
  def call(self, graph_module: torch.fx.GraphModule):
283
282
  for node in graph_module.graph.nodes:
@@ -286,4 +285,4 @@ class BuildAtenCompositePass(pass_base.PassBase):
286
285
 
287
286
  graph_module.graph.lint()
288
287
  graph_module.recompile()
289
- return pass_base.PassResult(graph_module, True)
288
+ return fx_pass_base.PassResult(graph_module, True)
@@ -16,8 +16,7 @@
16
16
 
17
17
  import functools
18
18
 
19
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
20
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
19
+ from ai_edge_torch import fx_pass_base
21
20
  from ai_edge_torch.hlfb import mark_pattern
22
21
  from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
23
22
  import torch
@@ -103,7 +102,7 @@ def _get_interpolate_nearest2d_pattern():
103
102
  return pattern
104
103
 
105
104
 
106
- class BuildInterpolateCompositePass(ExportedProgramPassBase):
105
+ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
107
106
 
108
107
  def __init__(self):
109
108
  super().__init__()
@@ -124,4 +123,4 @@ class BuildInterpolateCompositePass(ExportedProgramPassBase):
124
123
 
125
124
  graph_module.graph.lint()
126
125
  graph_module.recompile()
127
- return ExportedProgramPassResult(exported_program, True)
126
+ return fx_pass_base.ExportedProgramPassResult(exported_program, True)
@@ -13,10 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from ai_edge_torch import fx_pass_base
16
17
  from ai_edge_torch import lowertools
17
18
  import torch
18
- from torch.fx.passes.infra.pass_base import PassBase
19
- from torch.fx.passes.infra.pass_base import PassResult
20
19
  import torch.utils._pytree as pytree
21
20
 
22
21
 
@@ -62,7 +61,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
62
61
  node.target = debuginfo_writer
63
62
 
64
63
 
65
- class InjectMlirDebuginfoPass(PassBase):
64
+ class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
66
65
 
67
66
  def call(self, graph_module: torch.fx.GraphModule):
68
67
  for node in graph_module.graph.nodes:
@@ -70,4 +69,4 @@ class InjectMlirDebuginfoPass(PassBase):
70
69
 
71
70
  graph_module.graph.lint()
72
71
  graph_module.recompile()
73
- return PassResult(graph_module, True)
72
+ return fx_pass_base.PassResult(graph_module, True)
@@ -18,8 +18,7 @@ import operator
18
18
  import os
19
19
  from typing import Union
20
20
 
21
- from ai_edge_torch._convert.fx_passes import ExportedProgramPassBase
22
- from ai_edge_torch._convert.fx_passes import ExportedProgramPassResult
21
+ from ai_edge_torch import fx_pass_base
23
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
24
23
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
25
24
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
@@ -31,7 +30,7 @@ import torch.ao.quantization.quantize_pt2e
31
30
  TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
32
31
 
33
32
 
34
- class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
33
+ class OptimizeLayoutTransposesPass(fx_pass_base.ExportedProgramPassBase):
35
34
 
36
35
  def get_source_meta(self, node: torch.fx.Node):
37
36
  keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
@@ -94,7 +93,7 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
94
93
 
95
94
  q_args = input_q.args[1:]
96
95
  q_kwargs = input_q.kwargs
97
- q_op, dq_op = self.get_paired_q_dq_ops(input_q.target)
96
+ q_op, dq_op = utils.get_paired_q_dq_ops(input_q.target)
98
97
  with graph.inserting_before(target):
99
98
  # Q and DQ inserted here may required updating the `axis` arg when they
100
99
  # are per_channel ops. However, instead of updating here, the nodes would
@@ -301,4 +300,4 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
301
300
  # Mark const node again for debugging
302
301
  self.mark_const_nodes(exported_program)
303
302
 
304
- return ExportedProgramPassResult(exported_program, True)
303
+ return fx_pass_base.ExportedProgramPassResult(exported_program, True)
ai_edge_torch/config.py CHANGED
@@ -21,4 +21,7 @@ import os
21
21
 
22
22
  @dataclasses.dataclass
23
23
  class Config:
24
- use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "True") == "True"
24
+ use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "true").lower() in (
25
+ "1",
26
+ "true",
27
+ )
@@ -0,0 +1,101 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import abc
17
+ import collections
18
+ from typing import Sequence, Union
19
+
20
+ import torch
21
+ from torch.fx.passes.infra.pass_base import PassBase
22
+ from torch.fx.passes.infra.pass_base import PassResult
23
+ from torch.fx.passes.infra.pass_manager import pass_result_wrapper
24
+ import torch.utils._pytree as pytree
25
+
26
+ FxPassBase = PassBase
27
+ FxPassResult = PassResult
28
+ ExportedProgramPassResult = collections.namedtuple(
29
+ "ExportedProgramPassResult", ["exported_program", "modified"]
30
+ )
31
+
32
+
33
+ class ExportedProgramPassBase(abc.ABC):
34
+
35
+ def __call__(
36
+ self, exported_program: torch.export.ExportedProgram
37
+ ) -> ExportedProgramPassResult:
38
+ self.requires(exported_program)
39
+ res = self.call(exported_program)
40
+ self.ensures(exported_program)
41
+ return res
42
+
43
+ @abc.abstractmethod
44
+ def call(
45
+ self, exported_program: torch.export.ExportedProgram
46
+ ) -> ExportedProgramPassResult:
47
+ pass
48
+
49
+ def requires(self, exported_program: torch.export.ExportedProgram) -> None:
50
+ pass
51
+
52
+ def ensures(self, exported_program: torch.export.ExportedProgram) -> None:
53
+ pass
54
+
55
+
56
+ # TODO(cnchan): make a PassManager class.
57
+ def run_passes(
58
+ exported_program: torch.export.ExportedProgram,
59
+ passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
60
+ ) -> torch.export.ExportedProgram:
61
+ passes, _ = pytree.tree_flatten(passes)
62
+ for pass_ in passes:
63
+ if not isinstance(pass_, ExportedProgramPassBase):
64
+ pass_ = pass_result_wrapper(pass_)
65
+ if isinstance(pass_, ExportedProgramPassBase):
66
+ exported_program = pass_(exported_program).exported_program
67
+ else:
68
+ gm = exported_program.graph_module
69
+ gm, modified = pass_(gm)
70
+ if modified and gm is not exported_program.graph_module:
71
+ exported_program = torch.export.ExportedProgram(
72
+ root=gm,
73
+ graph=gm.graph,
74
+ graph_signature=exported_program.graph_signature,
75
+ state_dict=exported_program.state_dict,
76
+ range_constraints=exported_program.range_constraints,
77
+ module_call_graph=exported_program.module_call_graph,
78
+ example_inputs=exported_program.example_inputs,
79
+ verifier=exported_program.verifier,
80
+ constants=exported_program.constants,
81
+ )
82
+ return exported_program
83
+
84
+
85
+ class CanonicalizePass(ExportedProgramPassBase):
86
+
87
+ # A dummy decomp table for running ExportedProgram.run_decompositions without
88
+ # any op decompositions but just aot_export_module. Due to the check in
89
+ # run_decompositions, if None or an empty dict is passed as decomp_table,
90
+ # it will run the default aten-coreaten decompositions. Therefore a non-empty
91
+ # dummy decomp table is needed.
92
+ # Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
93
+ _DUMMY_DECOMP_TABLE = {
94
+ torch._ops.OperatorBase(): lambda: None,
95
+ }
96
+
97
+ def call(self, exported_program: torch.export.ExportedProgram):
98
+ exported_program = exported_program.run_decompositions(
99
+ self._DUMMY_DECOMP_TABLE
100
+ )
101
+ return ExportedProgramPassResult(exported_program, True)
@@ -47,10 +47,10 @@ def convert_gemma2_to_tflite(
47
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
48
48
  )
49
49
  # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len)
52
- decode_token = torch.tensor([[0]], dtype=torch.long)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
54
  kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
55
 
56
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
@@ -47,10 +47,10 @@ def convert_gemma_to_tflite(
47
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
48
48
  )
49
49
  # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len)
52
- decode_token = torch.tensor([[0]], dtype=torch.long)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
54
  kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
55
 
56
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
@@ -203,9 +203,9 @@ def define_and_run_2b(checkpoint_path: str) -> None:
203
203
  kv_cache_max_len = 1024
204
204
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
205
205
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
206
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
206
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
207
207
  tokens[0, :4] = idx
208
- input_pos = torch.arange(0, kv_cache_max_len)
208
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
209
209
  kv = kv_utils.KVCache.from_model_config(model.config)
210
210
  output = model.forward(tokens, input_pos, kv)
211
211
  print("comparing with goldens..")
@@ -280,9 +280,9 @@ def define_and_run_2b(checkpoint_path: str) -> None:
280
280
  toks = torch.from_numpy(
281
281
  np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
282
282
  )
283
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
283
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
284
284
  tokens[0, :9] = toks
285
- input_pos = torch.arange(0, kv_cache_max_len)
285
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
286
286
  kv = kv_utils.KVCache.from_model_config(model.config)
287
287
  out = model.forward(tokens, input_pos, kv)
288
288
  out_final = out["logits"][0, 8, :]
@@ -0,0 +1,86 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting OpenELM model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ import ai_edge_torch
22
+ from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ from ai_edge_torch.generative.quantize import quant_recipes
25
+ import torch
26
+
27
+
28
+ def convert_openelm_to_tflite(
29
+ checkpoint_path: str,
30
+ prefill_seq_len: int = 512,
31
+ kv_cache_max_len: int = 1024,
32
+ quantize: bool = True,
33
+ ):
34
+ """Converts OpenELM model to multi-signature tflite model.
35
+
36
+ Args:
37
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
38
+ holding the checkpoint.
39
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
+ Defaults to 512.
41
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
+ including both prefill and decode. Defaults to 1024.
43
+ quantize (bool, optional): Whether the model should be quanized. Defaults
44
+ to True.
45
+ """
46
+ pytorch_model = openelm.build_model(
47
+ checkpoint_path, kv_cache_max_len=kv_cache_max_len
48
+ )
49
+ # Tensors used to trace the model graph during conversion.
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
+
56
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
+ edge_model = (
58
+ ai_edge_torch.signature(
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
75
+ )
76
+ .convert(quant_config=quant_config)
77
+ )
78
+ quant_suffix = 'q8' if quantize else 'f32'
79
+ edge_model.export(
80
+ f'/tmp/openelm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
81
+ )
82
+
83
+
84
+ if __name__ == '__main__':
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm')
86
+ convert_openelm_to_tflite(path)
@@ -0,0 +1,237 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building an OpenELM model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from ai_edge_torch.generative.layers import attention
22
+ from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
+ import ai_edge_torch.generative.layers.model_config as cfg
26
+ import ai_edge_torch.generative.utilities.loader as loading_utils
27
+ import numpy as np
28
+ import torch
29
+ from torch import nn
30
+
31
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
32
+ ff_up_proj="transformer.layers.{}.ffn.proj_1",
33
+ ff_down_proj="transformer.layers.{}.ffn.proj_2",
34
+ attn_fused_qkv_proj="transformer.layers.{}.attn.qkv_proj",
35
+ attn_query_norm="transformer.layers.{}.attn.q_norm",
36
+ attn_key_norm="transformer.layers.{}.attn.k_norm",
37
+ attn_output_proj="transformer.layers.{}.attn.out_proj",
38
+ pre_attn_norm="transformer.layers.{}.attn_norm",
39
+ pre_ff_norm="transformer.layers.{}.ffn_norm",
40
+ embedding="transformer.token_embeddings",
41
+ final_norm="transformer.norm",
42
+ lm_head=None,
43
+ )
44
+
45
+
46
+ class OpenELM(nn.Module):
47
+ """An OpenELM model built from the Edge Generative API layers."""
48
+
49
+ def __init__(self, config: cfg.ModelConfig):
50
+ super().__init__()
51
+
52
+ # Construct model layers.
53
+ self.tok_embedding = nn.Embedding(
54
+ config.vocab_size, config.embedding_dim, padding_idx=0
55
+ )
56
+ self.lm_head = nn.Linear(
57
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
58
+ )
59
+ # OpenELM re-uses the embedding as the head projection layer.
60
+ self.lm_head.weight.data = self.tok_embedding.weight.data
61
+ self.transformer_blocks = nn.ModuleList(
62
+ attention.TransformerBlock(config.block_config(idx), config)
63
+ for idx in range(config.num_layers)
64
+ )
65
+ self.final_norm = builder.build_norm(
66
+ config.embedding_dim,
67
+ config.final_norm_config,
68
+ )
69
+ # OpenELM has same hyper parameters for rotary_percentage and head_dim for
70
+ # each layer block. Use the first block.
71
+ attn_config = config.block_config(0).attn_config
72
+ self.rope_cache = attn_utils.build_rope_cache(
73
+ size=config.kv_cache_max,
74
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
75
+ base=10_000,
76
+ condense_ratio=1,
77
+ dtype=torch.float32,
78
+ device=torch.device("cpu"),
79
+ )
80
+ self.mask_cache = attn_utils.build_causal_mask_cache(
81
+ size=config.kv_cache_max,
82
+ dtype=torch.float32,
83
+ device=torch.device("cpu"),
84
+ )
85
+ self.config = config
86
+
87
+ @torch.inference_mode
88
+ def forward(
89
+ self,
90
+ tokens: torch.Tensor,
91
+ input_pos: torch.Tensor,
92
+ kv_cache: kv_utils.KVCache,
93
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
94
+ _, seq_len = tokens.size()
95
+ assert self.config.max_seq_len >= seq_len, (
96
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
97
+ f" {self.config.max_seq_len}"
98
+ )
99
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
100
+ "The number of transformer blocks and the number of KV cache entries"
101
+ " must be the same."
102
+ )
103
+
104
+ cos, sin = self.rope_cache
105
+ cos = cos.index_select(0, input_pos)
106
+ sin = sin.index_select(0, input_pos)
107
+ mask = self.mask_cache.index_select(2, input_pos)
108
+ mask = mask[:, :, :, : self.config.kv_cache_max]
109
+
110
+ # token embeddings of shape (b, t, n_embd)
111
+ x = self.tok_embedding(tokens)
112
+
113
+ updated_kv_entires = []
114
+ for i, block in enumerate(self.transformer_blocks):
115
+ kv_entry = kv_cache.caches[i] if kv_cache else None
116
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
117
+ if kv_entry:
118
+ updated_kv_entires.append(kv_entry)
119
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
120
+
121
+ x = self.final_norm(x)
122
+ logits = self.lm_head(x) # (b, t, vocab_size)
123
+ return {"logits": logits, "kv_cache": updated_kv_cache}
124
+
125
+
126
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
127
+ """Returns the model config for an OpenELM model.
128
+
129
+ Args:
130
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
131
+ is 1024.
132
+
133
+ Returns:
134
+ The model config for an OpenELM model.
135
+ """
136
+ norm_config = cfg.NormalizationConfig(
137
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
138
+ )
139
+ num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
140
+ num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
141
+
142
+ def make_divisible(v, d):
143
+ """Ensures that all layers have a channel number that is divisible by d."""
144
+ new_v = int(v + d / 2) // d * d
145
+ # Make sure that round down does not go down by more than 10%.
146
+ if new_v < 0.9 * v:
147
+ new_v += d
148
+ return new_v
149
+
150
+ # The way to get intermediate size is from
151
+ # https://huggingface.co/apple/OpenELM-3B/blob/main/modeling_openelm.py
152
+ def get_intermediate_size(idx: int) -> int:
153
+ return make_divisible((0.5 + 0.1 * idx) * 3072, 256)
154
+
155
+ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
156
+ return cfg.TransformerBlockConfig(
157
+ attn_config=cfg.AttentionConfig(
158
+ num_heads=num_heads[idx],
159
+ head_dim=128,
160
+ num_query_groups=num_query_groups[idx],
161
+ rotary_percentage=1.0,
162
+ qkv_transpose_before_split=True,
163
+ query_norm_config=norm_config,
164
+ key_norm_config=norm_config,
165
+ ),
166
+ ff_config=cfg.FeedForwardConfig(
167
+ type=cfg.FeedForwardType.SEQUENTIAL,
168
+ activation=cfg.ActivationConfig(
169
+ cfg.ActivationType.SILU_GLU, gate_is_front=True
170
+ ),
171
+ intermediate_size=get_intermediate_size(idx),
172
+ pre_ff_norm_config=norm_config,
173
+ ),
174
+ pre_attention_norm_config=norm_config,
175
+ )
176
+
177
+ num_layers = 36
178
+ config = cfg.ModelConfig(
179
+ vocab_size=32000,
180
+ num_layers=num_layers,
181
+ max_seq_len=2048,
182
+ embedding_dim=3072,
183
+ kv_cache_max_len=kv_cache_max_len,
184
+ block_configs=[get_block_config(i) for i in range(num_layers)],
185
+ final_norm_config=norm_config,
186
+ )
187
+ return config
188
+
189
+
190
+ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
191
+ config = get_model_config(kv_cache_max_len)
192
+ config.vocab_size = 128
193
+ config.num_layers = 2
194
+ config.max_seq_len = 2 * kv_cache_max_len
195
+ config.embedding_dim = 128
196
+ config.block_configs = config.block_configs[: config.num_layers]
197
+ for block_config in config.block_configs:
198
+ block_config.attn_config.num_heads = 3
199
+ block_config.attn_config.head_dim = 64
200
+ block_config.ff_config.intermediate_size = 128
201
+ return config
202
+
203
+
204
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
205
+ config = get_model_config(**kwargs)
206
+ model = OpenELM(config)
207
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
208
+ # Since embedding and lm-head use the same weight, we need to set strict
209
+ # to False.
210
+ loader.load(model, strict=False)
211
+ model.eval()
212
+ return model
213
+
214
+
215
+ def define_and_run(checkpoint_path: str) -> None:
216
+ """Instantiates and runs an OpenELM model."""
217
+
218
+ current_dir = pathlib.Path(__file__).parent.resolve()
219
+ openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt")
220
+ kv_cache_max_len = 1024
221
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
222
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
223
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
224
+ tokens[0, :4] = idx
225
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
226
+ kv = kv_utils.KVCache.from_model_config(model.config)
227
+ output = model.forward(tokens, input_pos, kv)
228
+ assert torch.allclose(
229
+ openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
230
+ )
231
+
232
+
233
+ if __name__ == "__main__":
234
+ input_checkpoint_path = os.path.join(
235
+ pathlib.Path.home(), "Downloads/llm_data/openelm"
236
+ )
237
+ define_and_run(input_checkpoint_path)