ai-edge-torch-nightly 0.3.0.dev20240916__py3-none-any.whl → 0.3.0.dev20240919__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +36 -56
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +36 -56
  3. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +36 -56
  4. ai_edge_torch/generative/examples/openelm/openelm.py +0 -29
  5. ai_edge_torch/generative/examples/openelm/verify.py +61 -0
  6. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +36 -56
  7. ai_edge_torch/generative/examples/phi/phi2.py +4 -31
  8. ai_edge_torch/generative/examples/phi/verify.py +53 -0
  9. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +36 -56
  10. ai_edge_torch/generative/examples/smollm/smollm.py +0 -30
  11. ai_edge_torch/generative/examples/smollm/verify.py +59 -0
  12. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -0
  13. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +36 -56
  14. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +0 -29
  15. ai_edge_torch/generative/examples/tiny_llama/verify.py +61 -0
  16. ai_edge_torch/generative/layers/attention.py +8 -4
  17. ai_edge_torch/generative/layers/builder.py +3 -1
  18. ai_edge_torch/generative/layers/model_config.py +3 -0
  19. ai_edge_torch/generative/layers/normalization.py +31 -20
  20. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +19 -9
  21. ai_edge_torch/generative/layers/unet/blocks_2d.py +11 -4
  22. ai_edge_torch/generative/layers/unet/model_config.py +3 -0
  23. ai_edge_torch/generative/test/test_model_conversion_large.py +1 -1
  24. ai_edge_torch/generative/utilities/converter.py +82 -0
  25. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +10 -0
  26. ai_edge_torch/generative/utilities/verifier.py +200 -0
  27. ai_edge_torch/odml_torch/lowerings/_basic.py +14 -4
  28. ai_edge_torch/odml_torch/lowerings/registry.py +1 -0
  29. ai_edge_torch/version.py +1 -1
  30. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/METADATA +1 -1
  31. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/RECORD +34 -28
  32. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/LICENSE +0 -0
  33. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/WHEEL +0 -0
  34. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/top_level.txt +0 -0
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.smollm import smollm
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/smollm_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_smollm_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts SmolLM model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = smollm.build_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/smollm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm')
86
- convert_smollm_to_tflite(path)
66
+ app.run(main)
@@ -16,15 +16,10 @@
16
16
  """Example of building a SmolLM model."""
17
17
 
18
18
  import copy
19
- import os
20
- import pathlib
21
19
 
22
20
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.model_config as cfg
25
22
  import ai_edge_torch.generative.utilities.loader as loading_utils
26
- import numpy as np
27
- import torch
28
23
  from torch import nn
29
24
 
30
25
  TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
@@ -104,28 +99,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
104
99
  loader.load(model, strict=False)
105
100
  model.eval()
106
101
  return model
107
-
108
-
109
- def define_and_run(checkpoint_path: str) -> None:
110
- """Instantiates and runs a SmolLM model."""
111
-
112
- current_dir = pathlib.Path(__file__).parent.resolve()
113
- smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
114
- kv_cache_max_len = 1024
115
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
116
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
117
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
118
- tokens[0, :4] = idx
119
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
120
- kv = kv_utils.KVCache.from_model_config(model.config)
121
- output = model.forward(tokens, input_pos, kv)
122
- assert torch.allclose(
123
- smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
124
- )
125
-
126
-
127
- if __name__ == "__main__":
128
- input_checkpoint_path = os.path.join(
129
- pathlib.Path.home(), "Downloads/llm_data/smollm"
130
- )
131
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,59 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored SmolLM-135M model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.smollm import smollm
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "What is the meaning of life?",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "HuggingFaceTB/SmolLM-135M"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
37
+
38
+ # Locate the cached dir.
39
+ cached_config_file = transformers.utils.cached_file(
40
+ checkpoint, transformers.utils.CONFIG_NAME
41
+ )
42
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
43
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
44
+ reauthored_model = smollm.build_model(reauthored_checkpoint)
45
+
46
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
47
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
48
+
49
+ verifier.verify_reauthored_model(
50
+ original_model=original_model,
51
+ reauthored_model=reauthored_model,
52
+ tokenizer=tokenizer,
53
+ prompts=_PROMPTS.value,
54
+ atol=1e-04,
55
+ )
56
+
57
+
58
+ if __name__ == "__main__":
59
+ app.run(main)
@@ -336,6 +336,8 @@ class Diffusion(nn.Module):
336
336
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
337
337
  query_dim=output_channel,
338
338
  cross_dim=config.transformer_cross_attention_dim,
339
+ hidden_dim=output_channel,
340
+ output_dim=output_channel,
339
341
  attention_batch_size=config.transformer_batch_size,
340
342
  normalization_config=config.transformer_norm_config,
341
343
  attention_config=build_attention_config(
@@ -406,6 +408,8 @@ class Diffusion(nn.Module):
406
408
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
407
409
  query_dim=mid_block_channels,
408
410
  cross_dim=config.transformer_cross_attention_dim,
411
+ hidden_dim=mid_block_channels,
412
+ output_dim=mid_block_channels,
409
413
  attention_batch_size=config.transformer_batch_size,
410
414
  normalization_config=config.transformer_norm_config,
411
415
  attention_config=build_attention_config(
@@ -477,6 +481,8 @@ class Diffusion(nn.Module):
477
481
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
478
482
  query_dim=output_channel,
479
483
  cross_dim=config.transformer_cross_attention_dim,
484
+ hidden_dim=output_channel,
485
+ output_dim=output_channel,
480
486
  attention_batch_size=config.transformer_batch_size,
481
487
  normalization_config=config.transformer_norm_config,
482
488
  attention_config=build_attention_config(
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/tiny_llama_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_tiny_llama_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts TinyLlama model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = tiny_llama.build_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/tiny_llama_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama')
86
- convert_tiny_llama_to_tflite(path)
66
+ app.run(main)
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building a TinyLlama model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -179,28 +175,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
179
175
  loader.load(model)
180
176
  model.eval()
181
177
  return model
182
-
183
-
184
- def define_and_run(checkpoint_path: str) -> None:
185
- """Instantiates and runs a TinyLlama model."""
186
-
187
- current_dir = pathlib.Path(__file__).parent.resolve()
188
- tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
189
- kv_cache_max_len = 1024
190
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
191
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
192
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
193
- tokens[0, :4] = idx
194
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
195
- kv = kv_utils.KVCache.from_model_config(model.config)
196
- output = model.forward(tokens, input_pos, kv)
197
- assert torch.allclose(
198
- tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
199
- )
200
-
201
-
202
- if __name__ == "__main__":
203
- input_checkpoint_path = os.path.join(
204
- pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
205
- )
206
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,61 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored TinyLlama-1.1B model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "Show me the program to add 2 and 3.",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
37
+ checkpoint, trust_remote_code=True
38
+ )
39
+
40
+ # Locate the cached dir.
41
+ cached_config_file = transformers.utils.cached_file(
42
+ checkpoint, transformers.utils.CONFIG_NAME
43
+ )
44
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
45
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
46
+ reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
47
+
48
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
49
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
50
+
51
+ verifier.verify_reauthored_model(
52
+ original_model=original_model,
53
+ reauthored_model=reauthored_model,
54
+ tokenizer=tokenizer,
55
+ prompts=_PROMPTS.value,
56
+ atol=1e-04,
57
+ )
58
+
59
+
60
+ if __name__ == "__main__":
61
+ app.run(main)
@@ -298,6 +298,8 @@ class CrossAttention(nn.Module):
298
298
  batch_size: int,
299
299
  query_dim: int,
300
300
  cross_dim: int,
301
+ hidden_dim: int,
302
+ output_dim: int,
301
303
  config: cfg.AttentionConfig,
302
304
  enable_hlfb: bool,
303
305
  ):
@@ -307,6 +309,8 @@ class CrossAttention(nn.Module):
307
309
  batch_size (int): batch size of the input tensor.
308
310
  query_dim (int): query tensor's dimension.
309
311
  cross_dim (int): cross attention's dimensions, for key and value tensors.
312
+ hidden_dim (int): hidden dimension that q, k, v tensors project to.
313
+ output_dim (int): output tensor's dimension.
310
314
  config (cfg.AttentionConfig): attention specific configurations.
311
315
  enable_hlfb (bool): whether hlfb is enabled or not.
312
316
  """
@@ -314,16 +318,16 @@ class CrossAttention(nn.Module):
314
318
  self.config = config
315
319
  self.n_heads = config.num_heads
316
320
  self.q_projection = nn.Linear(
317
- query_dim, query_dim, bias=config.qkv_use_bias
321
+ query_dim, hidden_dim, bias=config.qkv_use_bias
318
322
  )
319
323
  self.k_projection = nn.Linear(
320
- cross_dim, query_dim, bias=config.qkv_use_bias
324
+ cross_dim, hidden_dim, bias=config.qkv_use_bias
321
325
  )
322
326
  self.v_projection = nn.Linear(
323
- cross_dim, query_dim, bias=config.qkv_use_bias
327
+ cross_dim, hidden_dim, bias=config.qkv_use_bias
324
328
  )
325
329
  self.output_projection = nn.Linear(
326
- query_dim, query_dim, bias=config.output_proj_use_bias
330
+ hidden_dim, output_dim, bias=config.output_proj_use_bias
327
331
  )
328
332
 
329
333
  self.sdpa_func = (
@@ -75,7 +75,9 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
75
75
  zero_centered_gamma=config.zero_centered,
76
76
  )
77
77
  elif config.type == cfg.NormalizationType.LAYER_NORM:
78
- return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
78
+ return normalization.LayerNorm(
79
+ dim, config.epsilon, config.enable_hlfb, config.use_input_shape
80
+ )
79
81
  elif config.type == cfg.NormalizationType.GROUP_NORM:
80
82
  return normalization.GroupNorm(
81
83
  config.group_num, dim, config.epsilon, config.enable_hlfb
@@ -69,6 +69,9 @@ class NormalizationConfig:
69
69
  zero_centered: bool = False
70
70
  # Number of groups used in group normalization.
71
71
  group_num: Optional[float] = None
72
+ # Whether to use the input shape to determine the dimension of normalization
73
+ # when type is LAYER_NORM.
74
+ use_input_shape: bool = True
72
75
 
73
76
 
74
77
  @dataclass
@@ -78,7 +78,7 @@ class GroupNorm(torch.nn.Module):
78
78
  group_num (int): Number of groups to separate the channels into.
79
79
  dim (int): Dimension of the input tensor.
80
80
  eps (float): A small float value to ensure numerical stability (default:
81
- 1e-6).
81
+ 1e-5).
82
82
  enable_hlfb (bool): Whether to convert this normalization into a single
83
83
  op.
84
84
  """
@@ -112,7 +112,13 @@ class GroupNorm(torch.nn.Module):
112
112
 
113
113
  class LayerNorm(torch.nn.Module):
114
114
 
115
- def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False):
115
+ def __init__(
116
+ self,
117
+ dim: int,
118
+ eps: float = 1e-5,
119
+ enable_hlfb: bool = False,
120
+ use_input_shape: bool = True,
121
+ ):
116
122
  """Initialize the LayerNorm layer.
117
123
 
118
124
  Args:
@@ -121,9 +127,12 @@ class LayerNorm(torch.nn.Module):
121
127
  1e-6).
122
128
  enable_hlfb (bool): Whether to convert this normalization into a single
123
129
  op.
130
+ use_input_shape (bool): Whether to use the input shape to determine the
131
+ dimension of normalization (default: True).
124
132
  """
125
133
  super().__init__()
126
134
  self.enable_hlfb = enable_hlfb
135
+ self.use_input_shape = use_input_shape
127
136
  self.eps = eps
128
137
  self.weight = torch.nn.Parameter(torch.ones(dim))
129
138
  self.bias = torch.nn.Parameter(torch.ones(dim))
@@ -139,19 +148,18 @@ class LayerNorm(torch.nn.Module):
139
148
  """
140
149
  if self.enable_hlfb:
141
150
  return layer_norm_with_hlfb(
142
- x,
143
- self.weight,
144
- self.bias,
145
- self.eps,
151
+ x, self.weight, self.bias, self.eps, self.use_input_shape
146
152
  )
153
+
154
+ if self.use_input_shape:
155
+ normalized_shape = x.shape
156
+ weight = self.weight.broadcast_to(x.shape)
157
+ bias = self.bias.broadcast_to(x.shape)
147
158
  else:
148
- return F.layer_norm(
149
- x,
150
- x.shape,
151
- self.weight.broadcast_to(x.shape),
152
- self.bias.broadcast_to(x.shape),
153
- self.eps,
154
- )
159
+ normalized_shape = self.weight.shape
160
+ weight = self.weight
161
+ bias = self.bias
162
+ return F.layer_norm(x, normalized_shape, weight, bias, self.eps)
155
163
 
156
164
 
157
165
  def group_norm_with_hlfb(
@@ -193,6 +201,7 @@ def layer_norm_with_hlfb(
193
201
  w: torch.Tensor,
194
202
  b: torch.Tensor,
195
203
  eps: float,
204
+ use_input_shape: bool,
196
205
  ):
197
206
  """Layer Normalization with high-level function boundary enabled.
198
207
 
@@ -201,18 +210,20 @@ def layer_norm_with_hlfb(
201
210
  w (torch.Tensor): The weight tensor for the normalization.
202
211
  b (torch.Tensor): The bias tensor for the normalization.
203
212
  eps (float): A small float value to ensure numerical stability.
213
+ use_input_shape (bool): Whether to use the input shape to determine the
214
+ dimension of normalization.
204
215
 
205
216
  Returns:
206
217
  The output tensor of Layer Normalization.
207
218
  """
208
219
  builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
209
220
  x, w, b = builder.mark_inputs(x, w, b)
210
- y = F.layer_norm(
211
- x,
212
- x.shape,
213
- weight=w.broadcast_to(x.shape),
214
- bias=b.broadcast_to(x.shape),
215
- eps=eps,
216
- )
221
+ if use_input_shape:
222
+ normalized_shape = x.shape
223
+ w = w.broadcast_to(x.shape)
224
+ b = b.broadcast_to(x.shape)
225
+ else:
226
+ normalized_shape = w.shape
227
+ y = F.layer_norm(x, normalized_shape, w, b, eps=eps)
217
228
  y = builder.mark_outputs(y)
218
229
  return y
@@ -119,15 +119,25 @@ def scaled_dot_product_attention_with_hlfb(
119
119
  # Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
120
120
  k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
121
121
  v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
122
- y = F.scaled_dot_product_attention(
123
- q,
124
- k,
125
- v,
126
- attn_mask=mask,
127
- dropout_p=0.0,
128
- is_causal=mask is None,
129
- scale=scale,
130
- )
122
+ if softcap is None:
123
+ y = F.scaled_dot_product_attention(
124
+ q,
125
+ k,
126
+ v,
127
+ attn_mask=mask,
128
+ dropout_p=0.0,
129
+ is_causal=mask is None,
130
+ scale=scale,
131
+ )
132
+ else:
133
+ q.mul_(scale)
134
+ scores = q @ k.transpose(-1, -2)
135
+ scores = scores / softcap
136
+ scores = torch.tanh(scores)
137
+ scores = scores * softcap
138
+ scores = scores + mask
139
+ out = F.softmax(scores.float(), dim=-1).type_as(q)
140
+ y = torch.matmul(out, v)
131
141
 
132
142
  result = y.transpose(1, 2)
133
143
  result = builder.mark_outputs(result)