ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
- ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
- ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
- ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
- ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +43 -30
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +84 -73
- ai_edge_torch/generative/layers/builder.py +38 -14
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +61 -33
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +77 -62
- ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +28 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ import logging
|
|
17
17
|
import os
|
18
18
|
from typing import Any, Optional
|
19
19
|
|
20
|
+
from ai_edge_torch import fx_pass_base
|
20
21
|
from ai_edge_torch import lowertools
|
21
22
|
from ai_edge_torch import model
|
22
23
|
from ai_edge_torch._convert import fx_passes
|
@@ -34,7 +35,7 @@ def _run_convert_passes(
|
|
34
35
|
exported_program = generative_fx_passes.run_generative_passes(
|
35
36
|
exported_program
|
36
37
|
)
|
37
|
-
return
|
38
|
+
return fx_pass_base.run_passes(
|
38
39
|
exported_program,
|
39
40
|
[
|
40
41
|
fx_passes.BuildInterpolateCompositePass(),
|
@@ -15,44 +15,8 @@
|
|
15
15
|
|
16
16
|
from typing import Sequence, Union
|
17
17
|
|
18
|
-
from ai_edge_torch._convert.fx_passes.
|
19
|
-
from ai_edge_torch._convert.fx_passes.
|
20
|
-
from ai_edge_torch._convert.fx_passes.
|
21
|
-
from ai_edge_torch._convert.fx_passes.
|
22
|
-
from ai_edge_torch.
|
23
|
-
from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA
|
24
|
-
from ai_edge_torch._convert.fx_passes.canonicalize_pass import CanonicalizePass
|
25
|
-
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
|
26
|
-
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
|
27
|
-
from torch.export import ExportedProgram
|
28
|
-
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
|
29
|
-
import torch.utils._pytree as pytree
|
30
|
-
|
31
|
-
|
32
|
-
# TODO(cnchan): make a PassManager class.
|
33
|
-
def run_passes(
|
34
|
-
exported_program: ExportedProgram,
|
35
|
-
passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
|
36
|
-
) -> ExportedProgram:
|
37
|
-
passes, _ = pytree.tree_flatten(passes)
|
38
|
-
for pass_ in passes:
|
39
|
-
if not isinstance(pass_, ExportedProgramPassBase):
|
40
|
-
pass_ = pass_result_wrapper(pass_)
|
41
|
-
if isinstance(pass_, ExportedProgramPassBase):
|
42
|
-
exported_program = pass_(exported_program).exported_program
|
43
|
-
else:
|
44
|
-
gm = exported_program.graph_module
|
45
|
-
gm, modified = pass_(gm)
|
46
|
-
if modified and gm is not exported_program.graph_module:
|
47
|
-
exported_program = ExportedProgram(
|
48
|
-
root=gm,
|
49
|
-
graph=gm.graph,
|
50
|
-
graph_signature=exported_program.graph_signature,
|
51
|
-
state_dict=exported_program.state_dict,
|
52
|
-
range_constraints=exported_program.range_constraints,
|
53
|
-
module_call_graph=exported_program.module_call_graph,
|
54
|
-
example_inputs=exported_program.example_inputs,
|
55
|
-
verifier=exported_program.verifier,
|
56
|
-
constants=exported_program.constants,
|
57
|
-
)
|
58
|
-
return exported_program
|
18
|
+
from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
|
19
|
+
from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
|
20
|
+
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
|
21
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
|
22
|
+
from ai_edge_torch.fx_pass_base import CanonicalizePass
|
@@ -13,11 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from functools import reduce
|
17
16
|
from typing import Any, Callable
|
17
|
+
from ai_edge_torch import fx_pass_base
|
18
18
|
from ai_edge_torch import lowertools
|
19
19
|
import torch
|
20
|
-
from torch.fx.passes.infra import pass_base
|
21
20
|
import torch.utils._pytree as pytree
|
22
21
|
|
23
22
|
_composite_builders: dict[
|
@@ -277,7 +276,7 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
|
277
276
|
node.target = embedding
|
278
277
|
|
279
278
|
|
280
|
-
class BuildAtenCompositePass(
|
279
|
+
class BuildAtenCompositePass(fx_pass_base.PassBase):
|
281
280
|
|
282
281
|
def call(self, graph_module: torch.fx.GraphModule):
|
283
282
|
for node in graph_module.graph.nodes:
|
@@ -286,4 +285,4 @@ class BuildAtenCompositePass(pass_base.PassBase):
|
|
286
285
|
|
287
286
|
graph_module.graph.lint()
|
288
287
|
graph_module.recompile()
|
289
|
-
return
|
288
|
+
return fx_pass_base.PassResult(graph_module, True)
|
@@ -16,8 +16,7 @@
|
|
16
16
|
|
17
17
|
import functools
|
18
18
|
|
19
|
-
from ai_edge_torch
|
20
|
-
from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
|
19
|
+
from ai_edge_torch import fx_pass_base
|
21
20
|
from ai_edge_torch.hlfb import mark_pattern
|
22
21
|
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
23
22
|
import torch
|
@@ -103,7 +102,7 @@ def _get_interpolate_nearest2d_pattern():
|
|
103
102
|
return pattern
|
104
103
|
|
105
104
|
|
106
|
-
class BuildInterpolateCompositePass(ExportedProgramPassBase):
|
105
|
+
class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
|
107
106
|
|
108
107
|
def __init__(self):
|
109
108
|
super().__init__()
|
@@ -124,4 +123,4 @@ class BuildInterpolateCompositePass(ExportedProgramPassBase):
|
|
124
123
|
|
125
124
|
graph_module.graph.lint()
|
126
125
|
graph_module.recompile()
|
127
|
-
return ExportedProgramPassResult(exported_program, True)
|
126
|
+
return fx_pass_base.ExportedProgramPassResult(exported_program, True)
|
@@ -13,10 +13,9 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
from ai_edge_torch import fx_pass_base
|
16
17
|
from ai_edge_torch import lowertools
|
17
18
|
import torch
|
18
|
-
from torch.fx.passes.infra.pass_base import PassBase
|
19
|
-
from torch.fx.passes.infra.pass_base import PassResult
|
20
19
|
import torch.utils._pytree as pytree
|
21
20
|
|
22
21
|
|
@@ -62,7 +61,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
|
|
62
61
|
node.target = debuginfo_writer
|
63
62
|
|
64
63
|
|
65
|
-
class InjectMlirDebuginfoPass(PassBase):
|
64
|
+
class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
|
66
65
|
|
67
66
|
def call(self, graph_module: torch.fx.GraphModule):
|
68
67
|
for node in graph_module.graph.nodes:
|
@@ -70,4 +69,4 @@ class InjectMlirDebuginfoPass(PassBase):
|
|
70
69
|
|
71
70
|
graph_module.graph.lint()
|
72
71
|
graph_module.recompile()
|
73
|
-
return PassResult(graph_module, True)
|
72
|
+
return fx_pass_base.PassResult(graph_module, True)
|
@@ -18,8 +18,7 @@ import operator
|
|
18
18
|
import os
|
19
19
|
from typing import Union
|
20
20
|
|
21
|
-
from ai_edge_torch
|
22
|
-
from ai_edge_torch._convert.fx_passes import ExportedProgramPassResult
|
21
|
+
from ai_edge_torch import fx_pass_base
|
23
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
|
24
23
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
25
24
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
|
@@ -31,7 +30,7 @@ import torch.ao.quantization.quantize_pt2e
|
|
31
30
|
TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
|
32
31
|
|
33
32
|
|
34
|
-
class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
33
|
+
class OptimizeLayoutTransposesPass(fx_pass_base.ExportedProgramPassBase):
|
35
34
|
|
36
35
|
def get_source_meta(self, node: torch.fx.Node):
|
37
36
|
keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
|
@@ -94,7 +93,7 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
|
94
93
|
|
95
94
|
q_args = input_q.args[1:]
|
96
95
|
q_kwargs = input_q.kwargs
|
97
|
-
q_op, dq_op =
|
96
|
+
q_op, dq_op = utils.get_paired_q_dq_ops(input_q.target)
|
98
97
|
with graph.inserting_before(target):
|
99
98
|
# Q and DQ inserted here may required updating the `axis` arg when they
|
100
99
|
# are per_channel ops. However, instead of updating here, the nodes would
|
@@ -301,4 +300,4 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
|
301
300
|
# Mark const node again for debugging
|
302
301
|
self.mark_const_nodes(exported_program)
|
303
302
|
|
304
|
-
return ExportedProgramPassResult(exported_program, True)
|
303
|
+
return fx_pass_base.ExportedProgramPassResult(exported_program, True)
|
ai_edge_torch/config.py
CHANGED
@@ -0,0 +1,101 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import abc
|
17
|
+
import collections
|
18
|
+
from typing import Sequence, Union
|
19
|
+
|
20
|
+
import torch
|
21
|
+
from torch.fx.passes.infra.pass_base import PassBase
|
22
|
+
from torch.fx.passes.infra.pass_base import PassResult
|
23
|
+
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
|
24
|
+
import torch.utils._pytree as pytree
|
25
|
+
|
26
|
+
FxPassBase = PassBase
|
27
|
+
FxPassResult = PassResult
|
28
|
+
ExportedProgramPassResult = collections.namedtuple(
|
29
|
+
"ExportedProgramPassResult", ["exported_program", "modified"]
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
class ExportedProgramPassBase(abc.ABC):
|
34
|
+
|
35
|
+
def __call__(
|
36
|
+
self, exported_program: torch.export.ExportedProgram
|
37
|
+
) -> ExportedProgramPassResult:
|
38
|
+
self.requires(exported_program)
|
39
|
+
res = self.call(exported_program)
|
40
|
+
self.ensures(exported_program)
|
41
|
+
return res
|
42
|
+
|
43
|
+
@abc.abstractmethod
|
44
|
+
def call(
|
45
|
+
self, exported_program: torch.export.ExportedProgram
|
46
|
+
) -> ExportedProgramPassResult:
|
47
|
+
pass
|
48
|
+
|
49
|
+
def requires(self, exported_program: torch.export.ExportedProgram) -> None:
|
50
|
+
pass
|
51
|
+
|
52
|
+
def ensures(self, exported_program: torch.export.ExportedProgram) -> None:
|
53
|
+
pass
|
54
|
+
|
55
|
+
|
56
|
+
# TODO(cnchan): make a PassManager class.
|
57
|
+
def run_passes(
|
58
|
+
exported_program: torch.export.ExportedProgram,
|
59
|
+
passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
|
60
|
+
) -> torch.export.ExportedProgram:
|
61
|
+
passes, _ = pytree.tree_flatten(passes)
|
62
|
+
for pass_ in passes:
|
63
|
+
if not isinstance(pass_, ExportedProgramPassBase):
|
64
|
+
pass_ = pass_result_wrapper(pass_)
|
65
|
+
if isinstance(pass_, ExportedProgramPassBase):
|
66
|
+
exported_program = pass_(exported_program).exported_program
|
67
|
+
else:
|
68
|
+
gm = exported_program.graph_module
|
69
|
+
gm, modified = pass_(gm)
|
70
|
+
if modified and gm is not exported_program.graph_module:
|
71
|
+
exported_program = torch.export.ExportedProgram(
|
72
|
+
root=gm,
|
73
|
+
graph=gm.graph,
|
74
|
+
graph_signature=exported_program.graph_signature,
|
75
|
+
state_dict=exported_program.state_dict,
|
76
|
+
range_constraints=exported_program.range_constraints,
|
77
|
+
module_call_graph=exported_program.module_call_graph,
|
78
|
+
example_inputs=exported_program.example_inputs,
|
79
|
+
verifier=exported_program.verifier,
|
80
|
+
constants=exported_program.constants,
|
81
|
+
)
|
82
|
+
return exported_program
|
83
|
+
|
84
|
+
|
85
|
+
class CanonicalizePass(ExportedProgramPassBase):
|
86
|
+
|
87
|
+
# A dummy decomp table for running ExportedProgram.run_decompositions without
|
88
|
+
# any op decompositions but just aot_export_module. Due to the check in
|
89
|
+
# run_decompositions, if None or an empty dict is passed as decomp_table,
|
90
|
+
# it will run the default aten-coreaten decompositions. Therefore a non-empty
|
91
|
+
# dummy decomp table is needed.
|
92
|
+
# Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
|
93
|
+
_DUMMY_DECOMP_TABLE = {
|
94
|
+
torch._ops.OperatorBase(): lambda: None,
|
95
|
+
}
|
96
|
+
|
97
|
+
def call(self, exported_program: torch.export.ExportedProgram):
|
98
|
+
exported_program = exported_program.run_decompositions(
|
99
|
+
self._DUMMY_DECOMP_TABLE
|
100
|
+
)
|
101
|
+
return ExportedProgramPassResult(exported_program, True)
|
@@ -13,55 +13,74 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""Example of converting a Gemma2 model to multi-signature tflite model."""
|
17
|
+
|
16
18
|
import os
|
17
|
-
|
19
|
+
import pathlib
|
18
20
|
|
19
21
|
import ai_edge_torch
|
20
22
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
25
|
import torch
|
23
26
|
|
24
27
|
|
25
|
-
def
|
28
|
+
def convert_gemma2_to_tflite(
|
26
29
|
checkpoint_path: str,
|
27
30
|
prefill_seq_len: int = 512,
|
28
31
|
kv_cache_max_len: int = 1024,
|
29
32
|
quantize: bool = True,
|
30
33
|
):
|
31
|
-
"""
|
32
|
-
tflite model.
|
34
|
+
"""Converts a Gemma2 2B model to multi-signature tflite model.
|
33
35
|
|
34
36
|
Args:
|
35
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
37
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
+
holding the checkpoint.
|
36
39
|
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
37
40
|
Defaults to 512.
|
38
41
|
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
39
42
|
including both prefill and decode. Defaults to 1024.
|
40
|
-
quantize (bool, optional): Whether the model should be quanized.
|
41
|
-
|
43
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
+
to True.
|
42
45
|
"""
|
43
46
|
pytorch_model = gemma2.build_2b_model(
|
44
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
45
48
|
)
|
46
49
|
# Tensors used to trace the model graph during conversion.
|
47
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
48
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
50
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
50
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
51
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
52
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
53
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
51
55
|
|
52
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
57
|
edge_model = (
|
54
58
|
ai_edge_torch.signature(
|
55
|
-
'prefill',
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
56
75
|
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
76
|
.convert(quant_config=quant_config)
|
59
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
60
79
|
edge_model.export(
|
61
|
-
f'/tmp/
|
80
|
+
f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
62
81
|
)
|
63
82
|
|
64
83
|
|
65
84
|
if __name__ == '__main__':
|
66
|
-
|
67
|
-
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
|
86
|
+
convert_gemma2_to_tflite(path)
|
@@ -13,11 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""Example of converting a Gemma model to multi-signature tflite model."""
|
17
|
+
|
16
18
|
import os
|
17
|
-
|
19
|
+
import pathlib
|
18
20
|
|
19
21
|
import ai_edge_torch
|
20
22
|
from ai_edge_torch.generative.examples.gemma import gemma
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
25
|
import torch
|
23
26
|
|
@@ -44,24 +47,40 @@ def convert_gemma_to_tflite(
|
|
44
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
45
48
|
)
|
46
49
|
# Tensors used to trace the model graph during conversion.
|
47
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
48
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
50
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
50
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
51
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
52
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
53
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
51
55
|
|
52
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
57
|
edge_model = (
|
54
58
|
ai_edge_torch.signature(
|
55
|
-
'prefill',
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
56
75
|
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
76
|
.convert(quant_config=quant_config)
|
59
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
60
79
|
edge_model.export(
|
61
|
-
f'/tmp/
|
80
|
+
f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
62
81
|
)
|
63
82
|
|
64
83
|
|
65
84
|
if __name__ == '__main__':
|
66
|
-
|
67
|
-
convert_gemma_to_tflite(
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
|
86
|
+
convert_gemma_to_tflite(path)
|
@@ -12,13 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""Example of building a Gemma model."""
|
16
17
|
|
17
18
|
import os
|
18
|
-
|
19
|
+
import pathlib
|
19
20
|
|
20
21
|
from ai_edge_torch.generative.layers import attention
|
21
22
|
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
@@ -48,7 +50,6 @@ class Gemma(nn.Module):
|
|
48
50
|
def __init__(self, config: cfg.ModelConfig):
|
49
51
|
super().__init__()
|
50
52
|
|
51
|
-
self.config = config
|
52
53
|
# Construct model layers.
|
53
54
|
self.tok_embedding = nn.Embedding(
|
54
55
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
@@ -60,18 +61,20 @@ class Gemma(nn.Module):
|
|
60
61
|
)
|
61
62
|
# Gemma re-uses the embedding as the head projection layer.
|
62
63
|
self.lm_head.weight.data = self.tok_embedding.weight.data
|
64
|
+
# Gemma has only one block config.
|
65
|
+
block_config = config.block_config(0)
|
63
66
|
self.transformer_blocks = nn.ModuleList(
|
64
|
-
attention.TransformerBlock(
|
67
|
+
attention.TransformerBlock(block_config, config)
|
68
|
+
for _ in range(config.num_layers)
|
65
69
|
)
|
66
70
|
self.final_norm = builder.build_norm(
|
67
71
|
config.embedding_dim,
|
68
72
|
config.final_norm_config,
|
69
73
|
)
|
74
|
+
attn_config = block_config.attn_config
|
70
75
|
self.rope_cache = attn_utils.build_rope_cache(
|
71
76
|
size=config.kv_cache_max,
|
72
|
-
dim=int(
|
73
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
74
|
-
),
|
77
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
75
78
|
base=10_000,
|
76
79
|
condense_ratio=1,
|
77
80
|
dtype=torch.float32,
|
@@ -84,16 +87,22 @@ class Gemma(nn.Module):
|
|
84
87
|
)
|
85
88
|
self.config = config
|
86
89
|
|
87
|
-
# The model's forward function takes in additional k/v cache tensors
|
88
|
-
# and returns the updated k/v cache tensors to the caller.
|
89
|
-
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
90
90
|
@torch.inference_mode
|
91
|
-
def forward(
|
92
|
-
|
91
|
+
def forward(
|
92
|
+
self,
|
93
|
+
tokens: torch.Tensor,
|
94
|
+
input_pos: torch.Tensor,
|
95
|
+
kv_cache: kv_utils.KVCache,
|
96
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
97
|
+
_, seq_len = tokens.size()
|
93
98
|
assert self.config.max_seq_len >= seq_len, (
|
94
99
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
95
100
|
f" {self.config.max_seq_len}"
|
96
101
|
)
|
102
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
103
|
+
"The number of transformer blocks and the number of KV cache entries"
|
104
|
+
" must be the same."
|
105
|
+
)
|
97
106
|
|
98
107
|
cos, sin = self.rope_cache
|
99
108
|
cos = cos.index_select(0, input_pos)
|
@@ -102,15 +111,20 @@ class Gemma(nn.Module):
|
|
102
111
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
103
112
|
|
104
113
|
# token embeddings of shape (b, t, n_embd)
|
105
|
-
x = self.tok_embedding(
|
114
|
+
x = self.tok_embedding(tokens)
|
106
115
|
x = x * (self.config.embedding_dim**0.5)
|
107
116
|
|
108
|
-
|
109
|
-
|
117
|
+
updated_kv_entires = []
|
118
|
+
for i, block in enumerate(self.transformer_blocks):
|
119
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
120
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
121
|
+
if kv_entry:
|
122
|
+
updated_kv_entires.append(kv_entry)
|
123
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
110
124
|
|
111
125
|
x = self.final_norm(x)
|
112
|
-
|
113
|
-
return
|
126
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
127
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
114
128
|
|
115
129
|
|
116
130
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -139,18 +153,20 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
139
153
|
epsilon=1e-6,
|
140
154
|
zero_centered=True,
|
141
155
|
)
|
156
|
+
block_config = cfg.TransformerBlockConfig(
|
157
|
+
attn_config=attn_config,
|
158
|
+
ff_config=ff_config,
|
159
|
+
pre_attention_norm_config=norm_config,
|
160
|
+
post_attention_norm_config=norm_config,
|
161
|
+
)
|
142
162
|
config = cfg.ModelConfig(
|
143
163
|
vocab_size=256000,
|
144
164
|
num_layers=18,
|
145
165
|
max_seq_len=8192,
|
146
166
|
embedding_dim=2048,
|
147
167
|
kv_cache_max_len=kv_cache_max_len,
|
148
|
-
|
149
|
-
ff_config=ff_config,
|
150
|
-
pre_attention_norm_config=norm_config,
|
151
|
-
post_attention_norm_config=norm_config,
|
168
|
+
block_configs=block_config,
|
152
169
|
final_norm_config=norm_config,
|
153
|
-
parallel_residual=False,
|
154
170
|
lm_head_use_bias=False,
|
155
171
|
enable_hlfb=True,
|
156
172
|
)
|
@@ -159,7 +175,8 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
159
175
|
|
160
176
|
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
161
177
|
config = get_model_config_2b(kv_cache_max_len)
|
162
|
-
config.
|
178
|
+
# Gemma has only one block config.
|
179
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
163
180
|
config.vocab_size = 128
|
164
181
|
config.num_layers = 2
|
165
182
|
config.max_seq_len = 2 * kv_cache_max_len
|
@@ -170,32 +187,35 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
170
187
|
config = get_model_config_2b(**kwargs)
|
171
188
|
model = Gemma(config)
|
172
189
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
173
|
-
#
|
190
|
+
# Since embedding and lm-head use the same weight, we need to set strict
|
174
191
|
# to False.
|
175
192
|
loader.load(model, strict=False)
|
176
193
|
model.eval()
|
177
194
|
return model
|
178
195
|
|
179
196
|
|
180
|
-
def define_and_run_2b() -> None:
|
197
|
+
def define_and_run_2b(checkpoint_path: str) -> None:
|
181
198
|
"""Instantiates and runs a Gemma 2B model."""
|
182
199
|
|
183
|
-
current_dir = Path(__file__).parent.resolve()
|
200
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
184
201
|
gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
|
185
202
|
|
186
203
|
kv_cache_max_len = 1024
|
187
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
|
188
204
|
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
189
205
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
190
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.
|
206
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
191
207
|
tokens[0, :4] = idx
|
192
|
-
input_pos = torch.arange(0, kv_cache_max_len)
|
193
|
-
|
208
|
+
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
209
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
210
|
+
output = model.forward(tokens, input_pos, kv)
|
194
211
|
print("comparing with goldens..")
|
195
212
|
assert torch.allclose(
|
196
|
-
gemma_goldens,
|
213
|
+
gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
197
214
|
)
|
198
215
|
|
199
216
|
|
200
217
|
if __name__ == "__main__":
|
201
|
-
|
218
|
+
input_checkpoint_path = os.path.join(
|
219
|
+
pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
|
220
|
+
)
|
221
|
+
define_and_run_2b(input_checkpoint_path)
|