ai-edge-torch-nightly 0.2.0.dev20240611__py3-none-any.whl → 0.2.0.dev20240619__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
- ai_edge_torch/debug/__init__.py +1 -0
- ai_edge_torch/debug/culprit.py +70 -29
- ai_edge_torch/debug/test/test_search_model.py +50 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
- ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
- ai_edge_torch/generative/layers/attention.py +154 -26
- ai_edge_torch/generative/layers/model_config.py +3 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
- ai_edge_torch/generative/layers/unet/builder.py +20 -2
- ai_edge_torch/generative/layers/unet/model_config.py +157 -5
- ai_edge_torch/generative/test/test_model_conversion.py +24 -0
- ai_edge_torch/generative/test/test_quantize.py +1 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
- ai_edge_torch/generative/utilities/t5_loader.py +33 -17
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/RECORD +23 -22
- ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/top_level.txt +0 -0
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py
CHANGED
|
@@ -25,6 +25,25 @@ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layo
|
|
|
25
25
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
|
26
26
|
|
|
27
27
|
|
|
28
|
+
def can_partition(graph_module: torch.fx.GraphModule):
|
|
29
|
+
"""Returns true if the input graph_module can be partitioned by min cut solver
|
|
30
|
+
in a reasonable time.
|
|
31
|
+
|
|
32
|
+
The min cut solver implements O(|V|^2|E|) Dinic's algorithm, which may
|
|
33
|
+
take a long time to complete for large graph module. This function determines
|
|
34
|
+
whether the graph module can be partitioned by the graph module size.
|
|
35
|
+
|
|
36
|
+
See go/pytorch-layout-transpose-optimization for more details.
|
|
37
|
+
"""
|
|
38
|
+
graph = graph_module.graph
|
|
39
|
+
n_nodes = len(graph.nodes)
|
|
40
|
+
n_edges = sum(len(n.users) for n in graph.nodes)
|
|
41
|
+
|
|
42
|
+
# According to the experiments our model set, |V| < 2000 can
|
|
43
|
+
# be partitioned generally in a reasonable time.
|
|
44
|
+
return n_nodes**2 * n_edges < 2000**3
|
|
45
|
+
|
|
46
|
+
|
|
28
47
|
class MinCutSolver:
|
|
29
48
|
# A number that is large enough but can fit into int32 with all computations
|
|
30
49
|
# in the maximum flow.
|
|
@@ -261,10 +261,17 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
|
|
261
261
|
self.mark_const_nodes(exported_program)
|
|
262
262
|
|
|
263
263
|
graph_module = exported_program.graph_module
|
|
264
|
-
|
|
264
|
+
partitioner = os.environ.get("AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER", None)
|
|
265
|
+
if partitioner == "MINCUT":
|
|
265
266
|
graph_module = layout_partitioners.min_cut.partition(graph_module)
|
|
266
|
-
|
|
267
|
+
elif partitioner == "GREEDY":
|
|
267
268
|
graph_module = layout_partitioners.greedy.partition(graph_module)
|
|
269
|
+
else:
|
|
270
|
+
# By default use min cut partitioner if possible
|
|
271
|
+
if layout_partitioners.min_cut.can_partition(graph_module):
|
|
272
|
+
graph_module = layout_partitioners.min_cut.partition(graph_module)
|
|
273
|
+
else:
|
|
274
|
+
graph_module = layout_partitioners.greedy.partition(graph_module)
|
|
268
275
|
|
|
269
276
|
graph = graph_module.graph
|
|
270
277
|
for node in list(graph.nodes):
|
ai_edge_torch/debug/__init__.py
CHANGED
ai_edge_torch/debug/culprit.py
CHANGED
|
@@ -21,7 +21,7 @@ import io
|
|
|
21
21
|
import operator
|
|
22
22
|
import os
|
|
23
23
|
import sys
|
|
24
|
-
from typing import Any, Generator, List, Optional, Tuple
|
|
24
|
+
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
|
|
25
25
|
|
|
26
26
|
from functorch.compile import minifier as fx_minifier
|
|
27
27
|
import torch
|
|
@@ -85,10 +85,9 @@ def _tensor_to_buffer(t: torch.Tensor):
|
|
|
85
85
|
|
|
86
86
|
|
|
87
87
|
@dataclasses.dataclass
|
|
88
|
-
class
|
|
88
|
+
class SearchResult:
|
|
89
89
|
graph_module: torch.fx.GraphModule
|
|
90
90
|
inputs: Tuple[Any]
|
|
91
|
-
_runtime_errors: bool
|
|
92
91
|
|
|
93
92
|
@property
|
|
94
93
|
def graph(self) -> torch.fx.Graph:
|
|
@@ -98,6 +97,11 @@ class Culprit:
|
|
|
98
97
|
def graph(self, fx_g: torch.fx.Graph):
|
|
99
98
|
self.graph_module.graph = fx_g
|
|
100
99
|
|
|
100
|
+
|
|
101
|
+
@dataclasses.dataclass
|
|
102
|
+
class Culprit(SearchResult):
|
|
103
|
+
_runtime_errors: bool
|
|
104
|
+
|
|
101
105
|
@property
|
|
102
106
|
def stack_traces(self) -> List[str]:
|
|
103
107
|
stack_traces = set()
|
|
@@ -342,42 +346,42 @@ def _fx_minifier_checker(fx_gm, inputs, runtime_errors=False):
|
|
|
342
346
|
return False
|
|
343
347
|
|
|
344
348
|
|
|
345
|
-
def
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
runtime_errors: bool = False,
|
|
349
|
+
def _search_model(
|
|
350
|
+
predicate_f: Callable[[torch.fx.GraphModule, List[Any]], bool],
|
|
351
|
+
model: Union[torch.export.ExportedProgram, torch.nn.Module],
|
|
352
|
+
export_args: Tuple[Any] = None,
|
|
350
353
|
*,
|
|
354
|
+
max_granularity: Optional[int] = None,
|
|
351
355
|
enable_fx_minifier_logging: bool = False,
|
|
352
|
-
) -> Generator[
|
|
353
|
-
"""Finds
|
|
356
|
+
) -> Generator[SearchResult, None, None]:
|
|
357
|
+
"""Finds subgraphs in the torch model that satify a certain predicate function provided by the users.
|
|
354
358
|
|
|
355
359
|
Args:
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
360
|
+
predicate_f: a predicate function the users specify.
|
|
361
|
+
It takes a FX (sub)graph and the inputs to this graph,
|
|
362
|
+
return True if the graph satisfies the predicate,
|
|
363
|
+
return False otherwise.
|
|
364
|
+
model: model in which to search subgraph.
|
|
365
|
+
export_args: A set of args to trace the model with,
|
|
366
|
+
i.e. model(*args) must run.
|
|
359
367
|
max_granularity - FX minifier arg. The maximum granularity (number of nodes)
|
|
360
368
|
in the returned ATen FX subgraph of the culprit.
|
|
361
|
-
|
|
362
|
-
with converted model.
|
|
363
|
-
enable_fx_minifier_logging: If true, allows the underlying FX minifier to log
|
|
364
|
-
the progress.
|
|
369
|
+
enable_fx_minifier_logging: If true, allows the underlying FX minifier to log the progress.
|
|
365
370
|
"""
|
|
366
371
|
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
372
|
+
if isinstance(model, torch.nn.Module):
|
|
373
|
+
try:
|
|
374
|
+
ep = torch.export.export(model, export_args)
|
|
375
|
+
except Exception as err:
|
|
376
|
+
raise ValueError(
|
|
377
|
+
"Your model is not exportable by torch.export.export. Please modify your model to be torch-exportable first."
|
|
378
|
+
) from err
|
|
379
|
+
else:
|
|
380
|
+
ep = model
|
|
373
381
|
|
|
374
382
|
fx_gm, fx_inputs = utils.exported_program_to_fx_graph_module_and_inputs(ep)
|
|
375
383
|
fx_gm = _normalize_getitem_nodes(fx_gm)
|
|
376
384
|
|
|
377
|
-
fx_minifier_checker = functools.partial(
|
|
378
|
-
_fx_minifier_checker, runtime_errors=runtime_errors
|
|
379
|
-
)
|
|
380
|
-
|
|
381
385
|
# HACK: temporarily disable XLA_HLO_DEBUG so that fx_minifier won't dump
|
|
382
386
|
# intermediate stablehlo files to storage.
|
|
383
387
|
# https://github.com/pytorch/pytorch/blob/main/torch/_functorch/fx_minifier.py#L440
|
|
@@ -405,13 +409,13 @@ def find_culprits(
|
|
|
405
409
|
raw_min_fx_gm, raw_min_inputs = fx_minifier(
|
|
406
410
|
fx_gm,
|
|
407
411
|
fx_inputs,
|
|
408
|
-
|
|
412
|
+
predicate_f,
|
|
409
413
|
max_granularity=max_granularity,
|
|
410
414
|
)
|
|
411
415
|
|
|
412
416
|
min_fx_gm, min_inputs = _normalize_minified_fx_gm(raw_min_fx_gm, raw_min_inputs)
|
|
413
417
|
found_culprits_num += 1
|
|
414
|
-
yield
|
|
418
|
+
yield SearchResult(min_fx_gm, min_inputs)
|
|
415
419
|
|
|
416
420
|
fx_gm, fx_inputs = _erase_sub_gm_from_gm(
|
|
417
421
|
fx_gm, fx_inputs, raw_min_fx_gm, raw_min_inputs
|
|
@@ -421,3 +425,40 @@ def find_culprits(
|
|
|
421
425
|
if str(e) == "Input graph did not fail the tester" and found_culprits_num > 0:
|
|
422
426
|
break
|
|
423
427
|
raise e
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def find_culprits(
|
|
431
|
+
torch_model: torch.nn.Module,
|
|
432
|
+
args: Tuple[Any],
|
|
433
|
+
max_granularity: Optional[int] = None,
|
|
434
|
+
runtime_errors: bool = False,
|
|
435
|
+
*,
|
|
436
|
+
enable_fx_minifier_logging: bool = False,
|
|
437
|
+
) -> Generator[Culprit, None, None]:
|
|
438
|
+
"""Finds culprits in the AI Edge Torch model conversion.
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
torch_model: model to export and save
|
|
442
|
+
args: A set of args to trace the model with, i.e.
|
|
443
|
+
torch_model(*args) must run
|
|
444
|
+
max_granularity - FX minifier arg. The maximum granularity (number of nodes)
|
|
445
|
+
in the returned ATen FX subgraph of the culprit.
|
|
446
|
+
runtime_errors: If true, find culprits for Python runtime errors
|
|
447
|
+
with converted model.
|
|
448
|
+
enable_fx_minifier_logging: If true, allows the underlying FX minifier to log the progress.
|
|
449
|
+
"""
|
|
450
|
+
|
|
451
|
+
fx_minifier_checker = functools.partial(
|
|
452
|
+
_fx_minifier_checker, runtime_errors=runtime_errors
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
for search_result in _search_model(
|
|
456
|
+
fx_minifier_checker,
|
|
457
|
+
torch_model,
|
|
458
|
+
args,
|
|
459
|
+
max_granularity=max_granularity,
|
|
460
|
+
enable_fx_minifier_logging=enable_fx_minifier_logging,
|
|
461
|
+
):
|
|
462
|
+
yield Culprit(
|
|
463
|
+
search_result.graph_module, search_result.inputs, _runtime_errors=runtime_errors
|
|
464
|
+
)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
import unittest
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
from ai_edge_torch.debug import _search_model
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TestSearchModel(unittest.TestCase):
|
|
25
|
+
|
|
26
|
+
def test_search_model_with_ops(self):
|
|
27
|
+
class MultipleOpsModel(torch.nn.Module):
|
|
28
|
+
|
|
29
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
30
|
+
sub_0 = x - 1
|
|
31
|
+
add_0 = y + 1
|
|
32
|
+
mul_0 = x * y
|
|
33
|
+
add_1 = sub_0 + add_0
|
|
34
|
+
mul_1 = add_0 * mul_0
|
|
35
|
+
sub_1 = add_1 - mul_1
|
|
36
|
+
return sub_1
|
|
37
|
+
|
|
38
|
+
model = MultipleOpsModel().eval()
|
|
39
|
+
args = (torch.rand(10), torch.rand(10))
|
|
40
|
+
|
|
41
|
+
def find_subgraph_with_sub(fx_gm, inputs):
|
|
42
|
+
return torch.ops.aten.sub.Tensor in [n.target for n in fx_gm.graph.nodes]
|
|
43
|
+
|
|
44
|
+
results = list(_search_model(find_subgraph_with_sub, model, args))
|
|
45
|
+
self.assertEqual(len(results), 2)
|
|
46
|
+
self.assertIn(torch.ops.aten.sub.Tensor, [n.target for n in results[0].graph.nodes])
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
if __name__ == "__main__":
|
|
50
|
+
unittest.main()
|
|
@@ -21,11 +21,11 @@ import torch
|
|
|
21
21
|
import ai_edge_torch
|
|
22
22
|
import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
|
|
23
23
|
import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
|
|
24
|
-
|
|
24
|
+
import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
|
|
25
25
|
from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
|
|
26
26
|
import ai_edge_torch.generative.examples.stable_diffusion.util as util
|
|
27
|
-
import ai_edge_torch.generative.utilities.autoencoder_loader as autoencoder_loader
|
|
28
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
28
|
+
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
@torch.inference_mode
|
|
@@ -45,11 +45,14 @@ def convert_stable_diffusion_to_tflite(
|
|
|
45
45
|
encoder = Encoder()
|
|
46
46
|
encoder.load_state_dict(torch.load(encoder_ckpt_path))
|
|
47
47
|
|
|
48
|
-
|
|
49
|
-
|
|
48
|
+
diffusion_model = diffusion.Diffusion(diffusion.get_model_config(2))
|
|
49
|
+
diffusion_loader = stable_diffusion_loader.DiffusionModelLoader(
|
|
50
|
+
diffusion_ckpt_path, diffusion.TENSORS_NAMES
|
|
51
|
+
)
|
|
52
|
+
diffusion_loader.load(diffusion_model)
|
|
50
53
|
|
|
51
54
|
decoder_model = decoder.Decoder(decoder.get_model_config())
|
|
52
|
-
decoder_loader =
|
|
55
|
+
decoder_loader = stable_diffusion_loader.AutoEncoderModelLoader(
|
|
53
56
|
decoder_ckpt_path, decoder.TENSORS_NAMES
|
|
54
57
|
)
|
|
55
58
|
decoder_loader.load(decoder_model)
|
|
@@ -84,7 +87,7 @@ def convert_stable_diffusion_to_tflite(
|
|
|
84
87
|
# Diffusion
|
|
85
88
|
ai_edge_torch.signature(
|
|
86
89
|
'diffusion',
|
|
87
|
-
|
|
90
|
+
diffusion_model,
|
|
88
91
|
(torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
|
|
89
92
|
).convert().export('/tmp/stable_diffusion/diffusion.tflite')
|
|
90
93
|
|
|
@@ -20,20 +20,20 @@ import ai_edge_torch.generative.layers.builder as layers_builder
|
|
|
20
20
|
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
21
21
|
import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
|
|
22
22
|
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
|
|
23
|
-
import ai_edge_torch.generative.utilities.
|
|
23
|
+
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
|
|
24
24
|
|
|
25
|
-
TENSORS_NAMES =
|
|
25
|
+
TENSORS_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
|
|
26
26
|
post_quant_conv="0",
|
|
27
27
|
conv_in="1",
|
|
28
|
-
mid_block_tensor_names=
|
|
28
|
+
mid_block_tensor_names=stable_diffusion_loader.MidBlockTensorNames(
|
|
29
29
|
residual_block_tensor_names=[
|
|
30
|
-
|
|
30
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
31
31
|
norm_1="2.groupnorm_1",
|
|
32
32
|
norm_2="2.groupnorm_2",
|
|
33
33
|
conv_1="2.conv_1",
|
|
34
34
|
conv_2="2.conv_2",
|
|
35
35
|
),
|
|
36
|
-
|
|
36
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
37
37
|
norm_1="4.groupnorm_1",
|
|
38
38
|
norm_2="4.groupnorm_2",
|
|
39
39
|
conv_1="4.conv_1",
|
|
@@ -41,7 +41,7 @@ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
|
|
|
41
41
|
),
|
|
42
42
|
],
|
|
43
43
|
attention_block_tensor_names=[
|
|
44
|
-
|
|
44
|
+
stable_diffusion_loader.AttentionBlockTensorNames(
|
|
45
45
|
norm="3.groupnorm",
|
|
46
46
|
fused_qkv_proj="3.attention.in_proj",
|
|
47
47
|
output_proj="3.attention.out_proj",
|
|
@@ -49,21 +49,21 @@ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
|
|
|
49
49
|
],
|
|
50
50
|
),
|
|
51
51
|
up_decoder_blocks_tensor_names=[
|
|
52
|
-
|
|
52
|
+
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
53
53
|
residual_block_tensor_names=[
|
|
54
|
-
|
|
54
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
55
55
|
norm_1="5.groupnorm_1",
|
|
56
56
|
norm_2="5.groupnorm_2",
|
|
57
57
|
conv_1="5.conv_1",
|
|
58
58
|
conv_2="5.conv_2",
|
|
59
59
|
),
|
|
60
|
-
|
|
60
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
61
61
|
norm_1="6.groupnorm_1",
|
|
62
62
|
norm_2="6.groupnorm_2",
|
|
63
63
|
conv_1="6.conv_1",
|
|
64
64
|
conv_2="6.conv_2",
|
|
65
65
|
),
|
|
66
|
-
|
|
66
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
67
67
|
norm_1="7.groupnorm_1",
|
|
68
68
|
norm_2="7.groupnorm_2",
|
|
69
69
|
conv_1="7.conv_1",
|
|
@@ -72,21 +72,21 @@ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
|
|
|
72
72
|
],
|
|
73
73
|
upsample_conv="9",
|
|
74
74
|
),
|
|
75
|
-
|
|
75
|
+
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
76
76
|
residual_block_tensor_names=[
|
|
77
|
-
|
|
77
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
78
78
|
norm_1="10.groupnorm_1",
|
|
79
79
|
norm_2="10.groupnorm_2",
|
|
80
80
|
conv_1="10.conv_1",
|
|
81
81
|
conv_2="10.conv_2",
|
|
82
82
|
),
|
|
83
|
-
|
|
83
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
84
84
|
norm_1="11.groupnorm_1",
|
|
85
85
|
norm_2="11.groupnorm_2",
|
|
86
86
|
conv_1="11.conv_1",
|
|
87
87
|
conv_2="11.conv_2",
|
|
88
88
|
),
|
|
89
|
-
|
|
89
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
90
90
|
norm_1="12.groupnorm_1",
|
|
91
91
|
norm_2="12.groupnorm_2",
|
|
92
92
|
conv_1="12.conv_1",
|
|
@@ -95,22 +95,22 @@ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
|
|
|
95
95
|
],
|
|
96
96
|
upsample_conv="14",
|
|
97
97
|
),
|
|
98
|
-
|
|
98
|
+
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
99
99
|
residual_block_tensor_names=[
|
|
100
|
-
|
|
100
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
101
101
|
norm_1="15.groupnorm_1",
|
|
102
102
|
norm_2="15.groupnorm_2",
|
|
103
103
|
conv_1="15.conv_1",
|
|
104
104
|
conv_2="15.conv_2",
|
|
105
105
|
residual_layer="15.residual_layer",
|
|
106
106
|
),
|
|
107
|
-
|
|
107
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
108
108
|
norm_1="16.groupnorm_1",
|
|
109
109
|
norm_2="16.groupnorm_2",
|
|
110
110
|
conv_1="16.conv_1",
|
|
111
111
|
conv_2="16.conv_2",
|
|
112
112
|
),
|
|
113
|
-
|
|
113
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
114
114
|
norm_1="17.groupnorm_1",
|
|
115
115
|
norm_2="17.groupnorm_2",
|
|
116
116
|
conv_1="17.conv_1",
|
|
@@ -119,22 +119,22 @@ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
|
|
|
119
119
|
],
|
|
120
120
|
upsample_conv="19",
|
|
121
121
|
),
|
|
122
|
-
|
|
122
|
+
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
123
123
|
residual_block_tensor_names=[
|
|
124
|
-
|
|
124
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
125
125
|
norm_1="20.groupnorm_1",
|
|
126
126
|
norm_2="20.groupnorm_2",
|
|
127
127
|
conv_1="20.conv_1",
|
|
128
128
|
conv_2="20.conv_2",
|
|
129
129
|
residual_layer="20.residual_layer",
|
|
130
130
|
),
|
|
131
|
-
|
|
131
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
132
132
|
norm_1="21.groupnorm_1",
|
|
133
133
|
norm_2="21.groupnorm_2",
|
|
134
134
|
conv_1="21.conv_1",
|
|
135
135
|
conv_2="21.conv_2",
|
|
136
136
|
),
|
|
137
|
-
|
|
137
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
138
138
|
norm_1="22.groupnorm_1",
|
|
139
139
|
norm_2="22.groupnorm_2",
|
|
140
140
|
conv_1="22.conv_1",
|
|
@@ -225,8 +225,8 @@ class Decoder(nn.Module):
|
|
|
225
225
|
num_layers=config.layers_per_block,
|
|
226
226
|
add_upsample=not_final_block,
|
|
227
227
|
upsample_conv=True,
|
|
228
|
-
sampling_config=unet_cfg.
|
|
229
|
-
|
|
228
|
+
sampling_config=unet_cfg.UpSamplingConfig(
|
|
229
|
+
mode=unet_cfg.SamplingType.NEAREST, scale_factor=2
|
|
230
230
|
),
|
|
231
231
|
)
|
|
232
232
|
)
|
|
@@ -245,6 +245,14 @@ class Decoder(nn.Module):
|
|
|
245
245
|
)
|
|
246
246
|
|
|
247
247
|
def forward(self, latents_tensor: torch.Tensor) -> torch.Tensor:
|
|
248
|
+
"""Forward function of decoder model.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
latents (torch.Tensor): latents space tensor.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
output decoded image tensor from decoder model.
|
|
255
|
+
"""
|
|
248
256
|
x = latents_tensor / self.config.scaling_factor
|
|
249
257
|
x = self.post_quant_conv(x)
|
|
250
258
|
x = self.conv_in(x)
|
|
@@ -271,7 +279,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
|
271
279
|
)
|
|
272
280
|
|
|
273
281
|
att_config = unet_cfg.AttentionBlock2DConfig(
|
|
274
|
-
|
|
282
|
+
dim=block_out_channels[-1],
|
|
275
283
|
normalization_config=norm_config,
|
|
276
284
|
attention_config=layers_cfg.AttentionConfig(
|
|
277
285
|
num_heads=1,
|