ai-edge-torch-nightly 0.3.0.dev20240913__py3-none-any.whl → 0.3.0.dev20240915__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
  4. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
  5. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
  6. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
  7. ai_edge_torch/config.py +4 -1
  8. ai_edge_torch/fx_pass_base.py +101 -0
  9. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +4 -4
  10. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -4
  11. ai_edge_torch/generative/examples/gemma/gemma.py +2 -2
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -2
  13. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +86 -0
  14. ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
  15. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +4 -4
  16. ai_edge_torch/generative/examples/phi/phi2.py +2 -2
  17. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  18. ai_edge_torch/generative/examples/{smallm → smollm}/convert_to_tflite.py +12 -12
  19. ai_edge_torch/generative/examples/{smallm/smallm.py → smollm/smollm.py} +24 -15
  20. ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
  21. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
  22. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
  23. ai_edge_torch/generative/examples/t5/t5.py +8 -8
  24. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  25. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -4
  26. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
  27. ai_edge_torch/generative/fx_passes/__init__.py +4 -4
  28. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
  29. ai_edge_torch/generative/layers/attention.py +7 -0
  30. ai_edge_torch/generative/layers/builder.py +33 -11
  31. ai_edge_torch/generative/layers/feed_forward.py +26 -8
  32. ai_edge_torch/generative/layers/kv_cache.py +4 -4
  33. ai_edge_torch/generative/layers/model_config.py +24 -15
  34. ai_edge_torch/generative/quantize/example.py +2 -2
  35. ai_edge_torch/generative/test/test_model_conversion.py +28 -51
  36. ai_edge_torch/generative/test/test_model_conversion_large.py +43 -78
  37. ai_edge_torch/generative/test/test_quantize.py +5 -5
  38. ai_edge_torch/generative/utilities/loader.py +13 -0
  39. ai_edge_torch/odml_torch/export.py +40 -0
  40. ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
  41. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  42. ai_edge_torch/version.py +1 -1
  43. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/METADATA +1 -1
  44. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/RECORD +48 -46
  45. ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
  46. ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
  47. /ai_edge_torch/generative/examples/{smallm → openelm}/__init__.py +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/LICENSE +0 -0
  49. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/WHEEL +0 -0
  50. {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/top_level.txt +0 -0
@@ -47,10 +47,10 @@ def convert_phi2_to_tflite(
47
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
48
48
  )
49
49
  # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len)
52
- decode_token = torch.tensor([[0]], dtype=torch.long)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
54
  kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
55
55
 
56
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
@@ -192,9 +192,9 @@ def define_and_run(checkpoint_path: str) -> None:
192
192
  kv_cache_max_len = 1024
193
193
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
194
194
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
195
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
195
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
196
196
  tokens[0, :4] = idx
197
- input_pos = torch.arange(0, kv_cache_max_len)
197
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
198
198
  kv = kv_utils.KVCache.from_model_config(model.config)
199
199
  output = model.forward(tokens, input_pos, kv)
200
200
  print("comparing with goldens..")
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -13,25 +13,25 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of converting SmalLM model to multi-signature tflite model."""
16
+ """Example of converting SmolLM model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
19
  import pathlib
20
20
 
21
21
  import ai_edge_torch
22
- from ai_edge_torch.generative.examples.smallm import smallm
22
+ from ai_edge_torch.generative.examples.smollm import smollm
23
23
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
24
  from ai_edge_torch.generative.quantize import quant_recipes
25
25
  import torch
26
26
 
27
27
 
28
- def convert_smallm_to_tflite(
28
+ def convert_smollm_to_tflite(
29
29
  checkpoint_path: str,
30
30
  prefill_seq_len: int = 512,
31
31
  kv_cache_max_len: int = 1024,
32
32
  quantize: bool = True,
33
33
  ):
34
- """Converts SmalLM model to multi-signature tflite model.
34
+ """Converts SmolLM model to multi-signature tflite model.
35
35
 
36
36
  Args:
37
37
  checkpoint_path (str): The filepath to the model checkpoint, or directory
@@ -43,14 +43,14 @@ def convert_smallm_to_tflite(
43
43
  quantize (bool, optional): Whether the model should be quanized. Defaults
44
44
  to True.
45
45
  """
46
- pytorch_model = smallm.build_model(
46
+ pytorch_model = smollm.build_model(
47
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
48
48
  )
49
49
  # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len)
52
- decode_token = torch.tensor([[0]], dtype=torch.long)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
54
  kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
55
 
56
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
@@ -77,10 +77,10 @@ def convert_smallm_to_tflite(
77
77
  )
78
78
  quant_suffix = 'q8' if quantize else 'f32'
79
79
  edge_model.export(
80
- f'/tmp/smallm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
80
+ f'/tmp/smollm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
81
81
  )
82
82
 
83
83
 
84
84
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smallm')
86
- convert_smallm_to_tflite(path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm')
86
+ convert_smollm_to_tflite(path)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of building a SmalLM model."""
16
+ """Example of building a SmolLM model."""
17
17
 
18
18
  import copy
19
19
  import os
@@ -28,32 +28,32 @@ import torch
28
28
  from torch import nn
29
29
 
30
30
  TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
31
- # SmalLM re-uses the embedding as the head projection layer.
31
+ # SmolLM re-uses the embedding as the head projection layer.
32
32
  TENSOR_NAMES.lm_head = None
33
33
 
34
34
 
35
- class SmalLM(tiny_llama.TinyLlama):
36
- """A SmalLM model built from the Edge Generative API layers.
35
+ class SmolLM(tiny_llama.TinyLlama):
36
+ """A SmolLM model built from the Edge Generative API layers.
37
37
 
38
- SmalLM shares the same architecture as TinyLlama, but with different model
38
+ SmolLM shares the same architecture as TinyLlama, but with different model
39
39
  sizes.
40
40
  """
41
41
 
42
42
  def __init__(self, config: cfg.ModelConfig):
43
43
  super().__init__(config)
44
- # SmalLM re-uses the embedding as the head projection layer.
44
+ # SmolLM re-uses the embedding as the head projection layer.
45
45
  self.lm_head.weight.data = self.tok_embedding.weight.data
46
46
 
47
47
 
48
48
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
- """Returns the model config for a SmalLM 135M model.
49
+ """Returns the model config for a SmolLM 135M model.
50
50
 
51
51
  Args:
52
52
  kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
53
53
  is 1024.
54
54
 
55
55
  Returns:
56
- The model config for a SmalLM model.
56
+ The model config for a SmolLM model.
57
57
  """
58
58
  attn_config = cfg.AttentionConfig(
59
59
  num_heads=9,
@@ -86,9 +86,18 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
86
86
  return config
87
87
 
88
88
 
89
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
90
+ config = get_model_config(**kwargs)
91
+ config.vocab_size = 128
92
+ config.num_layers = 2
93
+ # SmolLM has only one block config.
94
+ config.block_config(0).ff_config.intermediate_size = 64
95
+ return config
96
+
97
+
89
98
  def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
90
99
  config = get_model_config(**kwargs)
91
- model = SmalLM(config)
100
+ model = SmolLM(config)
92
101
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
93
102
  # Since embedding and lm-head use the same weight, we need to set strict
94
103
  # to False.
@@ -98,25 +107,25 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
98
107
 
99
108
 
100
109
  def define_and_run(checkpoint_path: str) -> None:
101
- """Instantiates and runs a SmalLM model."""
110
+ """Instantiates and runs a SmolLM model."""
102
111
 
103
112
  current_dir = pathlib.Path(__file__).parent.resolve()
104
- smallm_goldens = torch.load(current_dir / "smallm_lm_logits.pt")
113
+ smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
105
114
  kv_cache_max_len = 1024
106
115
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
107
116
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
108
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
117
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
109
118
  tokens[0, :4] = idx
110
- input_pos = torch.arange(0, kv_cache_max_len)
119
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
111
120
  kv = kv_utils.KVCache.from_model_config(model.config)
112
121
  output = model.forward(tokens, input_pos, kv)
113
122
  assert torch.allclose(
114
- smallm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
123
+ smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
115
124
  )
116
125
 
117
126
 
118
127
  if __name__ == "__main__":
119
128
  input_checkpoint_path = os.path.join(
120
- pathlib.Path.home(), "Downloads/llm_data/smallm"
129
+ pathlib.Path.home(), "Downloads/llm_data/smollm"
121
130
  )
122
131
  define_and_run(input_checkpoint_path)
@@ -76,7 +76,7 @@ class CLIP(nn.Module):
76
76
 
77
77
  @torch.inference_mode
78
78
  def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
79
- tokens = tokens.type(torch.long)
79
+ tokens = tokens.type(torch.int)
80
80
 
81
81
  state = self.tok_embedding(tokens) + self.tok_embedding_position
82
82
  for layer in self.transformer_blocks:
@@ -94,7 +94,7 @@ def convert_stable_diffusion_to_tflite(
94
94
  n_tokens = 77
95
95
  timestamp = 0
96
96
  len_prompt = 1
97
- prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long)
97
+ prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.int)
98
98
  input_image = torch.full(
99
99
  (1, 3, image_height, image_width), 0, dtype=torch.float32
100
100
  )
@@ -29,24 +29,24 @@ def convert_t5_to_tflite_singlesig(checkpoint_path: str):
29
29
 
30
30
  # encoder
31
31
  seq_len = 512
32
- prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
32
+ prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
33
33
  prompt_e_token = [1, 2, 3, 4, 5, 6]
34
34
  prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
35
- prompt_e_token, dtype=torch.long
35
+ prompt_e_token, dtype=torch.int
36
36
  )
37
- prefill_e_input_pos = torch.arange(0, seq_len)
38
- prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
37
+ prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
38
+ prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
39
39
  prompt_d_token = [1, 2, 3, 4, 5, 6]
40
40
  prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
41
- prompt_d_token, dtype=torch.long
41
+ prompt_d_token, dtype=torch.int
42
42
  )
43
- prefill_d_input_pos = torch.arange(0, seq_len)
43
+ prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)
44
44
 
45
45
  # decoder
46
- decode_token = torch.tensor([[1]], dtype=torch.long)
47
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
48
- decode_d_token = torch.tensor([[1]], dtype=torch.long)
49
- decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
46
+ decode_token = torch.tensor([[1]], dtype=torch.int)
47
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
48
+ decode_d_token = torch.tensor([[1]], dtype=torch.int)
49
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int)
50
50
 
51
51
  # Pad mask for self attention only on "real" tokens.
52
52
  # Pad with `-inf` for any tokens indices that aren't desired.
@@ -81,24 +81,24 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str):
81
81
 
82
82
  # encoder
83
83
  seq_len = 512
84
- prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
84
+ prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
85
85
  prompt_e_token = [1, 2, 3, 4, 5, 6]
86
86
  prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
87
- prompt_e_token, dtype=torch.long
87
+ prompt_e_token, dtype=torch.int
88
88
  )
89
- prefill_e_input_pos = torch.arange(0, seq_len)
90
- prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
89
+ prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
90
+ prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
91
91
  prompt_d_token = [1, 2, 3, 4, 5, 6]
92
92
  prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
93
- prompt_d_token, dtype=torch.long
93
+ prompt_d_token, dtype=torch.int
94
94
  )
95
- prefill_d_input_pos = torch.arange(0, seq_len)
95
+ prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)
96
96
 
97
97
  # decoder
98
- decode_token = torch.tensor([[1]], dtype=torch.long)
99
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
100
- decode_d_token = torch.tensor([[1]], dtype=torch.long)
101
- decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
98
+ decode_token = torch.tensor([[1]], dtype=torch.int)
99
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
100
+ decode_d_token = torch.tensor([[1]], dtype=torch.int)
101
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int)
102
102
 
103
103
  # Pad mask for self attention only on "real" tokens.
104
104
  # Pad with `-inf` for any tokens indices that aren't desired.
@@ -601,12 +601,12 @@ def define_and_run_t5(checkpoint_path: str) -> None:
601
601
  model = build_t5_model(checkpoint_path)
602
602
 
603
603
  idx = get_sample_encoder_input_ids()
604
- tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
604
+ tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
605
605
  tokens[0, :77] = idx
606
- input_pos = torch.arange(0, 512)
606
+ input_pos = torch.arange(0, 512, dtype=torch.int)
607
607
 
608
- decode_d_token = torch.tensor([[0]], dtype=torch.int64)
609
- decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
608
+ decode_d_token = torch.tensor([[0]], dtype=torch.int)
609
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int)
610
610
  pad_mask = torch.zeros([model.config.kv_cache_max], dtype=torch.float32)
611
611
  pad_mask[77:] = float("-inf")
612
612
  lm_logits = model.forward(
@@ -633,12 +633,12 @@ def define_and_run_t5_split(checkpoint_path: str) -> None:
633
633
  )
634
634
  idx = get_sample_encoder_input_ids()
635
635
 
636
- tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
636
+ tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
637
637
  tokens[0, :77] = idx
638
- input_pos = torch.arange(0, 512)
638
+ input_pos = torch.arange(0, 512, dtype=torch.int)
639
639
 
640
- decode_d_token = torch.tensor([[0]], dtype=torch.int64)
641
- decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
640
+ decode_d_token = torch.tensor([[0]], dtype=torch.int)
641
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int)
642
642
  pad_mask = torch.zeros(
643
643
  [t5_encoder_model.config.kv_cache_max], dtype=torch.float32
644
644
  )
@@ -124,13 +124,13 @@ def get_model_config() -> cfg.ModelConfig:
124
124
 
125
125
 
126
126
  def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
127
- tokens = torch.unsqueeze(torch.arange(0, 100), 0)
128
- input_pos = torch.arange(0, 100)
127
+ tokens = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
128
+ input_pos = torch.arange(0, 100, dtype=torch.int)
129
129
  return tokens, input_pos
130
130
 
131
131
 
132
132
  def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
133
- tokens = torch.tensor([[1]], dtype=torch.long)
133
+ tokens = torch.tensor([[1]], dtype=torch.int)
134
134
  input_pos = torch.tensor([10])
135
135
  return tokens, input_pos
136
136
 
@@ -47,10 +47,10 @@ def convert_tiny_llama_to_tflite(
47
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
48
48
  )
49
49
  # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len)
52
- decode_token = torch.tensor([[0]], dtype=torch.long)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
54
  kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
55
 
56
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
@@ -189,9 +189,9 @@ def define_and_run(checkpoint_path: str) -> None:
189
189
  kv_cache_max_len = 1024
190
190
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
191
191
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
192
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
192
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
193
193
  tokens[0, :4] = idx
194
- input_pos = torch.arange(0, kv_cache_max_len)
194
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
195
195
  kv = kv_utils.KVCache.from_model_config(model.config)
196
196
  output = model.forward(tokens, input_pos, kv)
197
197
  assert torch.allclose(
@@ -12,16 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ai_edge_torch._convert.fx_passes import CanonicalizePass
16
- from ai_edge_torch._convert.fx_passes import run_passes
17
- from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass # NOQA
15
+ from ai_edge_torch import fx_pass_base
16
+ from ai_edge_torch.fx_pass_base import CanonicalizePass
17
+ from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass
18
18
  import torch
19
19
 
20
20
 
21
21
  def run_generative_passes(
22
22
  exported_program: torch.export.ExportedProgram,
23
23
  ) -> torch.export.ExportedProgram:
24
- return run_passes(
24
+ return fx_pass_base.run_passes(
25
25
  exported_program,
26
26
  [
27
27
  RemoveSDPACompositeZeroMaskPass(),
@@ -12,13 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from ai_edge_torch import fx_pass_base
15
16
  from ai_edge_torch import lowertools
16
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
17
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult
18
17
  import torch
19
18
 
20
19
 
21
- class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
20
+ class RemoveSDPACompositeZeroMaskPass(fx_pass_base.ExportedProgramPassBase):
22
21
 
23
22
  def is_zero_tensor_node(self, node: torch.fx.Node):
24
23
  return node.target == torch.ops.aten.zeros.default
@@ -48,4 +47,4 @@ class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
48
47
 
49
48
  exported_program.graph_module.graph.lint()
50
49
  exported_program.graph_module.recompile()
51
- return ExportedProgramPassResult(exported_program, True)
50
+ return fx_pass_base.ExportedProgramPassResult(exported_program, True)
@@ -160,6 +160,10 @@ class CausalSelfAttention(nn.Module):
160
160
  self.output_projection = nn.Linear(
161
161
  output_shape, dim, bias=config.output_proj_use_bias
162
162
  )
163
+ self.query_norm = builder.build_norm(
164
+ config.head_dim, config.query_norm_config
165
+ )
166
+ self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
163
167
  self.config = config
164
168
  self.enable_hlfb = enable_hlfb
165
169
  self.sdpa_func = (
@@ -224,6 +228,9 @@ class CausalSelfAttention(nn.Module):
224
228
  dim=-1,
225
229
  )
226
230
 
231
+ q = self.query_norm(q)
232
+ k = self.key_norm(k)
233
+
227
234
  q = q.reshape(B, T, -1, self.config.head_dim)
228
235
  k = k.reshape(B, T, -1, self.config.head_dim)
229
236
  v = v.reshape(B, T, -1, self.config.head_dim)
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  # Builder class for individual components.
16
+ from typing import Callable
17
+
16
18
  import ai_edge_torch.generative.layers.feed_forward as feed_forward
17
19
  import ai_edge_torch.generative.layers.model_config as cfg
18
20
  import ai_edge_torch.generative.layers.normalization as normalization
@@ -21,20 +23,34 @@ from torch import nn
21
23
  import torch.nn.functional as F
22
24
 
23
25
 
24
- class GeGLU(nn.Module):
25
- """GeGLU is an activation function which is a variant of GELU.
26
+ def build_glu(
27
+ act: Callable[[torch.Tensor], torch.Tensor], gate_is_front: bool = False
28
+ ) -> Callable[[torch.Tensor], torch.Tensor]:
29
+ """Builds an activation function with GLU (Gated Linear Unit).
30
+
31
+ If gate_is_front is True,
32
+ f(x) = act(x) * y
33
+ otherwise,
34
+ f(x) = x * act(y),
35
+ where x is the first half of the input and y is the second half of the input.
26
36
 
27
- GeGLU(x) = (xW+b) * GELU(xV+c)
28
- See: https://arxiv.org/abs/2002.05202v1
37
+ Args:
38
+ act (Callable[[torch.Tensor], torch.Tensor]): activation function to apply
39
+ to the gate.
40
+ gate_is_front: whether the gate is in front half of the input. Other part is
41
+ the output in GLU.
42
+
43
+ Returns:
44
+ A callable activation function with GLU.
29
45
  """
30
46
 
31
- def __init__(self, d_in: int, d_out: int):
32
- super().__init__()
33
- self.proj = nn.Linear(d_in, d_out * 2)
47
+ def _glu(x):
48
+ x, y = x.chunk(2, dim=-1)
49
+ if gate_is_front:
50
+ return act(x) * y
51
+ return x * act(y)
34
52
 
35
- def forward(self, x: torch.Tensor):
36
- x, gate = self.proj(x).chunk(2, dim=-1)
37
- return x * F.gelu(gate)
53
+ return _glu
38
54
 
39
55
 
40
56
  def build_norm(dim: int, config: cfg.NormalizationConfig):
@@ -99,6 +115,10 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
99
115
  hidden_dim=config.intermediate_size,
100
116
  activation=activation,
101
117
  use_bias=config.use_bias,
118
+ use_glu=(
119
+ config.activation.type == cfg.ActivationType.GE_GLU
120
+ or config.activation.type == cfg.ActivationType.SILU_GLU
121
+ ),
102
122
  pre_ff_norm=pre_ff_norm,
103
123
  post_ff_norm=post_ff_norm,
104
124
  )
@@ -129,8 +149,10 @@ def get_activation(config: cfg.ActivationConfig):
129
149
  # See: https://github.com/hendrycks/GELUs
130
150
  return lambda x: x * F.sigmoid(1.702 * x)
131
151
  elif config.type == cfg.ActivationType.GE_GLU:
132
- return GeGLU(config.dim_in, config.dim_out)
152
+ return build_glu(F.gelu, config.gate_is_front)
133
153
  elif config.type == cfg.ActivationType.RELU:
134
154
  return F.relu
155
+ elif config.type == cfg.ActivationType.SILU_GLU:
156
+ return build_glu(F.silu, config.gate_is_front)
135
157
  else:
136
158
  raise ValueError("Unsupported activation type.")
@@ -30,18 +30,27 @@ class SequentialFeedForward(nn.Module):
30
30
  hidden_dim: int,
31
31
  activation: Callable[[torch.Tensor], torch.Tensor],
32
32
  use_bias=False,
33
+ use_glu=False,
33
34
  pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
34
35
  post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
35
36
  ):
36
37
  """Init function for feedforward layer.
37
38
 
38
- Args: dim(int): embedding size. hidden_dim(int): hidden dim size of the
39
- feedforward layer. activation(Callable): activation function used in this
40
- block. use_bias(Boolean): whether to use bias. Default is false.
39
+ Args:
40
+ dim (int): embedding size.
41
+ hidden_dim (int): hidden dim size of the feedforward layer.
42
+ activation (Callable): activation function used in this block.
43
+ use_bias (Boolean): whether to use bias. Default is false.
44
+ use_glu (Boolean): whether to use glu in activation. Default is false.
45
+ pre_ff_norm (Callable): pre feedforward norm. Default is None.
46
+ post_ff_norm (Callable): post feedforward norm. Default is None.
41
47
  """
42
48
  super().__init__()
43
49
  self.act = activation
44
- self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
50
+ if use_glu:
51
+ self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
52
+ else:
53
+ self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
45
54
  self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
46
55
  self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
47
56
  self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
@@ -72,18 +81,27 @@ class GatedFeedForward(nn.Module):
72
81
  hidden_dim: int,
73
82
  activation: Callable[[torch.Tensor], torch.Tensor],
74
83
  use_bias=False,
84
+ use_glu=False,
75
85
  pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
76
86
  post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
77
87
  ):
78
88
  """Init function for feedforward layer.
79
89
 
80
- Args: dim(int): embedding size. hidden_dim(int): hidden dim size of the
81
- feedforward layer. activation(Callable): activation function used in this
82
- block. use_bias(Boolean): whether to use bias. Default is false.
90
+ Args:
91
+ dim (int): embedding size.
92
+ hidden_dim (int): hidden dim size of the feedforward layer.
93
+ activation (Callable): activation function used in this block.
94
+ use_bias (Boolean): whether to use bias. Default is false.
95
+ use_glu (Boolean): whether to use glu in activation. Default is false.
96
+ pre_ff_norm (Callable): pre feedforward norm. Default is None.
97
+ post_ff_norm (Callable): post feedforward norm. Default is None.
83
98
  """
84
99
  super().__init__()
85
100
  self.act = activation
86
- self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
101
+ if use_glu:
102
+ self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
103
+ else:
104
+ self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
87
105
  self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
88
106
  self.w3 = nn.Linear(dim, hidden_dim, bias=use_bias)
89
107
  self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
@@ -172,8 +172,8 @@ def _update_kv_base_impl(
172
172
  v_slice: torch.Tensor,
173
173
  ) -> KVCacheEntry:
174
174
  """Update the cache buffer without High Level Function Boundary annotation."""
175
- k = cache.k_cache.index_copy(1, input_pos, k_slice)
176
- v = cache.v_cache.index_copy(1, input_pos, v_slice)
175
+ k = cache.k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
176
+ v = cache.v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
177
177
  updated_cache = KVCacheEntry(k, v)
178
178
  return updated_cache
179
179
 
@@ -189,7 +189,7 @@ def _update_kv_hlfb_impl(
189
189
  k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
190
190
  cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
191
191
  )
192
- k = k_cache.index_copy(1, input_pos, k_slice)
193
- v = v_cache.index_copy(1, input_pos, v_slice)
192
+ k = k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
193
+ v = v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
194
194
  k, v = builder.mark_outputs(k, v)
195
195
  return KVCacheEntry(k, v)