ai-edge-torch-nightly 0.2.0.dev20240801__py3-none-any.whl → 0.2.0.dev20240803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (89) hide show
  1. ai_edge_torch/__init__.py +1 -0
  2. ai_edge_torch/convert/conversion.py +12 -8
  3. ai_edge_torch/convert/conversion_utils.py +38 -20
  4. ai_edge_torch/convert/converter.py +11 -5
  5. ai_edge_torch/convert/fx_passes/__init__.py +3 -4
  6. ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
  7. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +46 -40
  8. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
  9. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
  10. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
  11. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
  16. ai_edge_torch/convert/test/test_convert.py +39 -16
  17. ai_edge_torch/convert/test/test_convert_composites.py +115 -86
  18. ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
  19. ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
  20. ai_edge_torch/convert/to_channel_last_io.py +6 -2
  21. ai_edge_torch/debug/culprit.py +41 -16
  22. ai_edge_torch/debug/test/test_culprit.py +4 -3
  23. ai_edge_torch/debug/test/test_search_model.py +4 -3
  24. ai_edge_torch/debug/utils.py +3 -1
  25. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
  26. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
  27. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
  28. ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
  29. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
  30. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
  31. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
  32. ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
  33. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
  34. ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
  35. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  36. ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
  37. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +14 -6
  38. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +14 -7
  39. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +41 -16
  40. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  41. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +36 -13
  42. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  43. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  44. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  45. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  46. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  47. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
  48. ai_edge_torch/generative/examples/t5/t5.py +158 -125
  49. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
  50. ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
  52. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
  55. ai_edge_torch/generative/fx_passes/__init__.py +1 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
  57. ai_edge_torch/generative/layers/attention.py +19 -11
  58. ai_edge_torch/generative/layers/builder.py +3 -4
  59. ai_edge_torch/generative/layers/kv_cache.py +4 -3
  60. ai_edge_torch/generative/layers/model_config.py +6 -2
  61. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
  62. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
  63. ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
  64. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  65. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
  66. ai_edge_torch/generative/quantize/example.py +2 -3
  67. ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
  68. ai_edge_torch/generative/test/loader_test.py +5 -4
  69. ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
  70. ai_edge_torch/generative/test/test_model_conversion.py +2 -3
  71. ai_edge_torch/generative/test/test_quantize.py +45 -48
  72. ai_edge_torch/generative/utilities/loader.py +55 -28
  73. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
  74. ai_edge_torch/generative/utilities/t5_loader.py +77 -48
  75. ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
  76. ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
  79. ai_edge_torch/model.py +8 -5
  80. ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
  81. ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
  82. ai_edge_torch/quantize/quant_config.py +6 -2
  83. ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
  84. ai_edge_torch/version.py +16 -0
  85. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/METADATA +1 -1
  86. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/RECORD +89 -88
  87. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/LICENSE +0 -0
  88. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/WHEEL +0 -0
  89. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/top_level.txt +0 -0
@@ -16,16 +16,15 @@
16
16
 
17
17
  from typing import Optional, Tuple
18
18
 
19
- import torch
20
- from torch import nn
21
- import torch.nn.functional as F
22
-
23
19
  import ai_edge_torch.generative.layers.builder as builder
24
20
  from ai_edge_torch.generative.layers.kv_cache import KVCache
25
21
  import ai_edge_torch.generative.layers.model_config as cfg
26
22
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
27
23
  from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
28
24
  from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
25
+ import torch
26
+ from torch import nn
27
+ import torch.nn.functional as F
29
28
 
30
29
 
31
30
  def _embed_rope(
@@ -140,7 +139,9 @@ class CausalSelfAttention(nn.Module):
140
139
  shape = (config.num_heads + 2 * config.num_query_groups) * self.head_dim
141
140
  # Key, query, value projections for all heads.
142
141
  self.qkv_projection = nn.Linear(dim, shape, bias=config.qkv_use_bias)
143
- self.output_projection = nn.Linear(dim, dim, bias=config.output_proj_use_bias)
142
+ self.output_projection = nn.Linear(
143
+ dim, dim, bias=config.output_proj_use_bias
144
+ )
144
145
  self.config = config
145
146
  self.kv_cache = None
146
147
  self.batch_size = batch_size
@@ -181,9 +182,10 @@ class CausalSelfAttention(nn.Module):
181
182
  """
182
183
  # Batch size, sequence length, embedding dimensionality.
183
184
  B, T, E = x.size()
184
- assert (
185
- B == self.batch_size
186
- ), "batch size of input tensor must match with the batch size specified in the model configuration."
185
+ assert B == self.batch_size, (
186
+ "batch size of input tensor must match with the batch size specified in"
187
+ " the model configuration."
188
+ )
187
189
 
188
190
  qkv = self.qkv_projection(x)
189
191
 
@@ -279,9 +281,15 @@ class CrossAttention(nn.Module):
279
281
  self.config = config
280
282
  self.head_dim = query_dim // config.num_heads
281
283
  self.n_heads = config.num_heads
282
- self.q_projection = nn.Linear(query_dim, query_dim, bias=config.qkv_use_bias)
283
- self.k_projection = nn.Linear(cross_dim, query_dim, bias=config.qkv_use_bias)
284
- self.v_projection = nn.Linear(cross_dim, query_dim, bias=config.qkv_use_bias)
284
+ self.q_projection = nn.Linear(
285
+ query_dim, query_dim, bias=config.qkv_use_bias
286
+ )
287
+ self.k_projection = nn.Linear(
288
+ cross_dim, query_dim, bias=config.qkv_use_bias
289
+ )
290
+ self.v_projection = nn.Linear(
291
+ cross_dim, query_dim, bias=config.qkv_use_bias
292
+ )
285
293
  self.output_projection = nn.Linear(
286
294
  query_dim, query_dim, bias=config.output_proj_use_bias
287
295
  )
@@ -13,13 +13,12 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  # Builder class for individual components.
16
- import torch
17
- from torch import nn
18
- import torch.nn.functional as F
19
-
20
16
  import ai_edge_torch.generative.layers.feed_forward as feed_forward
21
17
  import ai_edge_torch.generative.layers.model_config as cfg
22
18
  import ai_edge_torch.generative.layers.normalization as normalization
19
+ import torch
20
+ from torch import nn
21
+ import torch.nn.functional as F
23
22
 
24
23
 
25
24
  class GeGLU(nn.Module):
@@ -14,16 +14,17 @@
14
14
  # ==============================================================================
15
15
  # `nn.Module` which implements a KV cache.
16
16
 
17
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
17
18
  import torch
18
19
  from torch import nn
19
20
  import torch_xla
20
21
 
21
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
22
-
23
22
 
24
23
  class KVCache(nn.Module):
25
24
 
26
- def __init__(self, batch_size, kv_cache_max, n_heads, head_dim, enable_hlfb=False):
25
+ def __init__(
26
+ self, batch_size, kv_cache_max, n_heads, head_dim, enable_hlfb=False
27
+ ):
27
28
  """Initializes the KVCache layer.
28
29
 
29
30
  Args:
@@ -124,9 +124,13 @@ class ModelConfig:
124
124
  default_factory=NormalizationConfig
125
125
  )
126
126
  # The normalization applied to feed forward's input.
127
- pre_ff_norm_config: NormalizationConfig = field(default_factory=NormalizationConfig)
127
+ pre_ff_norm_config: NormalizationConfig = field(
128
+ default_factory=NormalizationConfig
129
+ )
128
130
  # The normalization applied before LM head.
129
- final_norm_config: NormalizationConfig = field(default_factory=NormalizationConfig)
131
+ final_norm_config: NormalizationConfig = field(
132
+ default_factory=NormalizationConfig
133
+ )
130
134
 
131
135
  # If set to True, only pre_attention_norm is applied to the input and the
132
136
  # decode's output is computed as `output = input + attn_out + ff_out` where
@@ -16,7 +16,9 @@
16
16
  import torch
17
17
 
18
18
 
19
- def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
19
+ def apply_rope(
20
+ x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
21
+ ) -> torch.Tensor:
20
22
  """Computes rotary positional embedding.
21
23
 
22
24
  Args:
@@ -17,11 +17,10 @@
17
17
  import math
18
18
  from typing import Optional
19
19
 
20
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
20
21
  import torch
21
22
  import torch.nn.functional as F
22
23
 
23
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
24
-
25
24
 
26
25
  def scaled_dot_product_attention(
27
26
  q: torch.Tensor,
@@ -15,15 +15,14 @@
15
15
 
16
16
  from typing import List, Optional, Tuple
17
17
 
18
- import torch
19
- from torch import nn
20
-
21
18
  from ai_edge_torch.generative.layers.attention import CrossAttention
22
19
  from ai_edge_torch.generative.layers.attention import SelfAttention
23
20
  import ai_edge_torch.generative.layers.builder as layers_builder
24
21
  import ai_edge_torch.generative.layers.model_config as layers_cfg
25
22
  import ai_edge_torch.generative.layers.unet.builder as unet_builder
26
23
  import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
24
+ import torch
25
+ from torch import nn
27
26
 
28
27
 
29
28
  class ResidualBlock2D(nn.Module):
@@ -41,7 +40,11 @@ class ResidualBlock2D(nn.Module):
41
40
  config.in_channels, config.normalization_config
42
41
  )
43
42
  self.conv_1 = nn.Conv2d(
44
- config.in_channels, config.out_channels, kernel_size=3, stride=1, padding=1
43
+ config.in_channels,
44
+ config.out_channels,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1,
45
48
  )
46
49
  if config.time_embedding_channels is not None:
47
50
  self.time_emb_proj = nn.Linear(
@@ -53,14 +56,22 @@ class ResidualBlock2D(nn.Module):
53
56
  config.out_channels, config.normalization_config
54
57
  )
55
58
  self.conv_2 = nn.Conv2d(
56
- config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
59
+ config.out_channels,
60
+ config.out_channels,
61
+ kernel_size=3,
62
+ stride=1,
63
+ padding=1,
57
64
  )
58
65
  self.act_fn = layers_builder.get_activation(config.activation_config)
59
66
  if config.in_channels == config.out_channels:
60
67
  self.residual_layer = nn.Identity()
61
68
  else:
62
69
  self.residual_layer = nn.Conv2d(
63
- config.in_channels, config.out_channels, kernel_size=1, stride=1, padding=0
70
+ config.in_channels,
71
+ config.out_channels,
72
+ kernel_size=1,
73
+ stride=1,
74
+ padding=0,
64
75
  )
65
76
 
66
77
  def forward(
@@ -105,7 +116,9 @@ class AttentionBlock2D(nn.Module):
105
116
  """
106
117
  super().__init__()
107
118
  self.config = config
108
- self.norm = layers_builder.build_norm(config.dim, config.normalization_config)
119
+ self.norm = layers_builder.build_norm(
120
+ config.dim, config.normalization_config
121
+ )
109
122
  self.attention = SelfAttention(
110
123
  config.attention_batch_size,
111
124
  config.dim,
@@ -125,7 +138,10 @@ class AttentionBlock2D(nn.Module):
125
138
  """
126
139
  residual = input_tensor
127
140
  B, C, H, W = input_tensor.shape
128
- if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
141
+ if (
142
+ self.config.normalization_config.type
143
+ == layers_cfg.NormalizationType.GROUP_NORM
144
+ ):
129
145
  x = self.norm(input_tensor)
130
146
  x = x.view(B, C, H * W)
131
147
  x = x.transpose(-1, -2)
@@ -156,7 +172,9 @@ class CrossAttentionBlock2D(nn.Module):
156
172
  """
157
173
  super().__init__()
158
174
  self.config = config
159
- self.norm = layers_builder.build_norm(config.query_dim, config.normalization_config)
175
+ self.norm = layers_builder.build_norm(
176
+ config.query_dim, config.normalization_config
177
+ )
160
178
  self.attention = CrossAttention(
161
179
  config.attention_batch_size,
162
180
  config.query_dim,
@@ -180,7 +198,10 @@ class CrossAttentionBlock2D(nn.Module):
180
198
  """
181
199
  residual = input_tensor
182
200
  B, C, H, W = input_tensor.shape
183
- if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
201
+ if (
202
+ self.config.normalization_config.type
203
+ == layers_cfg.NormalizationType.GROUP_NORM
204
+ ):
184
205
  x = self.norm(input_tensor)
185
206
  x = x.view(B, C, H * W)
186
207
  x = x.transpose(-1, -2)
@@ -209,7 +230,9 @@ class FeedForwardBlock2D(nn.Module):
209
230
  super().__init__()
210
231
  self.config = config
211
232
  self.act = layers_builder.get_activation(config.activation_config)
212
- self.norm = layers_builder.build_norm(config.dim, config.normalization_config)
233
+ self.norm = layers_builder.build_norm(
234
+ config.dim, config.normalization_config
235
+ )
213
236
  if config.activation_config.type == layers_cfg.ActivationType.GE_GLU:
214
237
  self.w1 = nn.Identity()
215
238
  self.w2 = nn.Linear(config.hidden_dim, config.dim)
@@ -220,7 +243,10 @@ class FeedForwardBlock2D(nn.Module):
220
243
  def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
221
244
  residual = input_tensor
222
245
  B, C, H, W = input_tensor.shape
223
- if self.config.normalization_config.type == layers_cfg.NormalizationType.GROUP_NORM:
246
+ if (
247
+ self.config.normalization_config.type
248
+ == layers_cfg.NormalizationType.GROUP_NORM
249
+ ):
224
250
  x = self.norm(input_tensor)
225
251
  x = x.view(B, C, H * W)
226
252
  x = x.transpose(-1, -2)
@@ -287,7 +313,9 @@ class TransformerBlock2D(nn.Module):
287
313
  padding=0,
288
314
  )
289
315
  self.self_attention = AttentionBlock2D(config.attention_block_config)
290
- self.cross_attention = CrossAttentionBlock2D(config.cross_attention_block_config)
316
+ self.cross_attention = CrossAttentionBlock2D(
317
+ config.cross_attention_block_config
318
+ )
291
319
  self.feed_forward = FeedForwardBlock2D(config.feed_forward_block_config)
292
320
  self.conv_out = nn.Conv2d(
293
321
  config.attention_block_config.dim,
@@ -371,7 +399,9 @@ class DownEncoderBlock2D(nn.Module):
371
399
  if config.transformer_block_config:
372
400
  transformers.append(TransformerBlock2D(config.transformer_block_config))
373
401
  self.resnets = nn.ModuleList(resnets)
374
- self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
402
+ self.transformers = (
403
+ nn.ModuleList(transformers) if len(transformers) > 0 else None
404
+ )
375
405
  if config.add_downsample:
376
406
  self.downsampler = unet_builder.build_downsampling(config.sampling_config)
377
407
  else:
@@ -467,12 +497,18 @@ class UpDecoderBlock2D(nn.Module):
467
497
  if config.transformer_block_config:
468
498
  transformers.append(TransformerBlock2D(config.transformer_block_config))
469
499
  self.resnets = nn.ModuleList(resnets)
470
- self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
500
+ self.transformers = (
501
+ nn.ModuleList(transformers) if len(transformers) > 0 else None
502
+ )
471
503
  if config.add_upsample:
472
504
  self.upsampler = unet_builder.build_upsampling(config.sampling_config)
473
505
  if config.upsample_conv:
474
506
  self.upsample_conv = nn.Conv2d(
475
- config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
507
+ config.out_channels,
508
+ config.out_channels,
509
+ kernel_size=3,
510
+ stride=1,
511
+ padding=1,
476
512
  )
477
513
  else:
478
514
  self.upsampler = None
@@ -548,9 +584,13 @@ class SkipUpDecoderBlock2D(nn.Module):
548
584
  transformers = []
549
585
  for i in range(config.num_layers):
550
586
  res_skip_channels = (
551
- config.in_channels if (i == config.num_layers - 1) else config.out_channels
587
+ config.in_channels
588
+ if (i == config.num_layers - 1)
589
+ else config.out_channels
590
+ )
591
+ resnet_in_channels = (
592
+ config.prev_out_channels if i == 0 else config.out_channels
552
593
  )
553
- resnet_in_channels = config.prev_out_channels if i == 0 else config.out_channels
554
594
  resnets.append(
555
595
  ResidualBlock2D(
556
596
  unet_cfg.ResidualBlock2DConfig(
@@ -565,12 +605,18 @@ class SkipUpDecoderBlock2D(nn.Module):
565
605
  if config.transformer_block_config:
566
606
  transformers.append(TransformerBlock2D(config.transformer_block_config))
567
607
  self.resnets = nn.ModuleList(resnets)
568
- self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
608
+ self.transformers = (
609
+ nn.ModuleList(transformers) if len(transformers) > 0 else None
610
+ )
569
611
  if config.add_upsample:
570
612
  self.upsampler = unet_builder.build_upsampling(config.sampling_config)
571
613
  if config.upsample_conv:
572
614
  self.upsample_conv = nn.Conv2d(
573
- config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
615
+ config.out_channels,
616
+ config.out_channels,
617
+ kernel_size=3,
618
+ stride=1,
619
+ padding=1,
574
620
  )
575
621
  else:
576
622
  self.upsampler = None
@@ -678,7 +724,9 @@ class MidBlock2D(nn.Module):
678
724
  )
679
725
  self.resnets = nn.ModuleList(resnets)
680
726
  self.attentions = nn.ModuleList(attentions) if len(attentions) > 0 else None
681
- self.transformers = nn.ModuleList(transformers) if len(transformers) > 0 else None
727
+ self.transformers = (
728
+ nn.ModuleList(transformers) if len(transformers) > 0 else None
729
+ )
682
730
 
683
731
  def forward(
684
732
  self,
@@ -14,9 +14,8 @@
14
14
  # ==============================================================================
15
15
  # Builder utils for individual components.
16
16
 
17
- from torch import nn
18
-
19
17
  import ai_edge_torch.generative.layers.unet.model_config as unet_config
18
+ from torch import nn
20
19
 
21
20
 
22
21
  def build_upsampling(config: unet_config.UpSamplingConfig):
@@ -30,10 +29,14 @@ def build_upsampling(config: unet_config.UpSamplingConfig):
30
29
 
31
30
  def build_downsampling(config: unet_config.DownSamplingConfig):
32
31
  if config.mode == unet_config.SamplingType.AVERAGE:
33
- return nn.AvgPool2d(config.kernel_size, config.stride, padding=config.padding)
32
+ return nn.AvgPool2d(
33
+ config.kernel_size, config.stride, padding=config.padding
34
+ )
34
35
  elif config.mode == unet_config.SamplingType.CONVOLUTION:
35
36
  out_channels = (
36
- config.in_channels if config.out_channels is None else config.out_channels
37
+ config.in_channels
38
+ if config.out_channels is None
39
+ else config.out_channels
37
40
  )
38
41
  padding = (0, 1, 0, 1) if config.padding == 0 else config.padding
39
42
  return nn.Conv2d(
@@ -16,7 +16,6 @@
16
16
  import json
17
17
 
18
18
  from ai_edge_quantizer import quantizer
19
-
20
19
  from ai_edge_torch.generative.quantize import quant_attrs
21
20
  from ai_edge_torch.generative.quantize import quant_recipe
22
21
 
@@ -44,7 +43,9 @@ def _get_nbits_from_dtype(dtype: quant_attrs.Dtype) -> int:
44
43
  raise ValueError('Unimplemented number of bits')
45
44
 
46
45
 
47
- def _get_dtype_from_dtype(dtype: quant_attrs.Dtype) -> quantizer.qtyping.TensorDataType:
46
+ def _get_dtype_from_dtype(
47
+ dtype: quant_attrs.Dtype,
48
+ ) -> quantizer.qtyping.TensorDataType:
48
49
  if dtype == quant_attrs.Dtype.FP32 or dtype == quant_attrs.Dtype.FP16:
49
50
  return quantizer.qtyping.TensorDataType.FLOAT
50
51
  else:
@@ -59,7 +60,9 @@ def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode:
59
60
  raise ValueError('Unimplemented execution mode')
60
61
 
61
62
 
62
- def _get_channelwise_from_granularity(granularity: quant_attrs.Granularity) -> bool:
63
+ def _get_channelwise_from_granularity(
64
+ granularity: quant_attrs.Granularity,
65
+ ) -> bool:
63
66
  if granularity == quant_attrs.Granularity.CHANNELWISE:
64
67
  return True
65
68
  elif granularity == quant_attrs.Granularity.NONE:
@@ -87,7 +90,9 @@ def _set_quant_config(
87
90
  weight_tensor_config=_TensorQuantConfig(
88
91
  num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
89
92
  symmetric=True,
90
- channel_wise=_get_channelwise_from_granularity(layer_recipe.granularity),
93
+ channel_wise=_get_channelwise_from_granularity(
94
+ layer_recipe.granularity
95
+ ),
91
96
  dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
92
97
  ),
93
98
  execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
@@ -13,12 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import numpy as np
17
- import torch
18
-
19
16
  import ai_edge_torch
20
17
  from ai_edge_torch.generative.examples.gemma import gemma
21
18
  from ai_edge_torch.generative.quantize import quant_recipes
19
+ import numpy as np
20
+ import torch
22
21
 
23
22
 
24
23
  def main():
@@ -74,7 +74,8 @@ class LayerQuantRecipe:
74
74
 
75
75
  if not is_valid:
76
76
  raise ValueError(
77
- 'Unsupported LayerQuantRecipe configuration. See get_supported_recipe_matrix()'
77
+ 'Unsupported LayerQuantRecipe configuration. See'
78
+ ' get_supported_recipe_matrix()'
78
79
  )
79
80
 
80
81
 
@@ -18,11 +18,10 @@ import os
18
18
  import tempfile
19
19
  import unittest
20
20
 
21
- import safetensors.torch
22
- import torch
23
-
24
21
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
25
22
  from ai_edge_torch.generative.utilities import loader as loading_utils
23
+ import safetensors.torch
24
+ import torch
26
25
 
27
26
 
28
27
  class TestLoader(unittest.TestCase):
@@ -59,7 +58,9 @@ class TestLoader(unittest.TestCase):
59
58
  "model.layers.0.mlp.down_proj.weight": torch.randn((2048, 5632)),
60
59
  "model.layers.0.mlp.gate_proj.weight": torch.randn((5632, 2048)),
61
60
  "model.layers.0.mlp.up_proj.weight": torch.randn((5632, 2048)),
62
- "model.layers.0.post_attention_layernorm.weight": torch.randn((2048,)),
61
+ "model.layers.0.post_attention_layernorm.weight": torch.randn((
62
+ 2048,
63
+ )),
63
64
  "model.layers.0.self_attn.k_proj.weight": torch.randn((256, 2048)),
64
65
  "model.layers.0.self_attn.o_proj.weight": torch.randn((2048, 2048)),
65
66
  "model.layers.0.self_attn.q_proj.weight": torch.randn((2048, 2048)),
@@ -16,20 +16,23 @@
16
16
 
17
17
  import unittest
18
18
 
19
- import numpy as np
20
- import torch
21
-
22
19
  from ai_edge_torch.generative.examples.experimental.gemma import gemma
23
20
  from ai_edge_torch.generative.examples.experimental.phi import phi2
24
21
  from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
25
22
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
26
23
  import ai_edge_torch.generative.layers.model_config as cfg
24
+ import numpy as np
25
+ import torch
27
26
 
28
27
 
29
28
  class TestExternalKVLayers(unittest.TestCase):
30
29
 
31
- def _get_test_config(self, num_layers, head_dim, num_query_groups, kv_cache_max_len):
32
- attn_config = cfg.AttentionConfig(num_heads=1, num_query_groups=num_query_groups)
30
+ def _get_test_config(
31
+ self, num_layers, head_dim, num_query_groups, kv_cache_max_len
32
+ ):
33
+ attn_config = cfg.AttentionConfig(
34
+ num_heads=1, num_query_groups=num_query_groups
35
+ )
33
36
  config = cfg.ModelConfig(
34
37
  kv_cache_max_len=kv_cache_max_len,
35
38
  embedding_dim=head_dim,
@@ -56,23 +59,31 @@ class TestExternalKVLayers(unittest.TestCase):
56
59
  entry = kv.caches[0]
57
60
  # single-slice update
58
61
  input_pos = torch.tensor([1])
59
- k_slice = v_slice = torch.full((1, 1, NUM_QG, HEAD_DIM), 5, dtype=torch.float)
62
+ k_slice = v_slice = torch.full(
63
+ (1, 1, NUM_QG, HEAD_DIM), 5, dtype=torch.float
64
+ )
60
65
  updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
61
66
  self.assertEqual(
62
- updated_entry.k_cache.numpy().flatten().tolist(), [0, 0, 5, 5, 0, 0, 0, 0]
67
+ updated_entry.k_cache.numpy().flatten().tolist(),
68
+ [0, 0, 5, 5, 0, 0, 0, 0],
63
69
  )
64
70
  self.assertEqual(
65
- updated_entry.v_cache.numpy().flatten().tolist(), [0, 0, 5, 5, 0, 0, 0, 0]
71
+ updated_entry.v_cache.numpy().flatten().tolist(),
72
+ [0, 0, 5, 5, 0, 0, 0, 0],
66
73
  )
67
74
  # multi-slice update
68
75
  input_pos = torch.tensor([0, 3])
69
- k_slice = v_slice = torch.full((1, 2, NUM_QG, HEAD_DIM), 7, dtype=torch.float)
76
+ k_slice = v_slice = torch.full(
77
+ (1, 2, NUM_QG, HEAD_DIM), 7, dtype=torch.float
78
+ )
70
79
  updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
71
80
  self.assertEqual(
72
- updated_entry.k_cache.numpy().flatten().tolist(), [7, 7, 0, 0, 0, 0, 7, 7]
81
+ updated_entry.k_cache.numpy().flatten().tolist(),
82
+ [7, 7, 0, 0, 0, 0, 7, 7],
73
83
  )
74
84
  self.assertEqual(
75
- updated_entry.v_cache.numpy().flatten().tolist(), [7, 7, 0, 0, 0, 0, 7, 7]
85
+ updated_entry.v_cache.numpy().flatten().tolist(),
86
+ [7, 7, 0, 0, 0, 0, 7, 7],
76
87
  )
77
88
 
78
89
  def test_serialization(self):
@@ -18,15 +18,14 @@ import os
18
18
  import tempfile
19
19
  import unittest
20
20
 
21
- import numpy as np
22
- import torch
23
-
24
21
  import ai_edge_torch
25
22
  from ai_edge_torch.generative.examples.gemma import gemma
26
23
  from ai_edge_torch.generative.examples.phi2 import phi2
27
24
  from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
28
25
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
29
26
  from ai_edge_torch.testing import model_coverage
27
+ import numpy as np
28
+ import torch
30
29
 
31
30
 
32
31
  class TestModelConversion(unittest.TestCase):