ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
  4. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
  5. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
  6. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
  7. ai_edge_torch/config.py +4 -1
  8. ai_edge_torch/fx_pass_base.py +101 -0
  9. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
  10. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
  11. ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
  13. ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
  14. ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
  15. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
  16. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
  17. ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
  18. ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
  19. ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
  20. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
  21. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
  22. ai_edge_torch/generative/examples/t5/t5.py +43 -30
  23. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  24. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  25. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
  26. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
  27. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
  28. ai_edge_torch/generative/fx_passes/__init__.py +4 -4
  29. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
  30. ai_edge_torch/generative/layers/attention.py +84 -73
  31. ai_edge_torch/generative/layers/builder.py +38 -14
  32. ai_edge_torch/generative/layers/feed_forward.py +26 -8
  33. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  34. ai_edge_torch/generative/layers/model_config.py +61 -33
  35. ai_edge_torch/generative/layers/normalization.py +158 -0
  36. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  37. ai_edge_torch/generative/quantize/example.py +2 -2
  38. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  39. ai_edge_torch/generative/test/test_loader.py +1 -1
  40. ai_edge_torch/generative/test/test_model_conversion.py +77 -62
  41. ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
  42. ai_edge_torch/generative/test/test_quantize.py +5 -5
  43. ai_edge_torch/generative/test/utils.py +54 -0
  44. ai_edge_torch/generative/utilities/loader.py +28 -15
  45. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  46. ai_edge_torch/odml_torch/export.py +40 -0
  47. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  48. ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
  49. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
  50. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  51. ai_edge_torch/version.py +1 -1
  52. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
  53. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
  54. ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
  55. ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
  56. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  57. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  58. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  59. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  60. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  61. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  62. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  63. /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
  64. /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
  65. /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
  66. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
  67. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
  68. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ import logging
17
17
  import os
18
18
  from typing import Any, Optional
19
19
 
20
+ from ai_edge_torch import fx_pass_base
20
21
  from ai_edge_torch import lowertools
21
22
  from ai_edge_torch import model
22
23
  from ai_edge_torch._convert import fx_passes
@@ -34,7 +35,7 @@ def _run_convert_passes(
34
35
  exported_program = generative_fx_passes.run_generative_passes(
35
36
  exported_program
36
37
  )
37
- return fx_passes.run_passes(
38
+ return fx_pass_base.run_passes(
38
39
  exported_program,
39
40
  [
40
41
  fx_passes.BuildInterpolateCompositePass(),
@@ -15,44 +15,8 @@
15
15
 
16
16
  from typing import Sequence, Union
17
17
 
18
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
19
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
20
- from ai_edge_torch._convert.fx_passes._pass_base import FxPassBase
21
- from ai_edge_torch._convert.fx_passes._pass_base import FxPassResult
22
- from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA
23
- from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA
24
- from ai_edge_torch._convert.fx_passes.canonicalize_pass import CanonicalizePass
25
- from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
26
- from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
27
- from torch.export import ExportedProgram
28
- from torch.fx.passes.infra.pass_manager import pass_result_wrapper
29
- import torch.utils._pytree as pytree
30
-
31
-
32
- # TODO(cnchan): make a PassManager class.
33
- def run_passes(
34
- exported_program: ExportedProgram,
35
- passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
36
- ) -> ExportedProgram:
37
- passes, _ = pytree.tree_flatten(passes)
38
- for pass_ in passes:
39
- if not isinstance(pass_, ExportedProgramPassBase):
40
- pass_ = pass_result_wrapper(pass_)
41
- if isinstance(pass_, ExportedProgramPassBase):
42
- exported_program = pass_(exported_program).exported_program
43
- else:
44
- gm = exported_program.graph_module
45
- gm, modified = pass_(gm)
46
- if modified and gm is not exported_program.graph_module:
47
- exported_program = ExportedProgram(
48
- root=gm,
49
- graph=gm.graph,
50
- graph_signature=exported_program.graph_signature,
51
- state_dict=exported_program.state_dict,
52
- range_constraints=exported_program.range_constraints,
53
- module_call_graph=exported_program.module_call_graph,
54
- example_inputs=exported_program.example_inputs,
55
- verifier=exported_program.verifier,
56
- constants=exported_program.constants,
57
- )
58
- return exported_program
18
+ from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
19
+ from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
20
+ from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
21
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
22
+ from ai_edge_torch.fx_pass_base import CanonicalizePass
@@ -13,11 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from functools import reduce
17
16
  from typing import Any, Callable
17
+ from ai_edge_torch import fx_pass_base
18
18
  from ai_edge_torch import lowertools
19
19
  import torch
20
- from torch.fx.passes.infra import pass_base
21
20
  import torch.utils._pytree as pytree
22
21
 
23
22
  _composite_builders: dict[
@@ -277,7 +276,7 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
277
276
  node.target = embedding
278
277
 
279
278
 
280
- class BuildAtenCompositePass(pass_base.PassBase):
279
+ class BuildAtenCompositePass(fx_pass_base.PassBase):
281
280
 
282
281
  def call(self, graph_module: torch.fx.GraphModule):
283
282
  for node in graph_module.graph.nodes:
@@ -286,4 +285,4 @@ class BuildAtenCompositePass(pass_base.PassBase):
286
285
 
287
286
  graph_module.graph.lint()
288
287
  graph_module.recompile()
289
- return pass_base.PassResult(graph_module, True)
288
+ return fx_pass_base.PassResult(graph_module, True)
@@ -16,8 +16,7 @@
16
16
 
17
17
  import functools
18
18
 
19
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
20
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
19
+ from ai_edge_torch import fx_pass_base
21
20
  from ai_edge_torch.hlfb import mark_pattern
22
21
  from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
23
22
  import torch
@@ -103,7 +102,7 @@ def _get_interpolate_nearest2d_pattern():
103
102
  return pattern
104
103
 
105
104
 
106
- class BuildInterpolateCompositePass(ExportedProgramPassBase):
105
+ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
107
106
 
108
107
  def __init__(self):
109
108
  super().__init__()
@@ -124,4 +123,4 @@ class BuildInterpolateCompositePass(ExportedProgramPassBase):
124
123
 
125
124
  graph_module.graph.lint()
126
125
  graph_module.recompile()
127
- return ExportedProgramPassResult(exported_program, True)
126
+ return fx_pass_base.ExportedProgramPassResult(exported_program, True)
@@ -13,10 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from ai_edge_torch import fx_pass_base
16
17
  from ai_edge_torch import lowertools
17
18
  import torch
18
- from torch.fx.passes.infra.pass_base import PassBase
19
- from torch.fx.passes.infra.pass_base import PassResult
20
19
  import torch.utils._pytree as pytree
21
20
 
22
21
 
@@ -62,7 +61,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
62
61
  node.target = debuginfo_writer
63
62
 
64
63
 
65
- class InjectMlirDebuginfoPass(PassBase):
64
+ class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
66
65
 
67
66
  def call(self, graph_module: torch.fx.GraphModule):
68
67
  for node in graph_module.graph.nodes:
@@ -70,4 +69,4 @@ class InjectMlirDebuginfoPass(PassBase):
70
69
 
71
70
  graph_module.graph.lint()
72
71
  graph_module.recompile()
73
- return PassResult(graph_module, True)
72
+ return fx_pass_base.PassResult(graph_module, True)
@@ -18,8 +18,7 @@ import operator
18
18
  import os
19
19
  from typing import Union
20
20
 
21
- from ai_edge_torch._convert.fx_passes import ExportedProgramPassBase
22
- from ai_edge_torch._convert.fx_passes import ExportedProgramPassResult
21
+ from ai_edge_torch import fx_pass_base
23
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
24
23
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
25
24
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
@@ -31,7 +30,7 @@ import torch.ao.quantization.quantize_pt2e
31
30
  TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
32
31
 
33
32
 
34
- class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
33
+ class OptimizeLayoutTransposesPass(fx_pass_base.ExportedProgramPassBase):
35
34
 
36
35
  def get_source_meta(self, node: torch.fx.Node):
37
36
  keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
@@ -94,7 +93,7 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
94
93
 
95
94
  q_args = input_q.args[1:]
96
95
  q_kwargs = input_q.kwargs
97
- q_op, dq_op = self.get_paired_q_dq_ops(input_q.target)
96
+ q_op, dq_op = utils.get_paired_q_dq_ops(input_q.target)
98
97
  with graph.inserting_before(target):
99
98
  # Q and DQ inserted here may required updating the `axis` arg when they
100
99
  # are per_channel ops. However, instead of updating here, the nodes would
@@ -301,4 +300,4 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
301
300
  # Mark const node again for debugging
302
301
  self.mark_const_nodes(exported_program)
303
302
 
304
- return ExportedProgramPassResult(exported_program, True)
303
+ return fx_pass_base.ExportedProgramPassResult(exported_program, True)
ai_edge_torch/config.py CHANGED
@@ -21,4 +21,7 @@ import os
21
21
 
22
22
  @dataclasses.dataclass
23
23
  class Config:
24
- use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "True") == "True"
24
+ use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "true").lower() in (
25
+ "1",
26
+ "true",
27
+ )
@@ -0,0 +1,101 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import abc
17
+ import collections
18
+ from typing import Sequence, Union
19
+
20
+ import torch
21
+ from torch.fx.passes.infra.pass_base import PassBase
22
+ from torch.fx.passes.infra.pass_base import PassResult
23
+ from torch.fx.passes.infra.pass_manager import pass_result_wrapper
24
+ import torch.utils._pytree as pytree
25
+
26
+ FxPassBase = PassBase
27
+ FxPassResult = PassResult
28
+ ExportedProgramPassResult = collections.namedtuple(
29
+ "ExportedProgramPassResult", ["exported_program", "modified"]
30
+ )
31
+
32
+
33
+ class ExportedProgramPassBase(abc.ABC):
34
+
35
+ def __call__(
36
+ self, exported_program: torch.export.ExportedProgram
37
+ ) -> ExportedProgramPassResult:
38
+ self.requires(exported_program)
39
+ res = self.call(exported_program)
40
+ self.ensures(exported_program)
41
+ return res
42
+
43
+ @abc.abstractmethod
44
+ def call(
45
+ self, exported_program: torch.export.ExportedProgram
46
+ ) -> ExportedProgramPassResult:
47
+ pass
48
+
49
+ def requires(self, exported_program: torch.export.ExportedProgram) -> None:
50
+ pass
51
+
52
+ def ensures(self, exported_program: torch.export.ExportedProgram) -> None:
53
+ pass
54
+
55
+
56
+ # TODO(cnchan): make a PassManager class.
57
+ def run_passes(
58
+ exported_program: torch.export.ExportedProgram,
59
+ passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
60
+ ) -> torch.export.ExportedProgram:
61
+ passes, _ = pytree.tree_flatten(passes)
62
+ for pass_ in passes:
63
+ if not isinstance(pass_, ExportedProgramPassBase):
64
+ pass_ = pass_result_wrapper(pass_)
65
+ if isinstance(pass_, ExportedProgramPassBase):
66
+ exported_program = pass_(exported_program).exported_program
67
+ else:
68
+ gm = exported_program.graph_module
69
+ gm, modified = pass_(gm)
70
+ if modified and gm is not exported_program.graph_module:
71
+ exported_program = torch.export.ExportedProgram(
72
+ root=gm,
73
+ graph=gm.graph,
74
+ graph_signature=exported_program.graph_signature,
75
+ state_dict=exported_program.state_dict,
76
+ range_constraints=exported_program.range_constraints,
77
+ module_call_graph=exported_program.module_call_graph,
78
+ example_inputs=exported_program.example_inputs,
79
+ verifier=exported_program.verifier,
80
+ constants=exported_program.constants,
81
+ )
82
+ return exported_program
83
+
84
+
85
+ class CanonicalizePass(ExportedProgramPassBase):
86
+
87
+ # A dummy decomp table for running ExportedProgram.run_decompositions without
88
+ # any op decompositions but just aot_export_module. Due to the check in
89
+ # run_decompositions, if None or an empty dict is passed as decomp_table,
90
+ # it will run the default aten-coreaten decompositions. Therefore a non-empty
91
+ # dummy decomp table is needed.
92
+ # Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
93
+ _DUMMY_DECOMP_TABLE = {
94
+ torch._ops.OperatorBase(): lambda: None,
95
+ }
96
+
97
+ def call(self, exported_program: torch.export.ExportedProgram):
98
+ exported_program = exported_program.run_decompositions(
99
+ self._DUMMY_DECOMP_TABLE
100
+ )
101
+ return ExportedProgramPassResult(exported_program, True)
@@ -13,55 +13,74 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma2 model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma2
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
24
27
 
25
- def convert_gemma_to_tflite(
28
+ def convert_gemma2_to_tflite(
26
29
  checkpoint_path: str,
27
30
  prefill_seq_len: int = 512,
28
31
  kv_cache_max_len: int = 1024,
29
32
  quantize: bool = True,
30
33
  ):
31
- """Converting a Gemma 2 2B model to multi-signature
32
- tflite model.
34
+ """Converts a Gemma2 2B model to multi-signature tflite model.
33
35
 
34
36
  Args:
35
- checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
37
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
38
+ holding the checkpoint.
36
39
  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
37
40
  Defaults to 512.
38
41
  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
39
42
  including both prefill and decode. Defaults to 1024.
40
- quantize (bool, optional): Whether the model should be quanized.
41
- Defaults to True.
43
+ quantize (bool, optional): Whether the model should be quanized. Defaults
44
+ to True.
42
45
  """
43
46
  pytorch_model = gemma2.build_2b_model(
44
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
45
48
  )
46
49
  # Tensors used to trace the model graph during conversion.
47
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
48
- prefill_input_pos = torch.arange(0, prefill_seq_len)
49
- decode_token = torch.tensor([[0]], dtype=torch.long)
50
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma2-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
86
+ convert_gemma2_to_tflite(path)
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
@@ -44,24 +47,40 @@ def convert_gemma_to_tflite(
44
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
45
48
  )
46
49
  # Tensors used to trace the model graph during conversion.
47
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
48
- prefill_input_pos = torch.arange(0, prefill_seq_len)
49
- decode_token = torch.tensor([[0]], dtype=torch.long)
50
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
86
+ convert_gemma_to_tflite(path)
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building a Gemma model.
15
+
16
+ """Example of building a Gemma model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
25
  import ai_edge_torch.generative.layers.model_config as cfg
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -48,7 +50,6 @@ class Gemma(nn.Module):
48
50
  def __init__(self, config: cfg.ModelConfig):
49
51
  super().__init__()
50
52
 
51
- self.config = config
52
53
  # Construct model layers.
53
54
  self.tok_embedding = nn.Embedding(
54
55
  config.vocab_size, config.embedding_dim, padding_idx=0
@@ -60,18 +61,20 @@ class Gemma(nn.Module):
60
61
  )
61
62
  # Gemma re-uses the embedding as the head projection layer.
62
63
  self.lm_head.weight.data = self.tok_embedding.weight.data
64
+ # Gemma has only one block config.
65
+ block_config = config.block_config(0)
63
66
  self.transformer_blocks = nn.ModuleList(
64
- attention.TransformerBlock(config) for _ in range(config.num_layers)
67
+ attention.TransformerBlock(block_config, config)
68
+ for _ in range(config.num_layers)
65
69
  )
66
70
  self.final_norm = builder.build_norm(
67
71
  config.embedding_dim,
68
72
  config.final_norm_config,
69
73
  )
74
+ attn_config = block_config.attn_config
70
75
  self.rope_cache = attn_utils.build_rope_cache(
71
76
  size=config.kv_cache_max,
72
- dim=int(
73
- config.attn_config.rotary_percentage * config.attn_config.head_dim
74
- ),
77
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
75
78
  base=10_000,
76
79
  condense_ratio=1,
77
80
  dtype=torch.float32,
@@ -84,16 +87,22 @@ class Gemma(nn.Module):
84
87
  )
85
88
  self.config = config
86
89
 
87
- # The model's forward function takes in additional k/v cache tensors
88
- # and returns the updated k/v cache tensors to the caller.
89
- # This can be eliminated if we handle k/v cache updates inside the model itself.
90
90
  @torch.inference_mode
91
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
92
- _, seq_len = idx.size()
91
+ def forward(
92
+ self,
93
+ tokens: torch.Tensor,
94
+ input_pos: torch.Tensor,
95
+ kv_cache: kv_utils.KVCache,
96
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
97
+ _, seq_len = tokens.size()
93
98
  assert self.config.max_seq_len >= seq_len, (
94
99
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
95
100
  f" {self.config.max_seq_len}"
96
101
  )
102
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
103
+ "The number of transformer blocks and the number of KV cache entries"
104
+ " must be the same."
105
+ )
97
106
 
98
107
  cos, sin = self.rope_cache
99
108
  cos = cos.index_select(0, input_pos)
@@ -102,15 +111,20 @@ class Gemma(nn.Module):
102
111
  mask = mask[:, :, :, : self.config.kv_cache_max]
103
112
 
104
113
  # token embeddings of shape (b, t, n_embd)
105
- x = self.tok_embedding(idx)
114
+ x = self.tok_embedding(tokens)
106
115
  x = x * (self.config.embedding_dim**0.5)
107
116
 
108
- for _, block in enumerate(self.transformer_blocks):
109
- x = block(x, (cos, sin), mask, input_pos)
117
+ updated_kv_entires = []
118
+ for i, block in enumerate(self.transformer_blocks):
119
+ kv_entry = kv_cache.caches[i] if kv_cache else None
120
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
121
+ if kv_entry:
122
+ updated_kv_entires.append(kv_entry)
123
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
110
124
 
111
125
  x = self.final_norm(x)
112
- res = self.lm_head(x) # (b, t, vocab_size)
113
- return res
126
+ logits = self.lm_head(x) # (b, t, vocab_size)
127
+ return {"logits": logits, "kv_cache": updated_kv_cache}
114
128
 
115
129
 
116
130
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -139,18 +153,20 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
139
153
  epsilon=1e-6,
140
154
  zero_centered=True,
141
155
  )
156
+ block_config = cfg.TransformerBlockConfig(
157
+ attn_config=attn_config,
158
+ ff_config=ff_config,
159
+ pre_attention_norm_config=norm_config,
160
+ post_attention_norm_config=norm_config,
161
+ )
142
162
  config = cfg.ModelConfig(
143
163
  vocab_size=256000,
144
164
  num_layers=18,
145
165
  max_seq_len=8192,
146
166
  embedding_dim=2048,
147
167
  kv_cache_max_len=kv_cache_max_len,
148
- attn_config=attn_config,
149
- ff_config=ff_config,
150
- pre_attention_norm_config=norm_config,
151
- post_attention_norm_config=norm_config,
168
+ block_configs=block_config,
152
169
  final_norm_config=norm_config,
153
- parallel_residual=False,
154
170
  lm_head_use_bias=False,
155
171
  enable_hlfb=True,
156
172
  )
@@ -159,7 +175,8 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
159
175
 
160
176
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
161
177
  config = get_model_config_2b(kv_cache_max_len)
162
- config.ff_config.intermediate_size = 128
178
+ # Gemma has only one block config.
179
+ config.block_config(0).ff_config.intermediate_size = 128
163
180
  config.vocab_size = 128
164
181
  config.num_layers = 2
165
182
  config.max_seq_len = 2 * kv_cache_max_len
@@ -170,32 +187,35 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
170
187
  config = get_model_config_2b(**kwargs)
171
188
  model = Gemma(config)
172
189
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
173
- # since embedding and lm-head use the same weight, we need to set strict
190
+ # Since embedding and lm-head use the same weight, we need to set strict
174
191
  # to False.
175
192
  loader.load(model, strict=False)
176
193
  model.eval()
177
194
  return model
178
195
 
179
196
 
180
- def define_and_run_2b() -> None:
197
+ def define_and_run_2b(checkpoint_path: str) -> None:
181
198
  """Instantiates and runs a Gemma 2B model."""
182
199
 
183
- current_dir = Path(__file__).parent.resolve()
200
+ current_dir = pathlib.Path(__file__).parent.resolve()
184
201
  gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
185
202
 
186
203
  kv_cache_max_len = 1024
187
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
188
204
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
189
205
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
190
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
206
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
191
207
  tokens[0, :4] = idx
192
- input_pos = torch.arange(0, kv_cache_max_len)
193
- lm_logits = model.forward(tokens, input_pos)
208
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
209
+ kv = kv_utils.KVCache.from_model_config(model.config)
210
+ output = model.forward(tokens, input_pos, kv)
194
211
  print("comparing with goldens..")
195
212
  assert torch.allclose(
196
- gemma_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
213
+ gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
197
214
  )
198
215
 
199
216
 
200
217
  if __name__ == "__main__":
201
- define_and_run_2b()
218
+ input_checkpoint_path = os.path.join(
219
+ pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
220
+ )
221
+ define_and_run_2b(input_checkpoint_path)