ai-edge-torch-nightly 0.2.0.dev20240604__py3-none-any.whl → 0.2.0.dev20240606__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -0,0 +1,161 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A toy example which has basic transformer block (w/ externalized KV-Cache).
16
+
17
+ from typing import List, Tuple
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch_xla
23
+
24
+ import ai_edge_torch
25
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
+ import ai_edge_torch.generative.layers.builder as builder
27
+ from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
28
+ import ai_edge_torch.generative.layers.model_config as cfg
29
+
30
+ RoPECache = Tuple[torch.Tensor, torch.Tensor]
31
+
32
+
33
+ class ToyModelWithExternalKV(torch.nn.Module):
34
+
35
+ def __init__(self, config: cfg.ModelConfig) -> None:
36
+ super().__init__()
37
+ self.lm_head = nn.Linear(
38
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
39
+ )
40
+ self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
41
+ self.transformer_blocks = nn.ModuleList(
42
+ TransformerBlock(config) for _ in range(config.num_layers)
43
+ )
44
+ self.final_norm = builder.build_norm(
45
+ config.embedding_dim,
46
+ config.final_norm_config,
47
+ )
48
+ self.rope_cache = attn_utils.build_rope_cache(
49
+ size=config.max_seq_len,
50
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
51
+ base=10_000,
52
+ condense_ratio=1,
53
+ dtype=torch.float32,
54
+ device=torch.device('cpu'),
55
+ )
56
+ self.mask_cache = attn_utils.build_causal_mask_cache(
57
+ size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
58
+ )
59
+ self.config = config
60
+
61
+ def forward(
62
+ self,
63
+ idx: torch.Tensor,
64
+ input_pos: torch.Tensor,
65
+ k_caches: torch.Tensor,
66
+ v_caches: torch.Tensor,
67
+ ) -> (torch.Tensor, torch.Tensor, torch.Tensor):
68
+ x = self.tok_embedding(idx)
69
+ cos, sin = self.rope_cache
70
+ cos = cos.index_select(0, input_pos)
71
+ sin = sin.index_select(0, input_pos)
72
+ mask = self.mask_cache.index_select(2, input_pos)
73
+ mask = mask[:, :, :, : self.config.max_seq_len]
74
+
75
+ for i, block in enumerate(self.transformer_blocks):
76
+ input_k, input_v = k_caches[i], v_caches[i]
77
+ x, (updated_k, updated_v) = block(
78
+ x, (cos, sin), mask, input_pos, (input_k, input_v)
79
+ )
80
+ k_caches[i], v_caches[i] = updated_k, updated_v
81
+
82
+ x = self.final_norm(x)
83
+ return self.lm_head(x), k_caches, v_caches
84
+
85
+
86
+ def _export_stablehlo_mlir(model, args):
87
+ ep = torch.export.export(model, args)
88
+ stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
89
+ return stablehlo_gm.get_stablehlo_text()
90
+
91
+
92
+ def get_model_config() -> cfg.ModelConfig:
93
+ attn_config = cfg.AttentionConfig(
94
+ num_heads=32, num_query_groups=4, rotary_percentage=1.0
95
+ )
96
+ ff_config = cfg.FeedForwardConfig(
97
+ type=cfg.FeedForwardType.GATED,
98
+ activation=cfg.ActivationType.SILU,
99
+ intermediate_size=256,
100
+ )
101
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
102
+ config = cfg.ModelConfig(
103
+ vocab_size=150,
104
+ num_layers=2,
105
+ max_seq_len=100,
106
+ embedding_dim=128,
107
+ attn_config=attn_config,
108
+ ff_config=ff_config,
109
+ pre_attention_norm_config=norm_config,
110
+ pre_ff_norm_config=norm_config,
111
+ final_norm_config=norm_config,
112
+ enable_hlfb=True,
113
+ )
114
+ return config
115
+
116
+
117
+ def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
118
+ idx = torch.unsqueeze(torch.arange(0, 100), 0)
119
+ input_pos = torch.arange(0, 100)
120
+ return idx, input_pos
121
+
122
+
123
+ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
124
+ idx = torch.tensor([[1]], dtype=torch.long)
125
+ input_pos = torch.tensor([10])
126
+ return idx, input_pos
127
+
128
+
129
+ def define_and_run() -> None:
130
+ dump_mlir = False
131
+
132
+ config = get_model_config()
133
+ model = ToyModelWithExternalKV(config)
134
+ print('running an inference')
135
+ k_caches = torch.zeros((2, 1, 100, 4, 4), dtype=torch.float32)
136
+ v_caches = torch.zeros((2, 1, 100, 4, 4), dtype=torch.float32)
137
+
138
+ idx, input_pos = get_sample_prefill_inputs()
139
+ decode_idx, decode_input_pos = get_sample_decode_inputs()
140
+ print(model.forward(idx, input_pos, k_caches, v_caches))
141
+
142
+ if dump_mlir:
143
+ mlir_text = _export_stablehlo_mlir(model, (idx, input_pos, k_caches, v_caches))
144
+ with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
145
+ f.write(mlir_text)
146
+
147
+ # Convert model to tflite with 2 signatures (prefill + decode).
148
+ # TODO(b/344014416): currently conversion will fail, because we generate int64 index
149
+ # in dynamic update slice op.
150
+ print('converting toy model to tflite with 2 signatures (prefill + decode)')
151
+ edge_model = (
152
+ ai_edge_torch.signature('prefill', model, (idx, input_pos, k_caches, v_caches))
153
+ .signature('decode', model, (decode_idx, decode_input_pos, k_caches, v_caches))
154
+ .convert()
155
+ )
156
+ edge_model.export('/tmp/toy_external_kv_cache.tflite')
157
+
158
+
159
+ if __name__ == '__main__':
160
+ with torch.inference_mode():
161
+ define_and_run()
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
  # Common building blocks for Attention layer.
16
16
 
17
- import math
18
17
  from typing import Optional, Tuple
19
18
 
20
19
  import torch
@@ -25,101 +24,8 @@ import ai_edge_torch.generative.layers.builder as builder
25
24
  from ai_edge_torch.generative.layers.kv_cache import KVCache
26
25
  import ai_edge_torch.generative.layers.model_config as cfg
27
26
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
28
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
29
-
30
-
31
- def scaled_dot_product_attention(
32
- q: torch.Tensor,
33
- k: torch.Tensor,
34
- v: torch.Tensor,
35
- head_size: int,
36
- mask: Optional[torch.Tensor] = None,
37
- scale: Optional[float] = None,
38
- ):
39
- """Scaled dot product attention.
40
-
41
- Args:
42
- q (torch.Tensor): Query tensor, with shape [B, T, N, H].
43
- k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
44
- v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
45
- head_size (int): head dimension.
46
- mask (torch.Tensor): the optional mask tensor.
47
-
48
- Returns:
49
- The output tensor of scaled_dot_product_attention.
50
- """
51
-
52
- if scale is None:
53
- scale = 1.0 / math.sqrt(head_size)
54
-
55
- q = q.transpose(1, 2)
56
- k = k.transpose(1, 2)
57
- v = v.transpose(1, 2)
58
- if q.size() != k.size():
59
- # Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
60
- k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
61
- v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
62
- y = F.scaled_dot_product_attention(
63
- q,
64
- k,
65
- v,
66
- attn_mask=mask,
67
- dropout_p=0.0,
68
- is_causal=mask is None,
69
- scale=scale,
70
- )
71
- return y.transpose(1, 2)
72
-
73
-
74
- def scaled_dot_product_attention_with_hlfb(
75
- q: torch.Tensor,
76
- k: torch.Tensor,
77
- v: torch.Tensor,
78
- head_size: int,
79
- mask: Optional[torch.Tensor] = None,
80
- scale: Optional[float] = None,
81
- ):
82
- """Scaled dot product attention with high-level function boundary enabled.
83
-
84
- Args:
85
- q (torch.Tensor): Query tensor, with shape [B, T, N, H].
86
- k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
87
- v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
88
- head_size (int): head dimension.
89
- mask (torch.Tensor): the optional mask tensor.
90
-
91
- Returns:
92
- The output tensor of scaled_dot_product_attention.
93
- """
94
-
95
- if scale is None:
96
- scale = 1.0 / math.sqrt(head_size)
97
-
98
- builder = StableHLOCompositeBuilder(
99
- name="odml.scaled_dot_product_attention", attr={"scale": scale}
100
- )
101
- q, k, v, mask = builder.mark_inputs(q, k, v, mask)
102
-
103
- q = q.transpose(1, 2)
104
- k = k.transpose(1, 2)
105
- v = v.transpose(1, 2)
106
- if q.size() != k.size():
107
- # Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
108
- k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
109
- v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
110
- y = F.scaled_dot_product_attention(
111
- q,
112
- k,
113
- v,
114
- attn_mask=mask,
115
- dropout_p=0.0,
116
- is_causal=mask is None,
117
- scale=scale,
118
- )
119
-
120
- result = y.transpose(1, 2)
121
- result = builder.mark_outputs(result)
122
- return result
27
+ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
28
+ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
123
29
 
124
30
 
125
31
  class TransformerBlock(nn.Module):
@@ -151,7 +57,7 @@ class TransformerBlock(nn.Module):
151
57
  def forward(
152
58
  self,
153
59
  x: torch.Tensor,
154
- rope: Tuple[torch.Tensor, torch.Tensor],
60
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
155
61
  mask: Optional[torch.Tensor] = None,
156
62
  input_pos: Optional[torch.Tensor] = None,
157
63
  ) -> torch.Tensor:
@@ -182,7 +88,6 @@ class TransformerBlock(nn.Module):
182
88
  return output
183
89
 
184
90
 
185
- # CausalSelfAttention which can support MHQ, MQA or GQA.
186
91
  class CausalSelfAttention(nn.Module):
187
92
 
188
93
  def __init__(
@@ -229,11 +134,12 @@ class CausalSelfAttention(nn.Module):
229
134
  def forward(
230
135
  self,
231
136
  x: torch.Tensor,
232
- rope: Tuple[torch.Tensor, torch.Tensor],
137
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
233
138
  mask: Optional[torch.Tensor] = None,
234
139
  input_pos: Optional[torch.Tensor] = None,
235
140
  ) -> torch.Tensor:
236
- """Forward function of the CausalSelfAttention layer.
141
+ """Forward function of the CausalSelfAttention layer, which can support
142
+ MQA, GQA and MHA.
237
143
 
238
144
  Args:
239
145
  x (torch.Tensor): the input tensor.
@@ -253,28 +159,35 @@ class CausalSelfAttention(nn.Module):
253
159
  # Assemble into a number of query groups to support MHA, MQA and GQA.
254
160
  q_per_kv = self.config.num_heads // self.config.num_query_groups
255
161
  total_qkv = q_per_kv + 2 # Each group has >=1 queries, 1 key, and 1 value.
256
- qkv = qkv.view(
257
- B, T, self.config.num_query_groups, total_qkv, self.head_dim
258
- ) # (B, T, num_query_groups, total_qkv, head_dim)
162
+ if self.config.qkv_transpose_before_split:
163
+ qkv = qkv.view(
164
+ B, T, total_qkv, self.config.num_query_groups, self.head_dim
165
+ ) # (B, T, total_qkv, num_query_groups, head_dim)
166
+ qkv_axis = -3
167
+ else:
168
+ qkv = qkv.view(
169
+ B, T, self.config.num_query_groups, total_qkv, self.head_dim
170
+ ) # (B, T, num_query_groups, total_qkv, head_dim)
171
+ qkv_axis = -2
259
172
 
260
173
  # Split batched computation into three.
261
- q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2)
262
-
174
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=qkv_axis)
263
175
  q = q.reshape(B, T, -1, self.head_dim)
264
176
  k = k.reshape(B, T, -1, self.head_dim)
265
177
  v = v.reshape(B, T, -1, self.head_dim)
266
178
 
267
179
  # Compute rotary positional embedding for query and key.
268
180
  n_elem = int(self.config.rotary_percentage * self.head_dim)
269
- cos, sin = rope
270
- q_roped = rotary_pos_emb.apply_rope(
271
- q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
272
- )
273
- k_roped = rotary_pos_emb.apply_rope(
274
- k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
275
- )
276
- q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
277
- k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
181
+ if n_elem > 0:
182
+ cos, sin = rope
183
+ q_roped = rotary_pos_emb.apply_rope(
184
+ q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
185
+ )
186
+ k_roped = rotary_pos_emb.apply_rope(
187
+ k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
188
+ )
189
+ q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
190
+ k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
278
191
 
279
192
  if self.kv_cache is not None:
280
193
  # TODO(haoliang): Handle when execeeding max sequence length.
@@ -97,6 +97,10 @@ def _get_activation(type_: cfg.ActivationType):
97
97
  return F.gelu
98
98
  elif type_ == cfg.ActivationType.GELU_TANH:
99
99
  return lambda x: F.gelu(x, approximate="tanh")
100
+ elif type_ == cfg.ActivationType.GELU_QUICK:
101
+ # GELU approximation that is fast but somewhat inaccurate.
102
+ # See: https://github.com/hendrycks/GELUs
103
+ return lambda x: x * F.sigmoid(1.702 * x)
100
104
  elif type_ == cfg.ActivationType.RELU:
101
105
  return F.relu
102
106
  else:
@@ -27,6 +27,7 @@ class ActivationType(enum.Enum):
27
27
  SILU = enum.auto()
28
28
  GELU = enum.auto()
29
29
  GELU_TANH = enum.auto()
30
+ GELU_QUICK = enum.auto()
30
31
  RELU = enum.auto()
31
32
 
32
33
 
@@ -46,7 +47,7 @@ class FeedForwardType(enum.Enum):
46
47
 
47
48
  # `output = linear(act(linear(x)))`.
48
49
  SEQUENTIAL = enum.auto()
49
- # `output = linear(act(linear(x)) * lienar(x))`.
50
+ # `output = linear_2(act(linear_1(x)) * lienar_3(x))`.
50
51
  GATED = enum.auto()
51
52
 
52
53
 
@@ -60,6 +61,9 @@ class AttentionConfig:
60
61
  num_query_groups: Optional[int]
61
62
  # Percentage of Rotary Positional Embedding added Q and K projections.
62
63
  rotary_percentage: Optional[float] = None
64
+ # Whether to transpose the query groups of qkv bundled tensor before
65
+ # splitting into separated tensors.
66
+ qkv_transpose_before_split: bool = False
63
67
  # Whether to use bias with Query, Key, and Value projection.
64
68
  qkv_use_bias: bool = False
65
69
  # Whether to use bias with attention output projection.
@@ -0,0 +1,117 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Implements scaled dot product attention.
16
+
17
+ import math
18
+ from typing import Optional
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+
23
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
24
+
25
+
26
+ def scaled_dot_product_attention(
27
+ q: torch.Tensor,
28
+ k: torch.Tensor,
29
+ v: torch.Tensor,
30
+ head_size: int,
31
+ mask: Optional[torch.Tensor] = None,
32
+ scale: Optional[float] = None,
33
+ ):
34
+ """Scaled dot product attention.
35
+
36
+ Args:
37
+ q (torch.Tensor): Query tensor, with shape [B, T, N, H].
38
+ k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
39
+ v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
40
+ head_size (int): head dimension.
41
+ mask (torch.Tensor): the optional mask tensor.
42
+
43
+ Returns:
44
+ The output tensor of scaled_dot_product_attention.
45
+ """
46
+
47
+ if scale is None:
48
+ scale = 1.0 / math.sqrt(head_size)
49
+
50
+ q = q.transpose(1, 2)
51
+ k = k.transpose(1, 2)
52
+ v = v.transpose(1, 2)
53
+ if q.size() != k.size():
54
+ # Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
55
+ k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
56
+ v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
57
+ y = F.scaled_dot_product_attention(
58
+ q,
59
+ k,
60
+ v,
61
+ attn_mask=mask,
62
+ dropout_p=0.0,
63
+ is_causal=mask is None,
64
+ scale=scale,
65
+ )
66
+ return y.transpose(1, 2)
67
+
68
+
69
+ def scaled_dot_product_attention_with_hlfb(
70
+ q: torch.Tensor,
71
+ k: torch.Tensor,
72
+ v: torch.Tensor,
73
+ head_size: int,
74
+ mask: Optional[torch.Tensor] = None,
75
+ scale: Optional[float] = None,
76
+ ):
77
+ """Scaled dot product attention with high-level function boundary enabled.
78
+
79
+ Args:
80
+ q (torch.Tensor): Query tensor, with shape [B, T, N, H].
81
+ k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
82
+ v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
83
+ head_size (int): head dimension.
84
+ mask (torch.Tensor): the optional mask tensor.
85
+
86
+ Returns:
87
+ The output tensor of scaled_dot_product_attention.
88
+ """
89
+
90
+ if scale is None:
91
+ scale = 1.0 / math.sqrt(head_size)
92
+
93
+ builder = StableHLOCompositeBuilder(
94
+ name="odml.scaled_dot_product_attention", attr={"scale": scale}
95
+ )
96
+ q, k, v, mask = builder.mark_inputs(q, k, v, mask)
97
+
98
+ q = q.transpose(1, 2)
99
+ k = k.transpose(1, 2)
100
+ v = v.transpose(1, 2)
101
+ if q.size() != k.size():
102
+ # Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
103
+ k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
104
+ v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
105
+ y = F.scaled_dot_product_attention(
106
+ q,
107
+ k,
108
+ v,
109
+ attn_mask=mask,
110
+ dropout_p=0.0,
111
+ is_causal=mask is None,
112
+ scale=scale,
113
+ )
114
+
115
+ result = y.transpose(1, 2)
116
+ result = builder.mark_outputs(result)
117
+ return result