ai-edge-torch-nightly 0.3.0.dev20240913__py3-none-any.whl → 0.3.0.dev20240915__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/gemma.py +2 -2
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -2
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +86 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/phi/phi2.py +2 -2
- ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
- ai_edge_torch/generative/examples/{smallm → smollm}/convert_to_tflite.py +12 -12
- ai_edge_torch/generative/examples/{smallm/smallm.py → smollm/smollm.py} +24 -15
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +8 -8
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +7 -0
- ai_edge_torch/generative/layers/builder.py +33 -11
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +4 -4
- ai_edge_torch/generative/layers/model_config.py +24 -15
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/test_model_conversion.py +28 -51
- ai_edge_torch/generative/test/test_model_conversion_large.py +43 -78
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/utilities/loader.py +13 -0
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/RECORD +48 -46
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- /ai_edge_torch/generative/examples/{smallm → openelm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ import logging
|
|
17
17
|
import os
|
18
18
|
from typing import Any, Optional
|
19
19
|
|
20
|
+
from ai_edge_torch import fx_pass_base
|
20
21
|
from ai_edge_torch import lowertools
|
21
22
|
from ai_edge_torch import model
|
22
23
|
from ai_edge_torch._convert import fx_passes
|
@@ -34,7 +35,7 @@ def _run_convert_passes(
|
|
34
35
|
exported_program = generative_fx_passes.run_generative_passes(
|
35
36
|
exported_program
|
36
37
|
)
|
37
|
-
return
|
38
|
+
return fx_pass_base.run_passes(
|
38
39
|
exported_program,
|
39
40
|
[
|
40
41
|
fx_passes.BuildInterpolateCompositePass(),
|
@@ -15,44 +15,8 @@
|
|
15
15
|
|
16
16
|
from typing import Sequence, Union
|
17
17
|
|
18
|
-
from ai_edge_torch._convert.fx_passes.
|
19
|
-
from ai_edge_torch._convert.fx_passes.
|
20
|
-
from ai_edge_torch._convert.fx_passes.
|
21
|
-
from ai_edge_torch._convert.fx_passes.
|
22
|
-
from ai_edge_torch.
|
23
|
-
from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA
|
24
|
-
from ai_edge_torch._convert.fx_passes.canonicalize_pass import CanonicalizePass
|
25
|
-
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
|
26
|
-
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
|
27
|
-
from torch.export import ExportedProgram
|
28
|
-
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
|
29
|
-
import torch.utils._pytree as pytree
|
30
|
-
|
31
|
-
|
32
|
-
# TODO(cnchan): make a PassManager class.
|
33
|
-
def run_passes(
|
34
|
-
exported_program: ExportedProgram,
|
35
|
-
passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
|
36
|
-
) -> ExportedProgram:
|
37
|
-
passes, _ = pytree.tree_flatten(passes)
|
38
|
-
for pass_ in passes:
|
39
|
-
if not isinstance(pass_, ExportedProgramPassBase):
|
40
|
-
pass_ = pass_result_wrapper(pass_)
|
41
|
-
if isinstance(pass_, ExportedProgramPassBase):
|
42
|
-
exported_program = pass_(exported_program).exported_program
|
43
|
-
else:
|
44
|
-
gm = exported_program.graph_module
|
45
|
-
gm, modified = pass_(gm)
|
46
|
-
if modified and gm is not exported_program.graph_module:
|
47
|
-
exported_program = ExportedProgram(
|
48
|
-
root=gm,
|
49
|
-
graph=gm.graph,
|
50
|
-
graph_signature=exported_program.graph_signature,
|
51
|
-
state_dict=exported_program.state_dict,
|
52
|
-
range_constraints=exported_program.range_constraints,
|
53
|
-
module_call_graph=exported_program.module_call_graph,
|
54
|
-
example_inputs=exported_program.example_inputs,
|
55
|
-
verifier=exported_program.verifier,
|
56
|
-
constants=exported_program.constants,
|
57
|
-
)
|
58
|
-
return exported_program
|
18
|
+
from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
|
19
|
+
from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
|
20
|
+
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
|
21
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
|
22
|
+
from ai_edge_torch.fx_pass_base import CanonicalizePass
|
@@ -13,11 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from functools import reduce
|
17
16
|
from typing import Any, Callable
|
17
|
+
from ai_edge_torch import fx_pass_base
|
18
18
|
from ai_edge_torch import lowertools
|
19
19
|
import torch
|
20
|
-
from torch.fx.passes.infra import pass_base
|
21
20
|
import torch.utils._pytree as pytree
|
22
21
|
|
23
22
|
_composite_builders: dict[
|
@@ -277,7 +276,7 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
|
277
276
|
node.target = embedding
|
278
277
|
|
279
278
|
|
280
|
-
class BuildAtenCompositePass(
|
279
|
+
class BuildAtenCompositePass(fx_pass_base.PassBase):
|
281
280
|
|
282
281
|
def call(self, graph_module: torch.fx.GraphModule):
|
283
282
|
for node in graph_module.graph.nodes:
|
@@ -286,4 +285,4 @@ class BuildAtenCompositePass(pass_base.PassBase):
|
|
286
285
|
|
287
286
|
graph_module.graph.lint()
|
288
287
|
graph_module.recompile()
|
289
|
-
return
|
288
|
+
return fx_pass_base.PassResult(graph_module, True)
|
@@ -16,8 +16,7 @@
|
|
16
16
|
|
17
17
|
import functools
|
18
18
|
|
19
|
-
from ai_edge_torch
|
20
|
-
from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
|
19
|
+
from ai_edge_torch import fx_pass_base
|
21
20
|
from ai_edge_torch.hlfb import mark_pattern
|
22
21
|
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
23
22
|
import torch
|
@@ -103,7 +102,7 @@ def _get_interpolate_nearest2d_pattern():
|
|
103
102
|
return pattern
|
104
103
|
|
105
104
|
|
106
|
-
class BuildInterpolateCompositePass(ExportedProgramPassBase):
|
105
|
+
class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
|
107
106
|
|
108
107
|
def __init__(self):
|
109
108
|
super().__init__()
|
@@ -124,4 +123,4 @@ class BuildInterpolateCompositePass(ExportedProgramPassBase):
|
|
124
123
|
|
125
124
|
graph_module.graph.lint()
|
126
125
|
graph_module.recompile()
|
127
|
-
return ExportedProgramPassResult(exported_program, True)
|
126
|
+
return fx_pass_base.ExportedProgramPassResult(exported_program, True)
|
@@ -13,10 +13,9 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
from ai_edge_torch import fx_pass_base
|
16
17
|
from ai_edge_torch import lowertools
|
17
18
|
import torch
|
18
|
-
from torch.fx.passes.infra.pass_base import PassBase
|
19
|
-
from torch.fx.passes.infra.pass_base import PassResult
|
20
19
|
import torch.utils._pytree as pytree
|
21
20
|
|
22
21
|
|
@@ -62,7 +61,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
|
|
62
61
|
node.target = debuginfo_writer
|
63
62
|
|
64
63
|
|
65
|
-
class InjectMlirDebuginfoPass(PassBase):
|
64
|
+
class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
|
66
65
|
|
67
66
|
def call(self, graph_module: torch.fx.GraphModule):
|
68
67
|
for node in graph_module.graph.nodes:
|
@@ -70,4 +69,4 @@ class InjectMlirDebuginfoPass(PassBase):
|
|
70
69
|
|
71
70
|
graph_module.graph.lint()
|
72
71
|
graph_module.recompile()
|
73
|
-
return PassResult(graph_module, True)
|
72
|
+
return fx_pass_base.PassResult(graph_module, True)
|
@@ -18,8 +18,7 @@ import operator
|
|
18
18
|
import os
|
19
19
|
from typing import Union
|
20
20
|
|
21
|
-
from ai_edge_torch
|
22
|
-
from ai_edge_torch._convert.fx_passes import ExportedProgramPassResult
|
21
|
+
from ai_edge_torch import fx_pass_base
|
23
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
|
24
23
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
25
24
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
|
@@ -31,7 +30,7 @@ import torch.ao.quantization.quantize_pt2e
|
|
31
30
|
TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
|
32
31
|
|
33
32
|
|
34
|
-
class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
33
|
+
class OptimizeLayoutTransposesPass(fx_pass_base.ExportedProgramPassBase):
|
35
34
|
|
36
35
|
def get_source_meta(self, node: torch.fx.Node):
|
37
36
|
keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
|
@@ -94,7 +93,7 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
|
94
93
|
|
95
94
|
q_args = input_q.args[1:]
|
96
95
|
q_kwargs = input_q.kwargs
|
97
|
-
q_op, dq_op =
|
96
|
+
q_op, dq_op = utils.get_paired_q_dq_ops(input_q.target)
|
98
97
|
with graph.inserting_before(target):
|
99
98
|
# Q and DQ inserted here may required updating the `axis` arg when they
|
100
99
|
# are per_channel ops. However, instead of updating here, the nodes would
|
@@ -301,4 +300,4 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
|
301
300
|
# Mark const node again for debugging
|
302
301
|
self.mark_const_nodes(exported_program)
|
303
302
|
|
304
|
-
return ExportedProgramPassResult(exported_program, True)
|
303
|
+
return fx_pass_base.ExportedProgramPassResult(exported_program, True)
|
ai_edge_torch/config.py
CHANGED
@@ -0,0 +1,101 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import abc
|
17
|
+
import collections
|
18
|
+
from typing import Sequence, Union
|
19
|
+
|
20
|
+
import torch
|
21
|
+
from torch.fx.passes.infra.pass_base import PassBase
|
22
|
+
from torch.fx.passes.infra.pass_base import PassResult
|
23
|
+
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
|
24
|
+
import torch.utils._pytree as pytree
|
25
|
+
|
26
|
+
FxPassBase = PassBase
|
27
|
+
FxPassResult = PassResult
|
28
|
+
ExportedProgramPassResult = collections.namedtuple(
|
29
|
+
"ExportedProgramPassResult", ["exported_program", "modified"]
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
class ExportedProgramPassBase(abc.ABC):
|
34
|
+
|
35
|
+
def __call__(
|
36
|
+
self, exported_program: torch.export.ExportedProgram
|
37
|
+
) -> ExportedProgramPassResult:
|
38
|
+
self.requires(exported_program)
|
39
|
+
res = self.call(exported_program)
|
40
|
+
self.ensures(exported_program)
|
41
|
+
return res
|
42
|
+
|
43
|
+
@abc.abstractmethod
|
44
|
+
def call(
|
45
|
+
self, exported_program: torch.export.ExportedProgram
|
46
|
+
) -> ExportedProgramPassResult:
|
47
|
+
pass
|
48
|
+
|
49
|
+
def requires(self, exported_program: torch.export.ExportedProgram) -> None:
|
50
|
+
pass
|
51
|
+
|
52
|
+
def ensures(self, exported_program: torch.export.ExportedProgram) -> None:
|
53
|
+
pass
|
54
|
+
|
55
|
+
|
56
|
+
# TODO(cnchan): make a PassManager class.
|
57
|
+
def run_passes(
|
58
|
+
exported_program: torch.export.ExportedProgram,
|
59
|
+
passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
|
60
|
+
) -> torch.export.ExportedProgram:
|
61
|
+
passes, _ = pytree.tree_flatten(passes)
|
62
|
+
for pass_ in passes:
|
63
|
+
if not isinstance(pass_, ExportedProgramPassBase):
|
64
|
+
pass_ = pass_result_wrapper(pass_)
|
65
|
+
if isinstance(pass_, ExportedProgramPassBase):
|
66
|
+
exported_program = pass_(exported_program).exported_program
|
67
|
+
else:
|
68
|
+
gm = exported_program.graph_module
|
69
|
+
gm, modified = pass_(gm)
|
70
|
+
if modified and gm is not exported_program.graph_module:
|
71
|
+
exported_program = torch.export.ExportedProgram(
|
72
|
+
root=gm,
|
73
|
+
graph=gm.graph,
|
74
|
+
graph_signature=exported_program.graph_signature,
|
75
|
+
state_dict=exported_program.state_dict,
|
76
|
+
range_constraints=exported_program.range_constraints,
|
77
|
+
module_call_graph=exported_program.module_call_graph,
|
78
|
+
example_inputs=exported_program.example_inputs,
|
79
|
+
verifier=exported_program.verifier,
|
80
|
+
constants=exported_program.constants,
|
81
|
+
)
|
82
|
+
return exported_program
|
83
|
+
|
84
|
+
|
85
|
+
class CanonicalizePass(ExportedProgramPassBase):
|
86
|
+
|
87
|
+
# A dummy decomp table for running ExportedProgram.run_decompositions without
|
88
|
+
# any op decompositions but just aot_export_module. Due to the check in
|
89
|
+
# run_decompositions, if None or an empty dict is passed as decomp_table,
|
90
|
+
# it will run the default aten-coreaten decompositions. Therefore a non-empty
|
91
|
+
# dummy decomp table is needed.
|
92
|
+
# Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
|
93
|
+
_DUMMY_DECOMP_TABLE = {
|
94
|
+
torch._ops.OperatorBase(): lambda: None,
|
95
|
+
}
|
96
|
+
|
97
|
+
def call(self, exported_program: torch.export.ExportedProgram):
|
98
|
+
exported_program = exported_program.run_decompositions(
|
99
|
+
self._DUMMY_DECOMP_TABLE
|
100
|
+
)
|
101
|
+
return ExportedProgramPassResult(exported_program, True)
|
@@ -47,10 +47,10 @@ def convert_gemma2_to_tflite(
|
|
47
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
48
48
|
)
|
49
49
|
# Tensors used to trace the model graph during conversion.
|
50
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
51
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
52
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
53
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
50
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
51
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
52
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
53
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
54
54
|
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
55
|
|
56
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
@@ -47,10 +47,10 @@ def convert_gemma_to_tflite(
|
|
47
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
48
48
|
)
|
49
49
|
# Tensors used to trace the model graph during conversion.
|
50
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
51
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
52
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
53
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
50
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
51
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
52
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
53
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
54
54
|
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
55
|
|
56
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
@@ -203,9 +203,9 @@ def define_and_run_2b(checkpoint_path: str) -> None:
|
|
203
203
|
kv_cache_max_len = 1024
|
204
204
|
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
205
205
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
206
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.
|
206
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
207
207
|
tokens[0, :4] = idx
|
208
|
-
input_pos = torch.arange(0, kv_cache_max_len)
|
208
|
+
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
209
209
|
kv = kv_utils.KVCache.from_model_config(model.config)
|
210
210
|
output = model.forward(tokens, input_pos, kv)
|
211
211
|
print("comparing with goldens..")
|
@@ -280,9 +280,9 @@ def define_and_run_2b(checkpoint_path: str) -> None:
|
|
280
280
|
toks = torch.from_numpy(
|
281
281
|
np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
|
282
282
|
)
|
283
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.
|
283
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
284
284
|
tokens[0, :9] = toks
|
285
|
-
input_pos = torch.arange(0, kv_cache_max_len)
|
285
|
+
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
286
286
|
kv = kv_utils.KVCache.from_model_config(model.config)
|
287
287
|
out = model.forward(tokens, input_pos, kv)
|
288
288
|
out_final = out["logits"][0, 8, :]
|
@@ -0,0 +1,86 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting OpenELM model to multi-signature tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
import ai_edge_torch
|
22
|
+
from ai_edge_torch.generative.examples.openelm import openelm
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
+
import torch
|
26
|
+
|
27
|
+
|
28
|
+
def convert_openelm_to_tflite(
|
29
|
+
checkpoint_path: str,
|
30
|
+
prefill_seq_len: int = 512,
|
31
|
+
kv_cache_max_len: int = 1024,
|
32
|
+
quantize: bool = True,
|
33
|
+
):
|
34
|
+
"""Converts OpenELM model to multi-signature tflite model.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
+
holding the checkpoint.
|
39
|
+
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
+
Defaults to 512.
|
41
|
+
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
+
including both prefill and decode. Defaults to 1024.
|
43
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
+
to True.
|
45
|
+
"""
|
46
|
+
pytorch_model = openelm.build_model(
|
47
|
+
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
48
|
+
)
|
49
|
+
# Tensors used to trace the model graph during conversion.
|
50
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
51
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
52
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
53
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
+
|
56
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
+
edge_model = (
|
58
|
+
ai_edge_torch.signature(
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
75
|
+
)
|
76
|
+
.convert(quant_config=quant_config)
|
77
|
+
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
+
edge_model.export(
|
80
|
+
f'/tmp/openelm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
81
|
+
)
|
82
|
+
|
83
|
+
|
84
|
+
if __name__ == '__main__':
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm')
|
86
|
+
convert_openelm_to_tflite(path)
|
@@ -0,0 +1,237 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building an OpenELM model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from ai_edge_torch.generative.layers import attention
|
22
|
+
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
26
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
+
import numpy as np
|
28
|
+
import torch
|
29
|
+
from torch import nn
|
30
|
+
|
31
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
32
|
+
ff_up_proj="transformer.layers.{}.ffn.proj_1",
|
33
|
+
ff_down_proj="transformer.layers.{}.ffn.proj_2",
|
34
|
+
attn_fused_qkv_proj="transformer.layers.{}.attn.qkv_proj",
|
35
|
+
attn_query_norm="transformer.layers.{}.attn.q_norm",
|
36
|
+
attn_key_norm="transformer.layers.{}.attn.k_norm",
|
37
|
+
attn_output_proj="transformer.layers.{}.attn.out_proj",
|
38
|
+
pre_attn_norm="transformer.layers.{}.attn_norm",
|
39
|
+
pre_ff_norm="transformer.layers.{}.ffn_norm",
|
40
|
+
embedding="transformer.token_embeddings",
|
41
|
+
final_norm="transformer.norm",
|
42
|
+
lm_head=None,
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class OpenELM(nn.Module):
|
47
|
+
"""An OpenELM model built from the Edge Generative API layers."""
|
48
|
+
|
49
|
+
def __init__(self, config: cfg.ModelConfig):
|
50
|
+
super().__init__()
|
51
|
+
|
52
|
+
# Construct model layers.
|
53
|
+
self.tok_embedding = nn.Embedding(
|
54
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
55
|
+
)
|
56
|
+
self.lm_head = nn.Linear(
|
57
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
58
|
+
)
|
59
|
+
# OpenELM re-uses the embedding as the head projection layer.
|
60
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
61
|
+
self.transformer_blocks = nn.ModuleList(
|
62
|
+
attention.TransformerBlock(config.block_config(idx), config)
|
63
|
+
for idx in range(config.num_layers)
|
64
|
+
)
|
65
|
+
self.final_norm = builder.build_norm(
|
66
|
+
config.embedding_dim,
|
67
|
+
config.final_norm_config,
|
68
|
+
)
|
69
|
+
# OpenELM has same hyper parameters for rotary_percentage and head_dim for
|
70
|
+
# each layer block. Use the first block.
|
71
|
+
attn_config = config.block_config(0).attn_config
|
72
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
73
|
+
size=config.kv_cache_max,
|
74
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
75
|
+
base=10_000,
|
76
|
+
condense_ratio=1,
|
77
|
+
dtype=torch.float32,
|
78
|
+
device=torch.device("cpu"),
|
79
|
+
)
|
80
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
81
|
+
size=config.kv_cache_max,
|
82
|
+
dtype=torch.float32,
|
83
|
+
device=torch.device("cpu"),
|
84
|
+
)
|
85
|
+
self.config = config
|
86
|
+
|
87
|
+
@torch.inference_mode
|
88
|
+
def forward(
|
89
|
+
self,
|
90
|
+
tokens: torch.Tensor,
|
91
|
+
input_pos: torch.Tensor,
|
92
|
+
kv_cache: kv_utils.KVCache,
|
93
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
94
|
+
_, seq_len = tokens.size()
|
95
|
+
assert self.config.max_seq_len >= seq_len, (
|
96
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
97
|
+
f" {self.config.max_seq_len}"
|
98
|
+
)
|
99
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
100
|
+
"The number of transformer blocks and the number of KV cache entries"
|
101
|
+
" must be the same."
|
102
|
+
)
|
103
|
+
|
104
|
+
cos, sin = self.rope_cache
|
105
|
+
cos = cos.index_select(0, input_pos)
|
106
|
+
sin = sin.index_select(0, input_pos)
|
107
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
108
|
+
mask = mask[:, :, :, : self.config.kv_cache_max]
|
109
|
+
|
110
|
+
# token embeddings of shape (b, t, n_embd)
|
111
|
+
x = self.tok_embedding(tokens)
|
112
|
+
|
113
|
+
updated_kv_entires = []
|
114
|
+
for i, block in enumerate(self.transformer_blocks):
|
115
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
116
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
117
|
+
if kv_entry:
|
118
|
+
updated_kv_entires.append(kv_entry)
|
119
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
120
|
+
|
121
|
+
x = self.final_norm(x)
|
122
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
123
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
124
|
+
|
125
|
+
|
126
|
+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
127
|
+
"""Returns the model config for an OpenELM model.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
131
|
+
is 1024.
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
The model config for an OpenELM model.
|
135
|
+
"""
|
136
|
+
norm_config = cfg.NormalizationConfig(
|
137
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
|
138
|
+
)
|
139
|
+
num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
|
140
|
+
num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
|
141
|
+
|
142
|
+
def make_divisible(v, d):
|
143
|
+
"""Ensures that all layers have a channel number that is divisible by d."""
|
144
|
+
new_v = int(v + d / 2) // d * d
|
145
|
+
# Make sure that round down does not go down by more than 10%.
|
146
|
+
if new_v < 0.9 * v:
|
147
|
+
new_v += d
|
148
|
+
return new_v
|
149
|
+
|
150
|
+
# The way to get intermediate size is from
|
151
|
+
# https://huggingface.co/apple/OpenELM-3B/blob/main/modeling_openelm.py
|
152
|
+
def get_intermediate_size(idx: int) -> int:
|
153
|
+
return make_divisible((0.5 + 0.1 * idx) * 3072, 256)
|
154
|
+
|
155
|
+
def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
|
156
|
+
return cfg.TransformerBlockConfig(
|
157
|
+
attn_config=cfg.AttentionConfig(
|
158
|
+
num_heads=num_heads[idx],
|
159
|
+
head_dim=128,
|
160
|
+
num_query_groups=num_query_groups[idx],
|
161
|
+
rotary_percentage=1.0,
|
162
|
+
qkv_transpose_before_split=True,
|
163
|
+
query_norm_config=norm_config,
|
164
|
+
key_norm_config=norm_config,
|
165
|
+
),
|
166
|
+
ff_config=cfg.FeedForwardConfig(
|
167
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
168
|
+
activation=cfg.ActivationConfig(
|
169
|
+
cfg.ActivationType.SILU_GLU, gate_is_front=True
|
170
|
+
),
|
171
|
+
intermediate_size=get_intermediate_size(idx),
|
172
|
+
pre_ff_norm_config=norm_config,
|
173
|
+
),
|
174
|
+
pre_attention_norm_config=norm_config,
|
175
|
+
)
|
176
|
+
|
177
|
+
num_layers = 36
|
178
|
+
config = cfg.ModelConfig(
|
179
|
+
vocab_size=32000,
|
180
|
+
num_layers=num_layers,
|
181
|
+
max_seq_len=2048,
|
182
|
+
embedding_dim=3072,
|
183
|
+
kv_cache_max_len=kv_cache_max_len,
|
184
|
+
block_configs=[get_block_config(i) for i in range(num_layers)],
|
185
|
+
final_norm_config=norm_config,
|
186
|
+
)
|
187
|
+
return config
|
188
|
+
|
189
|
+
|
190
|
+
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
191
|
+
config = get_model_config(kv_cache_max_len)
|
192
|
+
config.vocab_size = 128
|
193
|
+
config.num_layers = 2
|
194
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
195
|
+
config.embedding_dim = 128
|
196
|
+
config.block_configs = config.block_configs[: config.num_layers]
|
197
|
+
for block_config in config.block_configs:
|
198
|
+
block_config.attn_config.num_heads = 3
|
199
|
+
block_config.attn_config.head_dim = 64
|
200
|
+
block_config.ff_config.intermediate_size = 128
|
201
|
+
return config
|
202
|
+
|
203
|
+
|
204
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
205
|
+
config = get_model_config(**kwargs)
|
206
|
+
model = OpenELM(config)
|
207
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
208
|
+
# Since embedding and lm-head use the same weight, we need to set strict
|
209
|
+
# to False.
|
210
|
+
loader.load(model, strict=False)
|
211
|
+
model.eval()
|
212
|
+
return model
|
213
|
+
|
214
|
+
|
215
|
+
def define_and_run(checkpoint_path: str) -> None:
|
216
|
+
"""Instantiates and runs an OpenELM model."""
|
217
|
+
|
218
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
219
|
+
openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt")
|
220
|
+
kv_cache_max_len = 1024
|
221
|
+
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
222
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
223
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
224
|
+
tokens[0, :4] = idx
|
225
|
+
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
226
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
227
|
+
output = model.forward(tokens, input_pos, kv)
|
228
|
+
assert torch.allclose(
|
229
|
+
openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
|
230
|
+
)
|
231
|
+
|
232
|
+
|
233
|
+
if __name__ == "__main__":
|
234
|
+
input_checkpoint_path = os.path.join(
|
235
|
+
pathlib.Path.home(), "Downloads/llm_data/openelm"
|
236
|
+
)
|
237
|
+
define_and_run(input_checkpoint_path)
|