ai-edge-torch-nightly 0.3.0.dev20240813__py3-none-any.whl → 0.3.0.dev20240817__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (32) hide show
  1. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +2 -2
  2. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +2 -2
  3. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +67 -0
  4. ai_edge_torch/generative/examples/gemma/gemma.py +3 -2
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +250 -0
  6. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -2
  7. ai_edge_torch/generative/examples/t5/t5.py +4 -4
  8. ai_edge_torch/generative/examples/t5/t5_attention.py +3 -3
  9. ai_edge_torch/generative/examples/test_models/toy_model.py +1 -1
  10. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +1 -1
  11. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +1 -1
  12. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
  13. ai_edge_torch/generative/layers/attention.py +12 -5
  14. ai_edge_torch/generative/layers/attention_utils.py +30 -0
  15. ai_edge_torch/generative/layers/builder.py +5 -0
  16. ai_edge_torch/generative/layers/feed_forward.py +15 -3
  17. ai_edge_torch/generative/layers/model_config.py +35 -13
  18. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +25 -9
  19. ai_edge_torch/generative/test/test_model_conversion.py +29 -1
  20. ai_edge_torch/generative/utilities/loader.py +29 -7
  21. ai_edge_torch/generative/utilities/t5_loader.py +8 -8
  22. ai_edge_torch/hlfb/test/test_mark_pattern.py +32 -8
  23. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +5 -0
  24. ai_edge_torch/lowertools/__init__.py +1 -0
  25. ai_edge_torch/lowertools/odml_torch_utils.py +3 -0
  26. ai_edge_torch/lowertools/test_utils.py +60 -0
  27. ai_edge_torch/version.py +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20240813.dist-info → ai_edge_torch_nightly-0.3.0.dev20240817.dist-info}/METADATA +1 -1
  29. {ai_edge_torch_nightly-0.3.0.dev20240813.dist-info → ai_edge_torch_nightly-0.3.0.dev20240817.dist-info}/RECORD +32 -29
  30. {ai_edge_torch_nightly-0.3.0.dev20240813.dist-info → ai_edge_torch_nightly-0.3.0.dev20240817.dist-info}/LICENSE +0 -0
  31. {ai_edge_torch_nightly-0.3.0.dev20240813.dist-info → ai_edge_torch_nightly-0.3.0.dev20240817.dist-info}/WHEEL +0 -0
  32. {ai_edge_torch_nightly-0.3.0.dev20240813.dist-info → ai_edge_torch_nightly-0.3.0.dev20240817.dist-info}/top_level.txt +0 -0
@@ -40,7 +40,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
40
40
  attn_value_proj="model.layers.{}.self_attn.v_proj",
41
41
  attn_output_proj="model.layers.{}.self_attn.o_proj",
42
42
  pre_attn_norm="model.layers.{}.input_layernorm",
43
- pre_ff_norm="model.layers.{}.post_attention_layernorm",
43
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
44
44
  embedding="model.embed_tokens",
45
45
  final_norm="model.norm",
46
46
  lm_head=None,
@@ -150,7 +150,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
150
150
  attn_config=attn_config,
151
151
  ff_config=ff_config,
152
152
  pre_attention_norm_config=norm_config,
153
- pre_ff_norm_config=norm_config,
153
+ post_attention_norm_config=norm_config,
154
154
  final_norm_config=norm_config,
155
155
  parallel_residual=False,
156
156
  lm_head_use_bias=False,
@@ -41,7 +41,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
41
41
  attn_value_proj="model.layers.{}.self_attn.v_proj",
42
42
  attn_output_proj="model.layers.{}.self_attn.o_proj",
43
43
  pre_attn_norm="model.layers.{}.input_layernorm",
44
- pre_ff_norm="model.layers.{}.post_attention_layernorm",
44
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
45
45
  embedding="model.embed_tokens",
46
46
  final_norm="model.norm",
47
47
  lm_head="lm_head",
@@ -142,7 +142,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
142
142
  attn_config=attn_config,
143
143
  ff_config=ff_config,
144
144
  pre_attention_norm_config=norm_config,
145
- pre_ff_norm_config=norm_config,
145
+ post_attention_norm_config=norm_config,
146
146
  final_norm_config=norm_config,
147
147
  enable_hlfb=True,
148
148
  )
@@ -0,0 +1,67 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import os
17
+ from pathlib import Path
18
+
19
+ import ai_edge_torch
20
+ from ai_edge_torch.generative.examples.gemma import gemma2
21
+ from ai_edge_torch.generative.quantize import quant_recipes
22
+ import torch
23
+
24
+
25
+ def convert_gemma_to_tflite(
26
+ checkpoint_path: str,
27
+ prefill_seq_len: int = 512,
28
+ kv_cache_max_len: int = 1024,
29
+ quantize: bool = True,
30
+ ):
31
+ """Converting a Gemma 2 2B model to multi-signature
32
+ tflite model.
33
+
34
+ Args:
35
+ checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
36
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
37
+ Defaults to 512.
38
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
39
+ including both prefill and decode. Defaults to 1024.
40
+ quantize (bool, optional): Whether the model should be quanized.
41
+ Defaults to True.
42
+ """
43
+ pytorch_model = gemma2.build_2b_model(
44
+ checkpoint_path, kv_cache_max_len=kv_cache_max_len
45
+ )
46
+ # Tensors used to trace the model graph during conversion.
47
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
48
+ prefill_input_pos = torch.arange(0, prefill_seq_len)
49
+ decode_token = torch.tensor([[0]], dtype=torch.long)
50
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
51
+
52
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
+ edge_model = (
54
+ ai_edge_torch.signature(
55
+ 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
56
+ )
57
+ .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
+ .convert(quant_config=quant_config)
59
+ )
60
+ edge_model.export(
61
+ f'/tmp/gemma2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
62
+ )
63
+
64
+
65
+ if __name__ == '__main__':
66
+ checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma2-2b')
67
+ convert_gemma_to_tflite(checkpoint_path)
@@ -35,7 +35,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
35
35
  attn_value_proj="model.layers.{}.self_attn.v_proj",
36
36
  attn_output_proj="model.layers.{}.self_attn.o_proj",
37
37
  pre_attn_norm="model.layers.{}.input_layernorm",
38
- pre_ff_norm="model.layers.{}.post_attention_layernorm",
38
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
39
39
  embedding="model.embed_tokens",
40
40
  final_norm="model.norm",
41
41
  lm_head=None,
@@ -138,7 +138,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
138
138
  attn_config=attn_config,
139
139
  ff_config=ff_config,
140
140
  pre_attention_norm_config=norm_config,
141
- pre_ff_norm_config=norm_config,
141
+ post_attention_norm_config=norm_config,
142
142
  final_norm_config=norm_config,
143
143
  parallel_residual=False,
144
144
  lm_head_use_bias=False,
@@ -160,6 +160,7 @@ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
160
160
  # since embedding and lm-head use the same weight, we need to set strict
161
161
  # to False.
162
162
  loader.load(model, strict=False)
163
+ model.eval()
163
164
  return model
164
165
 
165
166
 
@@ -0,0 +1,250 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Example of building the Gemma2 2B model.
16
+
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Optional, Tuple
20
+
21
+ from ai_edge_torch.generative.layers.attention import TransformerBlock
22
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
+ import ai_edge_torch.generative.layers.builder as builder
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
+ ff_up_proj="model.layers.{}.mlp.up_proj",
32
+ ff_down_proj="model.layers.{}.mlp.down_proj",
33
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
34
+ attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
35
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
36
+ pre_attn_norm="model.layers.{}.input_layernorm",
37
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
38
+ pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
39
+ post_ff_norm="model.layers.{}.post_feedforward_layernorm",
40
+ embedding="embedder",
41
+ final_norm="model.norm",
42
+ lm_head=None,
43
+ )
44
+
45
+
46
+ class Gemma2Block(TransformerBlock):
47
+
48
+ def forward(
49
+ self,
50
+ x: torch.Tensor,
51
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
52
+ mask: Optional[torch.Tensor] = None,
53
+ input_pos: Optional[torch.Tensor] = None,
54
+ ) -> torch.Tensor:
55
+ """Forward function of the Gemma2Block.
56
+
57
+ Exactly the same as TransformerBlock but we call the post-attention norm
58
+ immediately after attention and not after the residual pointwise addition.
59
+
60
+ Args:
61
+ x (torch.Tensor): the input tensor.
62
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
63
+ mask (torch.Tensor): the optional mask tensor.
64
+ input_pos (torch.Tensor): the optional input position tensor.
65
+
66
+ Returns:
67
+ output activation from this transformer block.
68
+ """
69
+
70
+ x_norm = self.pre_atten_norm(x)
71
+ attn_out = self.atten_func(x_norm, rope, mask, input_pos)
72
+ attn_out_norm = self.post_atten_norm(attn_out)
73
+ x = x + attn_out_norm
74
+ output = x + self.ff(x)
75
+ return output
76
+
77
+
78
+ class Gemma2(nn.Module):
79
+
80
+ def __init__(self, config: cfg.ModelConfig):
81
+ super().__init__()
82
+
83
+ self.config = config
84
+ # Construct model layers.
85
+ self.tok_embedding = nn.Embedding(
86
+ config.vocab_size, config.embedding_dim, padding_idx=0
87
+ )
88
+ self.lm_head = nn.Linear(
89
+ config.embedding_dim,
90
+ config.vocab_size,
91
+ bias=config.lm_head_use_bias,
92
+ )
93
+ # Gemma re-uses the embedding as the head projection layer.
94
+ self.lm_head.weight.data = self.tok_embedding.weight.data
95
+ self.transformer_blocks = nn.ModuleList(
96
+ Gemma2Block(config) for _ in range(config.num_layers)
97
+ )
98
+ self.final_norm = builder.build_norm(
99
+ config.embedding_dim,
100
+ config.final_norm_config,
101
+ )
102
+ self.rope_cache = attn_utils.build_rope_cache(
103
+ size=config.kv_cache_max,
104
+ dim=int(
105
+ config.attn_config.rotary_percentage * config.attn_config.head_dim
106
+ ),
107
+ base=10_000,
108
+ condense_ratio=1,
109
+ dtype=torch.float32,
110
+ device=torch.device("cpu"),
111
+ )
112
+ self.mask_cache = attn_utils.build_causal_mask_cache(
113
+ size=config.kv_cache_max,
114
+ dtype=torch.float32,
115
+ device=torch.device("cpu"),
116
+ )
117
+
118
+ self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
119
+ size=config.kv_cache_max,
120
+ window_size=self.config.attn_config.sliding_window_size,
121
+ dtype=torch.float32,
122
+ device=torch.device("cpu"),
123
+ )
124
+
125
+ self.config = config
126
+
127
+ def get_attention_mask(
128
+ self, idx: int, input_pos: torch.Tensor
129
+ ) -> torch.Tensor:
130
+ if self.config.attn_config.attn_types:
131
+ if (
132
+ self.config.attn_config.attn_types[idx]
133
+ == cfg.AttentionType.LOCAL_SLIDING
134
+ ):
135
+ return self.sliding_window_mask_cache.index_select(2, input_pos)
136
+
137
+ return self.mask_cache.index_select(2, input_pos)
138
+
139
+ @torch.inference_mode
140
+ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
141
+ B, T = idx.size()
142
+ assert self.config.max_seq_len >= T, (
143
+ f"Cannot forward sequence of length {T}, max seq length is only"
144
+ f" {self.config.max_seq_len}"
145
+ )
146
+
147
+ cos, sin = self.rope_cache
148
+ cos = cos.index_select(0, input_pos)
149
+ sin = sin.index_select(0, input_pos)
150
+
151
+ # token embeddings of shape (b, t, n_embd)
152
+ x = self.tok_embedding(idx)
153
+ x = x * (self.config.embedding_dim**0.5)
154
+
155
+ for i, block in enumerate(self.transformer_blocks):
156
+ mask = self.get_attention_mask(i, input_pos)
157
+ x = block(x, (cos, sin), mask, input_pos)
158
+
159
+ x = self.final_norm(x)
160
+ res = self.lm_head(x) # (b, t, vocab_size)
161
+ if self.config.final_logit_softcap is not None:
162
+ res = res / self.config.final_logit_softcap
163
+ res = torch.tanh(res)
164
+ res = res * self.config.final_logit_softcap
165
+ return res
166
+
167
+
168
+ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
169
+ attn_config = cfg.AttentionConfig(
170
+ num_heads=8,
171
+ head_dim=256,
172
+ num_query_groups=4,
173
+ rotary_percentage=1.0,
174
+ qkv_transpose_before_split=True,
175
+ logit_softcap=50.0,
176
+ sliding_window_size=4096,
177
+ attn_types=[cfg.AttentionType.GLOBAL, cfg.AttentionType.LOCAL_SLIDING]
178
+ * 13,
179
+ )
180
+
181
+ norm_config = cfg.NormalizationConfig(
182
+ type=cfg.NormalizationType.RMS_NORM,
183
+ epsilon=1e-6,
184
+ zero_centered=True,
185
+ )
186
+ ff_config = cfg.FeedForwardConfig(
187
+ type=cfg.FeedForwardType.GATED,
188
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
189
+ intermediate_size=9216,
190
+ pre_ff_norm_config=norm_config,
191
+ post_ff_norm_config=norm_config,
192
+ )
193
+ config = cfg.ModelConfig(
194
+ vocab_size=256000,
195
+ num_layers=26,
196
+ max_seq_len=8192,
197
+ embedding_dim=2304,
198
+ kv_cache_max_len=kv_cache_max_len,
199
+ attn_config=attn_config,
200
+ ff_config=ff_config,
201
+ pre_attention_norm_config=norm_config,
202
+ post_attention_norm_config=norm_config,
203
+ final_norm_config=norm_config,
204
+ parallel_residual=False,
205
+ lm_head_use_bias=False,
206
+ enable_hlfb=False,
207
+ final_logit_softcap=30.0,
208
+ )
209
+ return config
210
+
211
+
212
+ def get_fake_model_config_2b_for_test() -> cfg.ModelConfig:
213
+ config = get_model_config_2b()
214
+ config.num_layers = 2
215
+ return config
216
+
217
+
218
+ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
219
+ config = get_model_config_2b(**kwargs)
220
+ model = Gemma2(config)
221
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
222
+ # since embedding and lm-head use the same weight, we need to set strict
223
+ # to False.
224
+ loader.load(model, strict=False)
225
+ model.eval()
226
+ return model
227
+
228
+
229
+ def define_and_run_2b() -> None:
230
+ current_dir = Path(__file__).parent.resolve()
231
+ gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
232
+ print("Running GEMMA 2")
233
+ kv_cache_max_len = 1024
234
+ checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
235
+ model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
236
+ toks = torch.from_numpy(
237
+ np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
238
+ )
239
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
240
+ tokens[0, :9] = toks
241
+ input_pos = torch.arange(0, kv_cache_max_len)
242
+ out = model.forward(tokens, input_pos)
243
+ out_final = out[0, 8, :]
244
+ assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
245
+ print(out)
246
+
247
+
248
+ if __name__ == "__main__":
249
+ torch.set_printoptions(sci_mode=True)
250
+ define_and_run_2b()
@@ -35,7 +35,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
35
35
  pre_attn_norm=(
36
36
  "cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm1"
37
37
  ),
38
- pre_ff_norm=(
38
+ post_attn_norm=(
39
39
  "cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm2"
40
40
  ),
41
41
  embedding=(
@@ -120,7 +120,7 @@ def get_model_config() -> cfg.ModelConfig:
120
120
  attn_config=attn_config,
121
121
  ff_config=ff_config,
122
122
  pre_attention_norm_config=norm_config,
123
- pre_ff_norm_config=norm_config,
123
+ post_attention_norm_config=norm_config,
124
124
  final_norm_config=norm_config,
125
125
  enable_hlfb=True,
126
126
  )
@@ -38,7 +38,7 @@ ENCDEC_TENSOR_NAMES = {
38
38
  "{prefix}.block.0.layer.0.SelfAttention.relative_attention_bias"
39
39
  ),
40
40
  "pre_attn_norm": "{prefix}.block.{}.layer.0.layer_norm",
41
- "pre_ff_norm": "{prefix}.block.{}.layer.1.layer_norm",
41
+ "post_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
42
42
  "final_norm": "{prefix}.final_layer_norm",
43
43
  }
44
44
 
@@ -396,7 +396,7 @@ def get_model_config_t5() -> cfg.ModelConfig:
396
396
  relative_attention=True,
397
397
  ff_config=ff_config,
398
398
  pre_attention_norm_config=norm_config,
399
- pre_ff_norm_config=norm_config,
399
+ post_attention_norm_config=norm_config,
400
400
  final_norm_config=norm_config,
401
401
  parallel_residual=False,
402
402
  lm_head_use_bias=False,
@@ -419,7 +419,7 @@ def build_t5_model(checkpoint_path: str) -> nn.Module:
419
419
  "cross_attn_value_proj": "{prefix}.block.{}.layer.1.EncDecAttention.v",
420
420
  "cross_attn_output_proj": "{prefix}.block.{}.layer.1.EncDecAttention.o",
421
421
  # In the decoder, the FF is layer 2 in the Transformer block
422
- "pre_ff_norm": "{prefix}.block.{}.layer.2.layer_norm",
422
+ "post_attn_norm": "{prefix}.block.{}.layer.2.layer_norm",
423
423
  # In the decoder, the cross attention is layer 1 in the Transformer block
424
424
  "pre_cross_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
425
425
  }
@@ -475,7 +475,7 @@ def build_t5_decoder_model(
475
475
  "cross_attn_value_proj": "{prefix}.block.{}.layer.1.EncDecAttention.v",
476
476
  "cross_attn_output_proj": "{prefix}.block.{}.layer.1.EncDecAttention.o",
477
477
  # In the decoder, the FF is layer 2 in the Transformer block
478
- "pre_ff_norm": "{prefix}.block.{}.layer.2.layer_norm",
478
+ "post_attn_norm": "{prefix}.block.{}.layer.2.layer_norm",
479
479
  # In the decoder, the cross attention is layer 1 in the Transformer block
480
480
  "pre_cross_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
481
481
  }
@@ -68,8 +68,8 @@ class EncoderDecoderBlock(nn.Module):
68
68
  else:
69
69
  self.cross_atten_func = None
70
70
 
71
- self.pre_ff_norm = builder.build_norm(
72
- config.embedding_dim, config.pre_ff_norm_config
71
+ self.post_atten_norm = builder.build_norm(
72
+ config.embedding_dim, config.post_attention_norm_config
73
73
  )
74
74
  self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
75
75
  self.config = config
@@ -118,7 +118,7 @@ class EncoderDecoderBlock(nn.Module):
118
118
  )
119
119
  attn_out = hidden_states + attn_out
120
120
 
121
- forwarded = self.pre_ff_norm(attn_out)
121
+ forwarded = self.post_atten_norm(attn_out)
122
122
  forwarded = self.ff(forwarded)
123
123
  hidden_states = attn_out + forwarded
124
124
 
@@ -93,7 +93,7 @@ def get_model_config() -> cfg.ModelConfig:
93
93
  attn_config=attn_config,
94
94
  ff_config=ff_config,
95
95
  pre_attention_norm_config=norm_config,
96
- pre_ff_norm_config=norm_config,
96
+ post_attention_norm_config=norm_config,
97
97
  final_norm_config=norm_config,
98
98
  )
99
99
  return config
@@ -107,7 +107,7 @@ def get_model_config() -> cfg.ModelConfig:
107
107
  attn_config=attn_config,
108
108
  ff_config=ff_config,
109
109
  pre_attention_norm_config=norm_config,
110
- pre_ff_norm_config=norm_config,
110
+ post_attention_norm_config=norm_config,
111
111
  final_norm_config=norm_config,
112
112
  enable_hlfb=True,
113
113
  )
@@ -94,7 +94,7 @@ def get_model_config() -> cfg.ModelConfig:
94
94
  attn_config=attn_config,
95
95
  ff_config=ff_config,
96
96
  pre_attention_norm_config=norm_config,
97
- pre_ff_norm_config=norm_config,
97
+ post_attention_norm_config=norm_config,
98
98
  final_norm_config=norm_config,
99
99
  enable_hlfb=True,
100
100
  )
@@ -35,7 +35,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
35
35
  attn_value_proj="model.layers.{}.self_attn.v_proj",
36
36
  attn_output_proj="model.layers.{}.self_attn.o_proj",
37
37
  pre_attn_norm="model.layers.{}.input_layernorm",
38
- pre_ff_norm="model.layers.{}.post_attention_layernorm",
38
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
39
39
  embedding="model.embed_tokens",
40
40
  final_norm="model.norm",
41
41
  lm_head="lm_head",
@@ -130,7 +130,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
130
130
  attn_config=attn_config,
131
131
  ff_config=ff_config,
132
132
  pre_attention_norm_config=norm_config,
133
- pre_ff_norm_config=norm_config,
133
+ post_attention_norm_config=norm_config,
134
134
  final_norm_config=norm_config,
135
135
  enable_hlfb=True,
136
136
  )
@@ -74,8 +74,8 @@ class TransformerBlock(nn.Module):
74
74
  config.kv_cache_max,
75
75
  config.enable_hlfb,
76
76
  )
77
- self.pre_ff_norm = builder.build_norm(
78
- config.embedding_dim, config.pre_ff_norm_config
77
+ self.post_atten_norm = builder.build_norm(
78
+ config.embedding_dim, config.post_attention_norm_config
79
79
  )
80
80
  self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
81
81
  self.config = config
@@ -108,7 +108,7 @@ class TransformerBlock(nn.Module):
108
108
  x_norm = self.pre_atten_norm(x)
109
109
  attn_out = self.atten_func(x_norm, rope, mask, input_pos)
110
110
  x = x + attn_out
111
- x_norm = self.pre_ff_norm(x)
111
+ x_norm = self.post_atten_norm(x)
112
112
  output = x + self.ff(x_norm)
113
113
 
114
114
  return output
@@ -228,8 +228,15 @@ class CausalSelfAttention(nn.Module):
228
228
  # TODO(haoliang): Handle when execeeding max sequence length.
229
229
  k, v = self.kv_cache.update_cache(input_pos, k, v)
230
230
 
231
- y = self.sdpa_func(q, k, v, self.config.head_dim, mask=mask)
232
- y = y.reshape(B, T, E)
231
+ y = self.sdpa_func(
232
+ q,
233
+ k,
234
+ v,
235
+ self.config.head_dim,
236
+ mask=mask,
237
+ softcap=self.config.logit_softcap,
238
+ )
239
+ y = y.reshape(B, T, -1)
233
240
 
234
241
  # Compute the output projection.
235
242
  y = self.output_projection(y)
@@ -74,12 +74,42 @@ def build_causal_mask_cache(
74
74
  Returns:
75
75
  torch.Tensor: Causal attention mask.
76
76
  """
77
+
77
78
  if device is None:
78
79
  device = torch.device('cpu')
79
80
  mask = torch.full((size, size), float('-inf'), dtype=dtype, device=device)
80
81
  return torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
81
82
 
82
83
 
84
+ def build_sliding_window_mask_cache(
85
+ size: int,
86
+ window_size: int,
87
+ dtype: torch.dtype = torch.float32,
88
+ device: torch.device = None,
89
+ ) -> torch.Tensor:
90
+ """Build a cache for a sliding window mask.
91
+
92
+ Args:
93
+ size (int): The size of the built mask cache.
94
+ window_size (int): The window size that is "seen" by a token.
95
+ dtype (torch.dtype, optional): Output tensor's data type. Defaults to
96
+ torch.float32.
97
+ device (torch.device, optional): Output tensor's data type. Defaults to
98
+ None in which case "cpu" is used.
99
+
100
+ Returns:
101
+ torch.Tensor: Causal attention mask.
102
+ """
103
+
104
+ mask = build_causal_mask_cache(size, dtype, device)
105
+ all_ones = torch.ones_like(mask)
106
+ window_size = min(size, window_size)
107
+ sliding_mask = torch.triu(all_ones, -1 * window_size + 1) * torch.tril(
108
+ all_ones, window_size - 1
109
+ )
110
+ return torch.where(sliding_mask == 1, mask, -2.3819763e38)
111
+
112
+
83
113
  def relative_position_bucket(
84
114
  relative_position: torch.Tensor,
85
115
  bidirectional: bool,
@@ -89,11 +89,16 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
89
89
 
90
90
  activation = get_activation(config.activation)
91
91
 
92
+ pre_ff_norm = build_norm(dim, config.pre_ff_norm_config)
93
+ post_ff_norm = build_norm(dim, config.post_ff_norm_config)
94
+
92
95
  return ff_module(
93
96
  dim=dim,
94
97
  hidden_dim=config.intermediate_size,
95
98
  activation=activation,
96
99
  use_bias=config.use_bias,
100
+ pre_ff_norm=pre_ff_norm,
101
+ post_ff_norm=post_ff_norm,
97
102
  )
98
103
 
99
104
 
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  # Common building blocks for FeedForward layers.
16
16
 
17
- from typing import Callable
17
+ from typing import Callable, Optional
18
18
 
19
19
  import torch
20
20
  from torch import nn
@@ -30,6 +30,8 @@ class SequentialFeedForward(nn.Module):
30
30
  hidden_dim: int,
31
31
  activation: Callable[[torch.Tensor], torch.Tensor],
32
32
  use_bias=False,
33
+ pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
34
+ post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
33
35
  ):
34
36
  """Init function for feedforward layer.
35
37
 
@@ -41,6 +43,8 @@ class SequentialFeedForward(nn.Module):
41
43
  self.act = activation
42
44
  self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
43
45
  self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
46
+ self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
47
+ self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
44
48
 
45
49
  def forward(self, x):
46
50
  """Forward pass for Feedforward layer.
@@ -51,7 +55,9 @@ class SequentialFeedForward(nn.Module):
51
55
  Returns:
52
56
  torch.Tensor: output tensor after feedforward.
53
57
  """
54
- return self.w2(self.act(self.w1(x)))
58
+ x_norm = self.pre_ff_norm(x)
59
+ out = self.w2(self.act(self.w1(x_norm)))
60
+ return self.post_ff_norm(out)
55
61
 
56
62
 
57
63
  class GatedFeedForward(nn.Module):
@@ -66,6 +72,8 @@ class GatedFeedForward(nn.Module):
66
72
  hidden_dim: int,
67
73
  activation: Callable[[torch.Tensor], torch.Tensor],
68
74
  use_bias=False,
75
+ pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
76
+ post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
69
77
  ):
70
78
  """Init function for feedforward layer.
71
79
 
@@ -78,6 +86,8 @@ class GatedFeedForward(nn.Module):
78
86
  self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
79
87
  self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
80
88
  self.w3 = nn.Linear(dim, hidden_dim, bias=use_bias)
89
+ self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
90
+ self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
81
91
 
82
92
  def forward(self, x):
83
93
  """Forward pass for Feedforward layer.
@@ -88,4 +98,6 @@ class GatedFeedForward(nn.Module):
88
98
  Returns:
89
99
  torch.Tensor: output tensor after feedforward.
90
100
  """
91
- return self.w2(self.act(self.w1(x)) * self.w3(x))
101
+ x_norm = self.pre_ff_norm(x)
102
+ out = self.w2(self.act(self.w1(x_norm)) * self.w3(x_norm))
103
+ return self.post_ff_norm(out)