ai-edge-torch-nightly 0.3.0.dev20240913__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/gemma.py +2 -2
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -2
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +86 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/phi/phi2.py +2 -2
- ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
- ai_edge_torch/generative/examples/{smallm → smollm}/convert_to_tflite.py +12 -12
- ai_edge_torch/generative/examples/{smallm/smallm.py → smollm/smollm.py} +24 -15
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +8 -8
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +7 -0
- ai_edge_torch/generative/layers/builder.py +33 -11
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +4 -4
- ai_edge_torch/generative/layers/model_config.py +24 -15
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/test_model_conversion.py +28 -51
- ai_edge_torch/generative/test/test_model_conversion_large.py +43 -78
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/utilities/loader.py +13 -0
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +48 -46
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- /ai_edge_torch/generative/examples/{smallm → openelm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
| @@ -17,6 +17,7 @@ import logging | |
| 17 17 | 
             
            import os
         | 
| 18 18 | 
             
            from typing import Any, Optional
         | 
| 19 19 |  | 
| 20 | 
            +
            from ai_edge_torch import fx_pass_base
         | 
| 20 21 | 
             
            from ai_edge_torch import lowertools
         | 
| 21 22 | 
             
            from ai_edge_torch import model
         | 
| 22 23 | 
             
            from ai_edge_torch._convert import fx_passes
         | 
| @@ -34,7 +35,7 @@ def _run_convert_passes( | |
| 34 35 | 
             
              exported_program = generative_fx_passes.run_generative_passes(
         | 
| 35 36 | 
             
                  exported_program
         | 
| 36 37 | 
             
              )
         | 
| 37 | 
            -
              return  | 
| 38 | 
            +
              return fx_pass_base.run_passes(
         | 
| 38 39 | 
             
                  exported_program,
         | 
| 39 40 | 
             
                  [
         | 
| 40 41 | 
             
                      fx_passes.BuildInterpolateCompositePass(),
         | 
| @@ -15,44 +15,8 @@ | |
| 15 15 |  | 
| 16 16 | 
             
            from typing import Sequence, Union
         | 
| 17 17 |  | 
| 18 | 
            -
            from ai_edge_torch._convert.fx_passes. | 
| 19 | 
            -
            from ai_edge_torch._convert.fx_passes. | 
| 20 | 
            -
            from ai_edge_torch._convert.fx_passes. | 
| 21 | 
            -
            from ai_edge_torch._convert.fx_passes. | 
| 22 | 
            -
            from ai_edge_torch. | 
| 23 | 
            -
            from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass  # NOQA
         | 
| 24 | 
            -
            from ai_edge_torch._convert.fx_passes.canonicalize_pass import CanonicalizePass
         | 
| 25 | 
            -
            from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass  # NOQA
         | 
| 26 | 
            -
            from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass  # NOQA
         | 
| 27 | 
            -
            from torch.export import ExportedProgram
         | 
| 28 | 
            -
            from torch.fx.passes.infra.pass_manager import pass_result_wrapper
         | 
| 29 | 
            -
            import torch.utils._pytree as pytree
         | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
            # TODO(cnchan): make a PassManager class.
         | 
| 33 | 
            -
            def run_passes(
         | 
| 34 | 
            -
                exported_program: ExportedProgram,
         | 
| 35 | 
            -
                passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
         | 
| 36 | 
            -
            ) -> ExportedProgram:
         | 
| 37 | 
            -
              passes, _ = pytree.tree_flatten(passes)
         | 
| 38 | 
            -
              for pass_ in passes:
         | 
| 39 | 
            -
                if not isinstance(pass_, ExportedProgramPassBase):
         | 
| 40 | 
            -
                  pass_ = pass_result_wrapper(pass_)
         | 
| 41 | 
            -
                if isinstance(pass_, ExportedProgramPassBase):
         | 
| 42 | 
            -
                  exported_program = pass_(exported_program).exported_program
         | 
| 43 | 
            -
                else:
         | 
| 44 | 
            -
                  gm = exported_program.graph_module
         | 
| 45 | 
            -
                  gm, modified = pass_(gm)
         | 
| 46 | 
            -
                  if modified and gm is not exported_program.graph_module:
         | 
| 47 | 
            -
                    exported_program = ExportedProgram(
         | 
| 48 | 
            -
                        root=gm,
         | 
| 49 | 
            -
                        graph=gm.graph,
         | 
| 50 | 
            -
                        graph_signature=exported_program.graph_signature,
         | 
| 51 | 
            -
                        state_dict=exported_program.state_dict,
         | 
| 52 | 
            -
                        range_constraints=exported_program.range_constraints,
         | 
| 53 | 
            -
                        module_call_graph=exported_program.module_call_graph,
         | 
| 54 | 
            -
                        example_inputs=exported_program.example_inputs,
         | 
| 55 | 
            -
                        verifier=exported_program.verifier,
         | 
| 56 | 
            -
                        constants=exported_program.constants,
         | 
| 57 | 
            -
                    )
         | 
| 58 | 
            -
              return exported_program
         | 
| 18 | 
            +
            from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
         | 
| 19 | 
            +
            from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
         | 
| 20 | 
            +
            from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
         | 
| 21 | 
            +
            from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
         | 
| 22 | 
            +
            from ai_edge_torch.fx_pass_base import CanonicalizePass
         | 
| @@ -13,11 +13,10 @@ | |
| 13 13 | 
             
            # limitations under the License.
         | 
| 14 14 | 
             
            # ==============================================================================
         | 
| 15 15 |  | 
| 16 | 
            -
            from functools import reduce
         | 
| 17 16 | 
             
            from typing import Any, Callable
         | 
| 17 | 
            +
            from ai_edge_torch import fx_pass_base
         | 
| 18 18 | 
             
            from ai_edge_torch import lowertools
         | 
| 19 19 | 
             
            import torch
         | 
| 20 | 
            -
            from torch.fx.passes.infra import pass_base
         | 
| 21 20 | 
             
            import torch.utils._pytree as pytree
         | 
| 22 21 |  | 
| 23 22 | 
             
            _composite_builders: dict[
         | 
| @@ -277,7 +276,7 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node): | |
| 277 276 | 
             
              node.target = embedding
         | 
| 278 277 |  | 
| 279 278 |  | 
| 280 | 
            -
            class BuildAtenCompositePass( | 
| 279 | 
            +
            class BuildAtenCompositePass(fx_pass_base.PassBase):
         | 
| 281 280 |  | 
| 282 281 | 
             
              def call(self, graph_module: torch.fx.GraphModule):
         | 
| 283 282 | 
             
                for node in graph_module.graph.nodes:
         | 
| @@ -286,4 +285,4 @@ class BuildAtenCompositePass(pass_base.PassBase): | |
| 286 285 |  | 
| 287 286 | 
             
                graph_module.graph.lint()
         | 
| 288 287 | 
             
                graph_module.recompile()
         | 
| 289 | 
            -
                return  | 
| 288 | 
            +
                return fx_pass_base.PassResult(graph_module, True)
         | 
| @@ -16,8 +16,7 @@ | |
| 16 16 |  | 
| 17 17 | 
             
            import functools
         | 
| 18 18 |  | 
| 19 | 
            -
            from ai_edge_torch | 
| 20 | 
            -
            from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult  # NOQA
         | 
| 19 | 
            +
            from ai_edge_torch import fx_pass_base
         | 
| 21 20 | 
             
            from ai_edge_torch.hlfb import mark_pattern
         | 
| 22 21 | 
             
            from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
         | 
| 23 22 | 
             
            import torch
         | 
| @@ -103,7 +102,7 @@ def _get_interpolate_nearest2d_pattern(): | |
| 103 102 | 
             
              return pattern
         | 
| 104 103 |  | 
| 105 104 |  | 
| 106 | 
            -
            class BuildInterpolateCompositePass(ExportedProgramPassBase):
         | 
| 105 | 
            +
            class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
         | 
| 107 106 |  | 
| 108 107 | 
             
              def __init__(self):
         | 
| 109 108 | 
             
                super().__init__()
         | 
| @@ -124,4 +123,4 @@ class BuildInterpolateCompositePass(ExportedProgramPassBase): | |
| 124 123 |  | 
| 125 124 | 
             
                graph_module.graph.lint()
         | 
| 126 125 | 
             
                graph_module.recompile()
         | 
| 127 | 
            -
                return ExportedProgramPassResult(exported_program, True)
         | 
| 126 | 
            +
                return fx_pass_base.ExportedProgramPassResult(exported_program, True)
         | 
| @@ -13,10 +13,9 @@ | |
| 13 13 | 
             
            # limitations under the License.
         | 
| 14 14 | 
             
            # ==============================================================================
         | 
| 15 15 |  | 
| 16 | 
            +
            from ai_edge_torch import fx_pass_base
         | 
| 16 17 | 
             
            from ai_edge_torch import lowertools
         | 
| 17 18 | 
             
            import torch
         | 
| 18 | 
            -
            from torch.fx.passes.infra.pass_base import PassBase
         | 
| 19 | 
            -
            from torch.fx.passes.infra.pass_base import PassResult
         | 
| 20 19 | 
             
            import torch.utils._pytree as pytree
         | 
| 21 20 |  | 
| 22 21 |  | 
| @@ -62,7 +61,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule): | |
| 62 61 | 
             
              node.target = debuginfo_writer
         | 
| 63 62 |  | 
| 64 63 |  | 
| 65 | 
            -
            class InjectMlirDebuginfoPass(PassBase):
         | 
| 64 | 
            +
            class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
         | 
| 66 65 |  | 
| 67 66 | 
             
              def call(self, graph_module: torch.fx.GraphModule):
         | 
| 68 67 | 
             
                for node in graph_module.graph.nodes:
         | 
| @@ -70,4 +69,4 @@ class InjectMlirDebuginfoPass(PassBase): | |
| 70 69 |  | 
| 71 70 | 
             
                graph_module.graph.lint()
         | 
| 72 71 | 
             
                graph_module.recompile()
         | 
| 73 | 
            -
                return PassResult(graph_module, True)
         | 
| 72 | 
            +
                return fx_pass_base.PassResult(graph_module, True)
         | 
| @@ -18,8 +18,7 @@ import operator | |
| 18 18 | 
             
            import os
         | 
| 19 19 | 
             
            from typing import Union
         | 
| 20 20 |  | 
| 21 | 
            -
            from ai_edge_torch | 
| 22 | 
            -
            from ai_edge_torch._convert.fx_passes import ExportedProgramPassResult
         | 
| 21 | 
            +
            from ai_edge_torch import fx_pass_base
         | 
| 23 22 | 
             
            from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check  # NOQA
         | 
| 24 23 | 
             
            from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark  # NOQA
         | 
| 25 24 | 
             
            from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners  # NOQA
         | 
| @@ -31,7 +30,7 @@ import torch.ao.quantization.quantize_pt2e | |
| 31 30 | 
             
            TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
         | 
| 32 31 |  | 
| 33 32 |  | 
| 34 | 
            -
            class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
         | 
| 33 | 
            +
            class OptimizeLayoutTransposesPass(fx_pass_base.ExportedProgramPassBase):
         | 
| 35 34 |  | 
| 36 35 | 
             
              def get_source_meta(self, node: torch.fx.Node):
         | 
| 37 36 | 
             
                keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
         | 
| @@ -94,7 +93,7 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase): | |
| 94 93 |  | 
| 95 94 | 
             
                q_args = input_q.args[1:]
         | 
| 96 95 | 
             
                q_kwargs = input_q.kwargs
         | 
| 97 | 
            -
                q_op, dq_op =  | 
| 96 | 
            +
                q_op, dq_op = utils.get_paired_q_dq_ops(input_q.target)
         | 
| 98 97 | 
             
                with graph.inserting_before(target):
         | 
| 99 98 | 
             
                  # Q and DQ inserted here may required updating the `axis` arg when they
         | 
| 100 99 | 
             
                  # are per_channel ops. However, instead of updating here, the nodes would
         | 
| @@ -301,4 +300,4 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase): | |
| 301 300 | 
             
                # Mark const node again for debugging
         | 
| 302 301 | 
             
                self.mark_const_nodes(exported_program)
         | 
| 303 302 |  | 
| 304 | 
            -
                return ExportedProgramPassResult(exported_program, True)
         | 
| 303 | 
            +
                return fx_pass_base.ExportedProgramPassResult(exported_program, True)
         | 
    
        ai_edge_torch/config.py
    CHANGED
    
    
| @@ -0,0 +1,101 @@ | |
| 1 | 
            +
            # Copyright 2024 The AI Edge Torch Authors.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            # ==============================================================================
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import abc
         | 
| 17 | 
            +
            import collections
         | 
| 18 | 
            +
            from typing import Sequence, Union
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            import torch
         | 
| 21 | 
            +
            from torch.fx.passes.infra.pass_base import PassBase
         | 
| 22 | 
            +
            from torch.fx.passes.infra.pass_base import PassResult
         | 
| 23 | 
            +
            from torch.fx.passes.infra.pass_manager import pass_result_wrapper
         | 
| 24 | 
            +
            import torch.utils._pytree as pytree
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            FxPassBase = PassBase
         | 
| 27 | 
            +
            FxPassResult = PassResult
         | 
| 28 | 
            +
            ExportedProgramPassResult = collections.namedtuple(
         | 
| 29 | 
            +
                "ExportedProgramPassResult", ["exported_program", "modified"]
         | 
| 30 | 
            +
            )
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class ExportedProgramPassBase(abc.ABC):
         | 
| 34 | 
            +
             | 
| 35 | 
            +
              def __call__(
         | 
| 36 | 
            +
                  self, exported_program: torch.export.ExportedProgram
         | 
| 37 | 
            +
              ) -> ExportedProgramPassResult:
         | 
| 38 | 
            +
                self.requires(exported_program)
         | 
| 39 | 
            +
                res = self.call(exported_program)
         | 
| 40 | 
            +
                self.ensures(exported_program)
         | 
| 41 | 
            +
                return res
         | 
| 42 | 
            +
             | 
| 43 | 
            +
              @abc.abstractmethod
         | 
| 44 | 
            +
              def call(
         | 
| 45 | 
            +
                  self, exported_program: torch.export.ExportedProgram
         | 
| 46 | 
            +
              ) -> ExportedProgramPassResult:
         | 
| 47 | 
            +
                pass
         | 
| 48 | 
            +
             | 
| 49 | 
            +
              def requires(self, exported_program: torch.export.ExportedProgram) -> None:
         | 
| 50 | 
            +
                pass
         | 
| 51 | 
            +
             | 
| 52 | 
            +
              def ensures(self, exported_program: torch.export.ExportedProgram) -> None:
         | 
| 53 | 
            +
                pass
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            # TODO(cnchan): make a PassManager class.
         | 
| 57 | 
            +
            def run_passes(
         | 
| 58 | 
            +
                exported_program: torch.export.ExportedProgram,
         | 
| 59 | 
            +
                passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
         | 
| 60 | 
            +
            ) -> torch.export.ExportedProgram:
         | 
| 61 | 
            +
              passes, _ = pytree.tree_flatten(passes)
         | 
| 62 | 
            +
              for pass_ in passes:
         | 
| 63 | 
            +
                if not isinstance(pass_, ExportedProgramPassBase):
         | 
| 64 | 
            +
                  pass_ = pass_result_wrapper(pass_)
         | 
| 65 | 
            +
                if isinstance(pass_, ExportedProgramPassBase):
         | 
| 66 | 
            +
                  exported_program = pass_(exported_program).exported_program
         | 
| 67 | 
            +
                else:
         | 
| 68 | 
            +
                  gm = exported_program.graph_module
         | 
| 69 | 
            +
                  gm, modified = pass_(gm)
         | 
| 70 | 
            +
                  if modified and gm is not exported_program.graph_module:
         | 
| 71 | 
            +
                    exported_program = torch.export.ExportedProgram(
         | 
| 72 | 
            +
                        root=gm,
         | 
| 73 | 
            +
                        graph=gm.graph,
         | 
| 74 | 
            +
                        graph_signature=exported_program.graph_signature,
         | 
| 75 | 
            +
                        state_dict=exported_program.state_dict,
         | 
| 76 | 
            +
                        range_constraints=exported_program.range_constraints,
         | 
| 77 | 
            +
                        module_call_graph=exported_program.module_call_graph,
         | 
| 78 | 
            +
                        example_inputs=exported_program.example_inputs,
         | 
| 79 | 
            +
                        verifier=exported_program.verifier,
         | 
| 80 | 
            +
                        constants=exported_program.constants,
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
              return exported_program
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            class CanonicalizePass(ExportedProgramPassBase):
         | 
| 86 | 
            +
             | 
| 87 | 
            +
              # A dummy decomp table for running ExportedProgram.run_decompositions without
         | 
| 88 | 
            +
              # any op decompositions but just aot_export_module. Due to the check in
         | 
| 89 | 
            +
              # run_decompositions, if None or an empty dict is passed as decomp_table,
         | 
| 90 | 
            +
              # it will run the default aten-coreaten decompositions. Therefore a non-empty
         | 
| 91 | 
            +
              # dummy decomp table is needed.
         | 
| 92 | 
            +
              # Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
         | 
| 93 | 
            +
              _DUMMY_DECOMP_TABLE = {
         | 
| 94 | 
            +
                  torch._ops.OperatorBase(): lambda: None,
         | 
| 95 | 
            +
              }
         | 
| 96 | 
            +
             | 
| 97 | 
            +
              def call(self, exported_program: torch.export.ExportedProgram):
         | 
| 98 | 
            +
                exported_program = exported_program.run_decompositions(
         | 
| 99 | 
            +
                    self._DUMMY_DECOMP_TABLE
         | 
| 100 | 
            +
                )
         | 
| 101 | 
            +
                return ExportedProgramPassResult(exported_program, True)
         | 
| @@ -47,10 +47,10 @@ def convert_gemma2_to_tflite( | |
| 47 47 | 
             
                  checkpoint_path, kv_cache_max_len=kv_cache_max_len
         | 
| 48 48 | 
             
              )
         | 
| 49 49 | 
             
              # Tensors used to trace the model graph during conversion.
         | 
| 50 | 
            -
              prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch. | 
| 51 | 
            -
              prefill_input_pos = torch.arange(0, prefill_seq_len)
         | 
| 52 | 
            -
              decode_token = torch.tensor([[0]], dtype=torch. | 
| 53 | 
            -
              decode_input_pos = torch.tensor([0], dtype=torch. | 
| 50 | 
            +
              prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
         | 
| 51 | 
            +
              prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
         | 
| 52 | 
            +
              decode_token = torch.tensor([[0]], dtype=torch.int)
         | 
| 53 | 
            +
              decode_input_pos = torch.tensor([0], dtype=torch.int)
         | 
| 54 54 | 
             
              kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
         | 
| 55 55 |  | 
| 56 56 | 
             
              quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
         | 
| @@ -47,10 +47,10 @@ def convert_gemma_to_tflite( | |
| 47 47 | 
             
                  checkpoint_path, kv_cache_max_len=kv_cache_max_len
         | 
| 48 48 | 
             
              )
         | 
| 49 49 | 
             
              # Tensors used to trace the model graph during conversion.
         | 
| 50 | 
            -
              prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch. | 
| 51 | 
            -
              prefill_input_pos = torch.arange(0, prefill_seq_len)
         | 
| 52 | 
            -
              decode_token = torch.tensor([[0]], dtype=torch. | 
| 53 | 
            -
              decode_input_pos = torch.tensor([0], dtype=torch. | 
| 50 | 
            +
              prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
         | 
| 51 | 
            +
              prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
         | 
| 52 | 
            +
              decode_token = torch.tensor([[0]], dtype=torch.int)
         | 
| 53 | 
            +
              decode_input_pos = torch.tensor([0], dtype=torch.int)
         | 
| 54 54 | 
             
              kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
         | 
| 55 55 |  | 
| 56 56 | 
             
              quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
         | 
| @@ -203,9 +203,9 @@ def define_and_run_2b(checkpoint_path: str) -> None: | |
| 203 203 | 
             
              kv_cache_max_len = 1024
         | 
| 204 204 | 
             
              model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
         | 
| 205 205 | 
             
              idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
         | 
| 206 | 
            -
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch. | 
| 206 | 
            +
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
         | 
| 207 207 | 
             
              tokens[0, :4] = idx
         | 
| 208 | 
            -
              input_pos = torch.arange(0, kv_cache_max_len)
         | 
| 208 | 
            +
              input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
         | 
| 209 209 | 
             
              kv = kv_utils.KVCache.from_model_config(model.config)
         | 
| 210 210 | 
             
              output = model.forward(tokens, input_pos, kv)
         | 
| 211 211 | 
             
              print("comparing with goldens..")
         | 
| @@ -280,9 +280,9 @@ def define_and_run_2b(checkpoint_path: str) -> None: | |
| 280 280 | 
             
              toks = torch.from_numpy(
         | 
| 281 281 | 
             
                  np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
         | 
| 282 282 | 
             
              )
         | 
| 283 | 
            -
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch. | 
| 283 | 
            +
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
         | 
| 284 284 | 
             
              tokens[0, :9] = toks
         | 
| 285 | 
            -
              input_pos = torch.arange(0, kv_cache_max_len)
         | 
| 285 | 
            +
              input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
         | 
| 286 286 | 
             
              kv = kv_utils.KVCache.from_model_config(model.config)
         | 
| 287 287 | 
             
              out = model.forward(tokens, input_pos, kv)
         | 
| 288 288 | 
             
              out_final = out["logits"][0, 8, :]
         | 
| @@ -0,0 +1,86 @@ | |
| 1 | 
            +
            # Copyright 2024 The AI Edge Torch Authors.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            # ==============================================================================
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            """Example of converting OpenELM model to multi-signature tflite model."""
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import os
         | 
| 19 | 
            +
            import pathlib
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            import ai_edge_torch
         | 
| 22 | 
            +
            from ai_edge_torch.generative.examples.openelm import openelm
         | 
| 23 | 
            +
            from ai_edge_torch.generative.layers import kv_cache as kv_utils
         | 
| 24 | 
            +
            from ai_edge_torch.generative.quantize import quant_recipes
         | 
| 25 | 
            +
            import torch
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def convert_openelm_to_tflite(
         | 
| 29 | 
            +
                checkpoint_path: str,
         | 
| 30 | 
            +
                prefill_seq_len: int = 512,
         | 
| 31 | 
            +
                kv_cache_max_len: int = 1024,
         | 
| 32 | 
            +
                quantize: bool = True,
         | 
| 33 | 
            +
            ):
         | 
| 34 | 
            +
              """Converts OpenELM model to multi-signature tflite model.
         | 
| 35 | 
            +
             | 
| 36 | 
            +
              Args:
         | 
| 37 | 
            +
                  checkpoint_path (str): The filepath to the model checkpoint, or directory
         | 
| 38 | 
            +
                    holding the checkpoint.
         | 
| 39 | 
            +
                  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
         | 
| 40 | 
            +
                    Defaults to 512.
         | 
| 41 | 
            +
                  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
         | 
| 42 | 
            +
                    including both prefill and decode. Defaults to 1024.
         | 
| 43 | 
            +
                  quantize (bool, optional): Whether the model should be quanized. Defaults
         | 
| 44 | 
            +
                    to True.
         | 
| 45 | 
            +
              """
         | 
| 46 | 
            +
              pytorch_model = openelm.build_model(
         | 
| 47 | 
            +
                  checkpoint_path, kv_cache_max_len=kv_cache_max_len
         | 
| 48 | 
            +
              )
         | 
| 49 | 
            +
              # Tensors used to trace the model graph during conversion.
         | 
| 50 | 
            +
              prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
         | 
| 51 | 
            +
              prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
         | 
| 52 | 
            +
              decode_token = torch.tensor([[0]], dtype=torch.int)
         | 
| 53 | 
            +
              decode_input_pos = torch.tensor([0], dtype=torch.int)
         | 
| 54 | 
            +
              kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
              quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
         | 
| 57 | 
            +
              edge_model = (
         | 
| 58 | 
            +
                  ai_edge_torch.signature(
         | 
| 59 | 
            +
                      'prefill',
         | 
| 60 | 
            +
                      pytorch_model,
         | 
| 61 | 
            +
                      sample_kwargs={
         | 
| 62 | 
            +
                          'tokens': prefill_tokens,
         | 
| 63 | 
            +
                          'input_pos': prefill_input_pos,
         | 
| 64 | 
            +
                          'kv_cache': kv,
         | 
| 65 | 
            +
                      },
         | 
| 66 | 
            +
                  )
         | 
| 67 | 
            +
                  .signature(
         | 
| 68 | 
            +
                      'decode',
         | 
| 69 | 
            +
                      pytorch_model,
         | 
| 70 | 
            +
                      sample_kwargs={
         | 
| 71 | 
            +
                          'tokens': decode_token,
         | 
| 72 | 
            +
                          'input_pos': decode_input_pos,
         | 
| 73 | 
            +
                          'kv_cache': kv,
         | 
| 74 | 
            +
                      },
         | 
| 75 | 
            +
                  )
         | 
| 76 | 
            +
                  .convert(quant_config=quant_config)
         | 
| 77 | 
            +
              )
         | 
| 78 | 
            +
              quant_suffix = 'q8' if quantize else 'f32'
         | 
| 79 | 
            +
              edge_model.export(
         | 
| 80 | 
            +
                  f'/tmp/openelm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
         | 
| 81 | 
            +
              )
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            if __name__ == '__main__':
         | 
| 85 | 
            +
              path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm')
         | 
| 86 | 
            +
              convert_openelm_to_tflite(path)
         | 
| @@ -0,0 +1,237 @@ | |
| 1 | 
            +
            # Copyright 2024 The AI Edge Torch Authors.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            # ==============================================================================
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            """Example of building an OpenELM model."""
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import os
         | 
| 19 | 
            +
            import pathlib
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from ai_edge_torch.generative.layers import attention
         | 
| 22 | 
            +
            from ai_edge_torch.generative.layers import builder
         | 
| 23 | 
            +
            from ai_edge_torch.generative.layers import kv_cache as kv_utils
         | 
| 24 | 
            +
            import ai_edge_torch.generative.layers.attention_utils as attn_utils
         | 
| 25 | 
            +
            import ai_edge_torch.generative.layers.model_config as cfg
         | 
| 26 | 
            +
            import ai_edge_torch.generative.utilities.loader as loading_utils
         | 
| 27 | 
            +
            import numpy as np
         | 
| 28 | 
            +
            import torch
         | 
| 29 | 
            +
            from torch import nn
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
         | 
| 32 | 
            +
                ff_up_proj="transformer.layers.{}.ffn.proj_1",
         | 
| 33 | 
            +
                ff_down_proj="transformer.layers.{}.ffn.proj_2",
         | 
| 34 | 
            +
                attn_fused_qkv_proj="transformer.layers.{}.attn.qkv_proj",
         | 
| 35 | 
            +
                attn_query_norm="transformer.layers.{}.attn.q_norm",
         | 
| 36 | 
            +
                attn_key_norm="transformer.layers.{}.attn.k_norm",
         | 
| 37 | 
            +
                attn_output_proj="transformer.layers.{}.attn.out_proj",
         | 
| 38 | 
            +
                pre_attn_norm="transformer.layers.{}.attn_norm",
         | 
| 39 | 
            +
                pre_ff_norm="transformer.layers.{}.ffn_norm",
         | 
| 40 | 
            +
                embedding="transformer.token_embeddings",
         | 
| 41 | 
            +
                final_norm="transformer.norm",
         | 
| 42 | 
            +
                lm_head=None,
         | 
| 43 | 
            +
            )
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            class OpenELM(nn.Module):
         | 
| 47 | 
            +
              """An OpenELM model built from the Edge Generative API layers."""
         | 
| 48 | 
            +
             | 
| 49 | 
            +
              def __init__(self, config: cfg.ModelConfig):
         | 
| 50 | 
            +
                super().__init__()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                # Construct model layers.
         | 
| 53 | 
            +
                self.tok_embedding = nn.Embedding(
         | 
| 54 | 
            +
                    config.vocab_size, config.embedding_dim, padding_idx=0
         | 
| 55 | 
            +
                )
         | 
| 56 | 
            +
                self.lm_head = nn.Linear(
         | 
| 57 | 
            +
                    config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
         | 
| 58 | 
            +
                )
         | 
| 59 | 
            +
                # OpenELM re-uses the embedding as the head projection layer.
         | 
| 60 | 
            +
                self.lm_head.weight.data = self.tok_embedding.weight.data
         | 
| 61 | 
            +
                self.transformer_blocks = nn.ModuleList(
         | 
| 62 | 
            +
                    attention.TransformerBlock(config.block_config(idx), config)
         | 
| 63 | 
            +
                    for idx in range(config.num_layers)
         | 
| 64 | 
            +
                )
         | 
| 65 | 
            +
                self.final_norm = builder.build_norm(
         | 
| 66 | 
            +
                    config.embedding_dim,
         | 
| 67 | 
            +
                    config.final_norm_config,
         | 
| 68 | 
            +
                )
         | 
| 69 | 
            +
                # OpenELM has same hyper parameters for rotary_percentage and head_dim for
         | 
| 70 | 
            +
                # each layer block. Use the first block.
         | 
| 71 | 
            +
                attn_config = config.block_config(0).attn_config
         | 
| 72 | 
            +
                self.rope_cache = attn_utils.build_rope_cache(
         | 
| 73 | 
            +
                    size=config.kv_cache_max,
         | 
| 74 | 
            +
                    dim=int(attn_config.rotary_percentage * attn_config.head_dim),
         | 
| 75 | 
            +
                    base=10_000,
         | 
| 76 | 
            +
                    condense_ratio=1,
         | 
| 77 | 
            +
                    dtype=torch.float32,
         | 
| 78 | 
            +
                    device=torch.device("cpu"),
         | 
| 79 | 
            +
                )
         | 
| 80 | 
            +
                self.mask_cache = attn_utils.build_causal_mask_cache(
         | 
| 81 | 
            +
                    size=config.kv_cache_max,
         | 
| 82 | 
            +
                    dtype=torch.float32,
         | 
| 83 | 
            +
                    device=torch.device("cpu"),
         | 
| 84 | 
            +
                )
         | 
| 85 | 
            +
                self.config = config
         | 
| 86 | 
            +
             | 
| 87 | 
            +
              @torch.inference_mode
         | 
| 88 | 
            +
              def forward(
         | 
| 89 | 
            +
                  self,
         | 
| 90 | 
            +
                  tokens: torch.Tensor,
         | 
| 91 | 
            +
                  input_pos: torch.Tensor,
         | 
| 92 | 
            +
                  kv_cache: kv_utils.KVCache,
         | 
| 93 | 
            +
              ) -> dict[torch.Tensor, kv_utils.KVCache]:
         | 
| 94 | 
            +
                _, seq_len = tokens.size()
         | 
| 95 | 
            +
                assert self.config.max_seq_len >= seq_len, (
         | 
| 96 | 
            +
                    f"Cannot forward sequence of length {seq_len}, max seq length is only"
         | 
| 97 | 
            +
                    f" {self.config.max_seq_len}"
         | 
| 98 | 
            +
                )
         | 
| 99 | 
            +
                assert len(self.transformer_blocks) == len(kv_cache.caches), (
         | 
| 100 | 
            +
                    "The number of transformer blocks and the number of KV cache entries"
         | 
| 101 | 
            +
                    " must be the same."
         | 
| 102 | 
            +
                )
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                cos, sin = self.rope_cache
         | 
| 105 | 
            +
                cos = cos.index_select(0, input_pos)
         | 
| 106 | 
            +
                sin = sin.index_select(0, input_pos)
         | 
| 107 | 
            +
                mask = self.mask_cache.index_select(2, input_pos)
         | 
| 108 | 
            +
                mask = mask[:, :, :, : self.config.kv_cache_max]
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                # token embeddings of shape (b, t, n_embd)
         | 
| 111 | 
            +
                x = self.tok_embedding(tokens)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                updated_kv_entires = []
         | 
| 114 | 
            +
                for i, block in enumerate(self.transformer_blocks):
         | 
| 115 | 
            +
                  kv_entry = kv_cache.caches[i] if kv_cache else None
         | 
| 116 | 
            +
                  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
         | 
| 117 | 
            +
                  if kv_entry:
         | 
| 118 | 
            +
                    updated_kv_entires.append(kv_entry)
         | 
| 119 | 
            +
                updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                x = self.final_norm(x)
         | 
| 122 | 
            +
                logits = self.lm_head(x)  # (b, t, vocab_size)
         | 
| 123 | 
            +
                return {"logits": logits, "kv_cache": updated_kv_cache}
         | 
| 124 | 
            +
             | 
| 125 | 
            +
             | 
| 126 | 
            +
            def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
         | 
| 127 | 
            +
              """Returns the model config for an OpenELM model.
         | 
| 128 | 
            +
             | 
| 129 | 
            +
              Args:
         | 
| 130 | 
            +
                kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
         | 
| 131 | 
            +
                  is 1024.
         | 
| 132 | 
            +
             | 
| 133 | 
            +
              Returns:
         | 
| 134 | 
            +
                The model config for an OpenELM model.
         | 
| 135 | 
            +
              """
         | 
| 136 | 
            +
              norm_config = cfg.NormalizationConfig(
         | 
| 137 | 
            +
                  type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
         | 
| 138 | 
            +
              )
         | 
| 139 | 
            +
              num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
         | 
| 140 | 
            +
              num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
         | 
| 141 | 
            +
             | 
| 142 | 
            +
              def make_divisible(v, d):
         | 
| 143 | 
            +
                """Ensures that all layers have a channel number that is divisible by d."""
         | 
| 144 | 
            +
                new_v = int(v + d / 2) // d * d
         | 
| 145 | 
            +
                # Make sure that round down does not go down by more than 10%.
         | 
| 146 | 
            +
                if new_v < 0.9 * v:
         | 
| 147 | 
            +
                  new_v += d
         | 
| 148 | 
            +
                return new_v
         | 
| 149 | 
            +
             | 
| 150 | 
            +
              # The way to get intermediate size is from
         | 
| 151 | 
            +
              # https://huggingface.co/apple/OpenELM-3B/blob/main/modeling_openelm.py
         | 
| 152 | 
            +
              def get_intermediate_size(idx: int) -> int:
         | 
| 153 | 
            +
                return make_divisible((0.5 + 0.1 * idx) * 3072, 256)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
              def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
         | 
| 156 | 
            +
                return cfg.TransformerBlockConfig(
         | 
| 157 | 
            +
                    attn_config=cfg.AttentionConfig(
         | 
| 158 | 
            +
                        num_heads=num_heads[idx],
         | 
| 159 | 
            +
                        head_dim=128,
         | 
| 160 | 
            +
                        num_query_groups=num_query_groups[idx],
         | 
| 161 | 
            +
                        rotary_percentage=1.0,
         | 
| 162 | 
            +
                        qkv_transpose_before_split=True,
         | 
| 163 | 
            +
                        query_norm_config=norm_config,
         | 
| 164 | 
            +
                        key_norm_config=norm_config,
         | 
| 165 | 
            +
                    ),
         | 
| 166 | 
            +
                    ff_config=cfg.FeedForwardConfig(
         | 
| 167 | 
            +
                        type=cfg.FeedForwardType.SEQUENTIAL,
         | 
| 168 | 
            +
                        activation=cfg.ActivationConfig(
         | 
| 169 | 
            +
                            cfg.ActivationType.SILU_GLU, gate_is_front=True
         | 
| 170 | 
            +
                        ),
         | 
| 171 | 
            +
                        intermediate_size=get_intermediate_size(idx),
         | 
| 172 | 
            +
                        pre_ff_norm_config=norm_config,
         | 
| 173 | 
            +
                    ),
         | 
| 174 | 
            +
                    pre_attention_norm_config=norm_config,
         | 
| 175 | 
            +
                )
         | 
| 176 | 
            +
             | 
| 177 | 
            +
              num_layers = 36
         | 
| 178 | 
            +
              config = cfg.ModelConfig(
         | 
| 179 | 
            +
                  vocab_size=32000,
         | 
| 180 | 
            +
                  num_layers=num_layers,
         | 
| 181 | 
            +
                  max_seq_len=2048,
         | 
| 182 | 
            +
                  embedding_dim=3072,
         | 
| 183 | 
            +
                  kv_cache_max_len=kv_cache_max_len,
         | 
| 184 | 
            +
                  block_configs=[get_block_config(i) for i in range(num_layers)],
         | 
| 185 | 
            +
                  final_norm_config=norm_config,
         | 
| 186 | 
            +
              )
         | 
| 187 | 
            +
              return config
         | 
| 188 | 
            +
             | 
| 189 | 
            +
             | 
| 190 | 
            +
            def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
         | 
| 191 | 
            +
              config = get_model_config(kv_cache_max_len)
         | 
| 192 | 
            +
              config.vocab_size = 128
         | 
| 193 | 
            +
              config.num_layers = 2
         | 
| 194 | 
            +
              config.max_seq_len = 2 * kv_cache_max_len
         | 
| 195 | 
            +
              config.embedding_dim = 128
         | 
| 196 | 
            +
              config.block_configs = config.block_configs[: config.num_layers]
         | 
| 197 | 
            +
              for block_config in config.block_configs:
         | 
| 198 | 
            +
                block_config.attn_config.num_heads = 3
         | 
| 199 | 
            +
                block_config.attn_config.head_dim = 64
         | 
| 200 | 
            +
                block_config.ff_config.intermediate_size = 128
         | 
| 201 | 
            +
              return config
         | 
| 202 | 
            +
             | 
| 203 | 
            +
             | 
| 204 | 
            +
            def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
         | 
| 205 | 
            +
              config = get_model_config(**kwargs)
         | 
| 206 | 
            +
              model = OpenELM(config)
         | 
| 207 | 
            +
              loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
         | 
| 208 | 
            +
              # Since embedding and lm-head use the same weight, we need to set strict
         | 
| 209 | 
            +
              # to False.
         | 
| 210 | 
            +
              loader.load(model, strict=False)
         | 
| 211 | 
            +
              model.eval()
         | 
| 212 | 
            +
              return model
         | 
| 213 | 
            +
             | 
| 214 | 
            +
             | 
| 215 | 
            +
            def define_and_run(checkpoint_path: str) -> None:
         | 
| 216 | 
            +
              """Instantiates and runs an OpenELM model."""
         | 
| 217 | 
            +
             | 
| 218 | 
            +
              current_dir = pathlib.Path(__file__).parent.resolve()
         | 
| 219 | 
            +
              openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt")
         | 
| 220 | 
            +
              kv_cache_max_len = 1024
         | 
| 221 | 
            +
              model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
         | 
| 222 | 
            +
              idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
         | 
| 223 | 
            +
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
         | 
| 224 | 
            +
              tokens[0, :4] = idx
         | 
| 225 | 
            +
              input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
         | 
| 226 | 
            +
              kv = kv_utils.KVCache.from_model_config(model.config)
         | 
| 227 | 
            +
              output = model.forward(tokens, input_pos, kv)
         | 
| 228 | 
            +
              assert torch.allclose(
         | 
| 229 | 
            +
                  openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
         | 
| 230 | 
            +
              )
         | 
| 231 | 
            +
             | 
| 232 | 
            +
             | 
| 233 | 
            +
            if __name__ == "__main__":
         | 
| 234 | 
            +
              input_checkpoint_path = os.path.join(
         | 
| 235 | 
            +
                  pathlib.Path.home(), "Downloads/llm_data/openelm"
         | 
| 236 | 
            +
              )
         | 
| 237 | 
            +
              define_and_run(input_checkpoint_path)
         |