ai-edge-torch-nightly 0.3.0.dev20240913__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (50) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
  4. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
  5. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
  6. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
  7. ai_edge_torch/config.py +4 -1
  8. ai_edge_torch/fx_pass_base.py +101 -0
  9. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +4 -4
  10. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -4
  11. ai_edge_torch/generative/examples/gemma/gemma.py +2 -2
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -2
  13. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +86 -0
  14. ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
  15. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +4 -4
  16. ai_edge_torch/generative/examples/phi/phi2.py +2 -2
  17. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  18. ai_edge_torch/generative/examples/{smallm → smollm}/convert_to_tflite.py +12 -12
  19. ai_edge_torch/generative/examples/{smallm/smallm.py → smollm/smollm.py} +24 -15
  20. ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
  21. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
  22. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
  23. ai_edge_torch/generative/examples/t5/t5.py +8 -8
  24. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  25. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -4
  26. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
  27. ai_edge_torch/generative/fx_passes/__init__.py +4 -4
  28. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
  29. ai_edge_torch/generative/layers/attention.py +7 -0
  30. ai_edge_torch/generative/layers/builder.py +33 -11
  31. ai_edge_torch/generative/layers/feed_forward.py +26 -8
  32. ai_edge_torch/generative/layers/kv_cache.py +4 -4
  33. ai_edge_torch/generative/layers/model_config.py +24 -15
  34. ai_edge_torch/generative/quantize/example.py +2 -2
  35. ai_edge_torch/generative/test/test_model_conversion.py +28 -51
  36. ai_edge_torch/generative/test/test_model_conversion_large.py +43 -78
  37. ai_edge_torch/generative/test/test_quantize.py +5 -5
  38. ai_edge_torch/generative/utilities/loader.py +13 -0
  39. ai_edge_torch/odml_torch/export.py +40 -0
  40. ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
  41. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  42. ai_edge_torch/version.py +1 -1
  43. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
  44. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +48 -46
  45. ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
  46. ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
  47. /ai_edge_torch/generative/examples/{smallm → openelm}/__init__.py +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
  49. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
  50. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ import logging
17
17
  import os
18
18
  from typing import Any, Optional
19
19
 
20
+ from ai_edge_torch import fx_pass_base
20
21
  from ai_edge_torch import lowertools
21
22
  from ai_edge_torch import model
22
23
  from ai_edge_torch._convert import fx_passes
@@ -34,7 +35,7 @@ def _run_convert_passes(
34
35
  exported_program = generative_fx_passes.run_generative_passes(
35
36
  exported_program
36
37
  )
37
- return fx_passes.run_passes(
38
+ return fx_pass_base.run_passes(
38
39
  exported_program,
39
40
  [
40
41
  fx_passes.BuildInterpolateCompositePass(),
@@ -15,44 +15,8 @@
15
15
 
16
16
  from typing import Sequence, Union
17
17
 
18
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
19
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
20
- from ai_edge_torch._convert.fx_passes._pass_base import FxPassBase
21
- from ai_edge_torch._convert.fx_passes._pass_base import FxPassResult
22
- from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA
23
- from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA
24
- from ai_edge_torch._convert.fx_passes.canonicalize_pass import CanonicalizePass
25
- from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
26
- from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
27
- from torch.export import ExportedProgram
28
- from torch.fx.passes.infra.pass_manager import pass_result_wrapper
29
- import torch.utils._pytree as pytree
30
-
31
-
32
- # TODO(cnchan): make a PassManager class.
33
- def run_passes(
34
- exported_program: ExportedProgram,
35
- passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
36
- ) -> ExportedProgram:
37
- passes, _ = pytree.tree_flatten(passes)
38
- for pass_ in passes:
39
- if not isinstance(pass_, ExportedProgramPassBase):
40
- pass_ = pass_result_wrapper(pass_)
41
- if isinstance(pass_, ExportedProgramPassBase):
42
- exported_program = pass_(exported_program).exported_program
43
- else:
44
- gm = exported_program.graph_module
45
- gm, modified = pass_(gm)
46
- if modified and gm is not exported_program.graph_module:
47
- exported_program = ExportedProgram(
48
- root=gm,
49
- graph=gm.graph,
50
- graph_signature=exported_program.graph_signature,
51
- state_dict=exported_program.state_dict,
52
- range_constraints=exported_program.range_constraints,
53
- module_call_graph=exported_program.module_call_graph,
54
- example_inputs=exported_program.example_inputs,
55
- verifier=exported_program.verifier,
56
- constants=exported_program.constants,
57
- )
58
- return exported_program
18
+ from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
19
+ from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
20
+ from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
21
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
22
+ from ai_edge_torch.fx_pass_base import CanonicalizePass
@@ -13,11 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from functools import reduce
17
16
  from typing import Any, Callable
17
+ from ai_edge_torch import fx_pass_base
18
18
  from ai_edge_torch import lowertools
19
19
  import torch
20
- from torch.fx.passes.infra import pass_base
21
20
  import torch.utils._pytree as pytree
22
21
 
23
22
  _composite_builders: dict[
@@ -277,7 +276,7 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
277
276
  node.target = embedding
278
277
 
279
278
 
280
- class BuildAtenCompositePass(pass_base.PassBase):
279
+ class BuildAtenCompositePass(fx_pass_base.PassBase):
281
280
 
282
281
  def call(self, graph_module: torch.fx.GraphModule):
283
282
  for node in graph_module.graph.nodes:
@@ -286,4 +285,4 @@ class BuildAtenCompositePass(pass_base.PassBase):
286
285
 
287
286
  graph_module.graph.lint()
288
287
  graph_module.recompile()
289
- return pass_base.PassResult(graph_module, True)
288
+ return fx_pass_base.PassResult(graph_module, True)
@@ -16,8 +16,7 @@
16
16
 
17
17
  import functools
18
18
 
19
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
20
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
19
+ from ai_edge_torch import fx_pass_base
21
20
  from ai_edge_torch.hlfb import mark_pattern
22
21
  from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
23
22
  import torch
@@ -103,7 +102,7 @@ def _get_interpolate_nearest2d_pattern():
103
102
  return pattern
104
103
 
105
104
 
106
- class BuildInterpolateCompositePass(ExportedProgramPassBase):
105
+ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
107
106
 
108
107
  def __init__(self):
109
108
  super().__init__()
@@ -124,4 +123,4 @@ class BuildInterpolateCompositePass(ExportedProgramPassBase):
124
123
 
125
124
  graph_module.graph.lint()
126
125
  graph_module.recompile()
127
- return ExportedProgramPassResult(exported_program, True)
126
+ return fx_pass_base.ExportedProgramPassResult(exported_program, True)
@@ -13,10 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from ai_edge_torch import fx_pass_base
16
17
  from ai_edge_torch import lowertools
17
18
  import torch
18
- from torch.fx.passes.infra.pass_base import PassBase
19
- from torch.fx.passes.infra.pass_base import PassResult
20
19
  import torch.utils._pytree as pytree
21
20
 
22
21
 
@@ -62,7 +61,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
62
61
  node.target = debuginfo_writer
63
62
 
64
63
 
65
- class InjectMlirDebuginfoPass(PassBase):
64
+ class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
66
65
 
67
66
  def call(self, graph_module: torch.fx.GraphModule):
68
67
  for node in graph_module.graph.nodes:
@@ -70,4 +69,4 @@ class InjectMlirDebuginfoPass(PassBase):
70
69
 
71
70
  graph_module.graph.lint()
72
71
  graph_module.recompile()
73
- return PassResult(graph_module, True)
72
+ return fx_pass_base.PassResult(graph_module, True)
@@ -18,8 +18,7 @@ import operator
18
18
  import os
19
19
  from typing import Union
20
20
 
21
- from ai_edge_torch._convert.fx_passes import ExportedProgramPassBase
22
- from ai_edge_torch._convert.fx_passes import ExportedProgramPassResult
21
+ from ai_edge_torch import fx_pass_base
23
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
24
23
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
25
24
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
@@ -31,7 +30,7 @@ import torch.ao.quantization.quantize_pt2e
31
30
  TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
32
31
 
33
32
 
34
- class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
33
+ class OptimizeLayoutTransposesPass(fx_pass_base.ExportedProgramPassBase):
35
34
 
36
35
  def get_source_meta(self, node: torch.fx.Node):
37
36
  keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
@@ -94,7 +93,7 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
94
93
 
95
94
  q_args = input_q.args[1:]
96
95
  q_kwargs = input_q.kwargs
97
- q_op, dq_op = self.get_paired_q_dq_ops(input_q.target)
96
+ q_op, dq_op = utils.get_paired_q_dq_ops(input_q.target)
98
97
  with graph.inserting_before(target):
99
98
  # Q and DQ inserted here may required updating the `axis` arg when they
100
99
  # are per_channel ops. However, instead of updating here, the nodes would
@@ -301,4 +300,4 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
301
300
  # Mark const node again for debugging
302
301
  self.mark_const_nodes(exported_program)
303
302
 
304
- return ExportedProgramPassResult(exported_program, True)
303
+ return fx_pass_base.ExportedProgramPassResult(exported_program, True)
ai_edge_torch/config.py CHANGED
@@ -21,4 +21,7 @@ import os
21
21
 
22
22
  @dataclasses.dataclass
23
23
  class Config:
24
- use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "True") == "True"
24
+ use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "true").lower() in (
25
+ "1",
26
+ "true",
27
+ )
@@ -0,0 +1,101 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import abc
17
+ import collections
18
+ from typing import Sequence, Union
19
+
20
+ import torch
21
+ from torch.fx.passes.infra.pass_base import PassBase
22
+ from torch.fx.passes.infra.pass_base import PassResult
23
+ from torch.fx.passes.infra.pass_manager import pass_result_wrapper
24
+ import torch.utils._pytree as pytree
25
+
26
+ FxPassBase = PassBase
27
+ FxPassResult = PassResult
28
+ ExportedProgramPassResult = collections.namedtuple(
29
+ "ExportedProgramPassResult", ["exported_program", "modified"]
30
+ )
31
+
32
+
33
+ class ExportedProgramPassBase(abc.ABC):
34
+
35
+ def __call__(
36
+ self, exported_program: torch.export.ExportedProgram
37
+ ) -> ExportedProgramPassResult:
38
+ self.requires(exported_program)
39
+ res = self.call(exported_program)
40
+ self.ensures(exported_program)
41
+ return res
42
+
43
+ @abc.abstractmethod
44
+ def call(
45
+ self, exported_program: torch.export.ExportedProgram
46
+ ) -> ExportedProgramPassResult:
47
+ pass
48
+
49
+ def requires(self, exported_program: torch.export.ExportedProgram) -> None:
50
+ pass
51
+
52
+ def ensures(self, exported_program: torch.export.ExportedProgram) -> None:
53
+ pass
54
+
55
+
56
+ # TODO(cnchan): make a PassManager class.
57
+ def run_passes(
58
+ exported_program: torch.export.ExportedProgram,
59
+ passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
60
+ ) -> torch.export.ExportedProgram:
61
+ passes, _ = pytree.tree_flatten(passes)
62
+ for pass_ in passes:
63
+ if not isinstance(pass_, ExportedProgramPassBase):
64
+ pass_ = pass_result_wrapper(pass_)
65
+ if isinstance(pass_, ExportedProgramPassBase):
66
+ exported_program = pass_(exported_program).exported_program
67
+ else:
68
+ gm = exported_program.graph_module
69
+ gm, modified = pass_(gm)
70
+ if modified and gm is not exported_program.graph_module:
71
+ exported_program = torch.export.ExportedProgram(
72
+ root=gm,
73
+ graph=gm.graph,
74
+ graph_signature=exported_program.graph_signature,
75
+ state_dict=exported_program.state_dict,
76
+ range_constraints=exported_program.range_constraints,
77
+ module_call_graph=exported_program.module_call_graph,
78
+ example_inputs=exported_program.example_inputs,
79
+ verifier=exported_program.verifier,
80
+ constants=exported_program.constants,
81
+ )
82
+ return exported_program
83
+
84
+
85
+ class CanonicalizePass(ExportedProgramPassBase):
86
+
87
+ # A dummy decomp table for running ExportedProgram.run_decompositions without
88
+ # any op decompositions but just aot_export_module. Due to the check in
89
+ # run_decompositions, if None or an empty dict is passed as decomp_table,
90
+ # it will run the default aten-coreaten decompositions. Therefore a non-empty
91
+ # dummy decomp table is needed.
92
+ # Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
93
+ _DUMMY_DECOMP_TABLE = {
94
+ torch._ops.OperatorBase(): lambda: None,
95
+ }
96
+
97
+ def call(self, exported_program: torch.export.ExportedProgram):
98
+ exported_program = exported_program.run_decompositions(
99
+ self._DUMMY_DECOMP_TABLE
100
+ )
101
+ return ExportedProgramPassResult(exported_program, True)
@@ -47,10 +47,10 @@ def convert_gemma2_to_tflite(
47
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
48
48
  )
49
49
  # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len)
52
- decode_token = torch.tensor([[0]], dtype=torch.long)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
54
  kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
55
 
56
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
@@ -47,10 +47,10 @@ def convert_gemma_to_tflite(
47
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
48
48
  )
49
49
  # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len)
52
- decode_token = torch.tensor([[0]], dtype=torch.long)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
54
  kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
55
 
56
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
@@ -203,9 +203,9 @@ def define_and_run_2b(checkpoint_path: str) -> None:
203
203
  kv_cache_max_len = 1024
204
204
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
205
205
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
206
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
206
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
207
207
  tokens[0, :4] = idx
208
- input_pos = torch.arange(0, kv_cache_max_len)
208
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
209
209
  kv = kv_utils.KVCache.from_model_config(model.config)
210
210
  output = model.forward(tokens, input_pos, kv)
211
211
  print("comparing with goldens..")
@@ -280,9 +280,9 @@ def define_and_run_2b(checkpoint_path: str) -> None:
280
280
  toks = torch.from_numpy(
281
281
  np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
282
282
  )
283
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
283
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
284
284
  tokens[0, :9] = toks
285
- input_pos = torch.arange(0, kv_cache_max_len)
285
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
286
286
  kv = kv_utils.KVCache.from_model_config(model.config)
287
287
  out = model.forward(tokens, input_pos, kv)
288
288
  out_final = out["logits"][0, 8, :]
@@ -0,0 +1,86 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting OpenELM model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ import ai_edge_torch
22
+ from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ from ai_edge_torch.generative.quantize import quant_recipes
25
+ import torch
26
+
27
+
28
+ def convert_openelm_to_tflite(
29
+ checkpoint_path: str,
30
+ prefill_seq_len: int = 512,
31
+ kv_cache_max_len: int = 1024,
32
+ quantize: bool = True,
33
+ ):
34
+ """Converts OpenELM model to multi-signature tflite model.
35
+
36
+ Args:
37
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
38
+ holding the checkpoint.
39
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
+ Defaults to 512.
41
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
+ including both prefill and decode. Defaults to 1024.
43
+ quantize (bool, optional): Whether the model should be quanized. Defaults
44
+ to True.
45
+ """
46
+ pytorch_model = openelm.build_model(
47
+ checkpoint_path, kv_cache_max_len=kv_cache_max_len
48
+ )
49
+ # Tensors used to trace the model graph during conversion.
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
+
56
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
+ edge_model = (
58
+ ai_edge_torch.signature(
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
75
+ )
76
+ .convert(quant_config=quant_config)
77
+ )
78
+ quant_suffix = 'q8' if quantize else 'f32'
79
+ edge_model.export(
80
+ f'/tmp/openelm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
81
+ )
82
+
83
+
84
+ if __name__ == '__main__':
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm')
86
+ convert_openelm_to_tflite(path)
@@ -0,0 +1,237 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building an OpenELM model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from ai_edge_torch.generative.layers import attention
22
+ from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
+ import ai_edge_torch.generative.layers.model_config as cfg
26
+ import ai_edge_torch.generative.utilities.loader as loading_utils
27
+ import numpy as np
28
+ import torch
29
+ from torch import nn
30
+
31
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
32
+ ff_up_proj="transformer.layers.{}.ffn.proj_1",
33
+ ff_down_proj="transformer.layers.{}.ffn.proj_2",
34
+ attn_fused_qkv_proj="transformer.layers.{}.attn.qkv_proj",
35
+ attn_query_norm="transformer.layers.{}.attn.q_norm",
36
+ attn_key_norm="transformer.layers.{}.attn.k_norm",
37
+ attn_output_proj="transformer.layers.{}.attn.out_proj",
38
+ pre_attn_norm="transformer.layers.{}.attn_norm",
39
+ pre_ff_norm="transformer.layers.{}.ffn_norm",
40
+ embedding="transformer.token_embeddings",
41
+ final_norm="transformer.norm",
42
+ lm_head=None,
43
+ )
44
+
45
+
46
+ class OpenELM(nn.Module):
47
+ """An OpenELM model built from the Edge Generative API layers."""
48
+
49
+ def __init__(self, config: cfg.ModelConfig):
50
+ super().__init__()
51
+
52
+ # Construct model layers.
53
+ self.tok_embedding = nn.Embedding(
54
+ config.vocab_size, config.embedding_dim, padding_idx=0
55
+ )
56
+ self.lm_head = nn.Linear(
57
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
58
+ )
59
+ # OpenELM re-uses the embedding as the head projection layer.
60
+ self.lm_head.weight.data = self.tok_embedding.weight.data
61
+ self.transformer_blocks = nn.ModuleList(
62
+ attention.TransformerBlock(config.block_config(idx), config)
63
+ for idx in range(config.num_layers)
64
+ )
65
+ self.final_norm = builder.build_norm(
66
+ config.embedding_dim,
67
+ config.final_norm_config,
68
+ )
69
+ # OpenELM has same hyper parameters for rotary_percentage and head_dim for
70
+ # each layer block. Use the first block.
71
+ attn_config = config.block_config(0).attn_config
72
+ self.rope_cache = attn_utils.build_rope_cache(
73
+ size=config.kv_cache_max,
74
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
75
+ base=10_000,
76
+ condense_ratio=1,
77
+ dtype=torch.float32,
78
+ device=torch.device("cpu"),
79
+ )
80
+ self.mask_cache = attn_utils.build_causal_mask_cache(
81
+ size=config.kv_cache_max,
82
+ dtype=torch.float32,
83
+ device=torch.device("cpu"),
84
+ )
85
+ self.config = config
86
+
87
+ @torch.inference_mode
88
+ def forward(
89
+ self,
90
+ tokens: torch.Tensor,
91
+ input_pos: torch.Tensor,
92
+ kv_cache: kv_utils.KVCache,
93
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
94
+ _, seq_len = tokens.size()
95
+ assert self.config.max_seq_len >= seq_len, (
96
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
97
+ f" {self.config.max_seq_len}"
98
+ )
99
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
100
+ "The number of transformer blocks and the number of KV cache entries"
101
+ " must be the same."
102
+ )
103
+
104
+ cos, sin = self.rope_cache
105
+ cos = cos.index_select(0, input_pos)
106
+ sin = sin.index_select(0, input_pos)
107
+ mask = self.mask_cache.index_select(2, input_pos)
108
+ mask = mask[:, :, :, : self.config.kv_cache_max]
109
+
110
+ # token embeddings of shape (b, t, n_embd)
111
+ x = self.tok_embedding(tokens)
112
+
113
+ updated_kv_entires = []
114
+ for i, block in enumerate(self.transformer_blocks):
115
+ kv_entry = kv_cache.caches[i] if kv_cache else None
116
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
117
+ if kv_entry:
118
+ updated_kv_entires.append(kv_entry)
119
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
120
+
121
+ x = self.final_norm(x)
122
+ logits = self.lm_head(x) # (b, t, vocab_size)
123
+ return {"logits": logits, "kv_cache": updated_kv_cache}
124
+
125
+
126
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
127
+ """Returns the model config for an OpenELM model.
128
+
129
+ Args:
130
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
131
+ is 1024.
132
+
133
+ Returns:
134
+ The model config for an OpenELM model.
135
+ """
136
+ norm_config = cfg.NormalizationConfig(
137
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
138
+ )
139
+ num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
140
+ num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
141
+
142
+ def make_divisible(v, d):
143
+ """Ensures that all layers have a channel number that is divisible by d."""
144
+ new_v = int(v + d / 2) // d * d
145
+ # Make sure that round down does not go down by more than 10%.
146
+ if new_v < 0.9 * v:
147
+ new_v += d
148
+ return new_v
149
+
150
+ # The way to get intermediate size is from
151
+ # https://huggingface.co/apple/OpenELM-3B/blob/main/modeling_openelm.py
152
+ def get_intermediate_size(idx: int) -> int:
153
+ return make_divisible((0.5 + 0.1 * idx) * 3072, 256)
154
+
155
+ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
156
+ return cfg.TransformerBlockConfig(
157
+ attn_config=cfg.AttentionConfig(
158
+ num_heads=num_heads[idx],
159
+ head_dim=128,
160
+ num_query_groups=num_query_groups[idx],
161
+ rotary_percentage=1.0,
162
+ qkv_transpose_before_split=True,
163
+ query_norm_config=norm_config,
164
+ key_norm_config=norm_config,
165
+ ),
166
+ ff_config=cfg.FeedForwardConfig(
167
+ type=cfg.FeedForwardType.SEQUENTIAL,
168
+ activation=cfg.ActivationConfig(
169
+ cfg.ActivationType.SILU_GLU, gate_is_front=True
170
+ ),
171
+ intermediate_size=get_intermediate_size(idx),
172
+ pre_ff_norm_config=norm_config,
173
+ ),
174
+ pre_attention_norm_config=norm_config,
175
+ )
176
+
177
+ num_layers = 36
178
+ config = cfg.ModelConfig(
179
+ vocab_size=32000,
180
+ num_layers=num_layers,
181
+ max_seq_len=2048,
182
+ embedding_dim=3072,
183
+ kv_cache_max_len=kv_cache_max_len,
184
+ block_configs=[get_block_config(i) for i in range(num_layers)],
185
+ final_norm_config=norm_config,
186
+ )
187
+ return config
188
+
189
+
190
+ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
191
+ config = get_model_config(kv_cache_max_len)
192
+ config.vocab_size = 128
193
+ config.num_layers = 2
194
+ config.max_seq_len = 2 * kv_cache_max_len
195
+ config.embedding_dim = 128
196
+ config.block_configs = config.block_configs[: config.num_layers]
197
+ for block_config in config.block_configs:
198
+ block_config.attn_config.num_heads = 3
199
+ block_config.attn_config.head_dim = 64
200
+ block_config.ff_config.intermediate_size = 128
201
+ return config
202
+
203
+
204
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
205
+ config = get_model_config(**kwargs)
206
+ model = OpenELM(config)
207
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
208
+ # Since embedding and lm-head use the same weight, we need to set strict
209
+ # to False.
210
+ loader.load(model, strict=False)
211
+ model.eval()
212
+ return model
213
+
214
+
215
+ def define_and_run(checkpoint_path: str) -> None:
216
+ """Instantiates and runs an OpenELM model."""
217
+
218
+ current_dir = pathlib.Path(__file__).parent.resolve()
219
+ openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt")
220
+ kv_cache_max_len = 1024
221
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
222
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
223
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
224
+ tokens[0, :4] = idx
225
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
226
+ kv = kv_utils.KVCache.from_model_config(model.config)
227
+ output = model.forward(tokens, input_pos, kv)
228
+ assert torch.allclose(
229
+ openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
230
+ )
231
+
232
+
233
+ if __name__ == "__main__":
234
+ input_checkpoint_path = os.path.join(
235
+ pathlib.Path.home(), "Downloads/llm_data/openelm"
236
+ )
237
+ define_and_run(input_checkpoint_path)