ai-edge-torch-nightly 0.2.0.dev20240611__py3-none-any.whl → 0.2.0.dev20240619__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (24) hide show
  1. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
  2. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
  3. ai_edge_torch/debug/__init__.py +1 -0
  4. ai_edge_torch/debug/culprit.py +70 -29
  5. ai_edge_torch/debug/test/test_search_model.py +50 -0
  6. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
  7. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
  8. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
  9. ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
  10. ai_edge_torch/generative/layers/attention.py +154 -26
  11. ai_edge_torch/generative/layers/model_config.py +3 -0
  12. ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
  13. ai_edge_torch/generative/layers/unet/builder.py +20 -2
  14. ai_edge_torch/generative/layers/unet/model_config.py +157 -5
  15. ai_edge_torch/generative/test/test_model_conversion.py +24 -0
  16. ai_edge_torch/generative/test/test_quantize.py +1 -0
  17. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
  18. ai_edge_torch/generative/utilities/t5_loader.py +33 -17
  19. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/METADATA +1 -1
  20. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/RECORD +23 -22
  21. ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
  22. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/LICENSE +0 -0
  23. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/WHEEL +0 -0
  24. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240619.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,25 @@ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layo
25
25
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
26
26
 
27
27
 
28
+ def can_partition(graph_module: torch.fx.GraphModule):
29
+ """Returns true if the input graph_module can be partitioned by min cut solver
30
+ in a reasonable time.
31
+
32
+ The min cut solver implements O(|V|^2|E|) Dinic's algorithm, which may
33
+ take a long time to complete for large graph module. This function determines
34
+ whether the graph module can be partitioned by the graph module size.
35
+
36
+ See go/pytorch-layout-transpose-optimization for more details.
37
+ """
38
+ graph = graph_module.graph
39
+ n_nodes = len(graph.nodes)
40
+ n_edges = sum(len(n.users) for n in graph.nodes)
41
+
42
+ # According to the experiments our model set, |V| < 2000 can
43
+ # be partitioned generally in a reasonable time.
44
+ return n_nodes**2 * n_edges < 2000**3
45
+
46
+
28
47
  class MinCutSolver:
29
48
  # A number that is large enough but can fit into int32 with all computations
30
49
  # in the maximum flow.
@@ -261,10 +261,17 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
261
261
  self.mark_const_nodes(exported_program)
262
262
 
263
263
  graph_module = exported_program.graph_module
264
- if os.environ.get("AIEDGETORCH_LAYOUT_OPTIMIZE_USE_MINCUT_PARTITIONER"):
264
+ partitioner = os.environ.get("AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER", None)
265
+ if partitioner == "MINCUT":
265
266
  graph_module = layout_partitioners.min_cut.partition(graph_module)
266
- else:
267
+ elif partitioner == "GREEDY":
267
268
  graph_module = layout_partitioners.greedy.partition(graph_module)
269
+ else:
270
+ # By default use min cut partitioner if possible
271
+ if layout_partitioners.min_cut.can_partition(graph_module):
272
+ graph_module = layout_partitioners.min_cut.partition(graph_module)
273
+ else:
274
+ graph_module = layout_partitioners.greedy.partition(graph_module)
268
275
 
269
276
  graph = graph_module.graph
270
277
  for node in list(graph.nodes):
@@ -13,4 +13,5 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from .culprit import _search_model
16
17
  from .culprit import find_culprits
@@ -21,7 +21,7 @@ import io
21
21
  import operator
22
22
  import os
23
23
  import sys
24
- from typing import Any, Generator, List, Optional, Tuple
24
+ from typing import Any, Callable, Generator, List, Optional, Tuple, Union
25
25
 
26
26
  from functorch.compile import minifier as fx_minifier
27
27
  import torch
@@ -85,10 +85,9 @@ def _tensor_to_buffer(t: torch.Tensor):
85
85
 
86
86
 
87
87
  @dataclasses.dataclass
88
- class Culprit:
88
+ class SearchResult:
89
89
  graph_module: torch.fx.GraphModule
90
90
  inputs: Tuple[Any]
91
- _runtime_errors: bool
92
91
 
93
92
  @property
94
93
  def graph(self) -> torch.fx.Graph:
@@ -98,6 +97,11 @@ class Culprit:
98
97
  def graph(self, fx_g: torch.fx.Graph):
99
98
  self.graph_module.graph = fx_g
100
99
 
100
+
101
+ @dataclasses.dataclass
102
+ class Culprit(SearchResult):
103
+ _runtime_errors: bool
104
+
101
105
  @property
102
106
  def stack_traces(self) -> List[str]:
103
107
  stack_traces = set()
@@ -342,42 +346,42 @@ def _fx_minifier_checker(fx_gm, inputs, runtime_errors=False):
342
346
  return False
343
347
 
344
348
 
345
- def find_culprits(
346
- torch_model: torch.nn.Module,
347
- args: Tuple[Any],
348
- max_granularity: Optional[int] = None,
349
- runtime_errors: bool = False,
349
+ def _search_model(
350
+ predicate_f: Callable[[torch.fx.GraphModule, List[Any]], bool],
351
+ model: Union[torch.export.ExportedProgram, torch.nn.Module],
352
+ export_args: Tuple[Any] = None,
350
353
  *,
354
+ max_granularity: Optional[int] = None,
351
355
  enable_fx_minifier_logging: bool = False,
352
- ) -> Generator[Culprit, None, None]:
353
- """Finds culprits in the AI Edge Torch model conversion.
356
+ ) -> Generator[SearchResult, None, None]:
357
+ """Finds subgraphs in the torch model that satify a certain predicate function provided by the users.
354
358
 
355
359
  Args:
356
- torch_model: model to export and save
357
- args: A set of args to trace the model with, i.e.
358
- torch_model(*args) must run
360
+ predicate_f: a predicate function the users specify.
361
+ It takes a FX (sub)graph and the inputs to this graph,
362
+ return True if the graph satisfies the predicate,
363
+ return False otherwise.
364
+ model: model in which to search subgraph.
365
+ export_args: A set of args to trace the model with,
366
+ i.e. model(*args) must run.
359
367
  max_granularity - FX minifier arg. The maximum granularity (number of nodes)
360
368
  in the returned ATen FX subgraph of the culprit.
361
- runtime_errors: If true, find culprits for Python runtime errors
362
- with converted model.
363
- enable_fx_minifier_logging: If true, allows the underlying FX minifier to log
364
- the progress.
369
+ enable_fx_minifier_logging: If true, allows the underlying FX minifier to log the progress.
365
370
  """
366
371
 
367
- try:
368
- ep = torch.export.export(torch_model, args)
369
- except Exception as err:
370
- raise ValueError(
371
- "Your model is not exportable by torch.export.export. Please modify your model to be torch-exportable first."
372
- ) from err
372
+ if isinstance(model, torch.nn.Module):
373
+ try:
374
+ ep = torch.export.export(model, export_args)
375
+ except Exception as err:
376
+ raise ValueError(
377
+ "Your model is not exportable by torch.export.export. Please modify your model to be torch-exportable first."
378
+ ) from err
379
+ else:
380
+ ep = model
373
381
 
374
382
  fx_gm, fx_inputs = utils.exported_program_to_fx_graph_module_and_inputs(ep)
375
383
  fx_gm = _normalize_getitem_nodes(fx_gm)
376
384
 
377
- fx_minifier_checker = functools.partial(
378
- _fx_minifier_checker, runtime_errors=runtime_errors
379
- )
380
-
381
385
  # HACK: temporarily disable XLA_HLO_DEBUG so that fx_minifier won't dump
382
386
  # intermediate stablehlo files to storage.
383
387
  # https://github.com/pytorch/pytorch/blob/main/torch/_functorch/fx_minifier.py#L440
@@ -405,13 +409,13 @@ def find_culprits(
405
409
  raw_min_fx_gm, raw_min_inputs = fx_minifier(
406
410
  fx_gm,
407
411
  fx_inputs,
408
- fx_minifier_checker,
412
+ predicate_f,
409
413
  max_granularity=max_granularity,
410
414
  )
411
415
 
412
416
  min_fx_gm, min_inputs = _normalize_minified_fx_gm(raw_min_fx_gm, raw_min_inputs)
413
417
  found_culprits_num += 1
414
- yield Culprit(min_fx_gm, min_inputs, _runtime_errors=runtime_errors)
418
+ yield SearchResult(min_fx_gm, min_inputs)
415
419
 
416
420
  fx_gm, fx_inputs = _erase_sub_gm_from_gm(
417
421
  fx_gm, fx_inputs, raw_min_fx_gm, raw_min_inputs
@@ -421,3 +425,40 @@ def find_culprits(
421
425
  if str(e) == "Input graph did not fail the tester" and found_culprits_num > 0:
422
426
  break
423
427
  raise e
428
+
429
+
430
+ def find_culprits(
431
+ torch_model: torch.nn.Module,
432
+ args: Tuple[Any],
433
+ max_granularity: Optional[int] = None,
434
+ runtime_errors: bool = False,
435
+ *,
436
+ enable_fx_minifier_logging: bool = False,
437
+ ) -> Generator[Culprit, None, None]:
438
+ """Finds culprits in the AI Edge Torch model conversion.
439
+
440
+ Args:
441
+ torch_model: model to export and save
442
+ args: A set of args to trace the model with, i.e.
443
+ torch_model(*args) must run
444
+ max_granularity - FX minifier arg. The maximum granularity (number of nodes)
445
+ in the returned ATen FX subgraph of the culprit.
446
+ runtime_errors: If true, find culprits for Python runtime errors
447
+ with converted model.
448
+ enable_fx_minifier_logging: If true, allows the underlying FX minifier to log the progress.
449
+ """
450
+
451
+ fx_minifier_checker = functools.partial(
452
+ _fx_minifier_checker, runtime_errors=runtime_errors
453
+ )
454
+
455
+ for search_result in _search_model(
456
+ fx_minifier_checker,
457
+ torch_model,
458
+ args,
459
+ max_granularity=max_granularity,
460
+ enable_fx_minifier_logging=enable_fx_minifier_logging,
461
+ ):
462
+ yield Culprit(
463
+ search_result.graph_module, search_result.inputs, _runtime_errors=runtime_errors
464
+ )
@@ -0,0 +1,50 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import unittest
18
+
19
+ import torch
20
+
21
+ from ai_edge_torch.debug import _search_model
22
+
23
+
24
+ class TestSearchModel(unittest.TestCase):
25
+
26
+ def test_search_model_with_ops(self):
27
+ class MultipleOpsModel(torch.nn.Module):
28
+
29
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
30
+ sub_0 = x - 1
31
+ add_0 = y + 1
32
+ mul_0 = x * y
33
+ add_1 = sub_0 + add_0
34
+ mul_1 = add_0 * mul_0
35
+ sub_1 = add_1 - mul_1
36
+ return sub_1
37
+
38
+ model = MultipleOpsModel().eval()
39
+ args = (torch.rand(10), torch.rand(10))
40
+
41
+ def find_subgraph_with_sub(fx_gm, inputs):
42
+ return torch.ops.aten.sub.Tensor in [n.target for n in fx_gm.graph.nodes]
43
+
44
+ results = list(_search_model(find_subgraph_with_sub, model, args))
45
+ self.assertEqual(len(results), 2)
46
+ self.assertIn(torch.ops.aten.sub.Tensor, [n.target for n in results[0].graph.nodes])
47
+
48
+
49
+ if __name__ == "__main__":
50
+ unittest.main()
@@ -21,11 +21,11 @@ import torch
21
21
  import ai_edge_torch
22
22
  import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
23
23
  import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
24
- from ai_edge_torch.generative.examples.stable_diffusion.diffusion import Diffusion # NOQA
24
+ import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
25
25
  from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
26
26
  import ai_edge_torch.generative.examples.stable_diffusion.util as util
27
- import ai_edge_torch.generative.utilities.autoencoder_loader as autoencoder_loader
28
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
28
+ import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
29
29
 
30
30
 
31
31
  @torch.inference_mode
@@ -45,11 +45,14 @@ def convert_stable_diffusion_to_tflite(
45
45
  encoder = Encoder()
46
46
  encoder.load_state_dict(torch.load(encoder_ckpt_path))
47
47
 
48
- diffusion = Diffusion()
49
- diffusion.load_state_dict(torch.load(diffusion_ckpt_path))
48
+ diffusion_model = diffusion.Diffusion(diffusion.get_model_config(2))
49
+ diffusion_loader = stable_diffusion_loader.DiffusionModelLoader(
50
+ diffusion_ckpt_path, diffusion.TENSORS_NAMES
51
+ )
52
+ diffusion_loader.load(diffusion_model)
50
53
 
51
54
  decoder_model = decoder.Decoder(decoder.get_model_config())
52
- decoder_loader = autoencoder_loader.AutoEncoderModelLoader(
55
+ decoder_loader = stable_diffusion_loader.AutoEncoderModelLoader(
53
56
  decoder_ckpt_path, decoder.TENSORS_NAMES
54
57
  )
55
58
  decoder_loader.load(decoder_model)
@@ -84,7 +87,7 @@ def convert_stable_diffusion_to_tflite(
84
87
  # Diffusion
85
88
  ai_edge_torch.signature(
86
89
  'diffusion',
87
- diffusion,
90
+ diffusion_model,
88
91
  (torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
89
92
  ).convert().export('/tmp/stable_diffusion/diffusion.tflite')
90
93
 
@@ -20,20 +20,20 @@ import ai_edge_torch.generative.layers.builder as layers_builder
20
20
  import ai_edge_torch.generative.layers.model_config as layers_cfg
21
21
  import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
22
22
  import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
23
- import ai_edge_torch.generative.utilities.autoencoder_loader as autoencoder_loader
23
+ import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
24
24
 
25
- TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
25
+ TENSORS_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
26
26
  post_quant_conv="0",
27
27
  conv_in="1",
28
- mid_block_tensor_names=autoencoder_loader.MidBlockTensorNames(
28
+ mid_block_tensor_names=stable_diffusion_loader.MidBlockTensorNames(
29
29
  residual_block_tensor_names=[
30
- autoencoder_loader.ResidualBlockTensorNames(
30
+ stable_diffusion_loader.ResidualBlockTensorNames(
31
31
  norm_1="2.groupnorm_1",
32
32
  norm_2="2.groupnorm_2",
33
33
  conv_1="2.conv_1",
34
34
  conv_2="2.conv_2",
35
35
  ),
36
- autoencoder_loader.ResidualBlockTensorNames(
36
+ stable_diffusion_loader.ResidualBlockTensorNames(
37
37
  norm_1="4.groupnorm_1",
38
38
  norm_2="4.groupnorm_2",
39
39
  conv_1="4.conv_1",
@@ -41,7 +41,7 @@ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
41
41
  ),
42
42
  ],
43
43
  attention_block_tensor_names=[
44
- autoencoder_loader.AttnetionBlockTensorNames(
44
+ stable_diffusion_loader.AttentionBlockTensorNames(
45
45
  norm="3.groupnorm",
46
46
  fused_qkv_proj="3.attention.in_proj",
47
47
  output_proj="3.attention.out_proj",
@@ -49,21 +49,21 @@ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
49
49
  ],
50
50
  ),
51
51
  up_decoder_blocks_tensor_names=[
52
- autoencoder_loader.UpDecoderBlockTensorNames(
52
+ stable_diffusion_loader.UpDecoderBlockTensorNames(
53
53
  residual_block_tensor_names=[
54
- autoencoder_loader.ResidualBlockTensorNames(
54
+ stable_diffusion_loader.ResidualBlockTensorNames(
55
55
  norm_1="5.groupnorm_1",
56
56
  norm_2="5.groupnorm_2",
57
57
  conv_1="5.conv_1",
58
58
  conv_2="5.conv_2",
59
59
  ),
60
- autoencoder_loader.ResidualBlockTensorNames(
60
+ stable_diffusion_loader.ResidualBlockTensorNames(
61
61
  norm_1="6.groupnorm_1",
62
62
  norm_2="6.groupnorm_2",
63
63
  conv_1="6.conv_1",
64
64
  conv_2="6.conv_2",
65
65
  ),
66
- autoencoder_loader.ResidualBlockTensorNames(
66
+ stable_diffusion_loader.ResidualBlockTensorNames(
67
67
  norm_1="7.groupnorm_1",
68
68
  norm_2="7.groupnorm_2",
69
69
  conv_1="7.conv_1",
@@ -72,21 +72,21 @@ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
72
72
  ],
73
73
  upsample_conv="9",
74
74
  ),
75
- autoencoder_loader.UpDecoderBlockTensorNames(
75
+ stable_diffusion_loader.UpDecoderBlockTensorNames(
76
76
  residual_block_tensor_names=[
77
- autoencoder_loader.ResidualBlockTensorNames(
77
+ stable_diffusion_loader.ResidualBlockTensorNames(
78
78
  norm_1="10.groupnorm_1",
79
79
  norm_2="10.groupnorm_2",
80
80
  conv_1="10.conv_1",
81
81
  conv_2="10.conv_2",
82
82
  ),
83
- autoencoder_loader.ResidualBlockTensorNames(
83
+ stable_diffusion_loader.ResidualBlockTensorNames(
84
84
  norm_1="11.groupnorm_1",
85
85
  norm_2="11.groupnorm_2",
86
86
  conv_1="11.conv_1",
87
87
  conv_2="11.conv_2",
88
88
  ),
89
- autoencoder_loader.ResidualBlockTensorNames(
89
+ stable_diffusion_loader.ResidualBlockTensorNames(
90
90
  norm_1="12.groupnorm_1",
91
91
  norm_2="12.groupnorm_2",
92
92
  conv_1="12.conv_1",
@@ -95,22 +95,22 @@ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
95
95
  ],
96
96
  upsample_conv="14",
97
97
  ),
98
- autoencoder_loader.UpDecoderBlockTensorNames(
98
+ stable_diffusion_loader.UpDecoderBlockTensorNames(
99
99
  residual_block_tensor_names=[
100
- autoencoder_loader.ResidualBlockTensorNames(
100
+ stable_diffusion_loader.ResidualBlockTensorNames(
101
101
  norm_1="15.groupnorm_1",
102
102
  norm_2="15.groupnorm_2",
103
103
  conv_1="15.conv_1",
104
104
  conv_2="15.conv_2",
105
105
  residual_layer="15.residual_layer",
106
106
  ),
107
- autoencoder_loader.ResidualBlockTensorNames(
107
+ stable_diffusion_loader.ResidualBlockTensorNames(
108
108
  norm_1="16.groupnorm_1",
109
109
  norm_2="16.groupnorm_2",
110
110
  conv_1="16.conv_1",
111
111
  conv_2="16.conv_2",
112
112
  ),
113
- autoencoder_loader.ResidualBlockTensorNames(
113
+ stable_diffusion_loader.ResidualBlockTensorNames(
114
114
  norm_1="17.groupnorm_1",
115
115
  norm_2="17.groupnorm_2",
116
116
  conv_1="17.conv_1",
@@ -119,22 +119,22 @@ TENSORS_NAMES = autoencoder_loader.AutoEncoderModelLoader.TensorNames(
119
119
  ],
120
120
  upsample_conv="19",
121
121
  ),
122
- autoencoder_loader.UpDecoderBlockTensorNames(
122
+ stable_diffusion_loader.UpDecoderBlockTensorNames(
123
123
  residual_block_tensor_names=[
124
- autoencoder_loader.ResidualBlockTensorNames(
124
+ stable_diffusion_loader.ResidualBlockTensorNames(
125
125
  norm_1="20.groupnorm_1",
126
126
  norm_2="20.groupnorm_2",
127
127
  conv_1="20.conv_1",
128
128
  conv_2="20.conv_2",
129
129
  residual_layer="20.residual_layer",
130
130
  ),
131
- autoencoder_loader.ResidualBlockTensorNames(
131
+ stable_diffusion_loader.ResidualBlockTensorNames(
132
132
  norm_1="21.groupnorm_1",
133
133
  norm_2="21.groupnorm_2",
134
134
  conv_1="21.conv_1",
135
135
  conv_2="21.conv_2",
136
136
  ),
137
- autoencoder_loader.ResidualBlockTensorNames(
137
+ stable_diffusion_loader.ResidualBlockTensorNames(
138
138
  norm_1="22.groupnorm_1",
139
139
  norm_2="22.groupnorm_2",
140
140
  conv_1="22.conv_1",
@@ -225,8 +225,8 @@ class Decoder(nn.Module):
225
225
  num_layers=config.layers_per_block,
226
226
  add_upsample=not_final_block,
227
227
  upsample_conv=True,
228
- sampling_config=unet_cfg.SamplingConfig(
229
- 2, unet_cfg.SamplingType.NEAREST
228
+ sampling_config=unet_cfg.UpSamplingConfig(
229
+ mode=unet_cfg.SamplingType.NEAREST, scale_factor=2
230
230
  ),
231
231
  )
232
232
  )
@@ -245,6 +245,14 @@ class Decoder(nn.Module):
245
245
  )
246
246
 
247
247
  def forward(self, latents_tensor: torch.Tensor) -> torch.Tensor:
248
+ """Forward function of decoder model.
249
+
250
+ Args:
251
+ latents (torch.Tensor): latents space tensor.
252
+
253
+ Returns:
254
+ output decoded image tensor from decoder model.
255
+ """
248
256
  x = latents_tensor / self.config.scaling_factor
249
257
  x = self.post_quant_conv(x)
250
258
  x = self.conv_in(x)
@@ -271,7 +279,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
271
279
  )
272
280
 
273
281
  att_config = unet_cfg.AttentionBlock2DConfig(
274
- dims=block_out_channels[-1],
282
+ dim=block_out_channels[-1],
275
283
  normalization_config=norm_config,
276
284
  attention_config=layers_cfg.AttentionConfig(
277
285
  num_heads=1,