ai-edge-torch-nightly 0.2.0.dev20240611__py3-none-any.whl → 0.2.0.dev20240618__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (24) hide show
  1. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
  2. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
  3. ai_edge_torch/debug/__init__.py +1 -0
  4. ai_edge_torch/debug/culprit.py +70 -29
  5. ai_edge_torch/debug/test/test_search_model.py +50 -0
  6. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
  7. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
  8. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
  9. ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
  10. ai_edge_torch/generative/layers/attention.py +154 -26
  11. ai_edge_torch/generative/layers/model_config.py +3 -0
  12. ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
  13. ai_edge_torch/generative/layers/unet/builder.py +20 -2
  14. ai_edge_torch/generative/layers/unet/model_config.py +157 -5
  15. ai_edge_torch/generative/test/test_model_conversion.py +24 -0
  16. ai_edge_torch/generative/test/test_quantize.py +1 -0
  17. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
  18. ai_edge_torch/generative/utilities/t5_loader.py +33 -17
  19. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240618.dist-info}/METADATA +1 -1
  20. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240618.dist-info}/RECORD +23 -22
  21. ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
  22. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240618.dist-info}/LICENSE +0 -0
  23. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240618.dist-info}/WHEEL +0 -0
  24. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240618.dist-info}/top_level.txt +0 -0
@@ -22,16 +22,28 @@ from typing import List, Optional
22
22
  import ai_edge_torch.generative.layers.model_config as layers_cfg
23
23
 
24
24
 
25
- @dataclass
25
+ @enum.unique
26
26
  class SamplingType(enum.Enum):
27
27
  NEAREST = enum.auto()
28
28
  BILINEAR = enum.auto()
29
+ AVERAGE = enum.auto()
30
+ CONVOLUTION = enum.auto()
29
31
 
30
32
 
31
33
  @dataclass
32
- class SamplingConfig:
34
+ class UpSamplingConfig:
35
+ mode: SamplingType
33
36
  scale_factor: float
37
+
38
+
39
+ @dataclass
40
+ class DownSamplingConfig:
34
41
  mode: SamplingType
42
+ in_channels: int
43
+ kernel_size: int
44
+ stride: int
45
+ padding: int
46
+ out_channels: Optional[int] = None
35
47
 
36
48
 
37
49
  @dataclass
@@ -46,9 +58,38 @@ class ResidualBlock2DConfig:
46
58
 
47
59
  @dataclass
48
60
  class AttentionBlock2DConfig:
49
- dims: int
61
+ dim: int
62
+ normalization_config: layers_cfg.NormalizationConfig
63
+ attention_config: layers_cfg.AttentionConfig
64
+ enable_hlfb: bool = True
65
+ attention_batch_size: int = 1
66
+
67
+
68
+ @dataclass
69
+ class CrossAttentionBlock2DConfig:
70
+ query_dim: int
71
+ cross_dim: int
50
72
  normalization_config: layers_cfg.NormalizationConfig
51
73
  attention_config: layers_cfg.AttentionConfig
74
+ enable_hlfb: bool = True
75
+ attention_batch_size: int = 1
76
+
77
+
78
+ @dataclass
79
+ class FeedForwardBlock2DConfig:
80
+ dim: int
81
+ hidden_dim: int
82
+ normalization_config: layers_cfg.NormalizationConfig
83
+ activation_config: layers_cfg.ActivationConfig
84
+ use_bias: bool
85
+
86
+
87
+ @dataclass
88
+ class TransformerBlock2Dconfig:
89
+ pre_conv_normalization_config: layers_cfg.NormalizationConfig
90
+ attention_block_config: AttentionBlock2DConfig
91
+ cross_attention_block_config: CrossAttentionBlock2DConfig
92
+ feed_forward_block_config: FeedForwardBlock2DConfig
52
93
 
53
94
 
54
95
  @dataclass
@@ -58,14 +99,62 @@ class UpDecoderBlock2DConfig:
58
99
  normalization_config: layers_cfg.NormalizationConfig
59
100
  activation_config: layers_cfg.ActivationConfig
60
101
  num_layers: int
61
- # Optional time embedding channels if the residual blocks take a time embedding context as input
102
+ # Optional time embedding channels if the residual blocks take a time embedding as input
62
103
  time_embedding_channels: Optional[int] = None
63
104
  # Whether to add upsample operation after residual blocks
64
105
  add_upsample: bool = True
65
106
  # Whether to add a conv2d layer after upsample
66
107
  upsample_conv: bool = True
67
108
  # Optional sampling config if add_upsample is True.
68
- sampling_config: Optional[SamplingConfig] = None
109
+ sampling_config: Optional[UpSamplingConfig] = None
110
+ # Optional config of transformer blocks interleaved with residual blocks
111
+ transformer_block_config: Optional[TransformerBlock2Dconfig] = None
112
+ # Optional dimension of context tensor if context tensor is given as input.
113
+ context_dim: Optional[int] = None
114
+
115
+
116
+ @dataclass
117
+ class SkipUpDecoderBlock2DConfig:
118
+ in_channels: int
119
+ out_channels: int
120
+ # The dimension of output channels of previous connected block
121
+ prev_out_channels: int
122
+ normalization_config: layers_cfg.NormalizationConfig
123
+ activation_config: layers_cfg.ActivationConfig
124
+ num_layers: int
125
+ # Optional time embedding channels if the residual blocks take a time embedding as input
126
+ time_embedding_channels: Optional[int] = None
127
+ # Whether to add upsample operation after residual blocks
128
+ add_upsample: bool = True
129
+ # Whether to add a conv2d layer after upsample
130
+ upsample_conv: bool = True
131
+ # Optional sampling config if add_upsample is True.
132
+ sampling_config: Optional[UpSamplingConfig] = None
133
+ # Optional config of transformer blocks interleaved with residual blocks
134
+ transformer_block_config: Optional[TransformerBlock2Dconfig] = None
135
+ # Optional dimension of context tensor if context tensor is given as input.
136
+ context_dim: Optional[int] = None
137
+
138
+
139
+ @dataclass
140
+ class DownEncoderBlock2DConfig:
141
+ in_channels: int
142
+ out_channels: int
143
+ normalization_config: layers_cfg.NormalizationConfig
144
+ activation_config: layers_cfg.ActivationConfig
145
+ num_layers: int
146
+ # Padding for the downsampling convolution.
147
+ padding: int = 1
148
+ # Optional time embedding channels if the residual blocks take a time embedding as input
149
+ time_embedding_channels: Optional[int] = None
150
+ # Whether to add downsample operation after residual blocks
151
+ add_downsample: bool = True
152
+ # Optional sampling config if add_upsample is True.
153
+ sampling_config: Optional[DownSamplingConfig] = None
154
+ # Optional config of transformer blocks interleaved with residual blocks
155
+ transformer_block_config: Optional[TransformerBlock2Dconfig] = None
156
+ # Optional dimension of context tensor if context tensor is given as input.
157
+ context_dim: Optional[int] = None
69
158
 
70
159
 
71
160
  @dataclass
@@ -78,6 +167,10 @@ class MidBlock2DConfig:
78
167
  time_embedding_channels: Optional[int] = None
79
168
  # Optional config of attention blocks interleaved with residual blocks
80
169
  attention_block_config: Optional[AttentionBlock2DConfig] = None
170
+ # Optional config of transformer blocks interleaved with residual blocks
171
+ transformer_block_config: Optional[TransformerBlock2Dconfig] = None
172
+ # Optional dimension of context tensor if context tensor is given as input.
173
+ context_dim: Optional[int] = None
81
174
 
82
175
 
83
176
  @dataclass
@@ -115,3 +208,62 @@ class AutoEncoderConfig:
115
208
 
116
209
  # The configuration of middle blocks, that is, after the last block of encoder and before the first block of decoder.
117
210
  mid_block_config: MidBlock2DConfig
211
+
212
+
213
+ @dataclass
214
+ class DiffusionModelConfig:
215
+ """Configurations of Diffusion model."""
216
+
217
+ # Number of channels in the input tensor.
218
+ in_channels: int
219
+
220
+ # Number of channels in the output tensor.
221
+ out_channels: int
222
+
223
+ # The output channels of each block.
224
+ block_out_channels: List[int]
225
+
226
+ # The layesr number of each block.
227
+ layers_per_block: int
228
+
229
+ # The padding to use for the downsampling.
230
+ downsample_padding: int
231
+
232
+ # Normalization config used in residual blocks.
233
+ residual_norm_config: layers_cfg.NormalizationConfig
234
+
235
+ # Activation config used in residual blocks
236
+ residual_activation_type: layers_cfg.ActivationType
237
+
238
+ # The batch size used in transformer blocks, for attention layers.
239
+ transformer_batch_size: int
240
+
241
+ # The number of attention heads used in transformer blocks.
242
+ transformer_num_attention_heads: int
243
+
244
+ # The dimension of cross attention used in transformer blocks.
245
+ transformer_cross_attention_dim: int
246
+
247
+ # Normalization config used in prev conv layer of transformer blocks.
248
+ transformer_pre_conv_norm_config: layers_cfg.NormalizationConfig
249
+
250
+ # Normalization config used in transformer blocks.
251
+ transformer_norm_config: layers_cfg.NormalizationConfig
252
+
253
+ # Activation type of feed forward used in transformer blocks.
254
+ transformer_ff_activation_type: layers_cfg.ActivationType
255
+
256
+ # Number of layers in mid block.
257
+ mid_block_layers: int
258
+
259
+ # Dimension of time embedding.
260
+ time_embedding_dim: int
261
+
262
+ # Time embedding dimensions for blocks.
263
+ time_embedding_blocks_dim: int
264
+
265
+ # Normalization config used for final layer
266
+ final_norm_config: layers_cfg.NormalizationConfig
267
+
268
+ # Activation type used in final layer
269
+ final_activation_type: layers_cfg.ActivationType
@@ -55,6 +55,30 @@ class TestModelConversion(unittest.TestCase):
55
55
  )
56
56
  )
57
57
 
58
+ def test_toy_model_with_multi_batches(self):
59
+ config = toy_model_with_kv_cache.get_model_config()
60
+ config.batch_size = 2
61
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
62
+ idx, input_pos = torch.tensor([[1], [2]], dtype=torch.long), torch.tensor(
63
+ [10], dtype=torch.int64
64
+ )
65
+
66
+ edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
67
+
68
+ # TODO(b/338288901): re-enable test to check output tensors.
69
+ skip_output_check = True
70
+ if skip_output_check is False:
71
+ self.assertTrue(
72
+ model_coverage.compare_tflite_torch(
73
+ edge_model,
74
+ pytorch_model,
75
+ (idx, input_pos),
76
+ num_valid_inputs=1,
77
+ atol=1e-5,
78
+ rtol=1e-5,
79
+ )
80
+ )
81
+
58
82
  def test_toy_model_with_kv_cache_with_hlfb(self):
59
83
  config = toy_model_with_kv_cache.get_model_config()
60
84
  config.enable_hlfb = True
@@ -116,6 +116,7 @@ class TestQuantizeConvert(unittest.TestCase):
116
116
  ]
117
117
  )
118
118
  def test_quantize_convert_toy_sizes(self, quant_config, expected_compression):
119
+ self.skipTest("b/346896669")
119
120
  config = toy_model_with_kv_cache.get_model_config()
120
121
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
121
122
  idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(