cache-dit 0.2.32__py3-none-any.whl → 0.2.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/__init__.py +14 -20
- cache_dit/cache_factory/block_adapters/block_adapters.py +47 -3
- cache_dit/cache_factory/block_adapters/block_registers.py +3 -2
- cache_dit/cache_factory/cache_adapters.py +8 -8
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +23 -62
- cache_dit/cache_factory/cache_blocks/pattern_base.py +23 -168
- cache_dit/cache_factory/cache_contexts/cache_context.py +18 -64
- cache_dit/cache_factory/cache_contexts/cache_manager.py +23 -71
- cache_dit/cache_factory/cache_contexts/taylorseer.py +11 -13
- cache_dit/cache_factory/cache_interface.py +9 -9
- cache_dit/cache_factory/patch_functors/__init__.py +1 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +142 -52
- cache_dit/cache_factory/patch_functors/functor_dit.py +130 -0
- cache_dit/quantize/quantize_ao.py +3 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/METADATA +184 -39
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/RECORD +21 -21
- cache_dit/quantize/quantize_svdq.py +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
import inspect
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
import numpy as np
|
|
5
3
|
from typing import Tuple, Optional, Dict, Any, Union
|
|
6
4
|
from diffusers import ChromaTransformer2DModel
|
|
7
5
|
from diffusers.models.transformers.transformer_chroma import (
|
|
6
|
+
ChromaTransformerBlock,
|
|
8
7
|
ChromaSingleTransformerBlock,
|
|
9
8
|
Transformer2DModelOutput,
|
|
10
9
|
)
|
|
@@ -27,24 +26,31 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
27
26
|
def apply(
|
|
28
27
|
self,
|
|
29
28
|
transformer: ChromaTransformer2DModel,
|
|
30
|
-
blocks: torch.nn.ModuleList = None,
|
|
31
29
|
**kwargs,
|
|
32
30
|
) -> ChromaTransformer2DModel:
|
|
33
31
|
if hasattr(transformer, "_is_patched"):
|
|
34
32
|
return transformer
|
|
35
33
|
|
|
36
|
-
if blocks is None:
|
|
37
|
-
blocks = transformer.single_transformer_blocks
|
|
38
|
-
|
|
39
34
|
is_patched = False
|
|
40
|
-
for block in
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
35
|
+
for index_block, block in enumerate(transformer.transformer_blocks):
|
|
36
|
+
assert isinstance(block, ChromaTransformerBlock)
|
|
37
|
+
img_offset = 3 * len(transformer.single_transformer_blocks)
|
|
38
|
+
txt_offset = img_offset + 6 * len(transformer.transformer_blocks)
|
|
39
|
+
img_modulation = img_offset + 6 * index_block
|
|
40
|
+
text_modulation = txt_offset + 6 * index_block
|
|
41
|
+
block._img_modulation = img_modulation
|
|
42
|
+
block._text_modulation = text_modulation
|
|
43
|
+
block.forward = __patch_double_forward__.__get__(block)
|
|
44
|
+
|
|
45
|
+
for index_block, block in enumerate(
|
|
46
|
+
transformer.single_transformer_blocks
|
|
47
|
+
):
|
|
48
|
+
assert isinstance(block, ChromaSingleTransformerBlock)
|
|
49
|
+
start_idx = 3 * index_block
|
|
50
|
+
block._start_idx = start_idx
|
|
51
|
+
block.forward = __patch_single_forward__.__get__(block)
|
|
52
|
+
|
|
53
|
+
is_patched = True
|
|
48
54
|
|
|
49
55
|
cls_name = transformer.__class__.__name__
|
|
50
56
|
|
|
@@ -69,25 +75,123 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
69
75
|
return transformer
|
|
70
76
|
|
|
71
77
|
|
|
78
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
|
|
79
|
+
def __patch_double_forward__(
|
|
80
|
+
self: ChromaTransformerBlock,
|
|
81
|
+
hidden_states: torch.Tensor,
|
|
82
|
+
encoder_hidden_states: torch.Tensor,
|
|
83
|
+
pooled_temb: torch.Tensor,
|
|
84
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
85
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
86
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
87
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
88
|
+
# TODO: Fuse controlnet into block forward
|
|
89
|
+
img_modulation = self._img_modulation
|
|
90
|
+
text_modulation = self._text_modulation
|
|
91
|
+
temb = torch.cat(
|
|
92
|
+
(
|
|
93
|
+
pooled_temb[:, img_modulation : img_modulation + 6],
|
|
94
|
+
pooled_temb[:, text_modulation : text_modulation + 6],
|
|
95
|
+
),
|
|
96
|
+
dim=1,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
|
|
100
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
|
101
|
+
hidden_states, emb=temb_img
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
(
|
|
105
|
+
norm_encoder_hidden_states,
|
|
106
|
+
c_gate_msa,
|
|
107
|
+
c_shift_mlp,
|
|
108
|
+
c_scale_mlp,
|
|
109
|
+
c_gate_mlp,
|
|
110
|
+
) = self.norm1_context(encoder_hidden_states, emb=temb_txt)
|
|
111
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
112
|
+
if attention_mask is not None:
|
|
113
|
+
attention_mask = (
|
|
114
|
+
attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Attention.
|
|
118
|
+
attention_outputs = self.attn(
|
|
119
|
+
hidden_states=norm_hidden_states,
|
|
120
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
|
121
|
+
image_rotary_emb=image_rotary_emb,
|
|
122
|
+
attention_mask=attention_mask,
|
|
123
|
+
**joint_attention_kwargs,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if len(attention_outputs) == 2:
|
|
127
|
+
attn_output, context_attn_output = attention_outputs
|
|
128
|
+
elif len(attention_outputs) == 3:
|
|
129
|
+
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
|
130
|
+
|
|
131
|
+
# Process attention outputs for the `hidden_states`.
|
|
132
|
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
|
133
|
+
hidden_states = hidden_states + attn_output
|
|
134
|
+
|
|
135
|
+
norm_hidden_states = self.norm2(hidden_states)
|
|
136
|
+
norm_hidden_states = (
|
|
137
|
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
ff_output = self.ff(norm_hidden_states)
|
|
141
|
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
|
142
|
+
|
|
143
|
+
hidden_states = hidden_states + ff_output
|
|
144
|
+
if len(attention_outputs) == 3:
|
|
145
|
+
hidden_states = hidden_states + ip_attn_output
|
|
146
|
+
|
|
147
|
+
# Process attention outputs for the `encoder_hidden_states`.
|
|
148
|
+
|
|
149
|
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
|
150
|
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
|
151
|
+
|
|
152
|
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
|
153
|
+
norm_encoder_hidden_states = (
|
|
154
|
+
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
|
|
155
|
+
+ c_shift_mlp[:, None]
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
|
159
|
+
encoder_hidden_states = (
|
|
160
|
+
encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
|
161
|
+
)
|
|
162
|
+
if encoder_hidden_states.dtype == torch.float16:
|
|
163
|
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
|
164
|
+
|
|
165
|
+
return encoder_hidden_states, hidden_states
|
|
166
|
+
|
|
167
|
+
|
|
72
168
|
# adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
|
|
73
169
|
def __patch_single_forward__(
|
|
74
170
|
self: ChromaSingleTransformerBlock, # Almost same as FluxSingleTransformerBlock
|
|
75
171
|
hidden_states: torch.Tensor,
|
|
76
|
-
|
|
77
|
-
temb: torch.Tensor,
|
|
172
|
+
pooled_temb: torch.Tensor,
|
|
78
173
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
174
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
79
175
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
80
|
-
) ->
|
|
81
|
-
|
|
82
|
-
|
|
176
|
+
) -> torch.Tensor:
|
|
177
|
+
# TODO: Fuse controlnet into block forward
|
|
178
|
+
start_idx = self._start_idx
|
|
179
|
+
temb = pooled_temb[:, start_idx : start_idx + 3]
|
|
83
180
|
|
|
84
181
|
residual = hidden_states
|
|
85
182
|
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
|
86
183
|
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
|
87
184
|
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
185
|
+
|
|
186
|
+
if attention_mask is not None:
|
|
187
|
+
attention_mask = (
|
|
188
|
+
attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
|
|
189
|
+
)
|
|
190
|
+
|
|
88
191
|
attn_output = self.attn(
|
|
89
192
|
hidden_states=norm_hidden_states,
|
|
90
193
|
image_rotary_emb=image_rotary_emb,
|
|
194
|
+
attention_mask=attention_mask,
|
|
91
195
|
**joint_attention_kwargs,
|
|
92
196
|
)
|
|
93
197
|
|
|
@@ -98,11 +202,7 @@ def __patch_single_forward__(
|
|
|
98
202
|
if hidden_states.dtype == torch.float16:
|
|
99
203
|
hidden_states = hidden_states.clip(-65504, 65504)
|
|
100
204
|
|
|
101
|
-
|
|
102
|
-
hidden_states[:, :text_seq_len],
|
|
103
|
-
hidden_states[:, text_seq_len:],
|
|
104
|
-
)
|
|
105
|
-
return encoder_hidden_states, hidden_states
|
|
205
|
+
return hidden_states
|
|
106
206
|
|
|
107
207
|
|
|
108
208
|
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
|
|
@@ -174,24 +274,13 @@ def __patch_transformer_forward__(
|
|
|
174
274
|
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
|
175
275
|
|
|
176
276
|
for index_block, block in enumerate(self.transformer_blocks):
|
|
177
|
-
img_offset = 3 * len(self.single_transformer_blocks)
|
|
178
|
-
txt_offset = img_offset + 6 * len(self.transformer_blocks)
|
|
179
|
-
img_modulation = img_offset + 6 * index_block
|
|
180
|
-
text_modulation = txt_offset + 6 * index_block
|
|
181
|
-
temb = torch.cat(
|
|
182
|
-
(
|
|
183
|
-
pooled_temb[:, img_modulation : img_modulation + 6],
|
|
184
|
-
pooled_temb[:, text_modulation : text_modulation + 6],
|
|
185
|
-
),
|
|
186
|
-
dim=1,
|
|
187
|
-
)
|
|
188
277
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
189
278
|
encoder_hidden_states, hidden_states = (
|
|
190
279
|
self._gradient_checkpointing_func(
|
|
191
280
|
block,
|
|
192
281
|
hidden_states,
|
|
193
282
|
encoder_hidden_states,
|
|
194
|
-
|
|
283
|
+
pooled_temb,
|
|
195
284
|
image_rotary_emb,
|
|
196
285
|
attention_mask,
|
|
197
286
|
)
|
|
@@ -201,12 +290,13 @@ def __patch_transformer_forward__(
|
|
|
201
290
|
encoder_hidden_states, hidden_states = block(
|
|
202
291
|
hidden_states=hidden_states,
|
|
203
292
|
encoder_hidden_states=encoder_hidden_states,
|
|
204
|
-
|
|
293
|
+
pooled_temb=pooled_temb,
|
|
205
294
|
image_rotary_emb=image_rotary_emb,
|
|
206
295
|
attention_mask=attention_mask,
|
|
207
296
|
joint_attention_kwargs=joint_attention_kwargs,
|
|
208
297
|
)
|
|
209
298
|
|
|
299
|
+
# TODO: Fuse controlnet into block forward
|
|
210
300
|
# controlnet residual
|
|
211
301
|
if controlnet_block_samples is not None:
|
|
212
302
|
interval_control = len(self.transformer_blocks) / len(
|
|
@@ -227,43 +317,43 @@ def __patch_transformer_forward__(
|
|
|
227
317
|
+ controlnet_block_samples[index_block // interval_control]
|
|
228
318
|
)
|
|
229
319
|
|
|
320
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
321
|
+
|
|
230
322
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
|
231
|
-
start_idx = 3 * index_block
|
|
232
|
-
temb = pooled_temb[:, start_idx : start_idx + 3]
|
|
233
323
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
)
|
|
324
|
+
hidden_states = self._gradient_checkpointing_func(
|
|
325
|
+
block,
|
|
326
|
+
hidden_states,
|
|
327
|
+
pooled_temb,
|
|
328
|
+
image_rotary_emb,
|
|
329
|
+
attention_mask,
|
|
330
|
+
joint_attention_kwargs,
|
|
242
331
|
)
|
|
243
332
|
|
|
244
333
|
else:
|
|
245
|
-
|
|
334
|
+
hidden_states = block(
|
|
246
335
|
hidden_states=hidden_states,
|
|
247
|
-
|
|
248
|
-
temb=temb,
|
|
336
|
+
pooled_temb=pooled_temb,
|
|
249
337
|
image_rotary_emb=image_rotary_emb,
|
|
250
338
|
attention_mask=attention_mask,
|
|
251
339
|
joint_attention_kwargs=joint_attention_kwargs,
|
|
252
340
|
)
|
|
253
341
|
|
|
342
|
+
# TODO: Fuse controlnet into block forward
|
|
254
343
|
# controlnet residual
|
|
255
344
|
if controlnet_single_block_samples is not None:
|
|
256
345
|
interval_control = len(self.single_transformer_blocks) / len(
|
|
257
346
|
controlnet_single_block_samples
|
|
258
347
|
)
|
|
259
348
|
interval_control = int(np.ceil(interval_control))
|
|
260
|
-
hidden_states = (
|
|
261
|
-
hidden_states
|
|
349
|
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
|
350
|
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
|
262
351
|
+ controlnet_single_block_samples[
|
|
263
352
|
index_block // interval_control
|
|
264
353
|
]
|
|
265
354
|
)
|
|
266
355
|
|
|
356
|
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
|
267
357
|
temb = pooled_temb[:, -2:]
|
|
268
358
|
hidden_states = self.norm_out(hidden_states, temb)
|
|
269
359
|
output = self.proj_out(hidden_states)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from typing import Optional, Dict, Any
|
|
5
|
+
from diffusers.models.transformers.dit_transformer_2d import (
|
|
6
|
+
DiTTransformer2DModel,
|
|
7
|
+
Transformer2DModelOutput,
|
|
8
|
+
)
|
|
9
|
+
from cache_dit.cache_factory.patch_functors.functor_base import (
|
|
10
|
+
PatchFunctor,
|
|
11
|
+
)
|
|
12
|
+
from cache_dit.logger import init_logger
|
|
13
|
+
|
|
14
|
+
logger = init_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DiTPatchFunctor(PatchFunctor):
|
|
18
|
+
|
|
19
|
+
def apply(
|
|
20
|
+
self,
|
|
21
|
+
transformer: DiTTransformer2DModel,
|
|
22
|
+
**kwargs,
|
|
23
|
+
) -> DiTTransformer2DModel:
|
|
24
|
+
if hasattr(transformer, "_is_patched"):
|
|
25
|
+
return transformer
|
|
26
|
+
|
|
27
|
+
is_patched = False
|
|
28
|
+
|
|
29
|
+
transformer._norm1_emb = transformer.transformer_blocks[0].norm1.emb
|
|
30
|
+
|
|
31
|
+
is_patched = True
|
|
32
|
+
|
|
33
|
+
cls_name = transformer.__class__.__name__
|
|
34
|
+
|
|
35
|
+
if is_patched:
|
|
36
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
37
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
38
|
+
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
39
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
40
|
+
"parallized forward and cause a downgrade of performance."
|
|
41
|
+
)
|
|
42
|
+
transformer.forward = __patch_transformer_forward__.__get__(
|
|
43
|
+
transformer
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
transformer._is_patched = is_patched # True or False
|
|
47
|
+
|
|
48
|
+
logger.info(
|
|
49
|
+
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
50
|
+
f"Patch: {is_patched}."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
return transformer
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def __patch_transformer_forward__(
|
|
57
|
+
self: DiTTransformer2DModel,
|
|
58
|
+
hidden_states: torch.Tensor,
|
|
59
|
+
timestep: Optional[torch.LongTensor] = None,
|
|
60
|
+
class_labels: Optional[torch.LongTensor] = None,
|
|
61
|
+
cross_attention_kwargs: Dict[str, Any] = None,
|
|
62
|
+
return_dict: bool = True,
|
|
63
|
+
):
|
|
64
|
+
height, width = (
|
|
65
|
+
hidden_states.shape[-2] // self.patch_size,
|
|
66
|
+
hidden_states.shape[-1] // self.patch_size,
|
|
67
|
+
)
|
|
68
|
+
hidden_states = self.pos_embed(hidden_states)
|
|
69
|
+
|
|
70
|
+
# 2. Blocks
|
|
71
|
+
for block in self.transformer_blocks:
|
|
72
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
73
|
+
hidden_states = self._gradient_checkpointing_func(
|
|
74
|
+
block,
|
|
75
|
+
hidden_states,
|
|
76
|
+
None,
|
|
77
|
+
None,
|
|
78
|
+
None,
|
|
79
|
+
timestep,
|
|
80
|
+
cross_attention_kwargs,
|
|
81
|
+
class_labels,
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
hidden_states = block(
|
|
85
|
+
hidden_states,
|
|
86
|
+
attention_mask=None,
|
|
87
|
+
encoder_hidden_states=None,
|
|
88
|
+
encoder_attention_mask=None,
|
|
89
|
+
timestep=timestep,
|
|
90
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
|
91
|
+
class_labels=class_labels,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# 3. Output
|
|
95
|
+
# conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype)
|
|
96
|
+
conditioning = self._norm1_emb(
|
|
97
|
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
|
98
|
+
)
|
|
99
|
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
|
100
|
+
hidden_states = (
|
|
101
|
+
self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
|
102
|
+
)
|
|
103
|
+
hidden_states = self.proj_out_2(hidden_states)
|
|
104
|
+
|
|
105
|
+
# unpatchify
|
|
106
|
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
|
107
|
+
hidden_states = hidden_states.reshape(
|
|
108
|
+
shape=(
|
|
109
|
+
-1,
|
|
110
|
+
height,
|
|
111
|
+
width,
|
|
112
|
+
self.patch_size,
|
|
113
|
+
self.patch_size,
|
|
114
|
+
self.out_channels,
|
|
115
|
+
)
|
|
116
|
+
)
|
|
117
|
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
|
118
|
+
output = hidden_states.reshape(
|
|
119
|
+
shape=(
|
|
120
|
+
-1,
|
|
121
|
+
self.out_channels,
|
|
122
|
+
height * self.patch_size,
|
|
123
|
+
width * self.patch_size,
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if not return_dict:
|
|
128
|
+
return (output,)
|
|
129
|
+
|
|
130
|
+
return Transformer2DModelOutput(sample=output)
|