ai-edge-torch-nightly 0.2.0.dev20240604__py3-none-any.whl → 0.2.0.dev20240606__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -99,6 +99,36 @@ def _aten_hardswish(gm: GraphModule, node: Node):
99
99
  node.target = hardswish
100
100
 
101
101
 
102
+ @_register_composite_builder(torch.ops.aten.gelu.default)
103
+ def _aten_gelu(gm: GraphModule, node: Node):
104
+ op = node.target
105
+ args_mapper = TorchOpArgumentsMapper(op)
106
+
107
+ def gelu(*args, **kwargs):
108
+ nonlocal op, args_mapper
109
+
110
+ full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
111
+
112
+ # TFLite supports exact and tanh approximate.
113
+ if full_kwargs["approximate"] != "none" and full_kwargs["approximate"] != "tanh":
114
+ return op(*args, **kwargs)
115
+
116
+ builder = StableHLOCompositeBuilder(
117
+ "aten.gelu.default",
118
+ attr=_tree_map_to_composite_attr_values(
119
+ {
120
+ "approximate": full_kwargs["approximate"],
121
+ }
122
+ ),
123
+ )
124
+ full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"])
125
+ output = op(full_kwargs["self"])
126
+ output = builder.mark_outputs(output)
127
+ return output
128
+
129
+ node.target = gelu
130
+
131
+
102
132
  @_register_composite_builder(torch.ops.aten.avg_pool2d.default)
103
133
  def _aten_avg_pool2d(gm: GraphModule, node: Node):
104
134
  op = node.target
@@ -169,6 +169,24 @@ class TestConvertComposites(unittest.TestCase):
169
169
  model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
170
170
  )
171
171
 
172
+ def test_convert_gelu(self):
173
+ """Tests conversion of a GELU module."""
174
+
175
+ args = (torch.randn((5, 10)),)
176
+ torch_module = torch.nn.GELU().eval()
177
+ edge_model = ai_edge_torch.convert(torch_module, args)
178
+
179
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
180
+
181
+ def test_convert_gelu_approximate(self):
182
+ """Tests conversion of an Approximate GELU module."""
183
+
184
+ args = (torch.randn((5, 10)),)
185
+ torch_module = torch.nn.GELU('tanh').eval()
186
+ edge_model = ai_edge_torch.convert(torch_module, args)
187
+
188
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
189
+
172
190
 
173
191
  if __name__ == '__main__':
174
192
  unittest.main()
@@ -15,65 +15,99 @@
15
15
 
16
16
  import torch
17
17
  from torch import nn
18
- from torch._prims_common import mask_tensor
19
- from torch._prims_common.wrappers import out_wrapper
20
18
 
21
- from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
19
+ from ai_edge_torch.generative.layers.attention import TransformerBlock
20
+ import ai_edge_torch.generative.layers.attention_utils as attention_utils
21
+ import ai_edge_torch.generative.layers.builder as builder
22
+ import ai_edge_torch.generative.layers.model_config as cfg
23
+ import ai_edge_torch.generative.utilities.loader as loading_utils
24
+
25
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
26
+ ff_up_proj="layers.{}.linear_1",
27
+ ff_down_proj="layers.{}.linear_2",
28
+ ff_gate_proj="layers.{}.linear_1",
29
+ attn_fused_qkv_proj="layers.{}.attention.in_proj",
30
+ attn_output_proj="layers.{}.attention.out_proj",
31
+ pre_attn_norm="layers.{}.layernorm_1",
32
+ pre_ff_norm="layers.{}.layernorm_2",
33
+ embedding="embedding.token_embedding",
34
+ embedding_position="embedding.position_value",
35
+ final_norm="layernorm",
36
+ lm_head=None,
37
+ )
22
38
 
23
39
 
24
- class CLIPEmbedding(nn.Module):
25
-
26
- def __init__(self, n_vocab: int, n_embd: int, n_token: int):
27
- super().__init__()
28
- self.token_embedding = nn.Embedding(n_vocab, n_embd)
29
- self.position_value = nn.Parameter(torch.zeros((n_token, n_embd)))
30
-
31
- def forward(self, tokens):
32
- x = self.token_embedding(tokens)
33
- x += self.position_value
34
- return x
35
-
36
-
37
- class CLIPLayer(nn.Module):
40
+ class CLIP(nn.Module):
41
+ """CLIP text encoder
42
+ For details, see https://arxiv.org/abs/2103.00020
43
+ """
38
44
 
39
- def __init__(self, n_head: int, n_embd: int):
45
+ def __init__(self, config: cfg.ModelConfig):
40
46
  super().__init__()
41
- self.layernorm_1 = nn.LayerNorm(n_embd)
42
- self.attention = SelfAttention(n_head, n_embd)
43
- self.layernorm_2 = nn.LayerNorm(n_embd)
44
- self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
45
- self.linear_2 = nn.Linear(4 * n_embd, n_embd)
46
-
47
- def forward(self, x):
48
- residue = x
49
- x = self.layernorm_1(x)
50
- x = self.attention(x, causal_mask=True)
51
- x += residue
47
+ self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
48
+ self.tok_embedding_position = nn.Parameter(
49
+ torch.zeros((config.max_seq_len, config.embedding_dim))
50
+ )
52
51
 
53
- residue = x
54
- x = self.layernorm_2(x)
55
- x = self.linear_1(x)
56
- x = x * torch.sigmoid(1.702 * x) # QuickGELU activation function
57
- x = self.linear_2(x)
58
- x += residue
52
+ self.config = config
53
+ self.transformer_blocks = nn.ModuleList(
54
+ TransformerBlock(config) for _ in range(config.num_layers)
55
+ )
56
+ self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config)
59
57
 
60
- return x
61
-
62
-
63
- class CLIP(nn.Module):
64
-
65
- def __init__(self):
66
- super().__init__()
67
- self.embedding = CLIPEmbedding(49408, 768, 77)
68
- self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)])
69
- self.layernorm = nn.LayerNorm(768)
58
+ self.mask_cache = attention_utils.build_causal_mask_cache(
59
+ size=config.max_seq_len, dtype=torch.float32
60
+ )
70
61
 
71
62
  @torch.inference_mode
72
63
  def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
73
64
  tokens = tokens.type(torch.long)
74
65
 
75
- state = self.embedding(tokens)
76
- for layer in self.layers:
77
- state = layer(state)
78
- output = self.layernorm(state)
66
+ state = self.tok_embedding(tokens) + self.tok_embedding_position
67
+ for layer in self.transformer_blocks:
68
+ state = layer(state, mask=self.mask_cache)
69
+ output = self.final_norm(state)
79
70
  return output
71
+
72
+
73
+ def get_model_config() -> cfg.ModelConfig:
74
+ max_seq_len = 77
75
+ vocab_size = 49408
76
+ num_layers = 12
77
+ num_heads = 12
78
+ num_query_groups = 12
79
+ embedding_dim = 768
80
+
81
+ attn_config = cfg.AttentionConfig(
82
+ num_heads=num_heads,
83
+ num_query_groups=num_query_groups,
84
+ rotary_percentage=0.0,
85
+ qkv_use_bias=True,
86
+ qkv_transpose_before_split=True,
87
+ output_proj_use_bias=True,
88
+ enable_kv_cache=False,
89
+ )
90
+
91
+ ff_config = cfg.FeedForwardConfig(
92
+ type=cfg.FeedForwardType.SEQUENTIAL,
93
+ activation=cfg.ActivationType.GELU_QUICK,
94
+ intermediate_size=embedding_dim * 4,
95
+ use_bias=True,
96
+ )
97
+
98
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
99
+
100
+ config = cfg.ModelConfig(
101
+ vocab_size=vocab_size,
102
+ num_layers=num_layers,
103
+ max_seq_len=max_seq_len,
104
+ embedding_dim=embedding_dim,
105
+ attn_config=attn_config,
106
+ ff_config=ff_config,
107
+ pre_attention_norm_config=norm_config,
108
+ pre_ff_norm_config=norm_config,
109
+ final_norm_config=norm_config,
110
+ enable_hlfb=True,
111
+ )
112
+
113
+ return config
@@ -19,11 +19,12 @@ from pathlib import Path
19
19
  import torch
20
20
 
21
21
  import ai_edge_torch
22
- from ai_edge_torch.generative.examples.stable_diffusion.clip import CLIP
22
+ import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
23
23
  from ai_edge_torch.generative.examples.stable_diffusion.decoder import Decoder
24
24
  from ai_edge_torch.generative.examples.stable_diffusion.diffusion import Diffusion # NOQA
25
25
  from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
26
26
  import ai_edge_torch.generative.examples.stable_diffusion.util as util
27
+ import ai_edge_torch.generative.utilities.loader as loading_utils
27
28
 
28
29
 
29
30
  @torch.inference_mode
@@ -36,8 +37,9 @@ def convert_stable_diffusion_to_tflite(
36
37
  image_width: int = 512,
37
38
  ):
38
39
 
39
- clip = CLIP()
40
- clip.load_state_dict(torch.load(clip_ckpt_path))
40
+ clip_model = clip.CLIP(clip.get_model_config())
41
+ loader = loading_utils.ModelLoader(clip_ckpt_path, clip.TENSOR_NAMES)
42
+ loader.load(clip_model, strict=False)
41
43
 
42
44
  encoder = Encoder()
43
45
  encoder.load_state_dict(torch.load(encoder_ckpt_path))
@@ -59,13 +61,13 @@ def convert_stable_diffusion_to_tflite(
59
61
  )
60
62
 
61
63
  input_latents = encoder(input_image, noise)
62
- context_cond = clip(prompt_tokens)
64
+ context_cond = clip_model(prompt_tokens)
63
65
  context_uncond = torch.zeros_like(context_cond)
64
66
  context = torch.cat([context_cond, context_uncond], axis=0)
65
67
  time_embedding = util.get_time_embedding(timestamp)
66
68
 
67
69
  # CLIP text encoder
68
- ai_edge_torch.signature('encode', clip, (prompt_tokens,)).convert().export(
70
+ ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
69
71
  '/tmp/stable_diffusion/clip.tflite'
70
72
  )
71
73
 
@@ -202,11 +202,6 @@ class UNet(nn.Module):
202
202
 
203
203
  x = self.bottleneck(x, context, time)
204
204
 
205
- # print('x shape:')
206
- # print(list(x.shape))
207
- # print('time shape:')
208
- # print(list(time.shape))
209
-
210
205
  for layers in self.decoders:
211
206
  x = torch.cat((x, skip_connections.pop()), dim=1)
212
207
  x = layers(x, context, time)
@@ -214,199 +209,6 @@ class UNet(nn.Module):
214
209
  return x
215
210
 
216
211
 
217
- # The encoder component.
218
- class UNetEncoder(nn.Module):
219
-
220
- def __init__(self):
221
- super().__init__()
222
- self.time_embedding = TimeEmbedding(320)
223
- self.encoders = nn.ModuleList(
224
- [
225
- SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
226
- SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
227
- SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
228
- SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
229
- SwitchSequential(ResidualBlock(320, 640), AttentionBlock(8, 80)),
230
- SwitchSequential(ResidualBlock(640, 640), AttentionBlock(8, 80)),
231
- SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
232
- SwitchSequential(ResidualBlock(640, 1280), AttentionBlock(8, 160)),
233
- SwitchSequential(ResidualBlock(1280, 1280), AttentionBlock(8, 160)),
234
- SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
235
- SwitchSequential(ResidualBlock(1280, 1280)),
236
- SwitchSequential(ResidualBlock(1280, 1280)),
237
- ]
238
- )
239
-
240
- def forward(self, x, context, time):
241
- time_embedding = self.time_embedding(time)
242
- skip_connections = []
243
- for layers in self.encoders:
244
- x = layers(x, context, time_embedding)
245
- skip_connections.append(x)
246
-
247
- return x, skip_connections, time_embedding
248
-
249
-
250
- class UNetBottleNeck(nn.Module):
251
-
252
- def __init__(self):
253
- super().__init__()
254
- self.bottleneck = SwitchSequential(
255
- ResidualBlock(1280, 1280),
256
- AttentionBlock(8, 160),
257
- ResidualBlock(1280, 1280),
258
- )
259
-
260
- def forward(self, x, context, time):
261
- x = self.bottleneck(x, context, time)
262
- # print('shape')
263
- # print(list(x.shape))
264
- return x
265
-
266
-
267
- # Unet decoder.
268
- class UNetDecoder1(nn.Module):
269
-
270
- def __init__(self):
271
- super().__init__()
272
- self.decoders = nn.ModuleList(
273
- [
274
- SwitchSequential(ResidualBlock(2560, 1280)),
275
- SwitchSequential(ResidualBlock(2560, 1280)),
276
- SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
277
- SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
278
- ]
279
- )
280
-
281
- def forward(self, x, context, time, s9, s10, s11, s12):
282
- x = torch.cat((x, s12), dim=1)
283
- x = self.decoders[0](x, context, time)
284
- x = torch.cat((x, s11), dim=1)
285
- x = self.decoders[1](x, context, time)
286
- x = torch.cat((x, s10), dim=1)
287
- x = self.decoders[2](x, context, time)
288
- x = torch.cat((x, s9), dim=1)
289
- x = self.decoders[3](x, context, time)
290
-
291
- return x
292
-
293
-
294
- class UNetDecoder2(nn.Module):
295
-
296
- def __init__(self):
297
- super().__init__()
298
- self.decoders = nn.ModuleList(
299
- [
300
- SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
301
- SwitchSequential(
302
- ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)
303
- ),
304
- SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
305
- SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
306
- ]
307
- )
308
-
309
- def forward(self, x, context, time, s5, s6, s7, s8):
310
- x = torch.cat((x, s8), dim=1)
311
- x = self.decoders[0](x, context, time)
312
- x = torch.cat((x, s7), dim=1)
313
- x = self.decoders[1](x, context, time)
314
- x = torch.cat((x, s6), dim=1)
315
- x = self.decoders[2](x, context, time)
316
- x = torch.cat((x, s5), dim=1)
317
- x = self.decoders[3](x, context, time)
318
- return x
319
-
320
-
321
- class UNetDecoder3(nn.Module):
322
-
323
- def __init__(self):
324
- super().__init__()
325
- self.decoders = nn.ModuleList(
326
- [
327
- SwitchSequential(
328
- ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)
329
- ),
330
- SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
331
- SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
332
- SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
333
- ]
334
- )
335
- self.final = FinalLayer(320, 4)
336
-
337
- def forward(self, x, context, time, s1, s2, s3, s4):
338
- x = torch.cat((x, s4), dim=1)
339
- x = self.decoders[0](x, context, time)
340
- x = torch.cat((x, s3), dim=1)
341
- x = self.decoders[1](x, context, time)
342
- x = torch.cat((x, s2), dim=1)
343
- x = self.decoders[2](x, context, time)
344
- x = torch.cat((x, s1), dim=1)
345
- x = self.decoders[3](x, context, time)
346
-
347
- x = self.final(x)
348
- return x
349
-
350
-
351
- class UNetDecoder(nn.Module):
352
-
353
- def __init__(self):
354
- super().__init__()
355
- self.decoders = nn.ModuleList(
356
- [
357
- SwitchSequential(ResidualBlock(2560, 1280)),
358
- SwitchSequential(ResidualBlock(2560, 1280)),
359
- SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
360
- SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
361
- SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
362
- SwitchSequential(
363
- ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)
364
- ),
365
- SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
366
- SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
367
- SwitchSequential(
368
- ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)
369
- ),
370
- SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
371
- SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
372
- SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
373
- ]
374
- )
375
- self.final = FinalLayer(320, 4)
376
-
377
- def forward(
378
- self, x, context, time, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12
379
- ):
380
- x = torch.cat((x, s12), dim=1)
381
- x = self.decoders[0](x, context, time)
382
- x = torch.cat((x, s11), dim=1)
383
- x = self.decoders[1](x, context, time)
384
- x = torch.cat((x, s10), dim=1)
385
- x = self.decoders[2](x, context, time)
386
- x = torch.cat((x, s9), dim=1)
387
- x = self.decoders[3](x, context, time)
388
- x = torch.cat((x, s8), dim=1)
389
- x = self.decoders[4](x, context, time)
390
- x = torch.cat((x, s7), dim=1)
391
- x = self.decoders[5](x, context, time)
392
- x = torch.cat((x, s6), dim=1)
393
- x = self.decoders[6](x, context, time)
394
- x = torch.cat((x, s5), dim=1)
395
- x = self.decoders[7](x, context, time)
396
- x = torch.cat((x, s4), dim=1)
397
- x = self.decoders[0](x, context, time)
398
- x = torch.cat((x, s3), dim=1)
399
- x = self.decoders[1](x, context, time)
400
- x = torch.cat((x, s2), dim=1)
401
- x = self.decoders[2](x, context, time)
402
- x = torch.cat((x, s1), dim=1)
403
- x = self.decoders[3](x, context, time)
404
-
405
- x = self.final(x)
406
-
407
- return x
408
-
409
-
410
212
  class FinalLayer(nn.Module):
411
213
 
412
214
  def __init__(self, in_channels, out_channels):
@@ -432,68 +234,6 @@ class Diffusion(nn.Module):
432
234
  @torch.inference_mode
433
235
  def forward(self, latent, context, time):
434
236
  time = self.time_embedding(time)
435
- # print('time:')
436
- # print(list(time.shape))
437
237
  output = self.unet(latent, context, time)
438
238
  output = self.final(output)
439
239
  return output
440
-
441
-
442
- # Calling code as if Diffusion is splitted into two parts.
443
- class DiffusionSplitted(nn.Module):
444
-
445
- def __init__(self):
446
- super().__init__()
447
- self.unet_encoder = UNetEncoder()
448
- self.bottleneck = UNetBottleNeck()
449
- self.unet_decoder1 = UNetDecoder1()
450
- self.unet_decoder2 = UNetDecoder2()
451
- self.unet_decoder3 = UNetDecoder3()
452
-
453
- def get_skip_connections(self, latent, context, time):
454
- _, skip_connections, _ = self.unet_encoder(latent, context, time)
455
- return skip_connections
456
-
457
- def forward(self, latent, context, time):
458
- output, skip_connections, time = self.unet_encoder(latent, context, time)
459
- # print("output shape of unet encoder...")
460
- # print(list(output.shape))
461
- # print("output shape of time...")
462
- # print(list(time.shape))
463
- output = self.bottleneck(output, context, time)
464
- # print("output shape of bn")
465
- # print(list(output.shape))
466
- output = self.unet_decoder1(
467
- output,
468
- context,
469
- time,
470
- skip_connections[8],
471
- skip_connections[9],
472
- skip_connections[10],
473
- skip_connections[11],
474
- )
475
- # print("output shape of d1:")
476
- # print(list(output.shape))
477
-
478
- output = self.unet_decoder2(
479
- output,
480
- context,
481
- time,
482
- skip_connections[4],
483
- skip_connections[5],
484
- skip_connections[6],
485
- skip_connections[7],
486
- )
487
-
488
- # print("output shape of d2:")
489
- # print(list(output.shape))
490
- output = self.unet_decoder3(
491
- output,
492
- context,
493
- time,
494
- skip_connections[0],
495
- skip_connections[1],
496
- skip_connections[2],
497
- skip_connections[3],
498
- )
499
- return output
@@ -20,11 +20,11 @@ import torch
20
20
  from torch import nn
21
21
  import torch.nn.functional as F
22
22
 
23
- from ai_edge_torch.generative.layers.attention import scaled_dot_product_attention # NOQA
24
- from ai_edge_torch.generative.layers.attention import scaled_dot_product_attention_with_hlfb # NOQA
25
23
  import ai_edge_torch.generative.layers.builder as builder
26
24
  from ai_edge_torch.generative.layers.kv_cache import KVCache
27
25
  import ai_edge_torch.generative.layers.model_config as cfg
26
+ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
27
+ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
28
28
 
29
29
 
30
30
  class EncoderDecoderBlock(nn.Module):