cache-dit 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +8 -6
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +17 -4
- cache_dit/cache_factory/block_adapters/__init__.py +555 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +538 -0
- cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
- cache_dit/cache_factory/cache_adapters.py +262 -938
- cache_dit/cache_factory/cache_blocks/__init__.py +60 -11
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +45 -41
- cache_dit/cache_factory/cache_blocks/pattern_base.py +106 -80
- cache_dit/cache_factory/cache_blocks/utils.py +16 -10
- cache_dit/cache_factory/cache_contexts/__init__.py +5 -0
- cache_dit/cache_factory/cache_contexts/cache_context.py +327 -0
- cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
- cache_dit/cache_factory/cache_interface.py +31 -31
- cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
- cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
- cache_dit/quantize/quantize_ao.py +1 -0
- cache_dit/utils.py +26 -26
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/METADATA +59 -23
- cache_dit-0.2.28.dist-info/RECORD +47 -0
- cache_dit/cache_factory/cache_context.py +0 -1155
- cache_dit-0.2.26.dist-info/RECORD +0 -42
- /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -10,9 +10,11 @@ from cache_dit.cache_factory import cache_type
|
|
|
10
10
|
from cache_dit.cache_factory import block_range
|
|
11
11
|
from cache_dit.cache_factory import CacheType
|
|
12
12
|
from cache_dit.cache_factory import BlockAdapter
|
|
13
|
+
from cache_dit.cache_factory import ParamsModifier
|
|
13
14
|
from cache_dit.cache_factory import ForwardPattern
|
|
14
15
|
from cache_dit.cache_factory import PatchFunctor
|
|
15
16
|
from cache_dit.cache_factory import supported_pipelines
|
|
17
|
+
from cache_dit.cache_factory import get_adapter
|
|
16
18
|
from cache_dit.compile import set_compile_configs
|
|
17
19
|
from cache_dit.quantize import quantize
|
|
18
20
|
from cache_dit.utils import summary
|
|
@@ -22,9 +24,9 @@ from cache_dit.logger import init_logger
|
|
|
22
24
|
NONE = CacheType.NONE
|
|
23
25
|
DBCache = CacheType.DBCache
|
|
24
26
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
27
|
+
Pattern_0 = ForwardPattern.Pattern_0
|
|
28
|
+
Pattern_1 = ForwardPattern.Pattern_1
|
|
29
|
+
Pattern_2 = ForwardPattern.Pattern_2
|
|
30
|
+
Pattern_3 = ForwardPattern.Pattern_3
|
|
31
|
+
Pattern_4 = ForwardPattern.Pattern_4
|
|
32
|
+
Pattern_5 = ForwardPattern.Pattern_5
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.28'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 28)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -1,10 +1,23 @@
|
|
|
1
|
-
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
2
1
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
3
2
|
from cache_dit.cache_factory.cache_types import cache_type
|
|
4
3
|
from cache_dit.cache_factory.cache_types import block_range
|
|
5
|
-
|
|
6
|
-
from cache_dit.cache_factory.
|
|
4
|
+
|
|
5
|
+
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
6
|
+
|
|
7
|
+
from cache_dit.cache_factory.patch_functors import PatchFunctor
|
|
8
|
+
|
|
9
|
+
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
10
|
+
from cache_dit.cache_factory.block_adapters import ParamsModifier
|
|
11
|
+
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
12
|
+
|
|
13
|
+
from cache_dit.cache_factory.cache_contexts import CachedContext
|
|
14
|
+
from cache_dit.cache_factory.cache_contexts import CachedContextManager
|
|
15
|
+
from cache_dit.cache_factory.cache_blocks import CachedBlocks
|
|
16
|
+
|
|
17
|
+
from cache_dit.cache_factory.cache_adapters import CachedAdapter
|
|
18
|
+
|
|
7
19
|
from cache_dit.cache_factory.cache_interface import enable_cache
|
|
8
20
|
from cache_dit.cache_factory.cache_interface import supported_pipelines
|
|
9
|
-
from cache_dit.cache_factory.
|
|
21
|
+
from cache_dit.cache_factory.cache_interface import get_adapter
|
|
22
|
+
|
|
10
23
|
from cache_dit.cache_factory.utils import load_options
|
|
@@ -0,0 +1,555 @@
|
|
|
1
|
+
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
2
|
+
from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
|
|
3
|
+
from cache_dit.cache_factory.block_adapters.block_adapters import ParamsModifier
|
|
4
|
+
from cache_dit.cache_factory.block_adapters.block_registers import (
|
|
5
|
+
BlockAdapterRegistry,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@BlockAdapterRegistry.register("Flux")
|
|
10
|
+
def flux_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
11
|
+
from diffusers import FluxTransformer2DModel
|
|
12
|
+
from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
|
|
13
|
+
|
|
14
|
+
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
15
|
+
|
|
16
|
+
return BlockAdapter(
|
|
17
|
+
pipe=pipe,
|
|
18
|
+
transformer=pipe.transformer,
|
|
19
|
+
blocks=(
|
|
20
|
+
pipe.transformer.transformer_blocks
|
|
21
|
+
+ pipe.transformer.single_transformer_blocks
|
|
22
|
+
),
|
|
23
|
+
blocks_name="transformer_blocks",
|
|
24
|
+
dummy_blocks_names=["single_transformer_blocks"],
|
|
25
|
+
patch_functor=FluxPatchFunctor(),
|
|
26
|
+
forward_pattern=ForwardPattern.Pattern_1,
|
|
27
|
+
disable_patch=kwargs.pop("disable_patch", False),
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@BlockAdapterRegistry.register("Mochi")
|
|
32
|
+
def mochi_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
33
|
+
from diffusers import MochiTransformer3DModel
|
|
34
|
+
|
|
35
|
+
assert isinstance(pipe.transformer, MochiTransformer3DModel)
|
|
36
|
+
return BlockAdapter(
|
|
37
|
+
pipe=pipe,
|
|
38
|
+
transformer=pipe.transformer,
|
|
39
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
40
|
+
blocks_name="transformer_blocks",
|
|
41
|
+
dummy_blocks_names=[],
|
|
42
|
+
forward_pattern=ForwardPattern.Pattern_0,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@BlockAdapterRegistry.register("CogVideoX")
|
|
47
|
+
def cogvideox_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
48
|
+
from diffusers import CogVideoXTransformer3DModel
|
|
49
|
+
|
|
50
|
+
assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
|
|
51
|
+
return BlockAdapter(
|
|
52
|
+
pipe=pipe,
|
|
53
|
+
transformer=pipe.transformer,
|
|
54
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
55
|
+
blocks_name="transformer_blocks",
|
|
56
|
+
dummy_blocks_names=[],
|
|
57
|
+
forward_pattern=ForwardPattern.Pattern_0,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@BlockAdapterRegistry.register("Wan")
|
|
62
|
+
def wan_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
63
|
+
from diffusers import (
|
|
64
|
+
WanTransformer3DModel,
|
|
65
|
+
WanVACETransformer3DModel,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
assert isinstance(
|
|
69
|
+
pipe.transformer,
|
|
70
|
+
(WanTransformer3DModel, WanVACETransformer3DModel),
|
|
71
|
+
)
|
|
72
|
+
if getattr(pipe, "transformer_2", None):
|
|
73
|
+
assert isinstance(
|
|
74
|
+
pipe.transformer_2,
|
|
75
|
+
(WanTransformer3DModel, WanVACETransformer3DModel),
|
|
76
|
+
)
|
|
77
|
+
# Wan 2.2 MoE
|
|
78
|
+
return BlockAdapter(
|
|
79
|
+
pipe=pipe,
|
|
80
|
+
transformer=[
|
|
81
|
+
pipe.transformer,
|
|
82
|
+
pipe.transformer_2,
|
|
83
|
+
],
|
|
84
|
+
blocks=[
|
|
85
|
+
pipe.transformer.blocks,
|
|
86
|
+
pipe.transformer_2.blocks,
|
|
87
|
+
],
|
|
88
|
+
blocks_name=[
|
|
89
|
+
"blocks",
|
|
90
|
+
"blocks",
|
|
91
|
+
],
|
|
92
|
+
forward_pattern=[
|
|
93
|
+
ForwardPattern.Pattern_2,
|
|
94
|
+
ForwardPattern.Pattern_2,
|
|
95
|
+
],
|
|
96
|
+
dummy_blocks_names=[],
|
|
97
|
+
has_separate_cfg=True,
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
# Wan 2.1
|
|
101
|
+
return BlockAdapter(
|
|
102
|
+
pipe=pipe,
|
|
103
|
+
transformer=pipe.transformer,
|
|
104
|
+
blocks=pipe.transformer.blocks,
|
|
105
|
+
blocks_name="blocks",
|
|
106
|
+
dummy_blocks_names=[],
|
|
107
|
+
forward_pattern=ForwardPattern.Pattern_2,
|
|
108
|
+
has_separate_cfg=True,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@BlockAdapterRegistry.register("HunyuanVideo")
|
|
113
|
+
def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
114
|
+
from diffusers import HunyuanVideoTransformer3DModel
|
|
115
|
+
|
|
116
|
+
assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
|
|
117
|
+
return BlockAdapter(
|
|
118
|
+
pipe=pipe,
|
|
119
|
+
blocks=(
|
|
120
|
+
pipe.transformer.transformer_blocks
|
|
121
|
+
+ pipe.transformer.single_transformer_blocks
|
|
122
|
+
),
|
|
123
|
+
blocks_name="transformer_blocks",
|
|
124
|
+
dummy_blocks_names=["single_transformer_blocks"],
|
|
125
|
+
forward_pattern=ForwardPattern.Pattern_0,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@BlockAdapterRegistry.register("QwenImage")
|
|
130
|
+
def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
131
|
+
from diffusers import QwenImageTransformer2DModel
|
|
132
|
+
|
|
133
|
+
assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
|
|
134
|
+
return BlockAdapter(
|
|
135
|
+
pipe=pipe,
|
|
136
|
+
transformer=pipe.transformer,
|
|
137
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
138
|
+
blocks_name="transformer_blocks",
|
|
139
|
+
dummy_blocks_names=[],
|
|
140
|
+
forward_pattern=ForwardPattern.Pattern_1,
|
|
141
|
+
has_separate_cfg=True,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@BlockAdapterRegistry.register("LTXVideo")
|
|
146
|
+
def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
147
|
+
from diffusers import LTXVideoTransformer3DModel
|
|
148
|
+
|
|
149
|
+
assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
|
|
150
|
+
return BlockAdapter(
|
|
151
|
+
pipe=pipe,
|
|
152
|
+
transformer=pipe.transformer,
|
|
153
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
154
|
+
blocks_name="transformer_blocks",
|
|
155
|
+
dummy_blocks_names=[],
|
|
156
|
+
forward_pattern=ForwardPattern.Pattern_2,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@BlockAdapterRegistry.register("Allegro")
|
|
161
|
+
def allegro_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
162
|
+
from diffusers import AllegroTransformer3DModel
|
|
163
|
+
|
|
164
|
+
assert isinstance(pipe.transformer, AllegroTransformer3DModel)
|
|
165
|
+
return BlockAdapter(
|
|
166
|
+
pipe=pipe,
|
|
167
|
+
transformer=pipe.transformer,
|
|
168
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
169
|
+
blocks_name="transformer_blocks",
|
|
170
|
+
dummy_blocks_names=[],
|
|
171
|
+
forward_pattern=ForwardPattern.Pattern_2,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@BlockAdapterRegistry.register("CogView3Plus")
|
|
176
|
+
def cogview3plus_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
177
|
+
from diffusers import CogView3PlusTransformer2DModel
|
|
178
|
+
|
|
179
|
+
assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
|
|
180
|
+
return BlockAdapter(
|
|
181
|
+
pipe=pipe,
|
|
182
|
+
transformer=pipe.transformer,
|
|
183
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
184
|
+
blocks_name="transformer_blocks",
|
|
185
|
+
dummy_blocks_names=[],
|
|
186
|
+
forward_pattern=ForwardPattern.Pattern_0,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@BlockAdapterRegistry.register("CogView4")
|
|
191
|
+
def cogview4_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
192
|
+
from diffusers import CogView4Transformer2DModel
|
|
193
|
+
|
|
194
|
+
assert isinstance(pipe.transformer, CogView4Transformer2DModel)
|
|
195
|
+
return BlockAdapter(
|
|
196
|
+
pipe=pipe,
|
|
197
|
+
transformer=pipe.transformer,
|
|
198
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
199
|
+
blocks_name="transformer_blocks",
|
|
200
|
+
dummy_blocks_names=[],
|
|
201
|
+
forward_pattern=ForwardPattern.Pattern_0,
|
|
202
|
+
has_separate_cfg=True,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@BlockAdapterRegistry.register("Cosmos")
|
|
207
|
+
def cosmos_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
208
|
+
from diffusers import CosmosTransformer3DModel
|
|
209
|
+
|
|
210
|
+
assert isinstance(pipe.transformer, CosmosTransformer3DModel)
|
|
211
|
+
return BlockAdapter(
|
|
212
|
+
pipe=pipe,
|
|
213
|
+
transformer=pipe.transformer,
|
|
214
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
215
|
+
blocks_name="transformer_blocks",
|
|
216
|
+
dummy_blocks_names=[],
|
|
217
|
+
forward_pattern=ForwardPattern.Pattern_2,
|
|
218
|
+
has_separate_cfg=True,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@BlockAdapterRegistry.register("EasyAnimate")
|
|
223
|
+
def easyanimate_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
224
|
+
from diffusers import EasyAnimateTransformer3DModel
|
|
225
|
+
|
|
226
|
+
assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
|
|
227
|
+
return BlockAdapter(
|
|
228
|
+
pipe=pipe,
|
|
229
|
+
transformer=pipe.transformer,
|
|
230
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
231
|
+
blocks_name="transformer_blocks",
|
|
232
|
+
dummy_blocks_names=[],
|
|
233
|
+
forward_pattern=ForwardPattern.Pattern_0,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
@BlockAdapterRegistry.register("SkyReelsV2")
|
|
238
|
+
def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
239
|
+
from diffusers import SkyReelsV2Transformer3DModel
|
|
240
|
+
|
|
241
|
+
assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
|
|
242
|
+
return BlockAdapter(
|
|
243
|
+
pipe=pipe,
|
|
244
|
+
transformer=pipe.transformer,
|
|
245
|
+
blocks=pipe.transformer.blocks,
|
|
246
|
+
blocks_name="blocks",
|
|
247
|
+
dummy_blocks_names=[],
|
|
248
|
+
forward_pattern=ForwardPattern.Pattern_2,
|
|
249
|
+
has_separate_cfg=True,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
@BlockAdapterRegistry.register("SD3")
|
|
254
|
+
def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
255
|
+
from diffusers import SD3Transformer2DModel
|
|
256
|
+
|
|
257
|
+
assert isinstance(pipe.transformer, SD3Transformer2DModel)
|
|
258
|
+
return BlockAdapter(
|
|
259
|
+
pipe=pipe,
|
|
260
|
+
transformer=pipe.transformer,
|
|
261
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
262
|
+
blocks_name="transformer_blocks",
|
|
263
|
+
dummy_blocks_names=[],
|
|
264
|
+
forward_pattern=ForwardPattern.Pattern_1,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@BlockAdapterRegistry.register("ConsisID")
|
|
269
|
+
def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
270
|
+
from diffusers import ConsisIDTransformer3DModel
|
|
271
|
+
|
|
272
|
+
assert isinstance(pipe.transformer, ConsisIDTransformer3DModel)
|
|
273
|
+
return BlockAdapter(
|
|
274
|
+
pipe=pipe,
|
|
275
|
+
transformer=pipe.transformer,
|
|
276
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
277
|
+
blocks_name="transformer_blocks",
|
|
278
|
+
dummy_blocks_names=[],
|
|
279
|
+
forward_pattern=ForwardPattern.Pattern_0,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
@BlockAdapterRegistry.register("DiT")
|
|
284
|
+
def dit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
285
|
+
from diffusers import DiTTransformer2DModel
|
|
286
|
+
|
|
287
|
+
assert isinstance(pipe.transformer, DiTTransformer2DModel)
|
|
288
|
+
return BlockAdapter(
|
|
289
|
+
pipe=pipe,
|
|
290
|
+
transformer=pipe.transformer,
|
|
291
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
292
|
+
blocks_name="transformer_blocks",
|
|
293
|
+
dummy_blocks_names=[],
|
|
294
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
@BlockAdapterRegistry.register("Amused")
|
|
299
|
+
def amused_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
300
|
+
from diffusers import UVit2DModel
|
|
301
|
+
|
|
302
|
+
assert isinstance(pipe.transformer, UVit2DModel)
|
|
303
|
+
return BlockAdapter(
|
|
304
|
+
pipe=pipe,
|
|
305
|
+
transformer=pipe.transformer,
|
|
306
|
+
blocks=pipe.transformer.transformer_layers,
|
|
307
|
+
blocks_name="transformer_layers",
|
|
308
|
+
dummy_blocks_names=[],
|
|
309
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
@BlockAdapterRegistry.register("Bria")
|
|
314
|
+
def bria_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
315
|
+
from diffusers import BriaTransformer2DModel
|
|
316
|
+
|
|
317
|
+
assert isinstance(pipe.transformer, BriaTransformer2DModel)
|
|
318
|
+
return BlockAdapter(
|
|
319
|
+
pipe=pipe,
|
|
320
|
+
transformer=pipe.transformer,
|
|
321
|
+
blocks=(
|
|
322
|
+
pipe.transformer.transformer_blocks
|
|
323
|
+
+ pipe.transformer.single_transformer_blocks
|
|
324
|
+
),
|
|
325
|
+
blocks_name="transformer_blocks",
|
|
326
|
+
dummy_blocks_names=["single_transformer_blocks"],
|
|
327
|
+
forward_pattern=ForwardPattern.Pattern_0,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
@BlockAdapterRegistry.register("HunyuanDiT")
|
|
332
|
+
def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
333
|
+
from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
|
|
334
|
+
|
|
335
|
+
assert isinstance(
|
|
336
|
+
pipe.transformer,
|
|
337
|
+
(HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
|
|
338
|
+
)
|
|
339
|
+
return BlockAdapter(
|
|
340
|
+
pipe=pipe,
|
|
341
|
+
transformer=pipe.transformer,
|
|
342
|
+
blocks=pipe.transformer.blocks,
|
|
343
|
+
blocks_name="blocks",
|
|
344
|
+
dummy_blocks_names=[],
|
|
345
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
@BlockAdapterRegistry.register("HunyuanDiTPAG")
|
|
350
|
+
def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
351
|
+
from diffusers import HunyuanDiT2DModel
|
|
352
|
+
|
|
353
|
+
assert isinstance(pipe.transformer, HunyuanDiT2DModel)
|
|
354
|
+
return BlockAdapter(
|
|
355
|
+
pipe=pipe,
|
|
356
|
+
transformer=pipe.transformer,
|
|
357
|
+
blocks=pipe.transformer.blocks,
|
|
358
|
+
blocks_name="blocks",
|
|
359
|
+
dummy_blocks_names=[],
|
|
360
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
@BlockAdapterRegistry.register("Lumina")
|
|
365
|
+
def lumina_adapter(pipe) -> BlockAdapter:
|
|
366
|
+
from diffusers import LuminaNextDiT2DModel
|
|
367
|
+
|
|
368
|
+
assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
|
|
369
|
+
return BlockAdapter(
|
|
370
|
+
pipe=pipe,
|
|
371
|
+
transformer=pipe.transformer,
|
|
372
|
+
blocks=pipe.transformer.layers,
|
|
373
|
+
blocks_name="layers",
|
|
374
|
+
dummy_blocks_names=[],
|
|
375
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
@BlockAdapterRegistry.register("Lumina2")
|
|
380
|
+
def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
381
|
+
from diffusers import Lumina2Transformer2DModel
|
|
382
|
+
|
|
383
|
+
assert isinstance(pipe.transformer, Lumina2Transformer2DModel)
|
|
384
|
+
return BlockAdapter(
|
|
385
|
+
pipe=pipe,
|
|
386
|
+
transformer=pipe.transformer,
|
|
387
|
+
blocks=pipe.transformer.layers,
|
|
388
|
+
blocks_name="layers",
|
|
389
|
+
dummy_blocks_names=[],
|
|
390
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
@BlockAdapterRegistry.register("OmniGen")
|
|
395
|
+
def omnigen_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
396
|
+
from diffusers import OmniGenTransformer2DModel
|
|
397
|
+
|
|
398
|
+
assert isinstance(pipe.transformer, OmniGenTransformer2DModel)
|
|
399
|
+
return BlockAdapter(
|
|
400
|
+
pipe=pipe,
|
|
401
|
+
transformer=pipe.transformer,
|
|
402
|
+
blocks=pipe.transformer.layers,
|
|
403
|
+
blocks_name="layers",
|
|
404
|
+
dummy_blocks_names=[],
|
|
405
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
@BlockAdapterRegistry.register("PixArt")
|
|
410
|
+
def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
411
|
+
from diffusers import PixArtTransformer2DModel
|
|
412
|
+
|
|
413
|
+
assert isinstance(pipe.transformer, PixArtTransformer2DModel)
|
|
414
|
+
return BlockAdapter(
|
|
415
|
+
pipe=pipe,
|
|
416
|
+
transformer=pipe.transformer,
|
|
417
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
418
|
+
blocks_name="transformer_blocks",
|
|
419
|
+
dummy_blocks_names=[],
|
|
420
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
@BlockAdapterRegistry.register("Sana")
|
|
425
|
+
def sana_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
426
|
+
from diffusers import SanaTransformer2DModel
|
|
427
|
+
|
|
428
|
+
assert isinstance(pipe.transformer, SanaTransformer2DModel)
|
|
429
|
+
return BlockAdapter(
|
|
430
|
+
pipe=pipe,
|
|
431
|
+
transformer=pipe.transformer,
|
|
432
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
433
|
+
blocks_name="transformer_blocks",
|
|
434
|
+
dummy_blocks_names=[],
|
|
435
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
@BlockAdapterRegistry.register("ShapE")
|
|
440
|
+
def shape_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
441
|
+
from diffusers import PriorTransformer
|
|
442
|
+
|
|
443
|
+
assert isinstance(pipe.prior, PriorTransformer)
|
|
444
|
+
return BlockAdapter(
|
|
445
|
+
pipe=pipe,
|
|
446
|
+
transformer=pipe.prior,
|
|
447
|
+
blocks=pipe.prior.transformer_blocks,
|
|
448
|
+
blocks_name="transformer_blocks",
|
|
449
|
+
dummy_blocks_names=[],
|
|
450
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
@BlockAdapterRegistry.register("StableAudio")
|
|
455
|
+
def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
456
|
+
from diffusers import StableAudioDiTModel
|
|
457
|
+
|
|
458
|
+
assert isinstance(pipe.transformer, StableAudioDiTModel)
|
|
459
|
+
return BlockAdapter(
|
|
460
|
+
pipe=pipe,
|
|
461
|
+
transformer=pipe.transformer,
|
|
462
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
463
|
+
blocks_name="transformer_blocks",
|
|
464
|
+
dummy_blocks_names=[],
|
|
465
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
@BlockAdapterRegistry.register("VisualCloze")
|
|
470
|
+
def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
471
|
+
from diffusers import FluxTransformer2DModel
|
|
472
|
+
from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
|
|
473
|
+
|
|
474
|
+
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
475
|
+
return BlockAdapter(
|
|
476
|
+
pipe=pipe,
|
|
477
|
+
transformer=pipe.transformer,
|
|
478
|
+
blocks=(
|
|
479
|
+
pipe.transformer.transformer_blocks
|
|
480
|
+
+ pipe.transformer.single_transformer_blocks
|
|
481
|
+
),
|
|
482
|
+
blocks_name="transformer_blocks",
|
|
483
|
+
dummy_blocks_names=["single_transformer_blocks"],
|
|
484
|
+
patch_functor=FluxPatchFunctor(),
|
|
485
|
+
forward_pattern=ForwardPattern.Pattern_1,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
@BlockAdapterRegistry.register("AuraFlow")
|
|
490
|
+
def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
491
|
+
from diffusers import AuraFlowTransformer2DModel
|
|
492
|
+
|
|
493
|
+
assert isinstance(pipe.transformer, AuraFlowTransformer2DModel)
|
|
494
|
+
return BlockAdapter(
|
|
495
|
+
pipe=pipe,
|
|
496
|
+
transformer=pipe.transformer,
|
|
497
|
+
blocks=[
|
|
498
|
+
# Only 4 joint blocks, apply no-cache
|
|
499
|
+
# pipe.transformer.joint_transformer_blocks,
|
|
500
|
+
pipe.transformer.single_transformer_blocks,
|
|
501
|
+
],
|
|
502
|
+
blocks_name=[
|
|
503
|
+
# "joint_transformer_blocks",
|
|
504
|
+
"single_transformer_blocks",
|
|
505
|
+
],
|
|
506
|
+
forward_pattern=[
|
|
507
|
+
# ForwardPattern.Pattern_1,
|
|
508
|
+
ForwardPattern.Pattern_3,
|
|
509
|
+
],
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
@BlockAdapterRegistry.register("Chroma")
|
|
514
|
+
def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
515
|
+
from diffusers import ChromaTransformer2DModel
|
|
516
|
+
|
|
517
|
+
assert isinstance(pipe.transformer, ChromaTransformer2DModel)
|
|
518
|
+
return BlockAdapter(
|
|
519
|
+
pipe=pipe,
|
|
520
|
+
transformer=pipe.transformer,
|
|
521
|
+
blocks=[
|
|
522
|
+
pipe.transformer.transformer_blocks,
|
|
523
|
+
pipe.transformer.single_transformer_blocks,
|
|
524
|
+
],
|
|
525
|
+
blocks_name=[
|
|
526
|
+
"transformer_blocks",
|
|
527
|
+
"single_transformer_blocks",
|
|
528
|
+
],
|
|
529
|
+
forward_pattern=[
|
|
530
|
+
ForwardPattern.Pattern_1,
|
|
531
|
+
ForwardPattern.Pattern_3,
|
|
532
|
+
],
|
|
533
|
+
has_separate_cfg=True,
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
@BlockAdapterRegistry.register("HiDream")
|
|
538
|
+
def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
539
|
+
from diffusers import HiDreamImageTransformer2DModel
|
|
540
|
+
|
|
541
|
+
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
|
542
|
+
return BlockAdapter(
|
|
543
|
+
pipe=pipe,
|
|
544
|
+
transformer=pipe.transformer,
|
|
545
|
+
blocks=[
|
|
546
|
+
pipe.transformer.double_stream_blocks,
|
|
547
|
+
pipe.transformer.single_stream_blocks,
|
|
548
|
+
],
|
|
549
|
+
dummy_blocks_names=[],
|
|
550
|
+
forward_pattern=[
|
|
551
|
+
ForwardPattern.Pattern_4,
|
|
552
|
+
ForwardPattern.Pattern_3,
|
|
553
|
+
],
|
|
554
|
+
check_num_outputs=False,
|
|
555
|
+
)
|