ai-edge-torch-nightly 0.2.0.dev20240611__py3-none-any.whl → 0.2.0.dev20240617__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
- ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
- ai_edge_torch/generative/layers/attention.py +154 -26
- ai_edge_torch/generative/layers/model_config.py +3 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
- ai_edge_torch/generative/layers/unet/builder.py +20 -2
- ai_edge_torch/generative/layers/unet/model_config.py +157 -5
- ai_edge_torch/generative/test/test_model_conversion.py +24 -0
- ai_edge_torch/generative/test/test_quantize.py +1 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
- ai_edge_torch/generative/utilities/t5_loader.py +33 -17
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/RECORD +20 -20
- ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/top_level.txt +0 -0
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py
CHANGED
|
@@ -25,6 +25,25 @@ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layo
|
|
|
25
25
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
|
26
26
|
|
|
27
27
|
|
|
28
|
+
def can_partition(graph_module: torch.fx.GraphModule):
|
|
29
|
+
"""Returns true if the input graph_module can be partitioned by min cut solver
|
|
30
|
+
in a reasonable time.
|
|
31
|
+
|
|
32
|
+
The min cut solver implements O(|V|^2|E|) Dinic's algorithm, which may
|
|
33
|
+
take a long time to complete for large graph module. This function determines
|
|
34
|
+
whether the graph module can be partitioned by the graph module size.
|
|
35
|
+
|
|
36
|
+
See go/pytorch-layout-transpose-optimization for more details.
|
|
37
|
+
"""
|
|
38
|
+
graph = graph_module.graph
|
|
39
|
+
n_nodes = len(graph.nodes)
|
|
40
|
+
n_edges = sum(len(n.users) for n in graph.nodes)
|
|
41
|
+
|
|
42
|
+
# According to the experiments our model set, |V| < 2000 can
|
|
43
|
+
# be partitioned generally in a reasonable time.
|
|
44
|
+
return n_nodes**2 * n_edges < 2000**3
|
|
45
|
+
|
|
46
|
+
|
|
28
47
|
class MinCutSolver:
|
|
29
48
|
# A number that is large enough but can fit into int32 with all computations
|
|
30
49
|
# in the maximum flow.
|
|
@@ -261,10 +261,17 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
|
|
261
261
|
self.mark_const_nodes(exported_program)
|
|
262
262
|
|
|
263
263
|
graph_module = exported_program.graph_module
|
|
264
|
-
|
|
264
|
+
partitioner = os.environ.get("AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER", None)
|
|
265
|
+
if partitioner == "MINCUT":
|
|
265
266
|
graph_module = layout_partitioners.min_cut.partition(graph_module)
|
|
266
|
-
|
|
267
|
+
elif partitioner == "GREEDY":
|
|
267
268
|
graph_module = layout_partitioners.greedy.partition(graph_module)
|
|
269
|
+
else:
|
|
270
|
+
# By default use min cut partitioner if possible
|
|
271
|
+
if layout_partitioners.min_cut.can_partition(graph_module):
|
|
272
|
+
graph_module = layout_partitioners.min_cut.partition(graph_module)
|
|
273
|
+
else:
|
|
274
|
+
graph_module = layout_partitioners.greedy.partition(graph_module)
|
|
268
275
|
|
|
269
276
|
graph = graph_module.graph
|
|
270
277
|
for node in list(graph.nodes):
|
|
@@ -21,11 +21,11 @@ import torch
|
|
|
21
21
|
import ai_edge_torch
|
|
22
22
|
import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
|
|
23
23
|
import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
|
|
24
|
-
|
|
24
|
+
import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
|
|
25
25
|
from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
|
|
26
26
|
import ai_edge_torch.generative.examples.stable_diffusion.util as util
|
|
27
|
-
import ai_edge_torch.generative.utilities.autoencoder_loader as autoencoder_loader
|
|
28
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
28
|
+
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
@torch.inference_mode
|
|
@@ -45,11 +45,14 @@ def convert_stable_diffusion_to_tflite(
|
|
|
45
45
|
encoder = Encoder()
|
|
46
46
|
encoder.load_state_dict(torch.load(encoder_ckpt_path))
|
|
47
47
|
|
|
48
|
-
|
|
49
|
-
|
|
48
|
+
diffusion_model = diffusion.Diffusion(diffusion.get_model_config(2))
|
|
49
|
+
diffusion_loader = stable_diffusion_loader.DiffusionModelLoader(
|
|
50
|
+
diffusion_ckpt_path, diffusion.TENSORS_NAMES
|
|
51
|
+
)
|
|
52
|
+
diffusion_loader.load(diffusion_model)
|
|
50
53
|
|
|
51
54
|
decoder_model = decoder.Decoder(decoder.get_model_config())
|
|
52
|
-
decoder_loader =
|
|
55
|
+
decoder_loader = stable_diffusion_loader.AutoEncoderModelLoader(
|
|
53
56
|
decoder_ckpt_path, decoder.TENSORS_NAMES
|
|
54
57
|
)
|
|
55
58
|
decoder_loader.load(decoder_model)
|
|
@@ -84,7 +87,7 @@ def convert_stable_diffusion_to_tflite(
|
|
|
84
87
|
# Diffusion
|
|
85
88
|
ai_edge_torch.signature(
|
|
86
89
|
'diffusion',
|
|
87
|
-
|
|
90
|
+
diffusion_model,
|
|
88
91
|
(torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
|
|
89
92
|
).convert().export('/tmp/stable_diffusion/diffusion.tflite')
|
|
90
93
|
|
|
@@ -20,20 +20,20 @@ import ai_edge_torch.generative.layers.builder as layers_builder
|
|
|
20
20
|
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
21
21
|
import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
|
|
22
22
|
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
|
|
23
|
-
import ai_edge_torch.generative.utilities.
|
|
23
|
+
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
|
|
24
24
|
|
|
25
|
-
TENSORS_NAMES =
|
|
25
|
+
TENSORS_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
|
|
26
26
|
post_quant_conv="0",
|
|
27
27
|
conv_in="1",
|
|
28
|
-
mid_block_tensor_names=
|
|
28
|
+
mid_block_tensor_names=stable_diffusion_loader.MidBlockTensorNames(
|
|
29
29
|
residual_block_tensor_names=[
|
|
30
|
-
|
|
30
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
31
31
|
norm_1="2.groupnorm_1",
|
|
32
32
|
norm_2="2.groupnorm_2",
|
|
33
33
|
conv_1="2.conv_1",
|
|
34
34
|
conv_2="2.conv_2",
|
|
35
35
|
),
|
|
36
|
-
|
|
36
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
37
37
|
norm_1="4.groupnorm_1",
|
|
38
38
|
norm_2="4.groupnorm_2",
|
|
39
39
|
conv_1="4.conv_1",
|
|
@@ -41,7 +41,7 @@ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
|
|
|
41
41
|
),
|
|
42
42
|
],
|
|
43
43
|
attention_block_tensor_names=[
|
|
44
|
-
|
|
44
|
+
stable_diffusion_loader.AttentionBlockTensorNames(
|
|
45
45
|
norm="3.groupnorm",
|
|
46
46
|
fused_qkv_proj="3.attention.in_proj",
|
|
47
47
|
output_proj="3.attention.out_proj",
|
|
@@ -49,21 +49,21 @@ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
|
|
|
49
49
|
],
|
|
50
50
|
),
|
|
51
51
|
up_decoder_blocks_tensor_names=[
|
|
52
|
-
|
|
52
|
+
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
53
53
|
residual_block_tensor_names=[
|
|
54
|
-
|
|
54
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
55
55
|
norm_1="5.groupnorm_1",
|
|
56
56
|
norm_2="5.groupnorm_2",
|
|
57
57
|
conv_1="5.conv_1",
|
|
58
58
|
conv_2="5.conv_2",
|
|
59
59
|
),
|
|
60
|
-
|
|
60
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
61
61
|
norm_1="6.groupnorm_1",
|
|
62
62
|
norm_2="6.groupnorm_2",
|
|
63
63
|
conv_1="6.conv_1",
|
|
64
64
|
conv_2="6.conv_2",
|
|
65
65
|
),
|
|
66
|
-
|
|
66
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
67
67
|
norm_1="7.groupnorm_1",
|
|
68
68
|
norm_2="7.groupnorm_2",
|
|
69
69
|
conv_1="7.conv_1",
|
|
@@ -72,21 +72,21 @@ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
|
|
|
72
72
|
],
|
|
73
73
|
upsample_conv="9",
|
|
74
74
|
),
|
|
75
|
-
|
|
75
|
+
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
76
76
|
residual_block_tensor_names=[
|
|
77
|
-
|
|
77
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
78
78
|
norm_1="10.groupnorm_1",
|
|
79
79
|
norm_2="10.groupnorm_2",
|
|
80
80
|
conv_1="10.conv_1",
|
|
81
81
|
conv_2="10.conv_2",
|
|
82
82
|
),
|
|
83
|
-
|
|
83
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
84
84
|
norm_1="11.groupnorm_1",
|
|
85
85
|
norm_2="11.groupnorm_2",
|
|
86
86
|
conv_1="11.conv_1",
|
|
87
87
|
conv_2="11.conv_2",
|
|
88
88
|
),
|
|
89
|
-
|
|
89
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
90
90
|
norm_1="12.groupnorm_1",
|
|
91
91
|
norm_2="12.groupnorm_2",
|
|
92
92
|
conv_1="12.conv_1",
|
|
@@ -95,22 +95,22 @@ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
|
|
|
95
95
|
],
|
|
96
96
|
upsample_conv="14",
|
|
97
97
|
),
|
|
98
|
-
|
|
98
|
+
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
99
99
|
residual_block_tensor_names=[
|
|
100
|
-
|
|
100
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
101
101
|
norm_1="15.groupnorm_1",
|
|
102
102
|
norm_2="15.groupnorm_2",
|
|
103
103
|
conv_1="15.conv_1",
|
|
104
104
|
conv_2="15.conv_2",
|
|
105
105
|
residual_layer="15.residual_layer",
|
|
106
106
|
),
|
|
107
|
-
|
|
107
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
108
108
|
norm_1="16.groupnorm_1",
|
|
109
109
|
norm_2="16.groupnorm_2",
|
|
110
110
|
conv_1="16.conv_1",
|
|
111
111
|
conv_2="16.conv_2",
|
|
112
112
|
),
|
|
113
|
-
|
|
113
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
114
114
|
norm_1="17.groupnorm_1",
|
|
115
115
|
norm_2="17.groupnorm_2",
|
|
116
116
|
conv_1="17.conv_1",
|
|
@@ -119,22 +119,22 @@ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
|
|
|
119
119
|
],
|
|
120
120
|
upsample_conv="19",
|
|
121
121
|
),
|
|
122
|
-
|
|
122
|
+
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
123
123
|
residual_block_tensor_names=[
|
|
124
|
-
|
|
124
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
125
125
|
norm_1="20.groupnorm_1",
|
|
126
126
|
norm_2="20.groupnorm_2",
|
|
127
127
|
conv_1="20.conv_1",
|
|
128
128
|
conv_2="20.conv_2",
|
|
129
129
|
residual_layer="20.residual_layer",
|
|
130
130
|
),
|
|
131
|
-
|
|
131
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
132
132
|
norm_1="21.groupnorm_1",
|
|
133
133
|
norm_2="21.groupnorm_2",
|
|
134
134
|
conv_1="21.conv_1",
|
|
135
135
|
conv_2="21.conv_2",
|
|
136
136
|
),
|
|
137
|
-
|
|
137
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
138
138
|
norm_1="22.groupnorm_1",
|
|
139
139
|
norm_2="22.groupnorm_2",
|
|
140
140
|
conv_1="22.conv_1",
|
|
@@ -225,8 +225,8 @@ class Decoder(nn.Module):
|
|
|
225
225
|
num_layers=config.layers_per_block,
|
|
226
226
|
add_upsample=not_final_block,
|
|
227
227
|
upsample_conv=True,
|
|
228
|
-
sampling_config=unet_cfg.
|
|
229
|
-
|
|
228
|
+
sampling_config=unet_cfg.UpSamplingConfig(
|
|
229
|
+
mode=unet_cfg.SamplingType.NEAREST, scale_factor=2
|
|
230
230
|
),
|
|
231
231
|
)
|
|
232
232
|
)
|
|
@@ -245,6 +245,14 @@ class Decoder(nn.Module):
|
|
|
245
245
|
)
|
|
246
246
|
|
|
247
247
|
def forward(self, latents_tensor: torch.Tensor) -> torch.Tensor:
|
|
248
|
+
"""Forward function of decoder model.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
latents (torch.Tensor): latents space tensor.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
output decoded image tensor from decoder model.
|
|
255
|
+
"""
|
|
248
256
|
x = latents_tensor / self.config.scaling_factor
|
|
249
257
|
x = self.post_quant_conv(x)
|
|
250
258
|
x = self.conv_in(x)
|
|
@@ -271,7 +279,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
|
271
279
|
)
|
|
272
280
|
|
|
273
281
|
att_config = unet_cfg.AttentionBlock2DConfig(
|
|
274
|
-
|
|
282
|
+
dim=block_out_channels[-1],
|
|
275
283
|
normalization_config=norm_config,
|
|
276
284
|
attention_config=layers_cfg.AttentionConfig(
|
|
277
285
|
num_heads=1,
|