cache-dit 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/block_adapters/__init__.py +166 -160
- cache_dit/cache_factory/block_adapters/block_adapters.py +195 -125
- cache_dit/cache_factory/block_adapters/block_registers.py +25 -13
- cache_dit/cache_factory/cache_adapters.py +209 -86
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +70 -67
- cache_dit/cache_factory/cache_blocks/utils.py +16 -0
- cache_dit/cache_factory/cache_contexts/cache_manager.py +22 -10
- cache_dit/cache_factory/cache_interface.py +26 -14
- cache_dit/cache_factory/cache_types.py +5 -5
- cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -2
- cache_dit/cache_factory/patch_functors/functor_flux.py +3 -2
- cache_dit/utils.py +168 -55
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/METADATA +34 -55
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/RECORD +21 -21
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ except ImportError:
|
|
|
6
6
|
|
|
7
7
|
from cache_dit.cache_factory import load_options
|
|
8
8
|
from cache_dit.cache_factory import enable_cache
|
|
9
|
+
from cache_dit.cache_factory import disable_cache
|
|
9
10
|
from cache_dit.cache_factory import cache_type
|
|
10
11
|
from cache_dit.cache_factory import block_range
|
|
11
12
|
from cache_dit.cache_factory import CacheType
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.30'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 30)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -17,6 +17,7 @@ from cache_dit.cache_factory.cache_blocks import CachedBlocks
|
|
|
17
17
|
from cache_dit.cache_factory.cache_adapters import CachedAdapter
|
|
18
18
|
|
|
19
19
|
from cache_dit.cache_factory.cache_interface import enable_cache
|
|
20
|
+
from cache_dit.cache_factory.cache_interface import disable_cache
|
|
20
21
|
from cache_dit.cache_factory.cache_interface import supported_pipelines
|
|
21
22
|
from cache_dit.cache_factory.cache_interface import get_adapter
|
|
22
23
|
|
|
@@ -9,23 +9,37 @@ from cache_dit.cache_factory.block_adapters.block_registers import (
|
|
|
9
9
|
@BlockAdapterRegistry.register("Flux")
|
|
10
10
|
def flux_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
11
11
|
from diffusers import FluxTransformer2DModel
|
|
12
|
-
from cache_dit.
|
|
12
|
+
from cache_dit.utils import is_diffusers_at_least_0_3_5
|
|
13
13
|
|
|
14
14
|
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
15
|
+
if is_diffusers_at_least_0_3_5():
|
|
16
|
+
return BlockAdapter(
|
|
17
|
+
pipe=pipe,
|
|
18
|
+
transformer=pipe.transformer,
|
|
19
|
+
blocks=[
|
|
20
|
+
pipe.transformer.transformer_blocks,
|
|
21
|
+
pipe.transformer.single_transformer_blocks,
|
|
22
|
+
],
|
|
23
|
+
forward_pattern=[
|
|
24
|
+
ForwardPattern.Pattern_1,
|
|
25
|
+
ForwardPattern.Pattern_1,
|
|
26
|
+
],
|
|
27
|
+
**kwargs,
|
|
28
|
+
)
|
|
29
|
+
else:
|
|
30
|
+
return BlockAdapter(
|
|
31
|
+
pipe=pipe,
|
|
32
|
+
transformer=pipe.transformer,
|
|
33
|
+
blocks=[
|
|
34
|
+
pipe.transformer.transformer_blocks,
|
|
35
|
+
pipe.transformer.single_transformer_blocks,
|
|
36
|
+
],
|
|
37
|
+
forward_pattern=[
|
|
38
|
+
ForwardPattern.Pattern_1,
|
|
39
|
+
ForwardPattern.Pattern_3,
|
|
40
|
+
],
|
|
41
|
+
**kwargs,
|
|
42
|
+
)
|
|
29
43
|
|
|
30
44
|
|
|
31
45
|
@BlockAdapterRegistry.register("Mochi")
|
|
@@ -37,9 +51,8 @@ def mochi_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
37
51
|
pipe=pipe,
|
|
38
52
|
transformer=pipe.transformer,
|
|
39
53
|
blocks=pipe.transformer.transformer_blocks,
|
|
40
|
-
blocks_name="transformer_blocks",
|
|
41
|
-
dummy_blocks_names=[],
|
|
42
54
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
55
|
+
**kwargs,
|
|
43
56
|
)
|
|
44
57
|
|
|
45
58
|
|
|
@@ -52,9 +65,8 @@ def cogvideox_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
52
65
|
pipe=pipe,
|
|
53
66
|
transformer=pipe.transformer,
|
|
54
67
|
blocks=pipe.transformer.transformer_blocks,
|
|
55
|
-
blocks_name="transformer_blocks",
|
|
56
|
-
dummy_blocks_names=[],
|
|
57
68
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
69
|
+
**kwargs,
|
|
58
70
|
)
|
|
59
71
|
|
|
60
72
|
|
|
@@ -85,16 +97,12 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
85
97
|
pipe.transformer.blocks,
|
|
86
98
|
pipe.transformer_2.blocks,
|
|
87
99
|
],
|
|
88
|
-
blocks_name=[
|
|
89
|
-
"blocks",
|
|
90
|
-
"blocks",
|
|
91
|
-
],
|
|
92
100
|
forward_pattern=[
|
|
93
101
|
ForwardPattern.Pattern_2,
|
|
94
102
|
ForwardPattern.Pattern_2,
|
|
95
103
|
],
|
|
96
|
-
dummy_blocks_names=[],
|
|
97
104
|
has_separate_cfg=True,
|
|
105
|
+
**kwargs,
|
|
98
106
|
)
|
|
99
107
|
else:
|
|
100
108
|
# Wan 2.1
|
|
@@ -102,10 +110,9 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
102
110
|
pipe=pipe,
|
|
103
111
|
transformer=pipe.transformer,
|
|
104
112
|
blocks=pipe.transformer.blocks,
|
|
105
|
-
blocks_name="blocks",
|
|
106
|
-
dummy_blocks_names=[],
|
|
107
113
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
108
114
|
has_separate_cfg=True,
|
|
115
|
+
**kwargs,
|
|
109
116
|
)
|
|
110
117
|
|
|
111
118
|
|
|
@@ -116,13 +123,18 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
116
123
|
assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
|
|
117
124
|
return BlockAdapter(
|
|
118
125
|
pipe=pipe,
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
+
transformer=pipe.transformer,
|
|
127
|
+
blocks=[
|
|
128
|
+
pipe.transformer.transformer_blocks,
|
|
129
|
+
pipe.transformer.single_transformer_blocks,
|
|
130
|
+
],
|
|
131
|
+
forward_pattern=[
|
|
132
|
+
ForwardPattern.Pattern_0,
|
|
133
|
+
ForwardPattern.Pattern_0,
|
|
134
|
+
],
|
|
135
|
+
# The type hint in diffusers is wrong
|
|
136
|
+
check_num_outputs=False,
|
|
137
|
+
**kwargs,
|
|
126
138
|
)
|
|
127
139
|
|
|
128
140
|
|
|
@@ -135,10 +147,9 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
135
147
|
pipe=pipe,
|
|
136
148
|
transformer=pipe.transformer,
|
|
137
149
|
blocks=pipe.transformer.transformer_blocks,
|
|
138
|
-
blocks_name="transformer_blocks",
|
|
139
|
-
dummy_blocks_names=[],
|
|
140
150
|
forward_pattern=ForwardPattern.Pattern_1,
|
|
141
151
|
has_separate_cfg=True,
|
|
152
|
+
**kwargs,
|
|
142
153
|
)
|
|
143
154
|
|
|
144
155
|
|
|
@@ -151,9 +162,8 @@ def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
151
162
|
pipe=pipe,
|
|
152
163
|
transformer=pipe.transformer,
|
|
153
164
|
blocks=pipe.transformer.transformer_blocks,
|
|
154
|
-
blocks_name="transformer_blocks",
|
|
155
|
-
dummy_blocks_names=[],
|
|
156
165
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
166
|
+
**kwargs,
|
|
157
167
|
)
|
|
158
168
|
|
|
159
169
|
|
|
@@ -166,9 +176,8 @@ def allegro_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
166
176
|
pipe=pipe,
|
|
167
177
|
transformer=pipe.transformer,
|
|
168
178
|
blocks=pipe.transformer.transformer_blocks,
|
|
169
|
-
blocks_name="transformer_blocks",
|
|
170
|
-
dummy_blocks_names=[],
|
|
171
179
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
180
|
+
**kwargs,
|
|
172
181
|
)
|
|
173
182
|
|
|
174
183
|
|
|
@@ -181,9 +190,8 @@ def cogview3plus_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
181
190
|
pipe=pipe,
|
|
182
191
|
transformer=pipe.transformer,
|
|
183
192
|
blocks=pipe.transformer.transformer_blocks,
|
|
184
|
-
blocks_name="transformer_blocks",
|
|
185
|
-
dummy_blocks_names=[],
|
|
186
193
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
194
|
+
**kwargs,
|
|
187
195
|
)
|
|
188
196
|
|
|
189
197
|
|
|
@@ -196,10 +204,9 @@ def cogview4_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
196
204
|
pipe=pipe,
|
|
197
205
|
transformer=pipe.transformer,
|
|
198
206
|
blocks=pipe.transformer.transformer_blocks,
|
|
199
|
-
blocks_name="transformer_blocks",
|
|
200
|
-
dummy_blocks_names=[],
|
|
201
207
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
202
208
|
has_separate_cfg=True,
|
|
209
|
+
**kwargs,
|
|
203
210
|
)
|
|
204
211
|
|
|
205
212
|
|
|
@@ -212,10 +219,9 @@ def cosmos_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
212
219
|
pipe=pipe,
|
|
213
220
|
transformer=pipe.transformer,
|
|
214
221
|
blocks=pipe.transformer.transformer_blocks,
|
|
215
|
-
blocks_name="transformer_blocks",
|
|
216
|
-
dummy_blocks_names=[],
|
|
217
222
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
218
223
|
has_separate_cfg=True,
|
|
224
|
+
**kwargs,
|
|
219
225
|
)
|
|
220
226
|
|
|
221
227
|
|
|
@@ -228,9 +234,8 @@ def easyanimate_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
228
234
|
pipe=pipe,
|
|
229
235
|
transformer=pipe.transformer,
|
|
230
236
|
blocks=pipe.transformer.transformer_blocks,
|
|
231
|
-
blocks_name="transformer_blocks",
|
|
232
|
-
dummy_blocks_names=[],
|
|
233
237
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
238
|
+
**kwargs,
|
|
234
239
|
)
|
|
235
240
|
|
|
236
241
|
|
|
@@ -243,10 +248,9 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
243
248
|
pipe=pipe,
|
|
244
249
|
transformer=pipe.transformer,
|
|
245
250
|
blocks=pipe.transformer.blocks,
|
|
246
|
-
blocks_name="blocks",
|
|
247
|
-
dummy_blocks_names=[],
|
|
248
251
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
249
252
|
has_separate_cfg=True,
|
|
253
|
+
**kwargs,
|
|
250
254
|
)
|
|
251
255
|
|
|
252
256
|
|
|
@@ -259,9 +263,8 @@ def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
259
263
|
pipe=pipe,
|
|
260
264
|
transformer=pipe.transformer,
|
|
261
265
|
blocks=pipe.transformer.transformer_blocks,
|
|
262
|
-
blocks_name="transformer_blocks",
|
|
263
|
-
dummy_blocks_names=[],
|
|
264
266
|
forward_pattern=ForwardPattern.Pattern_1,
|
|
267
|
+
**kwargs,
|
|
265
268
|
)
|
|
266
269
|
|
|
267
270
|
|
|
@@ -274,9 +277,8 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
274
277
|
pipe=pipe,
|
|
275
278
|
transformer=pipe.transformer,
|
|
276
279
|
blocks=pipe.transformer.transformer_blocks,
|
|
277
|
-
blocks_name="transformer_blocks",
|
|
278
|
-
dummy_blocks_names=[],
|
|
279
280
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
281
|
+
**kwargs,
|
|
280
282
|
)
|
|
281
283
|
|
|
282
284
|
|
|
@@ -289,9 +291,8 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
289
291
|
pipe=pipe,
|
|
290
292
|
transformer=pipe.transformer,
|
|
291
293
|
blocks=pipe.transformer.transformer_blocks,
|
|
292
|
-
blocks_name="transformer_blocks",
|
|
293
|
-
dummy_blocks_names=[],
|
|
294
294
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
295
|
+
**kwargs,
|
|
295
296
|
)
|
|
296
297
|
|
|
297
298
|
|
|
@@ -304,9 +305,8 @@ def amused_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
304
305
|
pipe=pipe,
|
|
305
306
|
transformer=pipe.transformer,
|
|
306
307
|
blocks=pipe.transformer.transformer_layers,
|
|
307
|
-
blocks_name="transformer_layers",
|
|
308
|
-
dummy_blocks_names=[],
|
|
309
308
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
309
|
+
**kwargs,
|
|
310
310
|
)
|
|
311
311
|
|
|
312
312
|
|
|
@@ -318,51 +318,20 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
318
318
|
return BlockAdapter(
|
|
319
319
|
pipe=pipe,
|
|
320
320
|
transformer=pipe.transformer,
|
|
321
|
-
blocks=
|
|
322
|
-
pipe.transformer.transformer_blocks
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
@BlockAdapterRegistry.register("HunyuanDiT")
|
|
332
|
-
def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
333
|
-
from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
|
|
334
|
-
|
|
335
|
-
assert isinstance(
|
|
336
|
-
pipe.transformer,
|
|
337
|
-
(HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
|
|
338
|
-
)
|
|
339
|
-
return BlockAdapter(
|
|
340
|
-
pipe=pipe,
|
|
341
|
-
transformer=pipe.transformer,
|
|
342
|
-
blocks=pipe.transformer.blocks,
|
|
343
|
-
blocks_name="blocks",
|
|
344
|
-
dummy_blocks_names=[],
|
|
345
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
346
|
-
)
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
@BlockAdapterRegistry.register("HunyuanDiTPAG")
|
|
350
|
-
def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
351
|
-
from diffusers import HunyuanDiT2DModel
|
|
352
|
-
|
|
353
|
-
assert isinstance(pipe.transformer, HunyuanDiT2DModel)
|
|
354
|
-
return BlockAdapter(
|
|
355
|
-
pipe=pipe,
|
|
356
|
-
transformer=pipe.transformer,
|
|
357
|
-
blocks=pipe.transformer.blocks,
|
|
358
|
-
blocks_name="blocks",
|
|
359
|
-
dummy_blocks_names=[],
|
|
360
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
321
|
+
blocks=[
|
|
322
|
+
pipe.transformer.transformer_blocks,
|
|
323
|
+
pipe.transformer.single_transformer_blocks,
|
|
324
|
+
],
|
|
325
|
+
forward_pattern=[
|
|
326
|
+
ForwardPattern.Pattern_0,
|
|
327
|
+
ForwardPattern.Pattern_0,
|
|
328
|
+
],
|
|
329
|
+
**kwargs,
|
|
361
330
|
)
|
|
362
331
|
|
|
363
332
|
|
|
364
333
|
@BlockAdapterRegistry.register("Lumina")
|
|
365
|
-
def lumina_adapter(pipe) -> BlockAdapter:
|
|
334
|
+
def lumina_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
366
335
|
from diffusers import LuminaNextDiT2DModel
|
|
367
336
|
|
|
368
337
|
assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
|
|
@@ -370,9 +339,8 @@ def lumina_adapter(pipe) -> BlockAdapter:
|
|
|
370
339
|
pipe=pipe,
|
|
371
340
|
transformer=pipe.transformer,
|
|
372
341
|
blocks=pipe.transformer.layers,
|
|
373
|
-
blocks_name="layers",
|
|
374
|
-
dummy_blocks_names=[],
|
|
375
342
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
343
|
+
**kwargs,
|
|
376
344
|
)
|
|
377
345
|
|
|
378
346
|
|
|
@@ -385,9 +353,8 @@ def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
385
353
|
pipe=pipe,
|
|
386
354
|
transformer=pipe.transformer,
|
|
387
355
|
blocks=pipe.transformer.layers,
|
|
388
|
-
blocks_name="layers",
|
|
389
|
-
dummy_blocks_names=[],
|
|
390
356
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
357
|
+
**kwargs,
|
|
391
358
|
)
|
|
392
359
|
|
|
393
360
|
|
|
@@ -400,9 +367,8 @@ def omnigen_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
400
367
|
pipe=pipe,
|
|
401
368
|
transformer=pipe.transformer,
|
|
402
369
|
blocks=pipe.transformer.layers,
|
|
403
|
-
blocks_name="layers",
|
|
404
|
-
dummy_blocks_names=[],
|
|
405
370
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
371
|
+
**kwargs,
|
|
406
372
|
)
|
|
407
373
|
|
|
408
374
|
|
|
@@ -415,39 +381,24 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
415
381
|
pipe=pipe,
|
|
416
382
|
transformer=pipe.transformer,
|
|
417
383
|
blocks=pipe.transformer.transformer_blocks,
|
|
418
|
-
blocks_name="transformer_blocks",
|
|
419
|
-
dummy_blocks_names=[],
|
|
420
384
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
385
|
+
**kwargs,
|
|
421
386
|
)
|
|
422
387
|
|
|
423
388
|
|
|
424
|
-
@BlockAdapterRegistry.register("Sana")
|
|
389
|
+
@BlockAdapterRegistry.register("Sana", supported=False)
|
|
425
390
|
def sana_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
426
391
|
from diffusers import SanaTransformer2DModel
|
|
427
392
|
|
|
393
|
+
# TODO: fix -> got multiple values for argument 'encoder_hidden_states'
|
|
394
|
+
|
|
428
395
|
assert isinstance(pipe.transformer, SanaTransformer2DModel)
|
|
429
396
|
return BlockAdapter(
|
|
430
397
|
pipe=pipe,
|
|
431
398
|
transformer=pipe.transformer,
|
|
432
399
|
blocks=pipe.transformer.transformer_blocks,
|
|
433
|
-
blocks_name="transformer_blocks",
|
|
434
|
-
dummy_blocks_names=[],
|
|
435
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
436
|
-
)
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
@BlockAdapterRegistry.register("ShapE")
|
|
440
|
-
def shape_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
441
|
-
from diffusers import PriorTransformer
|
|
442
|
-
|
|
443
|
-
assert isinstance(pipe.prior, PriorTransformer)
|
|
444
|
-
return BlockAdapter(
|
|
445
|
-
pipe=pipe,
|
|
446
|
-
transformer=pipe.prior,
|
|
447
|
-
blocks=pipe.prior.transformer_blocks,
|
|
448
|
-
blocks_name="transformer_blocks",
|
|
449
|
-
dummy_blocks_names=[],
|
|
450
400
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
401
|
+
**kwargs,
|
|
451
402
|
)
|
|
452
403
|
|
|
453
404
|
|
|
@@ -460,30 +411,45 @@ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
460
411
|
pipe=pipe,
|
|
461
412
|
transformer=pipe.transformer,
|
|
462
413
|
blocks=pipe.transformer.transformer_blocks,
|
|
463
|
-
blocks_name="transformer_blocks",
|
|
464
|
-
dummy_blocks_names=[],
|
|
465
414
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
415
|
+
**kwargs,
|
|
466
416
|
)
|
|
467
417
|
|
|
468
418
|
|
|
469
419
|
@BlockAdapterRegistry.register("VisualCloze")
|
|
470
420
|
def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
471
421
|
from diffusers import FluxTransformer2DModel
|
|
472
|
-
from cache_dit.
|
|
422
|
+
from cache_dit.utils import is_diffusers_at_least_0_3_5
|
|
473
423
|
|
|
474
424
|
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
425
|
+
if is_diffusers_at_least_0_3_5():
|
|
426
|
+
return BlockAdapter(
|
|
427
|
+
pipe=pipe,
|
|
428
|
+
transformer=pipe.transformer,
|
|
429
|
+
blocks=[
|
|
430
|
+
pipe.transformer.transformer_blocks,
|
|
431
|
+
pipe.transformer.single_transformer_blocks,
|
|
432
|
+
],
|
|
433
|
+
forward_pattern=[
|
|
434
|
+
ForwardPattern.Pattern_1,
|
|
435
|
+
ForwardPattern.Pattern_1,
|
|
436
|
+
],
|
|
437
|
+
**kwargs,
|
|
438
|
+
)
|
|
439
|
+
else:
|
|
440
|
+
return BlockAdapter(
|
|
441
|
+
pipe=pipe,
|
|
442
|
+
transformer=pipe.transformer,
|
|
443
|
+
blocks=[
|
|
444
|
+
pipe.transformer.transformer_blocks,
|
|
445
|
+
pipe.transformer.single_transformer_blocks,
|
|
446
|
+
],
|
|
447
|
+
forward_pattern=[
|
|
448
|
+
ForwardPattern.Pattern_1,
|
|
449
|
+
ForwardPattern.Pattern_3,
|
|
450
|
+
],
|
|
451
|
+
**kwargs,
|
|
452
|
+
)
|
|
487
453
|
|
|
488
454
|
|
|
489
455
|
@BlockAdapterRegistry.register("AuraFlow")
|
|
@@ -494,19 +460,9 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
494
460
|
return BlockAdapter(
|
|
495
461
|
pipe=pipe,
|
|
496
462
|
transformer=pipe.transformer,
|
|
497
|
-
blocks=
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
pipe.transformer.single_transformer_blocks,
|
|
501
|
-
],
|
|
502
|
-
blocks_name=[
|
|
503
|
-
# "joint_transformer_blocks",
|
|
504
|
-
"single_transformer_blocks",
|
|
505
|
-
],
|
|
506
|
-
forward_pattern=[
|
|
507
|
-
# ForwardPattern.Pattern_1,
|
|
508
|
-
ForwardPattern.Pattern_3,
|
|
509
|
-
],
|
|
463
|
+
blocks=pipe.transformer.single_transformer_blocks,
|
|
464
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
465
|
+
**kwargs,
|
|
510
466
|
)
|
|
511
467
|
|
|
512
468
|
|
|
@@ -522,20 +478,36 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
522
478
|
pipe.transformer.transformer_blocks,
|
|
523
479
|
pipe.transformer.single_transformer_blocks,
|
|
524
480
|
],
|
|
525
|
-
blocks_name=[
|
|
526
|
-
"transformer_blocks",
|
|
527
|
-
"single_transformer_blocks",
|
|
528
|
-
],
|
|
529
481
|
forward_pattern=[
|
|
530
482
|
ForwardPattern.Pattern_1,
|
|
531
483
|
ForwardPattern.Pattern_3,
|
|
532
484
|
],
|
|
533
485
|
has_separate_cfg=True,
|
|
486
|
+
**kwargs,
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
@BlockAdapterRegistry.register("ShapE")
|
|
491
|
+
def shape_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
492
|
+
from diffusers import PriorTransformer
|
|
493
|
+
|
|
494
|
+
assert isinstance(pipe.prior, PriorTransformer)
|
|
495
|
+
return BlockAdapter(
|
|
496
|
+
pipe=pipe,
|
|
497
|
+
transformer=pipe.prior,
|
|
498
|
+
blocks=pipe.prior.transformer_blocks,
|
|
499
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
500
|
+
**kwargs,
|
|
534
501
|
)
|
|
535
502
|
|
|
536
503
|
|
|
537
|
-
@BlockAdapterRegistry.register("HiDream")
|
|
504
|
+
@BlockAdapterRegistry.register("HiDream", supported=True)
|
|
538
505
|
def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
506
|
+
# NOTE: Need to patch Transformer forward to fully support
|
|
507
|
+
# double_stream_blocks and single_stream_blocks, namely, need
|
|
508
|
+
# to remove the logics inside the blocks forward loop:
|
|
509
|
+
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
|
|
510
|
+
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
|
|
539
511
|
from diffusers import HiDreamImageTransformer2DModel
|
|
540
512
|
|
|
541
513
|
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
|
@@ -543,13 +515,47 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
543
515
|
pipe=pipe,
|
|
544
516
|
transformer=pipe.transformer,
|
|
545
517
|
blocks=[
|
|
546
|
-
pipe.transformer.double_stream_blocks,
|
|
518
|
+
# pipe.transformer.double_stream_blocks,
|
|
547
519
|
pipe.transformer.single_stream_blocks,
|
|
548
520
|
],
|
|
549
|
-
dummy_blocks_names=[],
|
|
550
521
|
forward_pattern=[
|
|
551
|
-
ForwardPattern.Pattern_4,
|
|
522
|
+
# ForwardPattern.Pattern_4,
|
|
552
523
|
ForwardPattern.Pattern_3,
|
|
553
524
|
],
|
|
525
|
+
# The type hint in diffusers is wrong
|
|
554
526
|
check_num_outputs=False,
|
|
527
|
+
**kwargs,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
@BlockAdapterRegistry.register("HunyuanDiT", supported=False)
|
|
532
|
+
def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
533
|
+
# TODO: Patch Transformer forward
|
|
534
|
+
from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
|
|
535
|
+
|
|
536
|
+
assert isinstance(
|
|
537
|
+
pipe.transformer,
|
|
538
|
+
(HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
|
|
539
|
+
)
|
|
540
|
+
return BlockAdapter(
|
|
541
|
+
pipe=pipe,
|
|
542
|
+
transformer=pipe.transformer,
|
|
543
|
+
blocks=pipe.transformer.blocks,
|
|
544
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
545
|
+
**kwargs,
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
@BlockAdapterRegistry.register("HunyuanDiTPAG", supported=False)
|
|
550
|
+
def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
551
|
+
# TODO: Patch Transformer forward
|
|
552
|
+
from diffusers import HunyuanDiT2DModel
|
|
553
|
+
|
|
554
|
+
assert isinstance(pipe.transformer, HunyuanDiT2DModel)
|
|
555
|
+
return BlockAdapter(
|
|
556
|
+
pipe=pipe,
|
|
557
|
+
transformer=pipe.transformer,
|
|
558
|
+
blocks=pipe.transformer.blocks,
|
|
559
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
560
|
+
**kwargs,
|
|
555
561
|
)
|