ai-edge-torch-nightly 0.2.0.dev20240610__py3-none-any.whl → 0.2.0.dev20240617__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (30) hide show
  1. ai_edge_torch/convert/conversion_utils.py +17 -5
  2. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
  3. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
  4. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
  5. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
  6. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
  7. ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
  8. ai_edge_torch/generative/layers/attention.py +154 -26
  9. ai_edge_torch/generative/layers/model_config.py +4 -0
  10. ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
  11. ai_edge_torch/generative/layers/unet/builder.py +20 -2
  12. ai_edge_torch/generative/layers/unet/model_config.py +157 -5
  13. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  14. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +164 -0
  15. ai_edge_torch/generative/quantize/quant_attrs.py +2 -0
  16. ai_edge_torch/generative/quantize/quant_recipe.py +49 -4
  17. ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -2
  18. ai_edge_torch/generative/quantize/quant_recipes.py +3 -3
  19. ai_edge_torch/generative/quantize/supported_schemes.py +2 -1
  20. ai_edge_torch/generative/test/test_model_conversion.py +24 -0
  21. ai_edge_torch/generative/test/test_quantize.py +75 -20
  22. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
  23. ai_edge_torch/generative/utilities/t5_loader.py +33 -17
  24. ai_edge_torch/quantize/quant_config.py +11 -15
  25. {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/METADATA +1 -1
  26. {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/RECORD +29 -27
  27. ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
  28. {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/LICENSE +0 -0
  29. {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/WHEEL +0 -0
  30. {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import torch
20
20
  from torch import nn
21
21
  import torch.nn.functional as F
22
22
 
23
+ from ai_edge_torch.generative.layers.attention import CrossAttention
23
24
  import ai_edge_torch.generative.layers.builder as builder
24
25
  from ai_edge_torch.generative.layers.kv_cache import KVCache
25
26
  import ai_edge_torch.generative.layers.model_config as cfg
@@ -122,7 +123,7 @@ class EncoderDecoderBlock(nn.Module):
122
123
  return hidden_states, position_bias, encoder_decoder_position_bias
123
124
 
124
125
 
125
- class T5Attention(nn.Module):
126
+ class T5Attention(CrossAttention):
126
127
 
127
128
  def __init__(
128
129
  self,
@@ -138,51 +139,21 @@ class T5Attention(nn.Module):
138
139
  Args:
139
140
  dim (int): causal attention's input/output dimmension.
140
141
  config (cfg.AttentionConfig): attention specific configurations.
142
+ norm_config (cfg.NormalizationConfig): normalization configure before attention.
141
143
  kv_cache_max (int): determines the size of the KV Cache buffer, if enabled.
142
144
  enable_hlfb (bool): whether hlfb is enabled or not.
143
145
  has_relative_attention_bias (bool): whether we compute relative bias.
144
146
  """
145
- super().__init__()
147
+ super().__init__(dim, dim, config, kv_cache_max, enable_hlfb)
146
148
  self.pre_atten_norm = builder.build_norm(dim, norm_config)
147
149
 
148
150
  self.has_relative_attention_bias = has_relative_attention_bias
149
151
  self.relative_attention_num_buckets = config.relative_attention_num_buckets
150
- self.d_model = dim
151
- self.head_dim = dim // config.num_heads
152
- self.n_heads = config.num_heads
153
- self.inner_dim = self.n_heads * self.head_dim
154
-
155
- self.q = nn.Linear(self.d_model, self.inner_dim, bias=config.qkv_use_bias)
156
- self.k = nn.Linear(self.d_model, self.inner_dim, bias=config.qkv_use_bias)
157
- self.v = nn.Linear(self.d_model, self.inner_dim, bias=config.qkv_use_bias)
158
- # output projection
159
- self.proj = nn.Linear(
160
- self.inner_dim, self.d_model, bias=config.output_proj_use_bias
161
- )
162
-
163
152
  if self.has_relative_attention_bias:
164
153
  self.relative_attention_bias = nn.Embedding(
165
154
  self.relative_attention_num_buckets, self.n_heads
166
155
  )
167
156
 
168
- self.config = config
169
- self.kv_cache = None
170
- # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
171
- # Now only supports a max batch_size of 1.
172
- if config.enable_kv_cache:
173
- self.kv_cache = KVCache(
174
- 1,
175
- kv_cache_max,
176
- config.num_query_groups,
177
- self.head_dim,
178
- enable_hlfb,
179
- )
180
-
181
- if enable_hlfb:
182
- self.sdpa_func = scaled_dot_product_attention_with_hlfb
183
- else:
184
- self.sdpa_func = scaled_dot_product_attention
185
-
186
157
  def forward(
187
158
  self,
188
159
  x: torch.Tensor,
@@ -206,7 +177,7 @@ class T5Attention(nn.Module):
206
177
 
207
178
  x = self.pre_atten_norm(x)
208
179
  B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
209
- query_states = self.q(x)
180
+ query_states = self.q_projection(x)
210
181
  query_states = query_states.reshape(B, T, -1, self.head_dim) # (B, T, nh_q, hs)
211
182
 
212
183
  if key_value_states is not None:
@@ -217,13 +188,13 @@ class T5Attention(nn.Module):
217
188
  ) = (
218
189
  key_value_states.size()
219
190
  ) # batch size, sequence length, embedding dimensionality (n_embd)
220
- key_states = self.k(key_value_states)
221
- value_states = self.v(key_value_states)
191
+ key_states = self.k_projection(key_value_states)
192
+ value_states = self.v_projection(key_value_states)
222
193
  key_states = key_states.reshape(kvB, kvT, -1, self.head_dim)
223
194
  value_states = value_states.reshape(kvB, kvT, -1, self.head_dim)
224
195
  else:
225
- key_states = self.k(x)
226
- value_states = self.v(x)
196
+ key_states = self.k_projection(x)
197
+ value_states = self.v_projection(x)
227
198
  key_states = key_states.reshape(B, T, -1, self.head_dim)
228
199
  value_states = value_states.reshape(B, T, -1, self.head_dim)
229
200
 
@@ -251,5 +222,5 @@ class T5Attention(nn.Module):
251
222
  )
252
223
  y = y.reshape(B, T, C) # re-assemble all head outputs side by side
253
224
  # output projection
254
- y = self.proj(y)
225
+ y = self.output_projection(y)
255
226
  return y, position_bias
@@ -28,6 +28,33 @@ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_
28
28
  from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
29
29
 
30
30
 
31
+ def _embed_rope(
32
+ q: torch.Tensor,
33
+ k: torch.Tensor,
34
+ n_elem: int,
35
+ rope: Tuple[torch.Tensor, torch.Tensor],
36
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
37
+ """Embed rotary positional embedding for query and key.
38
+
39
+ Args:
40
+ q (torch.Tensor): query tensor.
41
+ k (torch.Tensor): key tensor.
42
+ n_elem (int): number of elements to embed rotarty positional embedding.
43
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
44
+ """
45
+ if n_elem > 0:
46
+ cos, sin = rope
47
+ q_roped = rotary_pos_emb.apply_rope(
48
+ q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
49
+ )
50
+ k_roped = rotary_pos_emb.apply_rope(
51
+ k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
52
+ )
53
+ q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
54
+ k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
55
+ return q, k
56
+
57
+
31
58
  class TransformerBlock(nn.Module):
32
59
 
33
60
  def __init__(self, config: cfg.ModelConfig) -> None:
@@ -43,6 +70,7 @@ class TransformerBlock(nn.Module):
43
70
  config.embedding_dim, config.pre_attention_norm_config
44
71
  )
45
72
  self.atten_func = CausalSelfAttention(
73
+ config.batch_size,
46
74
  config.embedding_dim,
47
75
  config.attn_config,
48
76
  config.kv_cache_max,
@@ -92,6 +120,7 @@ class CausalSelfAttention(nn.Module):
92
120
 
93
121
  def __init__(
94
122
  self,
123
+ batch_size: int,
95
124
  dim: int,
96
125
  config: cfg.AttentionConfig,
97
126
  kv_cache_max: int,
@@ -100,6 +129,7 @@ class CausalSelfAttention(nn.Module):
100
129
  """Initialize an instance of CausalSelfAttention.
101
130
 
102
131
  Args:
132
+ batch_size (int): batch size of the input tensor.
103
133
  dim (int): causal attention's input/output dimmension.
104
134
  config (cfg.AttentionConfig): attention specific configurations.
105
135
  kv_cache_max (int): determines the size of the KV Cache buffer, if enabled.
@@ -113,13 +143,12 @@ class CausalSelfAttention(nn.Module):
113
143
  self.output_projection = nn.Linear(dim, dim, bias=config.output_proj_use_bias)
114
144
  self.config = config
115
145
  self.kv_cache = None
146
+ self.batch_size = batch_size
116
147
 
117
148
  # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
118
- # Now only supports batch_size of 1.
119
- # TODO(haoliang): support batch_size greater than 1.
120
149
  if config.enable_kv_cache:
121
150
  self.kv_cache = KVCache(
122
- 1,
151
+ batch_size,
123
152
  kv_cache_max,
124
153
  config.num_query_groups,
125
154
  self.head_dim,
@@ -152,42 +181,38 @@ class CausalSelfAttention(nn.Module):
152
181
  """
153
182
  # Batch size, sequence length, embedding dimensionality.
154
183
  B, T, E = x.size()
155
- assert B == 1, "Currently only batch_size = 1 is supported."
184
+ assert (
185
+ B == self.batch_size
186
+ ), "batch size of input tensor must match with the batch size specified in the model configuration."
156
187
 
157
188
  qkv = self.qkv_projection(x)
158
189
 
159
190
  # Assemble into a number of query groups to support MHA, MQA and GQA.
160
191
  q_per_kv = self.config.num_heads // self.config.num_query_groups
161
- total_qkv = q_per_kv + 2 # Each group has >=1 queries, 1 key, and 1 value.
192
+ # Each group has >=1 queries, 1 key, and 1 value.
162
193
  if self.config.qkv_transpose_before_split:
163
- qkv = qkv.view(
164
- B, T, total_qkv, self.config.num_query_groups, self.head_dim
165
- ) # (B, T, total_qkv, num_query_groups, head_dim)
166
- qkv_axis = -3
194
+ qkv = qkv.view(B, T, -1, self.head_dim)
195
+ q, k, v = qkv.split(
196
+ (
197
+ q_per_kv * self.config.num_query_groups,
198
+ self.config.num_query_groups,
199
+ self.config.num_query_groups,
200
+ ),
201
+ dim=-2,
202
+ )
167
203
  else:
168
- qkv = qkv.view(
169
- B, T, self.config.num_query_groups, total_qkv, self.head_dim
170
- ) # (B, T, num_query_groups, total_qkv, head_dim)
171
- qkv_axis = -2
204
+ qkv = qkv.view(B, T, self.config.num_query_groups, -1)
205
+ q, k, v = qkv.split(
206
+ (q_per_kv * self.head_dim, self.head_dim, self.head_dim), dim=-1
207
+ )
172
208
 
173
- # Split batched computation into three.
174
- q, k, v = qkv.split((q_per_kv, 1, 1), dim=qkv_axis)
175
209
  q = q.reshape(B, T, -1, self.head_dim)
176
210
  k = k.reshape(B, T, -1, self.head_dim)
177
211
  v = v.reshape(B, T, -1, self.head_dim)
178
212
 
179
213
  # Compute rotary positional embedding for query and key.
180
214
  n_elem = int(self.config.rotary_percentage * self.head_dim)
181
- if n_elem > 0:
182
- cos, sin = rope
183
- q_roped = rotary_pos_emb.apply_rope(
184
- q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
185
- )
186
- k_roped = rotary_pos_emb.apply_rope(
187
- k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
188
- )
189
- q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
190
- k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
215
+ q, k = _embed_rope(q, k, n_elem, rope)
191
216
 
192
217
  if self.kv_cache is not None:
193
218
  # TODO(haoliang): Handle when execeeding max sequence length.
@@ -222,5 +247,108 @@ class SelfAttention(CausalSelfAttention):
222
247
  """
223
248
  B, T, _ = x.size()
224
249
  return super().forward(
225
- x, rope=rope, mask=torch.zeros((B, T), dtype=torch.float32), input_pos=input_pos
250
+ x,
251
+ rope=rope,
252
+ mask=torch.zeros((B, 1, T, T), dtype=torch.float32),
253
+ input_pos=input_pos,
226
254
  )
255
+
256
+
257
+ class CrossAttention(nn.Module):
258
+
259
+ def __init__(
260
+ self,
261
+ batch_size: int,
262
+ query_dim: int,
263
+ cross_dim: int,
264
+ config: cfg.AttentionConfig,
265
+ kv_cache_max: int,
266
+ enable_hlfb: bool,
267
+ ) -> None:
268
+ """Initialize an instance of CrossAttention.
269
+
270
+ Args:
271
+ batch_size (int): batch size of the input tensor.
272
+ query_dim (int): query tensor's dimension.
273
+ cross_dim (int): cross attention's dimensions, for key and value tensors.
274
+ config (cfg.AttentionConfig): attention specific configurations.
275
+ kv_cache_max (int): determines the size of the KV Cache buffer, if enabled.
276
+ enable_hlfb (bool): whether hlfb is enabled or not.
277
+ """
278
+ super().__init__()
279
+ self.config = config
280
+ self.head_dim = query_dim // config.num_heads
281
+ self.n_heads = config.num_heads
282
+ self.q_projection = nn.Linear(query_dim, query_dim, bias=config.qkv_use_bias)
283
+ self.k_projection = nn.Linear(cross_dim, query_dim, bias=config.qkv_use_bias)
284
+ self.v_projection = nn.Linear(cross_dim, query_dim, bias=config.qkv_use_bias)
285
+ self.output_projection = nn.Linear(
286
+ query_dim, query_dim, bias=config.output_proj_use_bias
287
+ )
288
+
289
+ self.kv_cache = None
290
+ # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
291
+ if config.enable_kv_cache:
292
+ self.kv_cache = KVCache(
293
+ batch_size,
294
+ kv_cache_max,
295
+ config.num_query_groups,
296
+ self.head_dim,
297
+ enable_hlfb,
298
+ )
299
+
300
+ if enable_hlfb:
301
+ self.sdpa_func = scaled_dot_product_attention_with_hlfb
302
+ else:
303
+ self.sdpa_func = scaled_dot_product_attention
304
+
305
+ def forward(
306
+ self,
307
+ x: torch.Tensor,
308
+ y: torch.Tensor,
309
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
310
+ mask: Optional[torch.Tensor] = None,
311
+ input_pos: Optional[torch.Tensor] = None,
312
+ ):
313
+ """Forward function of the CrossAttention layer.
314
+
315
+ Args:
316
+ x (torch.Tensor): the target tensor, with shape [B, target_seq_len, ...].
317
+ y (torch.Tensor): the source tensor, with shape [B, source_seq_len, ...].
318
+ rope (Tuple[torch.Tensor, torch.Tensor]): the optional input rope tensor.
319
+ mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape [B, n_heads, target_seq_len, source_seq_len].
320
+ input_pos (torch.Tensor): the optional input position tensor.
321
+
322
+ Returns:
323
+ output activation from this cross attention layer.
324
+ """
325
+ batch_size = x.size()[0]
326
+ target_seq_len = x.size()[1]
327
+ source_seq_len = y.size()[1]
328
+
329
+ q = self.q_projection(x)
330
+ k = self.k_projection(y)
331
+ v = self.v_projection(y)
332
+
333
+ interim_shape = (batch_size, -1, self.n_heads, self.head_dim)
334
+ q = q.view(interim_shape)
335
+ k = k.view(interim_shape)
336
+ v = v.view(interim_shape)
337
+
338
+ # Compute rotary positional embedding for query and key.
339
+ n_elem = int(self.config.rotary_percentage * self.head_dim)
340
+ q, k = _embed_rope(q, k, n_elem, rope)
341
+
342
+ if self.kv_cache is not None:
343
+ # TODO(haoliang): Handle when execeeding max sequence length.
344
+ k, v = self.kv_cache.update_cache(input_pos, k, v)
345
+ if mask is None:
346
+ mask = torch.zeros(
347
+ (batch_size, 1, target_seq_len, source_seq_len), dtype=torch.float32
348
+ )
349
+ y = self.sdpa_func(q, k, v, self.head_dim, mask=mask)
350
+ y = y.reshape(batch_size, target_seq_len, -1)
351
+
352
+ # Compute the output projection.
353
+ y = self.output_projection(y)
354
+ return y
@@ -27,6 +27,7 @@ class ActivationType(enum.Enum):
27
27
  SILU = enum.auto()
28
28
  GELU = enum.auto()
29
29
  GELU_TANH = enum.auto()
30
+ GELU_QUICK = enum.auto()
30
31
  GE_GLU = enum.auto()
31
32
  RELU = enum.auto()
32
33
 
@@ -138,6 +139,9 @@ class ModelConfig:
138
139
  # The Attention computation will include relative positional bias.
139
140
  relative_attention: bool = False
140
141
 
142
+ # Default batch size of the exported model. Default value is 1.
143
+ batch_size: int = 1
144
+
141
145
  @property
142
146
  def kv_cache_max(self) -> int:
143
147
  if self.kv_cache_max_len > 0: