ai-edge-torch-nightly 0.2.0.dev20240801__py3-none-any.whl → 0.2.0.dev20240803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (89) hide show
  1. ai_edge_torch/__init__.py +1 -0
  2. ai_edge_torch/convert/conversion.py +12 -8
  3. ai_edge_torch/convert/conversion_utils.py +38 -20
  4. ai_edge_torch/convert/converter.py +11 -5
  5. ai_edge_torch/convert/fx_passes/__init__.py +3 -4
  6. ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
  7. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +46 -40
  8. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
  9. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
  10. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
  11. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
  16. ai_edge_torch/convert/test/test_convert.py +39 -16
  17. ai_edge_torch/convert/test/test_convert_composites.py +115 -86
  18. ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
  19. ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
  20. ai_edge_torch/convert/to_channel_last_io.py +6 -2
  21. ai_edge_torch/debug/culprit.py +41 -16
  22. ai_edge_torch/debug/test/test_culprit.py +4 -3
  23. ai_edge_torch/debug/test/test_search_model.py +4 -3
  24. ai_edge_torch/debug/utils.py +3 -1
  25. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
  26. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
  27. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
  28. ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
  29. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
  30. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
  31. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
  32. ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
  33. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
  34. ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
  35. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  36. ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
  37. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +14 -6
  38. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +14 -7
  39. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +41 -16
  40. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  41. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +36 -13
  42. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  43. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  44. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  45. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  46. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  47. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
  48. ai_edge_torch/generative/examples/t5/t5.py +158 -125
  49. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
  50. ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
  52. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
  55. ai_edge_torch/generative/fx_passes/__init__.py +1 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
  57. ai_edge_torch/generative/layers/attention.py +19 -11
  58. ai_edge_torch/generative/layers/builder.py +3 -4
  59. ai_edge_torch/generative/layers/kv_cache.py +4 -3
  60. ai_edge_torch/generative/layers/model_config.py +6 -2
  61. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
  62. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
  63. ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
  64. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  65. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
  66. ai_edge_torch/generative/quantize/example.py +2 -3
  67. ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
  68. ai_edge_torch/generative/test/loader_test.py +5 -4
  69. ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
  70. ai_edge_torch/generative/test/test_model_conversion.py +2 -3
  71. ai_edge_torch/generative/test/test_quantize.py +45 -48
  72. ai_edge_torch/generative/utilities/loader.py +55 -28
  73. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
  74. ai_edge_torch/generative/utilities/t5_loader.py +77 -48
  75. ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
  76. ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
  79. ai_edge_torch/model.py +8 -5
  80. ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
  81. ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
  82. ai_edge_torch/quantize/quant_config.py +6 -2
  83. ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
  84. ai_edge_torch/version.py +16 -0
  85. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/METADATA +1 -1
  86. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/RECORD +89 -88
  87. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/LICENSE +0 -0
  88. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/WHEEL +0 -0
  89. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/top_level.txt +0 -0
@@ -19,15 +19,14 @@ import os
19
19
  from pathlib import Path
20
20
  from typing import Optional, Tuple
21
21
 
22
- import numpy as np
23
- import torch
24
- import torch.nn as nn
25
-
26
22
  from ai_edge_torch.generative.examples.t5.t5_attention import EncoderDecoderBlock # NOQA
27
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
28
24
  import ai_edge_torch.generative.layers.builder as builder
29
25
  import ai_edge_torch.generative.layers.model_config as cfg
30
26
  import ai_edge_torch.generative.utilities.t5_loader as loading_utils
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn as nn
31
30
 
32
31
  ENCDEC_TENSOR_NAMES = {
33
32
  "ff_up_proj": "{prefix}.block.{}.layer.{num}.DenseReluDense.wi",
@@ -36,7 +35,9 @@ ENCDEC_TENSOR_NAMES = {
36
35
  "attn_key_proj": "{prefix}.block.{}.layer.0.SelfAttention.k",
37
36
  "attn_value_proj": "{prefix}.block.{}.layer.0.SelfAttention.v",
38
37
  "attn_output_proj": "{prefix}.block.{}.layer.0.SelfAttention.o",
39
- "relative_attn_bias": "{prefix}.block.0.layer.0.SelfAttention.relative_attention_bias",
38
+ "relative_attn_bias": (
39
+ "{prefix}.block.0.layer.0.SelfAttention.relative_attention_bias"
40
+ ),
40
41
  "pre_attn_norm": "{prefix}.block.{}.layer.0.layer_norm",
41
42
  "pre_ff_norm": "{prefix}.block.{}.layer.1.layer_norm",
42
43
  "final_norm": "{prefix}.final_layer_norm",
@@ -52,13 +53,13 @@ class T5Stack(nn.Module):
52
53
  self.config = config
53
54
  self.embed_tokens = embed_tokens
54
55
  self.is_decoder = config.is_decoder
55
- self.transformer_blocks = nn.ModuleList(
56
- [
57
- EncoderDecoderBlock(config, has_relative_attention_bias=bool(i == 0))
58
- for i in range(config.num_layers)
59
- ]
56
+ self.transformer_blocks = nn.ModuleList([
57
+ EncoderDecoderBlock(config, has_relative_attention_bias=bool(i == 0))
58
+ for i in range(config.num_layers)
59
+ ])
60
+ self.final_norm = builder.build_norm(
61
+ config.embedding_dim, config.final_norm_config
60
62
  )
61
- self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config)
62
63
 
63
64
  def forward(
64
65
  self,
@@ -81,15 +82,17 @@ class T5Stack(nn.Module):
81
82
  encoder_decoder_position_bias = None
82
83
  for i, layer_module in enumerate(self.transformer_blocks):
83
84
  # EncoderDecoderBlock.forward
84
- hidden_states, position_bias, encoder_decoder_position_bias = layer_module(
85
- hidden_states,
86
- input_pos,
87
- mask=attention_mask,
88
- relative_position=relative_position,
89
- position_bias=position_bias,
90
- encoder_hidden_states=encoder_hidden_states,
91
- encoder_attention_mask=encoder_attention_mask,
92
- encoder_decoder_position_bias=encoder_decoder_position_bias,
85
+ hidden_states, position_bias, encoder_decoder_position_bias = (
86
+ layer_module(
87
+ hidden_states,
88
+ input_pos,
89
+ mask=attention_mask,
90
+ relative_position=relative_position,
91
+ position_bias=position_bias,
92
+ encoder_hidden_states=encoder_hidden_states,
93
+ encoder_attention_mask=encoder_attention_mask,
94
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
95
+ )
93
96
  )
94
97
 
95
98
  hidden_states = self.final_norm(hidden_states)
@@ -130,7 +133,9 @@ class T5(nn.Module):
130
133
  )
131
134
 
132
135
  self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
133
- size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
136
+ size=config.kv_cache_max,
137
+ dtype=torch.float32,
138
+ device=torch.device("cpu"),
134
139
  )
135
140
 
136
141
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
@@ -159,9 +164,10 @@ class T5(nn.Module):
159
164
  pad_mask: torch.Tensor,
160
165
  ) -> torch.Tensor:
161
166
  B, T = input_ids.size()
162
- assert (
163
- self.config.max_seq_len >= T
164
- ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
167
+ assert self.config.max_seq_len >= T, (
168
+ f"Cannot forward sequence of length {T}, max seq length is only"
169
+ f" {self.config.max_seq_len}"
170
+ )
165
171
 
166
172
  enc_mask = self.enc_attn_mask_cache.index_select(2, input_pos)
167
173
  enc_mask = enc_mask[:, :, :, : self.config.kv_cache_max]
@@ -170,10 +176,18 @@ class T5(nn.Module):
170
176
  dec_mask = self.dec_attn_mask_cache.index_select(2, decoder_input_pos)
171
177
  dec_mask = dec_mask[:, :, :, : self.config.kv_cache_max]
172
178
  enc_relative_position = self.enc_rel_pos_mask.index_select(2, input_pos)
173
- enc_relative_position = enc_relative_position[:, :, :, : self.config.kv_cache_max]
174
- dec_relative_position = self.enc_rel_pos_mask.index_select(2, decoder_input_pos)
175
- dec_relative_position = dec_relative_position[:, :, :, : self.config.kv_cache_max]
176
- enc_attention_mask = self.enc_attn_mask_cache.index_select(2, decoder_input_pos)
179
+ enc_relative_position = enc_relative_position[
180
+ :, :, :, : self.config.kv_cache_max
181
+ ]
182
+ dec_relative_position = self.enc_rel_pos_mask.index_select(
183
+ 2, decoder_input_pos
184
+ )
185
+ dec_relative_position = dec_relative_position[
186
+ :, :, :, : self.config.kv_cache_max
187
+ ]
188
+ enc_attention_mask = self.enc_attn_mask_cache.index_select(
189
+ 2, decoder_input_pos
190
+ )
177
191
  # Mask off any "pad" tokens that shouldn't contribute to cross attention
178
192
  enc_attention_mask[:, :, :, :] += pad_mask
179
193
 
@@ -210,7 +224,9 @@ class T5Encoder(nn.Module):
210
224
 
211
225
  self.config = config
212
226
  # Construct model layers.
213
- assert embedding_layer != None, "Passed in embedding layer should not be None!"
227
+ assert (
228
+ embedding_layer != None
229
+ ), "Passed in embedding layer should not be None!"
214
230
  self.tok_embedding = embedding_layer
215
231
 
216
232
  encoder_config = copy.deepcopy(config)
@@ -244,16 +260,19 @@ class T5Encoder(nn.Module):
244
260
  pad_mask: torch.Tensor,
245
261
  ) -> torch.Tensor:
246
262
  B, T = input_ids.size()
247
- assert (
248
- self.config.max_seq_len >= T
249
- ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
263
+ assert self.config.max_seq_len >= T, (
264
+ f"Cannot forward sequence of length {T}, max seq length is only"
265
+ f" {self.config.max_seq_len}"
266
+ )
250
267
 
251
268
  enc_mask = self.enc_attn_mask_cache.index_select(2, input_pos)
252
269
  enc_mask = enc_mask[:, :, :, : self.config.kv_cache_max]
253
270
  # Mask off any "pad" tokens that shouldn't contribute to self-attention
254
271
  enc_mask[:, :, :, :] += pad_mask
255
272
  enc_relative_position = self.enc_rel_pos_mask.index_select(2, input_pos)
256
- enc_relative_position = enc_relative_position[:, :, :, : self.config.kv_cache_max]
273
+ enc_relative_position = enc_relative_position[
274
+ :, :, :, : self.config.kv_cache_max
275
+ ]
257
276
 
258
277
  # Convert encoder inputs in embeddings if needed
259
278
  encoder_hidden_states = self.encoder(
@@ -273,7 +292,9 @@ class T5Decoder(nn.Module):
273
292
 
274
293
  self.config = config
275
294
  # Construct model layers.
276
- assert embedding_layer != None, "Passed in embedding layer should not be None!"
295
+ assert (
296
+ embedding_layer != None
297
+ ), "Passed in embedding layer should not be None!"
277
298
  self.tok_embedding = embedding_layer
278
299
 
279
300
  decoder_config = copy.deepcopy(config)
@@ -302,7 +323,9 @@ class T5Decoder(nn.Module):
302
323
  )
303
324
 
304
325
  self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
305
- size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
326
+ size=config.kv_cache_max,
327
+ dtype=torch.float32,
328
+ device=torch.device("cpu"),
306
329
  )
307
330
 
308
331
  @torch.inference_mode
@@ -315,9 +338,15 @@ class T5Decoder(nn.Module):
315
338
  ) -> torch.Tensor:
316
339
  dec_mask = self.dec_attn_mask_cache.index_select(2, decoder_input_pos)
317
340
  dec_mask = dec_mask[:, :, :, : self.config.kv_cache_max]
318
- dec_relative_position = self.enc_rel_pos_mask.index_select(2, decoder_input_pos)
319
- dec_relative_position = dec_relative_position[:, :, :, : self.config.kv_cache_max]
320
- enc_attention_mask = self.enc_attn_mask_cache.index_select(2, decoder_input_pos)
341
+ dec_relative_position = self.enc_rel_pos_mask.index_select(
342
+ 2, decoder_input_pos
343
+ )
344
+ dec_relative_position = dec_relative_position[
345
+ :, :, :, : self.config.kv_cache_max
346
+ ]
347
+ enc_attention_mask = self.enc_attn_mask_cache.index_select(
348
+ 2, decoder_input_pos
349
+ )
321
350
  # Mask off any "pad" tokens that shouldn't contribute to cross attention
322
351
  enc_attention_mask[:, :, :, :] += pad_mask
323
352
 
@@ -470,89 +499,85 @@ def build_t5_decoder_model(
470
499
 
471
500
 
472
501
  def get_sample_encoder_input_ids() -> torch.Tensor:
473
- idx = torch.tensor(
474
- [
475
- [
476
- 3856,
477
- 27111,
478
- 10,
479
- 4425,
480
- 51,
481
- 4008,
482
- 31,
483
- 7,
484
- 2306,
485
- 16576,
486
- 47,
487
- 4381,
488
- 16,
489
- 8,
490
- 3414,
491
- 13,
492
- 1410,
493
- 16,
494
- 932,
495
- 11,
496
- 1515,
497
- 2766,
498
- 6,
499
- 11,
500
- 4838,
501
- 16,
502
- 23964,
503
- 16,
504
- 1797,
505
- 13,
506
- 24,
507
- 215,
508
- 5,
509
- 94,
510
- 47,
511
- 2017,
512
- 168,
513
- 1204,
514
- 57,
515
- 6800,
516
- 7,
517
- 11,
518
- 9443,
519
- 38,
520
- 3673,
521
- 8,
522
- 4016,
523
- 13,
524
- 66,
525
- 70,
526
- 14234,
527
- 5,
528
- 2449,
529
- 1215,
530
- 83,
531
- 17,
532
- 16,
533
- 8782,
534
- 70,
535
- 723,
536
- 30,
537
- 8,
538
- 6162,
539
- 13,
540
- 1410,
541
- 12,
542
- 48,
543
- 833,
544
- 250,
545
- 13,
546
- 149,
547
- 231,
548
- 79,
549
- 1858,
550
- 16576,
551
- 5,
552
- 1,
553
- ]
554
- ]
555
- )
502
+ idx = torch.tensor([[
503
+ 3856,
504
+ 27111,
505
+ 10,
506
+ 4425,
507
+ 51,
508
+ 4008,
509
+ 31,
510
+ 7,
511
+ 2306,
512
+ 16576,
513
+ 47,
514
+ 4381,
515
+ 16,
516
+ 8,
517
+ 3414,
518
+ 13,
519
+ 1410,
520
+ 16,
521
+ 932,
522
+ 11,
523
+ 1515,
524
+ 2766,
525
+ 6,
526
+ 11,
527
+ 4838,
528
+ 16,
529
+ 23964,
530
+ 16,
531
+ 1797,
532
+ 13,
533
+ 24,
534
+ 215,
535
+ 5,
536
+ 94,
537
+ 47,
538
+ 2017,
539
+ 168,
540
+ 1204,
541
+ 57,
542
+ 6800,
543
+ 7,
544
+ 11,
545
+ 9443,
546
+ 38,
547
+ 3673,
548
+ 8,
549
+ 4016,
550
+ 13,
551
+ 66,
552
+ 70,
553
+ 14234,
554
+ 5,
555
+ 2449,
556
+ 1215,
557
+ 83,
558
+ 17,
559
+ 16,
560
+ 8782,
561
+ 70,
562
+ 723,
563
+ 30,
564
+ 8,
565
+ 6162,
566
+ 13,
567
+ 1410,
568
+ 12,
569
+ 48,
570
+ 833,
571
+ 250,
572
+ 13,
573
+ 149,
574
+ 231,
575
+ 79,
576
+ 1858,
577
+ 16576,
578
+ 5,
579
+ 1,
580
+ ]])
556
581
  return idx
557
582
 
558
583
 
@@ -584,9 +609,15 @@ def define_and_run_t5_split(checkpoint_path: str) -> None:
584
609
  t5_goldens = torch.load(current_dir / "t5_lm_logits.pt")
585
610
 
586
611
  config = get_model_config_t5()
587
- embedding_layer = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=0)
588
- t5_encoder_model = build_t5_encoder_model(config, embedding_layer, checkpoint_path)
589
- t5_decoder_model = build_t5_decoder_model(config, embedding_layer, checkpoint_path)
612
+ embedding_layer = nn.Embedding(
613
+ config.vocab_size, config.embedding_dim, padding_idx=0
614
+ )
615
+ t5_encoder_model = build_t5_encoder_model(
616
+ config, embedding_layer, checkpoint_path
617
+ )
618
+ t5_decoder_model = build_t5_decoder_model(
619
+ config, embedding_layer, checkpoint_path
620
+ )
590
621
  idx = get_sample_encoder_input_ids()
591
622
 
592
623
  tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
@@ -595,7 +626,9 @@ def define_and_run_t5_split(checkpoint_path: str) -> None:
595
626
 
596
627
  decode_d_token = torch.tensor([[0]], dtype=torch.int64)
597
628
  decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
598
- pad_mask = torch.zeros([t5_encoder_model.config.kv_cache_max], dtype=torch.float32)
629
+ pad_mask = torch.zeros(
630
+ [t5_encoder_model.config.kv_cache_max], dtype=torch.float32
631
+ )
599
632
  pad_mask[77:] = float("-inf")
600
633
  hidden_states = t5_encoder_model.forward(tokens, input_pos, pad_mask)
601
634
  lm_logits = t5_decoder_model.forward(
@@ -16,16 +16,15 @@
16
16
 
17
17
  from typing import Optional, Tuple
18
18
 
19
- import torch
20
- from torch import nn
21
- import torch.nn.functional as F
22
-
23
19
  from ai_edge_torch.generative.layers.attention import CrossAttention
24
20
  import ai_edge_torch.generative.layers.builder as builder
25
21
  from ai_edge_torch.generative.layers.kv_cache import KVCache
26
22
  import ai_edge_torch.generative.layers.model_config as cfg
27
23
  from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
28
24
  from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
25
+ import torch
26
+ from torch import nn
27
+ import torch.nn.functional as F
29
28
 
30
29
  BATCH_SIZE = 1
31
30
 
@@ -181,9 +180,13 @@ class T5Attention(CrossAttention):
181
180
  """
182
181
 
183
182
  x = self.pre_atten_norm(x)
184
- B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
183
+ B, T, C = (
184
+ x.size()
185
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
185
186
  query_states = self.q_projection(x)
186
- query_states = query_states.reshape(B, T, -1, self.head_dim) # (B, T, nh_q, hs)
187
+ query_states = query_states.reshape(
188
+ B, T, -1, self.head_dim
189
+ ) # (B, T, nh_q, hs)
187
190
 
188
191
  if key_value_states is not None:
189
192
  (
@@ -223,7 +226,12 @@ class T5Attention(CrossAttention):
223
226
 
224
227
  mask = mask + position_bias
225
228
  y = self.sdpa_func(
226
- query_states, key_states, value_states, self.head_dim, mask=mask, scale=1.0
229
+ query_states,
230
+ key_states,
231
+ value_states,
232
+ self.head_dim,
233
+ mask=mask,
234
+ scale=1.0,
227
235
  )
228
236
  y = y.reshape(B, T, C) # re-assemble all head outputs side by side
229
237
  # output projection
@@ -15,15 +15,14 @@
15
15
  # A toy example which has a single-layer transformer block.
16
16
  from typing import Tuple
17
17
 
18
- import numpy as np
19
- import torch
20
- import torch.nn as nn
21
-
22
18
  import ai_edge_torch
23
19
  from ai_edge_torch.generative.layers.attention import TransformerBlock
24
20
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
21
  import ai_edge_torch.generative.layers.builder as builder
26
22
  import ai_edge_torch.generative.layers.model_config as cfg
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
27
26
 
28
27
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
29
28
  KV_CACHE_MAX_LEN = 100
@@ -72,7 +71,10 @@ class ToySingleLayerModel(torch.nn.Module):
72
71
 
73
72
  def get_model_config() -> cfg.ModelConfig:
74
73
  attn_config = cfg.AttentionConfig(
75
- num_heads=32, num_query_groups=4, rotary_percentage=1.0, enable_kv_cache=False
74
+ num_heads=32,
75
+ num_query_groups=4,
76
+ rotary_percentage=1.0,
77
+ enable_kv_cache=False,
76
78
  )
77
79
  ff_config = cfg.FeedForwardConfig(
78
80
  type=cfg.FeedForwardType.GATED,
@@ -16,16 +16,15 @@
16
16
 
17
17
  from typing import Tuple
18
18
 
19
- import torch
20
- import torch.nn as nn
21
- import torch_xla
22
-
23
19
  import ai_edge_torch
24
20
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
21
  import ai_edge_torch.generative.layers.builder as builder
26
22
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
27
23
  from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
28
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch_xla
29
28
 
30
29
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
31
30
 
@@ -15,16 +15,15 @@
15
15
  # A toy example which has basic transformer block (w/ KV-Cache).
16
16
  from typing import List, Tuple
17
17
 
18
- import numpy as np
19
- import torch
20
- import torch.nn as nn
21
- import torch_xla
22
-
23
18
  import ai_edge_torch
24
19
  from ai_edge_torch.generative.layers.attention import TransformerBlock
25
20
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
21
  import ai_edge_torch.generative.layers.builder as builder
27
22
  import ai_edge_torch.generative.layers.model_config as cfg
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch_xla
28
27
 
29
28
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
30
29
 
@@ -16,11 +16,10 @@
16
16
  import os
17
17
  from pathlib import Path
18
18
 
19
- import torch
20
-
21
19
  import ai_edge_torch
22
20
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
21
  from ai_edge_torch.generative.quantize import quant_recipes
22
+ import torch
24
23
 
25
24
 
26
25
  def convert_tiny_llama_to_tflite(
@@ -58,7 +57,9 @@ def convert_tiny_llama_to_tflite(
58
57
  .signature('decode', pytorch_model, (decode_token, decode_input_pos))
59
58
  .convert(quant_config=quant_config)
60
59
  )
61
- edge_model.export(f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')
60
+ edge_model.export(
61
+ f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
62
+ )
62
63
 
63
64
 
64
65
  if __name__ == '__main__':
@@ -17,15 +17,14 @@
17
17
  import os
18
18
  from pathlib import Path
19
19
 
20
- import numpy as np
21
- import torch
22
- import torch.nn as nn
23
-
24
20
  from ai_edge_torch.generative.layers.attention import TransformerBlock
25
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
22
  import ai_edge_torch.generative.layers.builder as builder
27
23
  import ai_edge_torch.generative.layers.model_config as cfg
28
24
  import ai_edge_torch.generative.utilities.loader as loading_utils
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
29
28
 
30
29
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
30
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -72,7 +71,9 @@ class TinyLLamma(nn.Module):
72
71
  device=torch.device("cpu"),
73
72
  )
74
73
  self.mask_cache = attn_utils.build_causal_mask_cache(
75
- size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
74
+ size=config.kv_cache_max,
75
+ dtype=torch.float32,
76
+ device=torch.device("cpu"),
76
77
  )
77
78
  self.config = config
78
79
 
@@ -82,9 +83,10 @@ class TinyLLamma(nn.Module):
82
83
  @torch.inference_mode
83
84
  def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
84
85
  B, T = idx.size()
85
- assert (
86
- self.config.max_seq_len >= T
87
- ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
86
+ assert self.config.max_seq_len >= T, (
87
+ f"Cannot forward sequence of length {T}, max seq length is only"
88
+ f" {self.config.max_seq_len}"
89
+ )
88
90
 
89
91
  cos, sin = self.rope_cache
90
92
  cos = cos.index_select(0, input_pos)
@@ -12,11 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import torch
16
-
17
15
  from ai_edge_torch.convert.fx_passes import CanonicalizePass
18
16
  from ai_edge_torch.convert.fx_passes import run_passes
19
17
  from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass # NOQA
18
+ import torch
20
19
 
21
20
 
22
21
  def run_generative_passes(
@@ -12,10 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import torch
16
-
17
15
  from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
18
16
  from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
17
+ import torch
19
18
 
20
19
 
21
20
  class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
@@ -36,7 +35,11 @@ class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
36
35
  # Composite info:
37
36
  # - name: odml.scaled_dot_product_attention
38
37
  # - inputs: q, k, v, mask
39
- if name == "odml.scaled_dot_product_attention" and is_input and io_position == 3:
38
+ if (
39
+ name == "odml.scaled_dot_product_attention"
40
+ and is_input
41
+ and io_position == 3
42
+ ):
40
43
  if self.is_zero_tensor_node(source):
41
44
  # Remove the mark_tensor call on the mask input by
42
45
  # replacing the target with an identity function.