ai-edge-torch-nightly 0.3.0.dev20240913__py3-none-any.whl → 0.3.0.dev20240915__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (50) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
  4. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
  5. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
  6. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
  7. ai_edge_torch/config.py +4 -1
  8. ai_edge_torch/fx_pass_base.py +101 -0
  9. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +4 -4
  10. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -4
  11. ai_edge_torch/generative/examples/gemma/gemma.py +2 -2
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -2
  13. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +86 -0
  14. ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
  15. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +4 -4
  16. ai_edge_torch/generative/examples/phi/phi2.py +2 -2
  17. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  18. ai_edge_torch/generative/examples/{smallm → smollm}/convert_to_tflite.py +12 -12
  19. ai_edge_torch/generative/examples/{smallm/smallm.py → smollm/smollm.py} +24 -15
  20. ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
  21. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
  22. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
  23. ai_edge_torch/generative/examples/t5/t5.py +8 -8
  24. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  25. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -4
  26. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
  27. ai_edge_torch/generative/fx_passes/__init__.py +4 -4
  28. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
  29. ai_edge_torch/generative/layers/attention.py +7 -0
  30. ai_edge_torch/generative/layers/builder.py +33 -11
  31. ai_edge_torch/generative/layers/feed_forward.py +26 -8
  32. ai_edge_torch/generative/layers/kv_cache.py +4 -4
  33. ai_edge_torch/generative/layers/model_config.py +24 -15
  34. ai_edge_torch/generative/quantize/example.py +2 -2
  35. ai_edge_torch/generative/test/test_model_conversion.py +28 -51
  36. ai_edge_torch/generative/test/test_model_conversion_large.py +43 -78
  37. ai_edge_torch/generative/test/test_quantize.py +5 -5
  38. ai_edge_torch/generative/utilities/loader.py +13 -0
  39. ai_edge_torch/odml_torch/export.py +40 -0
  40. ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
  41. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  42. ai_edge_torch/version.py +1 -1
  43. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/METADATA +1 -1
  44. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/RECORD +48 -46
  45. ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
  46. ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
  47. /ai_edge_torch/generative/examples/{smallm → openelm}/__init__.py +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/LICENSE +0 -0
  49. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/WHEEL +0 -0
  50. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/top_level.txt +0 -0
@@ -47,10 +47,10 @@ def convert_phi2_to_tflite(
47
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
48
48
  )
49
49
  # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len)
52
- decode_token = torch.tensor([[0]], dtype=torch.long)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
54
  kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
55
55
 
56
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
@@ -192,9 +192,9 @@ def define_and_run(checkpoint_path: str) -> None:
192
192
  kv_cache_max_len = 1024
193
193
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
194
194
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
195
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
195
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
196
196
  tokens[0, :4] = idx
197
- input_pos = torch.arange(0, kv_cache_max_len)
197
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
198
198
  kv = kv_utils.KVCache.from_model_config(model.config)
199
199
  output = model.forward(tokens, input_pos, kv)
200
200
  print("comparing with goldens..")
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -13,25 +13,25 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of converting SmalLM model to multi-signature tflite model."""
16
+ """Example of converting SmolLM model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
19
  import pathlib
20
20
 
21
21
  import ai_edge_torch
22
- from ai_edge_torch.generative.examples.smallm import smallm
22
+ from ai_edge_torch.generative.examples.smollm import smollm
23
23
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
24
  from ai_edge_torch.generative.quantize import quant_recipes
25
25
  import torch
26
26
 
27
27
 
28
- def convert_smallm_to_tflite(
28
+ def convert_smollm_to_tflite(
29
29
  checkpoint_path: str,
30
30
  prefill_seq_len: int = 512,
31
31
  kv_cache_max_len: int = 1024,
32
32
  quantize: bool = True,
33
33
  ):
34
- """Converts SmalLM model to multi-signature tflite model.
34
+ """Converts SmolLM model to multi-signature tflite model.
35
35
 
36
36
  Args:
37
37
  checkpoint_path (str): The filepath to the model checkpoint, or directory
@@ -43,14 +43,14 @@ def convert_smallm_to_tflite(
43
43
  quantize (bool, optional): Whether the model should be quanized. Defaults
44
44
  to True.
45
45
  """
46
- pytorch_model = smallm.build_model(
46
+ pytorch_model = smollm.build_model(
47
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
48
48
  )
49
49
  # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len)
52
- decode_token = torch.tensor([[0]], dtype=torch.long)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
54
  kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
55
 
56
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
@@ -77,10 +77,10 @@ def convert_smallm_to_tflite(
77
77
  )
78
78
  quant_suffix = 'q8' if quantize else 'f32'
79
79
  edge_model.export(
80
- f'/tmp/smallm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
80
+ f'/tmp/smollm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
81
81
  )
82
82
 
83
83
 
84
84
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smallm')
86
- convert_smallm_to_tflite(path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm')
86
+ convert_smollm_to_tflite(path)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of building a SmalLM model."""
16
+ """Example of building a SmolLM model."""
17
17
 
18
18
  import copy
19
19
  import os
@@ -28,32 +28,32 @@ import torch
28
28
  from torch import nn
29
29
 
30
30
  TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
31
- # SmalLM re-uses the embedding as the head projection layer.
31
+ # SmolLM re-uses the embedding as the head projection layer.
32
32
  TENSOR_NAMES.lm_head = None
33
33
 
34
34
 
35
- class SmalLM(tiny_llama.TinyLlama):
36
- """A SmalLM model built from the Edge Generative API layers.
35
+ class SmolLM(tiny_llama.TinyLlama):
36
+ """A SmolLM model built from the Edge Generative API layers.
37
37
 
38
- SmalLM shares the same architecture as TinyLlama, but with different model
38
+ SmolLM shares the same architecture as TinyLlama, but with different model
39
39
  sizes.
40
40
  """
41
41
 
42
42
  def __init__(self, config: cfg.ModelConfig):
43
43
  super().__init__(config)
44
- # SmalLM re-uses the embedding as the head projection layer.
44
+ # SmolLM re-uses the embedding as the head projection layer.
45
45
  self.lm_head.weight.data = self.tok_embedding.weight.data
46
46
 
47
47
 
48
48
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
- """Returns the model config for a SmalLM 135M model.
49
+ """Returns the model config for a SmolLM 135M model.
50
50
 
51
51
  Args:
52
52
  kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
53
53
  is 1024.
54
54
 
55
55
  Returns:
56
- The model config for a SmalLM model.
56
+ The model config for a SmolLM model.
57
57
  """
58
58
  attn_config = cfg.AttentionConfig(
59
59
  num_heads=9,
@@ -86,9 +86,18 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
86
86
  return config
87
87
 
88
88
 
89
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
90
+ config = get_model_config(**kwargs)
91
+ config.vocab_size = 128
92
+ config.num_layers = 2
93
+ # SmolLM has only one block config.
94
+ config.block_config(0).ff_config.intermediate_size = 64
95
+ return config
96
+
97
+
89
98
  def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
90
99
  config = get_model_config(**kwargs)
91
- model = SmalLM(config)
100
+ model = SmolLM(config)
92
101
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
93
102
  # Since embedding and lm-head use the same weight, we need to set strict
94
103
  # to False.
@@ -98,25 +107,25 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
98
107
 
99
108
 
100
109
  def define_and_run(checkpoint_path: str) -> None:
101
- """Instantiates and runs a SmalLM model."""
110
+ """Instantiates and runs a SmolLM model."""
102
111
 
103
112
  current_dir = pathlib.Path(__file__).parent.resolve()
104
- smallm_goldens = torch.load(current_dir / "smallm_lm_logits.pt")
113
+ smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
105
114
  kv_cache_max_len = 1024
106
115
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
107
116
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
108
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
117
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
109
118
  tokens[0, :4] = idx
110
- input_pos = torch.arange(0, kv_cache_max_len)
119
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
111
120
  kv = kv_utils.KVCache.from_model_config(model.config)
112
121
  output = model.forward(tokens, input_pos, kv)
113
122
  assert torch.allclose(
114
- smallm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
123
+ smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
115
124
  )
116
125
 
117
126
 
118
127
  if __name__ == "__main__":
119
128
  input_checkpoint_path = os.path.join(
120
- pathlib.Path.home(), "Downloads/llm_data/smallm"
129
+ pathlib.Path.home(), "Downloads/llm_data/smollm"
121
130
  )
122
131
  define_and_run(input_checkpoint_path)
@@ -76,7 +76,7 @@ class CLIP(nn.Module):
76
76
 
77
77
  @torch.inference_mode
78
78
  def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
79
- tokens = tokens.type(torch.long)
79
+ tokens = tokens.type(torch.int)
80
80
 
81
81
  state = self.tok_embedding(tokens) + self.tok_embedding_position
82
82
  for layer in self.transformer_blocks:
@@ -94,7 +94,7 @@ def convert_stable_diffusion_to_tflite(
94
94
  n_tokens = 77
95
95
  timestamp = 0
96
96
  len_prompt = 1
97
- prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long)
97
+ prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.int)
98
98
  input_image = torch.full(
99
99
  (1, 3, image_height, image_width), 0, dtype=torch.float32
100
100
  )
@@ -29,24 +29,24 @@ def convert_t5_to_tflite_singlesig(checkpoint_path: str):
29
29
 
30
30
  # encoder
31
31
  seq_len = 512
32
- prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
32
+ prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
33
33
  prompt_e_token = [1, 2, 3, 4, 5, 6]
34
34
  prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
35
- prompt_e_token, dtype=torch.long
35
+ prompt_e_token, dtype=torch.int
36
36
  )
37
- prefill_e_input_pos = torch.arange(0, seq_len)
38
- prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
37
+ prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
38
+ prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
39
39
  prompt_d_token = [1, 2, 3, 4, 5, 6]
40
40
  prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
41
- prompt_d_token, dtype=torch.long
41
+ prompt_d_token, dtype=torch.int
42
42
  )
43
- prefill_d_input_pos = torch.arange(0, seq_len)
43
+ prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)
44
44
 
45
45
  # decoder
46
- decode_token = torch.tensor([[1]], dtype=torch.long)
47
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
48
- decode_d_token = torch.tensor([[1]], dtype=torch.long)
49
- decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
46
+ decode_token = torch.tensor([[1]], dtype=torch.int)
47
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
48
+ decode_d_token = torch.tensor([[1]], dtype=torch.int)
49
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int)
50
50
 
51
51
  # Pad mask for self attention only on "real" tokens.
52
52
  # Pad with `-inf` for any tokens indices that aren't desired.
@@ -81,24 +81,24 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str):
81
81
 
82
82
  # encoder
83
83
  seq_len = 512
84
- prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
84
+ prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
85
85
  prompt_e_token = [1, 2, 3, 4, 5, 6]
86
86
  prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
87
- prompt_e_token, dtype=torch.long
87
+ prompt_e_token, dtype=torch.int
88
88
  )
89
- prefill_e_input_pos = torch.arange(0, seq_len)
90
- prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
89
+ prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
90
+ prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
91
91
  prompt_d_token = [1, 2, 3, 4, 5, 6]
92
92
  prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
93
- prompt_d_token, dtype=torch.long
93
+ prompt_d_token, dtype=torch.int
94
94
  )
95
- prefill_d_input_pos = torch.arange(0, seq_len)
95
+ prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)
96
96
 
97
97
  # decoder
98
- decode_token = torch.tensor([[1]], dtype=torch.long)
99
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
100
- decode_d_token = torch.tensor([[1]], dtype=torch.long)
101
- decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
98
+ decode_token = torch.tensor([[1]], dtype=torch.int)
99
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
100
+ decode_d_token = torch.tensor([[1]], dtype=torch.int)
101
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int)
102
102
 
103
103
  # Pad mask for self attention only on "real" tokens.
104
104
  # Pad with `-inf` for any tokens indices that aren't desired.
@@ -601,12 +601,12 @@ def define_and_run_t5(checkpoint_path: str) -> None:
601
601
  model = build_t5_model(checkpoint_path)
602
602
 
603
603
  idx = get_sample_encoder_input_ids()
604
- tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
604
+ tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
605
605
  tokens[0, :77] = idx
606
- input_pos = torch.arange(0, 512)
606
+ input_pos = torch.arange(0, 512, dtype=torch.int)
607
607
 
608
- decode_d_token = torch.tensor([[0]], dtype=torch.int64)
609
- decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
608
+ decode_d_token = torch.tensor([[0]], dtype=torch.int)
609
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int)
610
610
  pad_mask = torch.zeros([model.config.kv_cache_max], dtype=torch.float32)
611
611
  pad_mask[77:] = float("-inf")
612
612
  lm_logits = model.forward(
@@ -633,12 +633,12 @@ def define_and_run_t5_split(checkpoint_path: str) -> None:
633
633
  )
634
634
  idx = get_sample_encoder_input_ids()
635
635
 
636
- tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
636
+ tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
637
637
  tokens[0, :77] = idx
638
- input_pos = torch.arange(0, 512)
638
+ input_pos = torch.arange(0, 512, dtype=torch.int)
639
639
 
640
- decode_d_token = torch.tensor([[0]], dtype=torch.int64)
641
- decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
640
+ decode_d_token = torch.tensor([[0]], dtype=torch.int)
641
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int)
642
642
  pad_mask = torch.zeros(
643
643
  [t5_encoder_model.config.kv_cache_max], dtype=torch.float32
644
644
  )
@@ -124,13 +124,13 @@ def get_model_config() -> cfg.ModelConfig:
124
124
 
125
125
 
126
126
  def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
127
- tokens = torch.unsqueeze(torch.arange(0, 100), 0)
128
- input_pos = torch.arange(0, 100)
127
+ tokens = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
128
+ input_pos = torch.arange(0, 100, dtype=torch.int)
129
129
  return tokens, input_pos
130
130
 
131
131
 
132
132
  def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
133
- tokens = torch.tensor([[1]], dtype=torch.long)
133
+ tokens = torch.tensor([[1]], dtype=torch.int)
134
134
  input_pos = torch.tensor([10])
135
135
  return tokens, input_pos
136
136
 
@@ -47,10 +47,10 @@ def convert_tiny_llama_to_tflite(
47
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
48
48
  )
49
49
  # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len)
52
- decode_token = torch.tensor([[0]], dtype=torch.long)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
54
  kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
55
 
56
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
@@ -189,9 +189,9 @@ def define_and_run(checkpoint_path: str) -> None:
189
189
  kv_cache_max_len = 1024
190
190
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
191
191
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
192
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
192
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
193
193
  tokens[0, :4] = idx
194
- input_pos = torch.arange(0, kv_cache_max_len)
194
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
195
195
  kv = kv_utils.KVCache.from_model_config(model.config)
196
196
  output = model.forward(tokens, input_pos, kv)
197
197
  assert torch.allclose(
@@ -12,16 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ai_edge_torch._convert.fx_passes import CanonicalizePass
16
- from ai_edge_torch._convert.fx_passes import run_passes
17
- from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass # NOQA
15
+ from ai_edge_torch import fx_pass_base
16
+ from ai_edge_torch.fx_pass_base import CanonicalizePass
17
+ from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass
18
18
  import torch
19
19
 
20
20
 
21
21
  def run_generative_passes(
22
22
  exported_program: torch.export.ExportedProgram,
23
23
  ) -> torch.export.ExportedProgram:
24
- return run_passes(
24
+ return fx_pass_base.run_passes(
25
25
  exported_program,
26
26
  [
27
27
  RemoveSDPACompositeZeroMaskPass(),
@@ -12,13 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from ai_edge_torch import fx_pass_base
15
16
  from ai_edge_torch import lowertools
16
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
17
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult
18
17
  import torch
19
18
 
20
19
 
21
- class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
20
+ class RemoveSDPACompositeZeroMaskPass(fx_pass_base.ExportedProgramPassBase):
22
21
 
23
22
  def is_zero_tensor_node(self, node: torch.fx.Node):
24
23
  return node.target == torch.ops.aten.zeros.default
@@ -48,4 +47,4 @@ class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
48
47
 
49
48
  exported_program.graph_module.graph.lint()
50
49
  exported_program.graph_module.recompile()
51
- return ExportedProgramPassResult(exported_program, True)
50
+ return fx_pass_base.ExportedProgramPassResult(exported_program, True)
@@ -160,6 +160,10 @@ class CausalSelfAttention(nn.Module):
160
160
  self.output_projection = nn.Linear(
161
161
  output_shape, dim, bias=config.output_proj_use_bias
162
162
  )
163
+ self.query_norm = builder.build_norm(
164
+ config.head_dim, config.query_norm_config
165
+ )
166
+ self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
163
167
  self.config = config
164
168
  self.enable_hlfb = enable_hlfb
165
169
  self.sdpa_func = (
@@ -224,6 +228,9 @@ class CausalSelfAttention(nn.Module):
224
228
  dim=-1,
225
229
  )
226
230
 
231
+ q = self.query_norm(q)
232
+ k = self.key_norm(k)
233
+
227
234
  q = q.reshape(B, T, -1, self.config.head_dim)
228
235
  k = k.reshape(B, T, -1, self.config.head_dim)
229
236
  v = v.reshape(B, T, -1, self.config.head_dim)
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  # Builder class for individual components.
16
+ from typing import Callable
17
+
16
18
  import ai_edge_torch.generative.layers.feed_forward as feed_forward
17
19
  import ai_edge_torch.generative.layers.model_config as cfg
18
20
  import ai_edge_torch.generative.layers.normalization as normalization
@@ -21,20 +23,34 @@ from torch import nn
21
23
  import torch.nn.functional as F
22
24
 
23
25
 
24
- class GeGLU(nn.Module):
25
- """GeGLU is an activation function which is a variant of GELU.
26
+ def build_glu(
27
+ act: Callable[[torch.Tensor], torch.Tensor], gate_is_front: bool = False
28
+ ) -> Callable[[torch.Tensor], torch.Tensor]:
29
+ """Builds an activation function with GLU (Gated Linear Unit).
30
+
31
+ If gate_is_front is True,
32
+ f(x) = act(x) * y
33
+ otherwise,
34
+ f(x) = x * act(y),
35
+ where x is the first half of the input and y is the second half of the input.
26
36
 
27
- GeGLU(x) = (xW+b) * GELU(xV+c)
28
- See: https://arxiv.org/abs/2002.05202v1
37
+ Args:
38
+ act (Callable[[torch.Tensor], torch.Tensor]): activation function to apply
39
+ to the gate.
40
+ gate_is_front: whether the gate is in front half of the input. Other part is
41
+ the output in GLU.
42
+
43
+ Returns:
44
+ A callable activation function with GLU.
29
45
  """
30
46
 
31
- def __init__(self, d_in: int, d_out: int):
32
- super().__init__()
33
- self.proj = nn.Linear(d_in, d_out * 2)
47
+ def _glu(x):
48
+ x, y = x.chunk(2, dim=-1)
49
+ if gate_is_front:
50
+ return act(x) * y
51
+ return x * act(y)
34
52
 
35
- def forward(self, x: torch.Tensor):
36
- x, gate = self.proj(x).chunk(2, dim=-1)
37
- return x * F.gelu(gate)
53
+ return _glu
38
54
 
39
55
 
40
56
  def build_norm(dim: int, config: cfg.NormalizationConfig):
@@ -99,6 +115,10 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
99
115
  hidden_dim=config.intermediate_size,
100
116
  activation=activation,
101
117
  use_bias=config.use_bias,
118
+ use_glu=(
119
+ config.activation.type == cfg.ActivationType.GE_GLU
120
+ or config.activation.type == cfg.ActivationType.SILU_GLU
121
+ ),
102
122
  pre_ff_norm=pre_ff_norm,
103
123
  post_ff_norm=post_ff_norm,
104
124
  )
@@ -129,8 +149,10 @@ def get_activation(config: cfg.ActivationConfig):
129
149
  # See: https://github.com/hendrycks/GELUs
130
150
  return lambda x: x * F.sigmoid(1.702 * x)
131
151
  elif config.type == cfg.ActivationType.GE_GLU:
132
- return GeGLU(config.dim_in, config.dim_out)
152
+ return build_glu(F.gelu, config.gate_is_front)
133
153
  elif config.type == cfg.ActivationType.RELU:
134
154
  return F.relu
155
+ elif config.type == cfg.ActivationType.SILU_GLU:
156
+ return build_glu(F.silu, config.gate_is_front)
135
157
  else:
136
158
  raise ValueError("Unsupported activation type.")
@@ -30,18 +30,27 @@ class SequentialFeedForward(nn.Module):
30
30
  hidden_dim: int,
31
31
  activation: Callable[[torch.Tensor], torch.Tensor],
32
32
  use_bias=False,
33
+ use_glu=False,
33
34
  pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
34
35
  post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
35
36
  ):
36
37
  """Init function for feedforward layer.
37
38
 
38
- Args: dim(int): embedding size. hidden_dim(int): hidden dim size of the
39
- feedforward layer. activation(Callable): activation function used in this
40
- block. use_bias(Boolean): whether to use bias. Default is false.
39
+ Args:
40
+ dim (int): embedding size.
41
+ hidden_dim (int): hidden dim size of the feedforward layer.
42
+ activation (Callable): activation function used in this block.
43
+ use_bias (Boolean): whether to use bias. Default is false.
44
+ use_glu (Boolean): whether to use glu in activation. Default is false.
45
+ pre_ff_norm (Callable): pre feedforward norm. Default is None.
46
+ post_ff_norm (Callable): post feedforward norm. Default is None.
41
47
  """
42
48
  super().__init__()
43
49
  self.act = activation
44
- self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
50
+ if use_glu:
51
+ self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
52
+ else:
53
+ self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
45
54
  self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
46
55
  self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
47
56
  self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
@@ -72,18 +81,27 @@ class GatedFeedForward(nn.Module):
72
81
  hidden_dim: int,
73
82
  activation: Callable[[torch.Tensor], torch.Tensor],
74
83
  use_bias=False,
84
+ use_glu=False,
75
85
  pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
76
86
  post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
77
87
  ):
78
88
  """Init function for feedforward layer.
79
89
 
80
- Args: dim(int): embedding size. hidden_dim(int): hidden dim size of the
81
- feedforward layer. activation(Callable): activation function used in this
82
- block. use_bias(Boolean): whether to use bias. Default is false.
90
+ Args:
91
+ dim (int): embedding size.
92
+ hidden_dim (int): hidden dim size of the feedforward layer.
93
+ activation (Callable): activation function used in this block.
94
+ use_bias (Boolean): whether to use bias. Default is false.
95
+ use_glu (Boolean): whether to use glu in activation. Default is false.
96
+ pre_ff_norm (Callable): pre feedforward norm. Default is None.
97
+ post_ff_norm (Callable): post feedforward norm. Default is None.
83
98
  """
84
99
  super().__init__()
85
100
  self.act = activation
86
- self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
101
+ if use_glu:
102
+ self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
103
+ else:
104
+ self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
87
105
  self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
88
106
  self.w3 = nn.Linear(dim, hidden_dim, bias=use_bias)
89
107
  self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
@@ -172,8 +172,8 @@ def _update_kv_base_impl(
172
172
  v_slice: torch.Tensor,
173
173
  ) -> KVCacheEntry:
174
174
  """Update the cache buffer without High Level Function Boundary annotation."""
175
- k = cache.k_cache.index_copy(1, input_pos, k_slice)
176
- v = cache.v_cache.index_copy(1, input_pos, v_slice)
175
+ k = cache.k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
176
+ v = cache.v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
177
177
  updated_cache = KVCacheEntry(k, v)
178
178
  return updated_cache
179
179
 
@@ -189,7 +189,7 @@ def _update_kv_hlfb_impl(
189
189
  k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
190
190
  cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
191
191
  )
192
- k = k_cache.index_copy(1, input_pos, k_slice)
193
- v = v_cache.index_copy(1, input_pos, v_slice)
192
+ k = k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
193
+ v = v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
194
194
  k, v = builder.mark_outputs(k, v)
195
195
  return KVCacheEntry(k, v)