ai-edge-torch-nightly 0.3.0.dev20240913__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/gemma.py +2 -2
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -2
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +86 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/phi/phi2.py +2 -2
- ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
- ai_edge_torch/generative/examples/{smallm → smollm}/convert_to_tflite.py +12 -12
- ai_edge_torch/generative/examples/{smallm/smallm.py → smollm/smollm.py} +24 -15
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +8 -8
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +7 -0
- ai_edge_torch/generative/layers/builder.py +33 -11
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +4 -4
- ai_edge_torch/generative/layers/model_config.py +24 -15
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/test_model_conversion.py +28 -51
- ai_edge_torch/generative/test/test_model_conversion_large.py +43 -78
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/utilities/loader.py +13 -0
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +48 -46
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- /ai_edge_torch/generative/examples/{smallm → openelm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
| @@ -47,10 +47,10 @@ def convert_phi2_to_tflite( | |
| 47 47 | 
             
                  checkpoint_path, kv_cache_max_len=kv_cache_max_len
         | 
| 48 48 | 
             
              )
         | 
| 49 49 | 
             
              # Tensors used to trace the model graph during conversion.
         | 
| 50 | 
            -
              prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch. | 
| 51 | 
            -
              prefill_input_pos = torch.arange(0, prefill_seq_len)
         | 
| 52 | 
            -
              decode_token = torch.tensor([[0]], dtype=torch. | 
| 53 | 
            -
              decode_input_pos = torch.tensor([0], dtype=torch. | 
| 50 | 
            +
              prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
         | 
| 51 | 
            +
              prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
         | 
| 52 | 
            +
              decode_token = torch.tensor([[0]], dtype=torch.int)
         | 
| 53 | 
            +
              decode_input_pos = torch.tensor([0], dtype=torch.int)
         | 
| 54 54 | 
             
              kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
         | 
| 55 55 |  | 
| 56 56 | 
             
              quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
         | 
| @@ -192,9 +192,9 @@ def define_and_run(checkpoint_path: str) -> None: | |
| 192 192 | 
             
              kv_cache_max_len = 1024
         | 
| 193 193 | 
             
              model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
         | 
| 194 194 | 
             
              idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
         | 
| 195 | 
            -
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch. | 
| 195 | 
            +
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
         | 
| 196 196 | 
             
              tokens[0, :4] = idx
         | 
| 197 | 
            -
              input_pos = torch.arange(0, kv_cache_max_len)
         | 
| 197 | 
            +
              input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
         | 
| 198 198 | 
             
              kv = kv_utils.KVCache.from_model_config(model.config)
         | 
| 199 199 | 
             
              output = model.forward(tokens, input_pos, kv)
         | 
| 200 200 | 
             
              print("comparing with goldens..")
         | 
| @@ -0,0 +1,14 @@ | |
| 1 | 
            +
            # Copyright 2024 The AI Edge Torch Authors.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            # ==============================================================================
         | 
| @@ -13,25 +13,25 @@ | |
| 13 13 | 
             
            # limitations under the License.
         | 
| 14 14 | 
             
            # ==============================================================================
         | 
| 15 15 |  | 
| 16 | 
            -
            """Example of converting  | 
| 16 | 
            +
            """Example of converting SmolLM model to multi-signature tflite model."""
         | 
| 17 17 |  | 
| 18 18 | 
             
            import os
         | 
| 19 19 | 
             
            import pathlib
         | 
| 20 20 |  | 
| 21 21 | 
             
            import ai_edge_torch
         | 
| 22 | 
            -
            from ai_edge_torch.generative.examples. | 
| 22 | 
            +
            from ai_edge_torch.generative.examples.smollm import smollm
         | 
| 23 23 | 
             
            from ai_edge_torch.generative.layers import kv_cache as kv_utils
         | 
| 24 24 | 
             
            from ai_edge_torch.generative.quantize import quant_recipes
         | 
| 25 25 | 
             
            import torch
         | 
| 26 26 |  | 
| 27 27 |  | 
| 28 | 
            -
            def  | 
| 28 | 
            +
            def convert_smollm_to_tflite(
         | 
| 29 29 | 
             
                checkpoint_path: str,
         | 
| 30 30 | 
             
                prefill_seq_len: int = 512,
         | 
| 31 31 | 
             
                kv_cache_max_len: int = 1024,
         | 
| 32 32 | 
             
                quantize: bool = True,
         | 
| 33 33 | 
             
            ):
         | 
| 34 | 
            -
              """Converts  | 
| 34 | 
            +
              """Converts SmolLM model to multi-signature tflite model.
         | 
| 35 35 |  | 
| 36 36 | 
             
              Args:
         | 
| 37 37 | 
             
                  checkpoint_path (str): The filepath to the model checkpoint, or directory
         | 
| @@ -43,14 +43,14 @@ def convert_smallm_to_tflite( | |
| 43 43 | 
             
                  quantize (bool, optional): Whether the model should be quanized. Defaults
         | 
| 44 44 | 
             
                    to True.
         | 
| 45 45 | 
             
              """
         | 
| 46 | 
            -
              pytorch_model =  | 
| 46 | 
            +
              pytorch_model = smollm.build_model(
         | 
| 47 47 | 
             
                  checkpoint_path, kv_cache_max_len=kv_cache_max_len
         | 
| 48 48 | 
             
              )
         | 
| 49 49 | 
             
              # Tensors used to trace the model graph during conversion.
         | 
| 50 | 
            -
              prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch. | 
| 51 | 
            -
              prefill_input_pos = torch.arange(0, prefill_seq_len)
         | 
| 52 | 
            -
              decode_token = torch.tensor([[0]], dtype=torch. | 
| 53 | 
            -
              decode_input_pos = torch.tensor([0], dtype=torch. | 
| 50 | 
            +
              prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
         | 
| 51 | 
            +
              prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
         | 
| 52 | 
            +
              decode_token = torch.tensor([[0]], dtype=torch.int)
         | 
| 53 | 
            +
              decode_input_pos = torch.tensor([0], dtype=torch.int)
         | 
| 54 54 | 
             
              kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
         | 
| 55 55 |  | 
| 56 56 | 
             
              quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
         | 
| @@ -77,10 +77,10 @@ def convert_smallm_to_tflite( | |
| 77 77 | 
             
              )
         | 
| 78 78 | 
             
              quant_suffix = 'q8' if quantize else 'f32'
         | 
| 79 79 | 
             
              edge_model.export(
         | 
| 80 | 
            -
                  f'/tmp/ | 
| 80 | 
            +
                  f'/tmp/smollm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
         | 
| 81 81 | 
             
              )
         | 
| 82 82 |  | 
| 83 83 |  | 
| 84 84 | 
             
            if __name__ == '__main__':
         | 
| 85 | 
            -
              path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/ | 
| 86 | 
            -
               | 
| 85 | 
            +
              path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm')
         | 
| 86 | 
            +
              convert_smollm_to_tflite(path)
         | 
| @@ -13,7 +13,7 @@ | |
| 13 13 | 
             
            # limitations under the License.
         | 
| 14 14 | 
             
            # ==============================================================================
         | 
| 15 15 |  | 
| 16 | 
            -
            """Example of building a  | 
| 16 | 
            +
            """Example of building a SmolLM model."""
         | 
| 17 17 |  | 
| 18 18 | 
             
            import copy
         | 
| 19 19 | 
             
            import os
         | 
| @@ -28,32 +28,32 @@ import torch | |
| 28 28 | 
             
            from torch import nn
         | 
| 29 29 |  | 
| 30 30 | 
             
            TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
         | 
| 31 | 
            -
            #  | 
| 31 | 
            +
            # SmolLM re-uses the embedding as the head projection layer.
         | 
| 32 32 | 
             
            TENSOR_NAMES.lm_head = None
         | 
| 33 33 |  | 
| 34 34 |  | 
| 35 | 
            -
            class  | 
| 36 | 
            -
              """A  | 
| 35 | 
            +
            class SmolLM(tiny_llama.TinyLlama):
         | 
| 36 | 
            +
              """A SmolLM model built from the Edge Generative API layers.
         | 
| 37 37 |  | 
| 38 | 
            -
               | 
| 38 | 
            +
              SmolLM shares the same architecture as TinyLlama, but with different model
         | 
| 39 39 | 
             
              sizes.
         | 
| 40 40 | 
             
              """
         | 
| 41 41 |  | 
| 42 42 | 
             
              def __init__(self, config: cfg.ModelConfig):
         | 
| 43 43 | 
             
                super().__init__(config)
         | 
| 44 | 
            -
                #  | 
| 44 | 
            +
                # SmolLM re-uses the embedding as the head projection layer.
         | 
| 45 45 | 
             
                self.lm_head.weight.data = self.tok_embedding.weight.data
         | 
| 46 46 |  | 
| 47 47 |  | 
| 48 48 | 
             
            def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
         | 
| 49 | 
            -
              """Returns the model config for a  | 
| 49 | 
            +
              """Returns the model config for a SmolLM 135M model.
         | 
| 50 50 |  | 
| 51 51 | 
             
              Args:
         | 
| 52 52 | 
             
                kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
         | 
| 53 53 | 
             
                  is 1024.
         | 
| 54 54 |  | 
| 55 55 | 
             
              Returns:
         | 
| 56 | 
            -
                The model config for a  | 
| 56 | 
            +
                The model config for a SmolLM model.
         | 
| 57 57 | 
             
              """
         | 
| 58 58 | 
             
              attn_config = cfg.AttentionConfig(
         | 
| 59 59 | 
             
                  num_heads=9,
         | 
| @@ -86,9 +86,18 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: | |
| 86 86 | 
             
              return config
         | 
| 87 87 |  | 
| 88 88 |  | 
| 89 | 
            +
            def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
         | 
| 90 | 
            +
              config = get_model_config(**kwargs)
         | 
| 91 | 
            +
              config.vocab_size = 128
         | 
| 92 | 
            +
              config.num_layers = 2
         | 
| 93 | 
            +
              # SmolLM has only one block config.
         | 
| 94 | 
            +
              config.block_config(0).ff_config.intermediate_size = 64
         | 
| 95 | 
            +
              return config
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 89 98 | 
             
            def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
         | 
| 90 99 | 
             
              config = get_model_config(**kwargs)
         | 
| 91 | 
            -
              model =  | 
| 100 | 
            +
              model = SmolLM(config)
         | 
| 92 101 | 
             
              loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
         | 
| 93 102 | 
             
              # Since embedding and lm-head use the same weight, we need to set strict
         | 
| 94 103 | 
             
              # to False.
         | 
| @@ -98,25 +107,25 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module: | |
| 98 107 |  | 
| 99 108 |  | 
| 100 109 | 
             
            def define_and_run(checkpoint_path: str) -> None:
         | 
| 101 | 
            -
              """Instantiates and runs a  | 
| 110 | 
            +
              """Instantiates and runs a SmolLM model."""
         | 
| 102 111 |  | 
| 103 112 | 
             
              current_dir = pathlib.Path(__file__).parent.resolve()
         | 
| 104 | 
            -
               | 
| 113 | 
            +
              smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
         | 
| 105 114 | 
             
              kv_cache_max_len = 1024
         | 
| 106 115 | 
             
              model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
         | 
| 107 116 | 
             
              idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
         | 
| 108 | 
            -
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch. | 
| 117 | 
            +
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
         | 
| 109 118 | 
             
              tokens[0, :4] = idx
         | 
| 110 | 
            -
              input_pos = torch.arange(0, kv_cache_max_len)
         | 
| 119 | 
            +
              input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
         | 
| 111 120 | 
             
              kv = kv_utils.KVCache.from_model_config(model.config)
         | 
| 112 121 | 
             
              output = model.forward(tokens, input_pos, kv)
         | 
| 113 122 | 
             
              assert torch.allclose(
         | 
| 114 | 
            -
                   | 
| 123 | 
            +
                  smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
         | 
| 115 124 | 
             
              )
         | 
| 116 125 |  | 
| 117 126 |  | 
| 118 127 | 
             
            if __name__ == "__main__":
         | 
| 119 128 | 
             
              input_checkpoint_path = os.path.join(
         | 
| 120 | 
            -
                  pathlib.Path.home(), "Downloads/llm_data/ | 
| 129 | 
            +
                  pathlib.Path.home(), "Downloads/llm_data/smollm"
         | 
| 121 130 | 
             
              )
         | 
| 122 131 | 
             
              define_and_run(input_checkpoint_path)
         | 
| @@ -76,7 +76,7 @@ class CLIP(nn.Module): | |
| 76 76 |  | 
| 77 77 | 
             
              @torch.inference_mode
         | 
| 78 78 | 
             
              def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
         | 
| 79 | 
            -
                tokens = tokens.type(torch. | 
| 79 | 
            +
                tokens = tokens.type(torch.int)
         | 
| 80 80 |  | 
| 81 81 | 
             
                state = self.tok_embedding(tokens) + self.tok_embedding_position
         | 
| 82 82 | 
             
                for layer in self.transformer_blocks:
         | 
| @@ -94,7 +94,7 @@ def convert_stable_diffusion_to_tflite( | |
| 94 94 | 
             
              n_tokens = 77
         | 
| 95 95 | 
             
              timestamp = 0
         | 
| 96 96 | 
             
              len_prompt = 1
         | 
| 97 | 
            -
              prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch. | 
| 97 | 
            +
              prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.int)
         | 
| 98 98 | 
             
              input_image = torch.full(
         | 
| 99 99 | 
             
                  (1, 3, image_height, image_width), 0, dtype=torch.float32
         | 
| 100 100 | 
             
              )
         | 
| @@ -29,24 +29,24 @@ def convert_t5_to_tflite_singlesig(checkpoint_path: str): | |
| 29 29 |  | 
| 30 30 | 
             
              # encoder
         | 
| 31 31 | 
             
              seq_len = 512
         | 
| 32 | 
            -
              prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch. | 
| 32 | 
            +
              prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
         | 
| 33 33 | 
             
              prompt_e_token = [1, 2, 3, 4, 5, 6]
         | 
| 34 34 | 
             
              prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
         | 
| 35 | 
            -
                  prompt_e_token, dtype=torch. | 
| 35 | 
            +
                  prompt_e_token, dtype=torch.int
         | 
| 36 36 | 
             
              )
         | 
| 37 | 
            -
              prefill_e_input_pos = torch.arange(0, seq_len)
         | 
| 38 | 
            -
              prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch. | 
| 37 | 
            +
              prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
         | 
| 38 | 
            +
              prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
         | 
| 39 39 | 
             
              prompt_d_token = [1, 2, 3, 4, 5, 6]
         | 
| 40 40 | 
             
              prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
         | 
| 41 | 
            -
                  prompt_d_token, dtype=torch. | 
| 41 | 
            +
                  prompt_d_token, dtype=torch.int
         | 
| 42 42 | 
             
              )
         | 
| 43 | 
            -
              prefill_d_input_pos = torch.arange(0, seq_len)
         | 
| 43 | 
            +
              prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)
         | 
| 44 44 |  | 
| 45 45 | 
             
              # decoder
         | 
| 46 | 
            -
              decode_token = torch.tensor([[1]], dtype=torch. | 
| 47 | 
            -
              decode_input_pos = torch.tensor([0], dtype=torch. | 
| 48 | 
            -
              decode_d_token = torch.tensor([[1]], dtype=torch. | 
| 49 | 
            -
              decode_d_input_pos = torch.tensor([0], dtype=torch. | 
| 46 | 
            +
              decode_token = torch.tensor([[1]], dtype=torch.int)
         | 
| 47 | 
            +
              decode_input_pos = torch.tensor([0], dtype=torch.int)
         | 
| 48 | 
            +
              decode_d_token = torch.tensor([[1]], dtype=torch.int)
         | 
| 49 | 
            +
              decode_d_input_pos = torch.tensor([0], dtype=torch.int)
         | 
| 50 50 |  | 
| 51 51 | 
             
              # Pad mask for self attention only on "real" tokens.
         | 
| 52 52 | 
             
              # Pad with `-inf` for any tokens indices that aren't desired.
         | 
| @@ -81,24 +81,24 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str): | |
| 81 81 |  | 
| 82 82 | 
             
              # encoder
         | 
| 83 83 | 
             
              seq_len = 512
         | 
| 84 | 
            -
              prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch. | 
| 84 | 
            +
              prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
         | 
| 85 85 | 
             
              prompt_e_token = [1, 2, 3, 4, 5, 6]
         | 
| 86 86 | 
             
              prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
         | 
| 87 | 
            -
                  prompt_e_token, dtype=torch. | 
| 87 | 
            +
                  prompt_e_token, dtype=torch.int
         | 
| 88 88 | 
             
              )
         | 
| 89 | 
            -
              prefill_e_input_pos = torch.arange(0, seq_len)
         | 
| 90 | 
            -
              prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch. | 
| 89 | 
            +
              prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
         | 
| 90 | 
            +
              prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
         | 
| 91 91 | 
             
              prompt_d_token = [1, 2, 3, 4, 5, 6]
         | 
| 92 92 | 
             
              prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
         | 
| 93 | 
            -
                  prompt_d_token, dtype=torch. | 
| 93 | 
            +
                  prompt_d_token, dtype=torch.int
         | 
| 94 94 | 
             
              )
         | 
| 95 | 
            -
              prefill_d_input_pos = torch.arange(0, seq_len)
         | 
| 95 | 
            +
              prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)
         | 
| 96 96 |  | 
| 97 97 | 
             
              # decoder
         | 
| 98 | 
            -
              decode_token = torch.tensor([[1]], dtype=torch. | 
| 99 | 
            -
              decode_input_pos = torch.tensor([0], dtype=torch. | 
| 100 | 
            -
              decode_d_token = torch.tensor([[1]], dtype=torch. | 
| 101 | 
            -
              decode_d_input_pos = torch.tensor([0], dtype=torch. | 
| 98 | 
            +
              decode_token = torch.tensor([[1]], dtype=torch.int)
         | 
| 99 | 
            +
              decode_input_pos = torch.tensor([0], dtype=torch.int)
         | 
| 100 | 
            +
              decode_d_token = torch.tensor([[1]], dtype=torch.int)
         | 
| 101 | 
            +
              decode_d_input_pos = torch.tensor([0], dtype=torch.int)
         | 
| 102 102 |  | 
| 103 103 | 
             
              # Pad mask for self attention only on "real" tokens.
         | 
| 104 104 | 
             
              # Pad with `-inf` for any tokens indices that aren't desired.
         | 
| @@ -601,12 +601,12 @@ def define_and_run_t5(checkpoint_path: str) -> None: | |
| 601 601 | 
             
              model = build_t5_model(checkpoint_path)
         | 
| 602 602 |  | 
| 603 603 | 
             
              idx = get_sample_encoder_input_ids()
         | 
| 604 | 
            -
              tokens = torch.full((1, 512), 0, dtype=torch. | 
| 604 | 
            +
              tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
         | 
| 605 605 | 
             
              tokens[0, :77] = idx
         | 
| 606 | 
            -
              input_pos = torch.arange(0, 512)
         | 
| 606 | 
            +
              input_pos = torch.arange(0, 512, dtype=torch.int)
         | 
| 607 607 |  | 
| 608 | 
            -
              decode_d_token = torch.tensor([[0]], dtype=torch. | 
| 609 | 
            -
              decode_d_input_pos = torch.tensor([0], dtype=torch. | 
| 608 | 
            +
              decode_d_token = torch.tensor([[0]], dtype=torch.int)
         | 
| 609 | 
            +
              decode_d_input_pos = torch.tensor([0], dtype=torch.int)
         | 
| 610 610 | 
             
              pad_mask = torch.zeros([model.config.kv_cache_max], dtype=torch.float32)
         | 
| 611 611 | 
             
              pad_mask[77:] = float("-inf")
         | 
| 612 612 | 
             
              lm_logits = model.forward(
         | 
| @@ -633,12 +633,12 @@ def define_and_run_t5_split(checkpoint_path: str) -> None: | |
| 633 633 | 
             
              )
         | 
| 634 634 | 
             
              idx = get_sample_encoder_input_ids()
         | 
| 635 635 |  | 
| 636 | 
            -
              tokens = torch.full((1, 512), 0, dtype=torch. | 
| 636 | 
            +
              tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
         | 
| 637 637 | 
             
              tokens[0, :77] = idx
         | 
| 638 | 
            -
              input_pos = torch.arange(0, 512)
         | 
| 638 | 
            +
              input_pos = torch.arange(0, 512, dtype=torch.int)
         | 
| 639 639 |  | 
| 640 | 
            -
              decode_d_token = torch.tensor([[0]], dtype=torch. | 
| 641 | 
            -
              decode_d_input_pos = torch.tensor([0], dtype=torch. | 
| 640 | 
            +
              decode_d_token = torch.tensor([[0]], dtype=torch.int)
         | 
| 641 | 
            +
              decode_d_input_pos = torch.tensor([0], dtype=torch.int)
         | 
| 642 642 | 
             
              pad_mask = torch.zeros(
         | 
| 643 643 | 
             
                  [t5_encoder_model.config.kv_cache_max], dtype=torch.float32
         | 
| 644 644 | 
             
              )
         | 
| @@ -124,13 +124,13 @@ def get_model_config() -> cfg.ModelConfig: | |
| 124 124 |  | 
| 125 125 |  | 
| 126 126 | 
             
            def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 127 | 
            -
              tokens = torch.unsqueeze(torch.arange(0, 100), 0)
         | 
| 128 | 
            -
              input_pos = torch.arange(0, 100)
         | 
| 127 | 
            +
              tokens = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
         | 
| 128 | 
            +
              input_pos = torch.arange(0, 100, dtype=torch.int)
         | 
| 129 129 | 
             
              return tokens, input_pos
         | 
| 130 130 |  | 
| 131 131 |  | 
| 132 132 | 
             
            def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 133 | 
            -
              tokens = torch.tensor([[1]], dtype=torch. | 
| 133 | 
            +
              tokens = torch.tensor([[1]], dtype=torch.int)
         | 
| 134 134 | 
             
              input_pos = torch.tensor([10])
         | 
| 135 135 | 
             
              return tokens, input_pos
         | 
| 136 136 |  | 
| @@ -47,10 +47,10 @@ def convert_tiny_llama_to_tflite( | |
| 47 47 | 
             
                  checkpoint_path, kv_cache_max_len=kv_cache_max_len
         | 
| 48 48 | 
             
              )
         | 
| 49 49 | 
             
              # Tensors used to trace the model graph during conversion.
         | 
| 50 | 
            -
              prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch. | 
| 51 | 
            -
              prefill_input_pos = torch.arange(0, prefill_seq_len)
         | 
| 52 | 
            -
              decode_token = torch.tensor([[0]], dtype=torch. | 
| 53 | 
            -
              decode_input_pos = torch.tensor([0], dtype=torch. | 
| 50 | 
            +
              prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
         | 
| 51 | 
            +
              prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
         | 
| 52 | 
            +
              decode_token = torch.tensor([[0]], dtype=torch.int)
         | 
| 53 | 
            +
              decode_input_pos = torch.tensor([0], dtype=torch.int)
         | 
| 54 54 | 
             
              kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
         | 
| 55 55 |  | 
| 56 56 | 
             
              quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
         | 
| @@ -189,9 +189,9 @@ def define_and_run(checkpoint_path: str) -> None: | |
| 189 189 | 
             
              kv_cache_max_len = 1024
         | 
| 190 190 | 
             
              model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
         | 
| 191 191 | 
             
              idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
         | 
| 192 | 
            -
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch. | 
| 192 | 
            +
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
         | 
| 193 193 | 
             
              tokens[0, :4] = idx
         | 
| 194 | 
            -
              input_pos = torch.arange(0, kv_cache_max_len)
         | 
| 194 | 
            +
              input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
         | 
| 195 195 | 
             
              kv = kv_utils.KVCache.from_model_config(model.config)
         | 
| 196 196 | 
             
              output = model.forward(tokens, input_pos, kv)
         | 
| 197 197 | 
             
              assert torch.allclose(
         | 
| @@ -12,16 +12,16 @@ | |
| 12 12 | 
             
            # See the License for the specific language governing permissions and
         | 
| 13 13 | 
             
            # limitations under the License.
         | 
| 14 14 | 
             
            # ==============================================================================
         | 
| 15 | 
            -
            from ai_edge_torch | 
| 16 | 
            -
            from ai_edge_torch. | 
| 17 | 
            -
            from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass | 
| 15 | 
            +
            from ai_edge_torch import fx_pass_base
         | 
| 16 | 
            +
            from ai_edge_torch.fx_pass_base import CanonicalizePass
         | 
| 17 | 
            +
            from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass
         | 
| 18 18 | 
             
            import torch
         | 
| 19 19 |  | 
| 20 20 |  | 
| 21 21 | 
             
            def run_generative_passes(
         | 
| 22 22 | 
             
                exported_program: torch.export.ExportedProgram,
         | 
| 23 23 | 
             
            ) -> torch.export.ExportedProgram:
         | 
| 24 | 
            -
              return run_passes(
         | 
| 24 | 
            +
              return fx_pass_base.run_passes(
         | 
| 25 25 | 
             
                  exported_program,
         | 
| 26 26 | 
             
                  [
         | 
| 27 27 | 
             
                      RemoveSDPACompositeZeroMaskPass(),
         | 
| @@ -12,13 +12,12 @@ | |
| 12 12 | 
             
            # See the License for the specific language governing permissions and
         | 
| 13 13 | 
             
            # limitations under the License.
         | 
| 14 14 | 
             
            # ==============================================================================
         | 
| 15 | 
            +
            from ai_edge_torch import fx_pass_base
         | 
| 15 16 | 
             
            from ai_edge_torch import lowertools
         | 
| 16 | 
            -
            from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
         | 
| 17 | 
            -
            from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult
         | 
| 18 17 | 
             
            import torch
         | 
| 19 18 |  | 
| 20 19 |  | 
| 21 | 
            -
            class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
         | 
| 20 | 
            +
            class RemoveSDPACompositeZeroMaskPass(fx_pass_base.ExportedProgramPassBase):
         | 
| 22 21 |  | 
| 23 22 | 
             
              def is_zero_tensor_node(self, node: torch.fx.Node):
         | 
| 24 23 | 
             
                return node.target == torch.ops.aten.zeros.default
         | 
| @@ -48,4 +47,4 @@ class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase): | |
| 48 47 |  | 
| 49 48 | 
             
                exported_program.graph_module.graph.lint()
         | 
| 50 49 | 
             
                exported_program.graph_module.recompile()
         | 
| 51 | 
            -
                return ExportedProgramPassResult(exported_program, True)
         | 
| 50 | 
            +
                return fx_pass_base.ExportedProgramPassResult(exported_program, True)
         | 
| @@ -160,6 +160,10 @@ class CausalSelfAttention(nn.Module): | |
| 160 160 | 
             
                self.output_projection = nn.Linear(
         | 
| 161 161 | 
             
                    output_shape, dim, bias=config.output_proj_use_bias
         | 
| 162 162 | 
             
                )
         | 
| 163 | 
            +
                self.query_norm = builder.build_norm(
         | 
| 164 | 
            +
                    config.head_dim, config.query_norm_config
         | 
| 165 | 
            +
                )
         | 
| 166 | 
            +
                self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
         | 
| 163 167 | 
             
                self.config = config
         | 
| 164 168 | 
             
                self.enable_hlfb = enable_hlfb
         | 
| 165 169 | 
             
                self.sdpa_func = (
         | 
| @@ -224,6 +228,9 @@ class CausalSelfAttention(nn.Module): | |
| 224 228 | 
             
                      dim=-1,
         | 
| 225 229 | 
             
                  )
         | 
| 226 230 |  | 
| 231 | 
            +
                q = self.query_norm(q)
         | 
| 232 | 
            +
                k = self.key_norm(k)
         | 
| 233 | 
            +
             | 
| 227 234 | 
             
                q = q.reshape(B, T, -1, self.config.head_dim)
         | 
| 228 235 | 
             
                k = k.reshape(B, T, -1, self.config.head_dim)
         | 
| 229 236 | 
             
                v = v.reshape(B, T, -1, self.config.head_dim)
         | 
| @@ -13,6 +13,8 @@ | |
| 13 13 | 
             
            # limitations under the License.
         | 
| 14 14 | 
             
            # ==============================================================================
         | 
| 15 15 | 
             
            # Builder class for individual components.
         | 
| 16 | 
            +
            from typing import Callable
         | 
| 17 | 
            +
             | 
| 16 18 | 
             
            import ai_edge_torch.generative.layers.feed_forward as feed_forward
         | 
| 17 19 | 
             
            import ai_edge_torch.generative.layers.model_config as cfg
         | 
| 18 20 | 
             
            import ai_edge_torch.generative.layers.normalization as normalization
         | 
| @@ -21,20 +23,34 @@ from torch import nn | |
| 21 23 | 
             
            import torch.nn.functional as F
         | 
| 22 24 |  | 
| 23 25 |  | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            +
            def build_glu(
         | 
| 27 | 
            +
                act: Callable[[torch.Tensor], torch.Tensor], gate_is_front: bool = False
         | 
| 28 | 
            +
            ) -> Callable[[torch.Tensor], torch.Tensor]:
         | 
| 29 | 
            +
              """Builds an activation function with GLU (Gated Linear Unit).
         | 
| 30 | 
            +
             | 
| 31 | 
            +
              If gate_is_front is True,
         | 
| 32 | 
            +
                f(x) = act(x) * y
         | 
| 33 | 
            +
              otherwise,
         | 
| 34 | 
            +
                f(x) = x * act(y),
         | 
| 35 | 
            +
              where x is the first half of the input and y is the second half of the input.
         | 
| 26 36 |  | 
| 27 | 
            -
               | 
| 28 | 
            -
             | 
| 37 | 
            +
              Args:
         | 
| 38 | 
            +
                act (Callable[[torch.Tensor], torch.Tensor]): activation function to apply
         | 
| 39 | 
            +
                  to the gate.
         | 
| 40 | 
            +
                gate_is_front: whether the gate is in front half of the input. Other part is
         | 
| 41 | 
            +
                  the output in GLU.
         | 
| 42 | 
            +
             | 
| 43 | 
            +
              Returns:
         | 
| 44 | 
            +
                A callable activation function with GLU.
         | 
| 29 45 | 
             
              """
         | 
| 30 46 |  | 
| 31 | 
            -
              def  | 
| 32 | 
            -
                 | 
| 33 | 
            -
                 | 
| 47 | 
            +
              def _glu(x):
         | 
| 48 | 
            +
                x, y = x.chunk(2, dim=-1)
         | 
| 49 | 
            +
                if gate_is_front:
         | 
| 50 | 
            +
                  return act(x) * y
         | 
| 51 | 
            +
                return x * act(y)
         | 
| 34 52 |  | 
| 35 | 
            -
               | 
| 36 | 
            -
                x, gate = self.proj(x).chunk(2, dim=-1)
         | 
| 37 | 
            -
                return x * F.gelu(gate)
         | 
| 53 | 
            +
              return _glu
         | 
| 38 54 |  | 
| 39 55 |  | 
| 40 56 | 
             
            def build_norm(dim: int, config: cfg.NormalizationConfig):
         | 
| @@ -99,6 +115,10 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig): | |
| 99 115 | 
             
                  hidden_dim=config.intermediate_size,
         | 
| 100 116 | 
             
                  activation=activation,
         | 
| 101 117 | 
             
                  use_bias=config.use_bias,
         | 
| 118 | 
            +
                  use_glu=(
         | 
| 119 | 
            +
                      config.activation.type == cfg.ActivationType.GE_GLU
         | 
| 120 | 
            +
                      or config.activation.type == cfg.ActivationType.SILU_GLU
         | 
| 121 | 
            +
                  ),
         | 
| 102 122 | 
             
                  pre_ff_norm=pre_ff_norm,
         | 
| 103 123 | 
             
                  post_ff_norm=post_ff_norm,
         | 
| 104 124 | 
             
              )
         | 
| @@ -129,8 +149,10 @@ def get_activation(config: cfg.ActivationConfig): | |
| 129 149 | 
             
                # See: https://github.com/hendrycks/GELUs
         | 
| 130 150 | 
             
                return lambda x: x * F.sigmoid(1.702 * x)
         | 
| 131 151 | 
             
              elif config.type == cfg.ActivationType.GE_GLU:
         | 
| 132 | 
            -
                return  | 
| 152 | 
            +
                return build_glu(F.gelu, config.gate_is_front)
         | 
| 133 153 | 
             
              elif config.type == cfg.ActivationType.RELU:
         | 
| 134 154 | 
             
                return F.relu
         | 
| 155 | 
            +
              elif config.type == cfg.ActivationType.SILU_GLU:
         | 
| 156 | 
            +
                return build_glu(F.silu, config.gate_is_front)
         | 
| 135 157 | 
             
              else:
         | 
| 136 158 | 
             
                raise ValueError("Unsupported activation type.")
         | 
| @@ -30,18 +30,27 @@ class SequentialFeedForward(nn.Module): | |
| 30 30 | 
             
                  hidden_dim: int,
         | 
| 31 31 | 
             
                  activation: Callable[[torch.Tensor], torch.Tensor],
         | 
| 32 32 | 
             
                  use_bias=False,
         | 
| 33 | 
            +
                  use_glu=False,
         | 
| 33 34 | 
             
                  pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
         | 
| 34 35 | 
             
                  post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
         | 
| 35 36 | 
             
              ):
         | 
| 36 37 | 
             
                """Init function for feedforward layer.
         | 
| 37 38 |  | 
| 38 | 
            -
                Args: | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 39 | 
            +
                Args:
         | 
| 40 | 
            +
                  dim (int): embedding size.
         | 
| 41 | 
            +
                  hidden_dim (int): hidden dim size of the feedforward layer.
         | 
| 42 | 
            +
                  activation (Callable): activation function used in this block.
         | 
| 43 | 
            +
                  use_bias (Boolean): whether to use bias. Default is false.
         | 
| 44 | 
            +
                  use_glu (Boolean): whether to use glu in activation. Default is false.
         | 
| 45 | 
            +
                  pre_ff_norm (Callable): pre feedforward norm. Default is None.
         | 
| 46 | 
            +
                  post_ff_norm (Callable): post feedforward norm. Default is None.
         | 
| 41 47 | 
             
                """
         | 
| 42 48 | 
             
                super().__init__()
         | 
| 43 49 | 
             
                self.act = activation
         | 
| 44 | 
            -
                 | 
| 50 | 
            +
                if use_glu:
         | 
| 51 | 
            +
                  self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
         | 
| 52 | 
            +
                else:
         | 
| 53 | 
            +
                  self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
         | 
| 45 54 | 
             
                self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
         | 
| 46 55 | 
             
                self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
         | 
| 47 56 | 
             
                self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
         | 
| @@ -72,18 +81,27 @@ class GatedFeedForward(nn.Module): | |
| 72 81 | 
             
                  hidden_dim: int,
         | 
| 73 82 | 
             
                  activation: Callable[[torch.Tensor], torch.Tensor],
         | 
| 74 83 | 
             
                  use_bias=False,
         | 
| 84 | 
            +
                  use_glu=False,
         | 
| 75 85 | 
             
                  pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
         | 
| 76 86 | 
             
                  post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
         | 
| 77 87 | 
             
              ):
         | 
| 78 88 | 
             
                """Init function for feedforward layer.
         | 
| 79 89 |  | 
| 80 | 
            -
                Args: | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 90 | 
            +
                Args:
         | 
| 91 | 
            +
                  dim (int): embedding size.
         | 
| 92 | 
            +
                  hidden_dim (int): hidden dim size of the feedforward layer.
         | 
| 93 | 
            +
                  activation (Callable): activation function used in this block.
         | 
| 94 | 
            +
                  use_bias (Boolean): whether to use bias. Default is false.
         | 
| 95 | 
            +
                  use_glu (Boolean): whether to use glu in activation. Default is false.
         | 
| 96 | 
            +
                  pre_ff_norm (Callable): pre feedforward norm. Default is None.
         | 
| 97 | 
            +
                  post_ff_norm (Callable): post feedforward norm. Default is None.
         | 
| 83 98 | 
             
                """
         | 
| 84 99 | 
             
                super().__init__()
         | 
| 85 100 | 
             
                self.act = activation
         | 
| 86 | 
            -
                 | 
| 101 | 
            +
                if use_glu:
         | 
| 102 | 
            +
                  self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
         | 
| 103 | 
            +
                else:
         | 
| 104 | 
            +
                  self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
         | 
| 87 105 | 
             
                self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
         | 
| 88 106 | 
             
                self.w3 = nn.Linear(dim, hidden_dim, bias=use_bias)
         | 
| 89 107 | 
             
                self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
         | 
| @@ -172,8 +172,8 @@ def _update_kv_base_impl( | |
| 172 172 | 
             
                v_slice: torch.Tensor,
         | 
| 173 173 | 
             
            ) -> KVCacheEntry:
         | 
| 174 174 | 
             
              """Update the cache buffer without High Level Function Boundary annotation."""
         | 
| 175 | 
            -
              k = cache.k_cache.index_copy(1, input_pos, k_slice)
         | 
| 176 | 
            -
              v = cache.v_cache.index_copy(1, input_pos, v_slice)
         | 
| 175 | 
            +
              k = cache.k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
         | 
| 176 | 
            +
              v = cache.v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
         | 
| 177 177 | 
             
              updated_cache = KVCacheEntry(k, v)
         | 
| 178 178 | 
             
              return updated_cache
         | 
| 179 179 |  | 
| @@ -189,7 +189,7 @@ def _update_kv_hlfb_impl( | |
| 189 189 | 
             
              k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
         | 
| 190 190 | 
             
                  cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
         | 
| 191 191 | 
             
              )
         | 
| 192 | 
            -
              k = k_cache.index_copy(1, input_pos, k_slice)
         | 
| 193 | 
            -
              v = v_cache.index_copy(1, input_pos, v_slice)
         | 
| 192 | 
            +
              k = k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
         | 
| 193 | 
            +
              v = v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
         | 
| 194 194 | 
             
              k, v = builder.mark_outputs(k, v)
         | 
| 195 195 | 
             
              return KVCacheEntry(k, v)
         |