cache-dit 1.0.9__py3-none-any.whl → 1.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/block_adapters/__init__.py +37 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +51 -3
- cache_dit/cache_factory/block_adapters/block_registers.py +41 -14
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +68 -30
- cache_dit/cache_factory/cache_contexts/cache_config.py +5 -3
- cache_dit/cache_factory/cache_contexts/cache_manager.py +125 -4
- cache_dit/cache_factory/cache_contexts/context_manager.py +9 -2
- cache_dit/cache_factory/cache_contexts/prune_manager.py +15 -2
- cache_dit/cache_factory/cache_interface.py +29 -3
- cache_dit/cache_factory/forward_pattern.py +14 -14
- cache_dit/parallelism/backends/native_diffusers/__init__.py +0 -3
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +74 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +254 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +17 -61
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +3 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +159 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +58 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +12 -0
- cache_dit/parallelism/parallel_backend.py +2 -0
- cache_dit/parallelism/parallel_config.py +8 -1
- cache_dit/parallelism/parallel_interface.py +9 -4
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +28 -9
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/utils.py +22 -2
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/METADATA +22 -13
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/RECORD +45 -29
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '1.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (1, 0,
|
|
31
|
+
__version__ = version = '1.0.10'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 0, 10)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -8,6 +8,7 @@ from cache_dit.cache_factory.patch_functors import PatchFunctor
|
|
|
8
8
|
|
|
9
9
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
10
10
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
11
|
+
from cache_dit.cache_factory.block_adapters import FakeDiffusionPipeline
|
|
11
12
|
|
|
12
13
|
from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
|
|
13
14
|
from cache_dit.cache_factory.cache_contexts import DBCacheConfig
|
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
2
2
|
from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
|
|
3
|
+
from cache_dit.cache_factory.block_adapters.block_adapters import (
|
|
4
|
+
FakeDiffusionPipeline,
|
|
5
|
+
)
|
|
3
6
|
from cache_dit.cache_factory.block_adapters.block_adapters import ParamsModifier
|
|
4
7
|
from cache_dit.cache_factory.block_adapters.block_registers import (
|
|
5
8
|
BlockAdapterRegistry,
|
|
@@ -27,6 +30,7 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
27
30
|
ForwardPattern.Pattern_1,
|
|
28
31
|
ForwardPattern.Pattern_1,
|
|
29
32
|
],
|
|
33
|
+
check_forward_pattern=True,
|
|
30
34
|
**kwargs,
|
|
31
35
|
)
|
|
32
36
|
else:
|
|
@@ -41,6 +45,7 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
41
45
|
ForwardPattern.Pattern_1,
|
|
42
46
|
ForwardPattern.Pattern_3,
|
|
43
47
|
],
|
|
48
|
+
check_forward_pattern=True,
|
|
44
49
|
**kwargs,
|
|
45
50
|
)
|
|
46
51
|
|
|
@@ -55,6 +60,7 @@ def mochi_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
55
60
|
transformer=pipe.transformer,
|
|
56
61
|
blocks=pipe.transformer.transformer_blocks,
|
|
57
62
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
63
|
+
check_forward_pattern=True,
|
|
58
64
|
**kwargs,
|
|
59
65
|
)
|
|
60
66
|
|
|
@@ -69,6 +75,7 @@ def cogvideox_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
69
75
|
transformer=pipe.transformer,
|
|
70
76
|
blocks=pipe.transformer.transformer_blocks,
|
|
71
77
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
78
|
+
check_forward_pattern=True,
|
|
72
79
|
**kwargs,
|
|
73
80
|
)
|
|
74
81
|
|
|
@@ -104,6 +111,7 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
104
111
|
ForwardPattern.Pattern_2,
|
|
105
112
|
ForwardPattern.Pattern_2,
|
|
106
113
|
],
|
|
114
|
+
check_forward_pattern=True,
|
|
107
115
|
has_separate_cfg=True,
|
|
108
116
|
**kwargs,
|
|
109
117
|
)
|
|
@@ -114,6 +122,7 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
114
122
|
transformer=pipe.transformer,
|
|
115
123
|
blocks=pipe.transformer.blocks,
|
|
116
124
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
125
|
+
check_forward_pattern=True,
|
|
117
126
|
has_separate_cfg=True,
|
|
118
127
|
**kwargs,
|
|
119
128
|
)
|
|
@@ -135,6 +144,7 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
135
144
|
ForwardPattern.Pattern_0,
|
|
136
145
|
ForwardPattern.Pattern_0,
|
|
137
146
|
],
|
|
147
|
+
check_forward_pattern=True,
|
|
138
148
|
# The type hint in diffusers is wrong
|
|
139
149
|
check_num_outputs=False,
|
|
140
150
|
**kwargs,
|
|
@@ -159,6 +169,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
159
169
|
blocks=pipe.transformer.transformer_blocks,
|
|
160
170
|
forward_pattern=ForwardPattern.Pattern_1,
|
|
161
171
|
patch_functor=QwenImageControlNetPatchFunctor(),
|
|
172
|
+
check_forward_pattern=True,
|
|
162
173
|
has_separate_cfg=True,
|
|
163
174
|
)
|
|
164
175
|
else:
|
|
@@ -167,6 +178,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
167
178
|
transformer=pipe.transformer,
|
|
168
179
|
blocks=pipe.transformer.transformer_blocks,
|
|
169
180
|
forward_pattern=ForwardPattern.Pattern_1,
|
|
181
|
+
check_forward_pattern=True,
|
|
170
182
|
has_separate_cfg=True,
|
|
171
183
|
**kwargs,
|
|
172
184
|
)
|
|
@@ -182,6 +194,7 @@ def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
182
194
|
transformer=pipe.transformer,
|
|
183
195
|
blocks=pipe.transformer.transformer_blocks,
|
|
184
196
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
197
|
+
check_forward_pattern=True,
|
|
185
198
|
**kwargs,
|
|
186
199
|
)
|
|
187
200
|
|
|
@@ -196,6 +209,7 @@ def allegro_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
196
209
|
transformer=pipe.transformer,
|
|
197
210
|
blocks=pipe.transformer.transformer_blocks,
|
|
198
211
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
212
|
+
check_forward_pattern=True,
|
|
199
213
|
**kwargs,
|
|
200
214
|
)
|
|
201
215
|
|
|
@@ -210,6 +224,7 @@ def cogview3plus_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
210
224
|
transformer=pipe.transformer,
|
|
211
225
|
blocks=pipe.transformer.transformer_blocks,
|
|
212
226
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
227
|
+
check_forward_pattern=True,
|
|
213
228
|
**kwargs,
|
|
214
229
|
)
|
|
215
230
|
|
|
@@ -224,6 +239,7 @@ def cogview4_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
224
239
|
transformer=pipe.transformer,
|
|
225
240
|
blocks=pipe.transformer.transformer_blocks,
|
|
226
241
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
242
|
+
check_forward_pattern=True,
|
|
227
243
|
has_separate_cfg=True,
|
|
228
244
|
**kwargs,
|
|
229
245
|
)
|
|
@@ -239,6 +255,7 @@ def cosmos_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
239
255
|
transformer=pipe.transformer,
|
|
240
256
|
blocks=pipe.transformer.transformer_blocks,
|
|
241
257
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
258
|
+
check_forward_pattern=True,
|
|
242
259
|
has_separate_cfg=True,
|
|
243
260
|
**kwargs,
|
|
244
261
|
)
|
|
@@ -254,6 +271,7 @@ def easyanimate_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
254
271
|
transformer=pipe.transformer,
|
|
255
272
|
blocks=pipe.transformer.transformer_blocks,
|
|
256
273
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
274
|
+
check_forward_pattern=True,
|
|
257
275
|
**kwargs,
|
|
258
276
|
)
|
|
259
277
|
|
|
@@ -271,6 +289,7 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
271
289
|
# encoder_hidden_states will never change in the blocks
|
|
272
290
|
# forward loop.
|
|
273
291
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
292
|
+
check_forward_pattern=True,
|
|
274
293
|
has_separate_cfg=True,
|
|
275
294
|
**kwargs,
|
|
276
295
|
)
|
|
@@ -286,6 +305,7 @@ def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
286
305
|
transformer=pipe.transformer,
|
|
287
306
|
blocks=pipe.transformer.transformer_blocks,
|
|
288
307
|
forward_pattern=ForwardPattern.Pattern_1,
|
|
308
|
+
check_forward_pattern=True,
|
|
289
309
|
**kwargs,
|
|
290
310
|
)
|
|
291
311
|
|
|
@@ -300,6 +320,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
300
320
|
transformer=pipe.transformer,
|
|
301
321
|
blocks=pipe.transformer.transformer_blocks,
|
|
302
322
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
323
|
+
check_forward_pattern=True,
|
|
303
324
|
**kwargs,
|
|
304
325
|
)
|
|
305
326
|
|
|
@@ -316,6 +337,7 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
316
337
|
blocks=pipe.transformer.transformer_blocks,
|
|
317
338
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
318
339
|
patch_functor=DiTPatchFunctor(),
|
|
340
|
+
check_forward_pattern=True,
|
|
319
341
|
**kwargs,
|
|
320
342
|
)
|
|
321
343
|
|
|
@@ -330,6 +352,7 @@ def amused_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
330
352
|
transformer=pipe.transformer,
|
|
331
353
|
blocks=pipe.transformer.transformer_layers,
|
|
332
354
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
355
|
+
check_forward_pattern=True,
|
|
333
356
|
**kwargs,
|
|
334
357
|
)
|
|
335
358
|
|
|
@@ -350,6 +373,7 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
350
373
|
ForwardPattern.Pattern_0,
|
|
351
374
|
ForwardPattern.Pattern_0,
|
|
352
375
|
],
|
|
376
|
+
check_forward_pattern=True,
|
|
353
377
|
**kwargs,
|
|
354
378
|
)
|
|
355
379
|
|
|
@@ -367,6 +391,7 @@ def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
367
391
|
transformer=pipe.transformer,
|
|
368
392
|
blocks=pipe.transformer.layers,
|
|
369
393
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
394
|
+
check_forward_pattern=True,
|
|
370
395
|
**kwargs,
|
|
371
396
|
)
|
|
372
397
|
|
|
@@ -381,6 +406,7 @@ def omnigen_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
381
406
|
transformer=pipe.transformer,
|
|
382
407
|
blocks=pipe.transformer.layers,
|
|
383
408
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
409
|
+
check_forward_pattern=True,
|
|
384
410
|
**kwargs,
|
|
385
411
|
)
|
|
386
412
|
|
|
@@ -395,6 +421,7 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
395
421
|
transformer=pipe.transformer,
|
|
396
422
|
blocks=pipe.transformer.transformer_blocks,
|
|
397
423
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
424
|
+
check_forward_pattern=True,
|
|
398
425
|
**kwargs,
|
|
399
426
|
)
|
|
400
427
|
|
|
@@ -409,6 +436,7 @@ def sana_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
409
436
|
transformer=pipe.transformer,
|
|
410
437
|
blocks=pipe.transformer.transformer_blocks,
|
|
411
438
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
439
|
+
check_forward_pattern=True,
|
|
412
440
|
**kwargs,
|
|
413
441
|
)
|
|
414
442
|
|
|
@@ -423,6 +451,7 @@ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
423
451
|
transformer=pipe.transformer,
|
|
424
452
|
blocks=pipe.transformer.transformer_blocks,
|
|
425
453
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
454
|
+
check_forward_pattern=True,
|
|
426
455
|
**kwargs,
|
|
427
456
|
)
|
|
428
457
|
|
|
@@ -445,6 +474,7 @@ def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
445
474
|
ForwardPattern.Pattern_1,
|
|
446
475
|
ForwardPattern.Pattern_1,
|
|
447
476
|
],
|
|
477
|
+
check_forward_pattern=True,
|
|
448
478
|
**kwargs,
|
|
449
479
|
)
|
|
450
480
|
else:
|
|
@@ -459,6 +489,7 @@ def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
459
489
|
ForwardPattern.Pattern_1,
|
|
460
490
|
ForwardPattern.Pattern_3,
|
|
461
491
|
],
|
|
492
|
+
check_forward_pattern=True,
|
|
462
493
|
**kwargs,
|
|
463
494
|
)
|
|
464
495
|
|
|
@@ -473,6 +504,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
473
504
|
transformer=pipe.transformer,
|
|
474
505
|
blocks=pipe.transformer.single_transformer_blocks,
|
|
475
506
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
507
|
+
check_forward_pattern=True,
|
|
476
508
|
**kwargs,
|
|
477
509
|
)
|
|
478
510
|
|
|
@@ -495,6 +527,7 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
495
527
|
ForwardPattern.Pattern_3,
|
|
496
528
|
],
|
|
497
529
|
patch_functor=ChromaPatchFunctor(),
|
|
530
|
+
check_forward_pattern=True,
|
|
498
531
|
has_separate_cfg=True,
|
|
499
532
|
**kwargs,
|
|
500
533
|
)
|
|
@@ -510,6 +543,7 @@ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
510
543
|
transformer=pipe.prior,
|
|
511
544
|
blocks=pipe.prior.transformer_blocks,
|
|
512
545
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
546
|
+
check_forward_pattern=True,
|
|
513
547
|
**kwargs,
|
|
514
548
|
)
|
|
515
549
|
|
|
@@ -559,6 +593,7 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
559
593
|
blocks=pipe.transformer.blocks,
|
|
560
594
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
561
595
|
patch_functor=HunyuanDiTPatchFunctor(),
|
|
596
|
+
check_forward_pattern=True,
|
|
562
597
|
**kwargs,
|
|
563
598
|
)
|
|
564
599
|
|
|
@@ -575,6 +610,7 @@ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
575
610
|
blocks=pipe.transformer.blocks,
|
|
576
611
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
577
612
|
patch_functor=HunyuanDiTPatchFunctor(),
|
|
613
|
+
check_forward_pattern=True,
|
|
578
614
|
**kwargs,
|
|
579
615
|
)
|
|
580
616
|
|
|
@@ -613,6 +649,7 @@ def prx_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
613
649
|
transformer=pipe.transformer,
|
|
614
650
|
blocks=pipe.transformer.blocks,
|
|
615
651
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
652
|
+
check_forward_pattern=True,
|
|
616
653
|
check_num_outputs=False,
|
|
617
654
|
**kwargs,
|
|
618
655
|
)
|
|
@@ -6,7 +6,7 @@ from collections.abc import Iterable
|
|
|
6
6
|
|
|
7
7
|
from typing import Any, Tuple, List, Optional, Union
|
|
8
8
|
|
|
9
|
-
from diffusers import DiffusionPipeline
|
|
9
|
+
from diffusers import DiffusionPipeline, ModelMixin
|
|
10
10
|
from cache_dit.cache_factory.patch_functors import PatchFunctor
|
|
11
11
|
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
12
12
|
from cache_dit.cache_factory.params_modifier import ParamsModifier
|
|
@@ -16,12 +16,22 @@ from cache_dit.logger import init_logger
|
|
|
16
16
|
logger = init_logger(__name__)
|
|
17
17
|
|
|
18
18
|
|
|
19
|
+
class FakeDiffusionPipeline:
|
|
20
|
+
# A placeholder for pipelines when pipe is None.
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
24
|
+
):
|
|
25
|
+
self.transformer = transformer # Reference only
|
|
26
|
+
|
|
27
|
+
|
|
19
28
|
@dataclasses.dataclass
|
|
20
29
|
class BlockAdapter:
|
|
21
30
|
|
|
22
31
|
# Transformer configurations.
|
|
23
32
|
pipe: Union[
|
|
24
33
|
DiffusionPipeline,
|
|
34
|
+
FakeDiffusionPipeline,
|
|
25
35
|
Any,
|
|
26
36
|
] = None
|
|
27
37
|
|
|
@@ -73,7 +83,7 @@ class BlockAdapter:
|
|
|
73
83
|
]
|
|
74
84
|
] = None
|
|
75
85
|
|
|
76
|
-
check_forward_pattern: bool =
|
|
86
|
+
check_forward_pattern: Optional[bool] = None
|
|
77
87
|
check_num_outputs: bool = False
|
|
78
88
|
|
|
79
89
|
# Pipeline Level Flags
|
|
@@ -110,12 +120,43 @@ class BlockAdapter:
|
|
|
110
120
|
def __post_init__(self):
|
|
111
121
|
if self.skip_post_init:
|
|
112
122
|
return
|
|
123
|
+
|
|
124
|
+
self.maybe_fake_pipe()
|
|
113
125
|
if any((self.pipe is not None, self.transformer is not None)):
|
|
114
126
|
self.maybe_fill_attrs()
|
|
115
127
|
self.maybe_patchify()
|
|
116
128
|
self.maybe_skip_checks()
|
|
117
129
|
|
|
130
|
+
def maybe_fake_pipe(self):
|
|
131
|
+
if self.pipe is None:
|
|
132
|
+
self.pipe = FakeDiffusionPipeline()
|
|
133
|
+
logger.warning("pipe is None, use FakeDiffusionPipeline instead.")
|
|
134
|
+
|
|
118
135
|
def maybe_skip_checks(self):
|
|
136
|
+
if self.check_forward_pattern is None:
|
|
137
|
+
if self.transformer is not None:
|
|
138
|
+
if self.nested_depth(self.transformer) == 0:
|
|
139
|
+
transformer = self.transformer
|
|
140
|
+
elif self.nested_depth(self.transformer) == 1:
|
|
141
|
+
transformer = self.transformer[0]
|
|
142
|
+
else:
|
|
143
|
+
raise ValueError(
|
|
144
|
+
"transformer nested depth can't more than 1, "
|
|
145
|
+
f"current is: {self.nested_depth(self.transformer)}"
|
|
146
|
+
)
|
|
147
|
+
if transformer.__module__.startswith("diffusers"):
|
|
148
|
+
self.check_forward_pattern = True
|
|
149
|
+
logger.info(
|
|
150
|
+
f"Found transformer from diffusers: {transformer.__module__} "
|
|
151
|
+
"enable check_forward_pattern by default."
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
self.check_forward_pattern = False
|
|
155
|
+
logger.info(
|
|
156
|
+
f"Found transformer NOT from diffusers: {transformer.__module__} "
|
|
157
|
+
"disable check_forward_pattern by default."
|
|
158
|
+
)
|
|
159
|
+
|
|
119
160
|
if getattr(self.transformer, "_hf_hook", None) is not None:
|
|
120
161
|
logger.warning("_hf_hook is not None, force skip pattern check!")
|
|
121
162
|
self.check_forward_pattern = False
|
|
@@ -208,7 +249,10 @@ class BlockAdapter:
|
|
|
208
249
|
if self.transformer is not None:
|
|
209
250
|
self.patch_functor.apply(self.transformer, *args, **kwargs)
|
|
210
251
|
else:
|
|
211
|
-
assert hasattr(self.pipe, "transformer")
|
|
252
|
+
assert hasattr(self.pipe, "transformer"), (
|
|
253
|
+
"pipe.transformer can not be None when patch_functor "
|
|
254
|
+
"is provided and transformer is None."
|
|
255
|
+
)
|
|
212
256
|
self.patch_functor.apply(self.pipe.transformer, *args, **kwargs)
|
|
213
257
|
|
|
214
258
|
@staticmethod
|
|
@@ -224,6 +268,10 @@ class BlockAdapter:
|
|
|
224
268
|
adapter.forward_pattern is not None
|
|
225
269
|
), "adapter.forward_pattern can not be None."
|
|
226
270
|
pipe = adapter.pipe
|
|
271
|
+
if isinstance(pipe, FakeDiffusionPipeline):
|
|
272
|
+
raise ValueError(
|
|
273
|
+
"Can not auto block adapter for FakeDiffusionPipeline."
|
|
274
|
+
)
|
|
227
275
|
|
|
228
276
|
assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
|
|
229
277
|
|
|
@@ -1,7 +1,11 @@
|
|
|
1
|
-
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Any, Tuple, List, Dict, Callable, Union
|
|
2
3
|
|
|
3
4
|
from diffusers import DiffusionPipeline
|
|
4
|
-
from cache_dit.cache_factory.block_adapters.block_adapters import
|
|
5
|
+
from cache_dit.cache_factory.block_adapters.block_adapters import (
|
|
6
|
+
BlockAdapter,
|
|
7
|
+
FakeDiffusionPipeline,
|
|
8
|
+
)
|
|
5
9
|
|
|
6
10
|
from cache_dit.logger import init_logger
|
|
7
11
|
|
|
@@ -35,24 +39,42 @@ class BlockAdapterRegistry:
|
|
|
35
39
|
@classmethod
|
|
36
40
|
def get_adapter(
|
|
37
41
|
cls,
|
|
38
|
-
|
|
42
|
+
pipe_or_module: DiffusionPipeline | torch.nn.Module | str | Any,
|
|
39
43
|
**kwargs,
|
|
40
44
|
) -> BlockAdapter | None:
|
|
41
|
-
if not isinstance(
|
|
42
|
-
|
|
45
|
+
if not isinstance(pipe_or_module, str):
|
|
46
|
+
cls_name: str = pipe_or_module.__class__.__name__
|
|
43
47
|
else:
|
|
44
|
-
|
|
48
|
+
cls_name = pipe_or_module
|
|
45
49
|
|
|
46
50
|
for name in cls._adapters:
|
|
47
|
-
if
|
|
48
|
-
|
|
51
|
+
if cls_name.startswith(name):
|
|
52
|
+
if not isinstance(pipe_or_module, DiffusionPipeline):
|
|
53
|
+
assert isinstance(pipe_or_module, torch.nn.Module)
|
|
54
|
+
# NOTE: Make pre-registered adapters support Transformer-only case.
|
|
55
|
+
# WARN: This branch is not officially supported and only for testing
|
|
56
|
+
# purpose. We construct a fake diffusion pipeline that contains the
|
|
57
|
+
# given transformer module. Currently, only works for DiT models which
|
|
58
|
+
# only have one transformer module. Case like multiple transformers
|
|
59
|
+
# is not supported, e.g, Wan2.2. Please use BlockAdapter directly for
|
|
60
|
+
# such cases.
|
|
61
|
+
return cls._adapters[name](
|
|
62
|
+
FakeDiffusionPipeline(pipe_or_module), **kwargs
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
return cls._adapters[name](pipe_or_module, **kwargs)
|
|
49
66
|
|
|
50
67
|
return None
|
|
51
68
|
|
|
52
69
|
@classmethod
|
|
53
70
|
def has_separate_cfg(
|
|
54
71
|
cls,
|
|
55
|
-
pipe_or_adapter:
|
|
72
|
+
pipe_or_adapter: Union[
|
|
73
|
+
DiffusionPipeline,
|
|
74
|
+
FakeDiffusionPipeline,
|
|
75
|
+
BlockAdapter,
|
|
76
|
+
Any,
|
|
77
|
+
],
|
|
56
78
|
) -> bool:
|
|
57
79
|
|
|
58
80
|
# Prefer custom setting from block adapter.
|
|
@@ -60,11 +82,16 @@ class BlockAdapterRegistry:
|
|
|
60
82
|
return pipe_or_adapter.has_separate_cfg
|
|
61
83
|
|
|
62
84
|
has_separate_cfg = False
|
|
85
|
+
if isinstance(pipe_or_adapter, FakeDiffusionPipeline):
|
|
86
|
+
return False
|
|
87
|
+
|
|
63
88
|
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
64
|
-
|
|
89
|
+
adapter = cls.get_adapter(
|
|
65
90
|
pipe_or_adapter,
|
|
66
91
|
skip_post_init=True, # check cfg setting only
|
|
67
|
-
)
|
|
92
|
+
)
|
|
93
|
+
if adapter is not None:
|
|
94
|
+
has_separate_cfg = adapter.has_separate_cfg
|
|
68
95
|
|
|
69
96
|
if has_separate_cfg:
|
|
70
97
|
return True
|
|
@@ -77,11 +104,11 @@ class BlockAdapterRegistry:
|
|
|
77
104
|
return False
|
|
78
105
|
|
|
79
106
|
@classmethod
|
|
80
|
-
def is_supported(cls,
|
|
81
|
-
|
|
107
|
+
def is_supported(cls, pipe_or_module) -> bool:
|
|
108
|
+
cls_name: str = pipe_or_module.__class__.__name__
|
|
82
109
|
|
|
83
110
|
for name in cls._adapters:
|
|
84
|
-
if
|
|
111
|
+
if cls_name.startswith(name):
|
|
85
112
|
return True
|
|
86
113
|
return False
|
|
87
114
|
|
|
@@ -5,10 +5,11 @@ import functools
|
|
|
5
5
|
from contextlib import ExitStack
|
|
6
6
|
from typing import Dict, List, Tuple, Any, Union, Callable, Optional
|
|
7
7
|
|
|
8
|
-
from diffusers import DiffusionPipeline
|
|
8
|
+
from diffusers import DiffusionPipeline, ModelMixin
|
|
9
9
|
|
|
10
10
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
11
11
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
12
|
+
from cache_dit.cache_factory.block_adapters import FakeDiffusionPipeline
|
|
12
13
|
from cache_dit.cache_factory.block_adapters import ParamsModifier
|
|
13
14
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
14
15
|
from cache_dit.cache_factory.cache_contexts import ContextManager
|
|
@@ -32,6 +33,9 @@ class CachedAdapter:
|
|
|
32
33
|
pipe_or_adapter: Union[
|
|
33
34
|
DiffusionPipeline,
|
|
34
35
|
BlockAdapter,
|
|
36
|
+
# Transformer-only
|
|
37
|
+
torch.nn.Module,
|
|
38
|
+
ModelMixin,
|
|
35
39
|
],
|
|
36
40
|
**context_kwargs,
|
|
37
41
|
) -> Union[
|
|
@@ -42,7 +46,9 @@ class CachedAdapter:
|
|
|
42
46
|
pipe_or_adapter is not None
|
|
43
47
|
), "pipe or block_adapter can not both None!"
|
|
44
48
|
|
|
45
|
-
if isinstance(
|
|
49
|
+
if isinstance(
|
|
50
|
+
pipe_or_adapter, (DiffusionPipeline, torch.nn.Module, ModelMixin)
|
|
51
|
+
):
|
|
46
52
|
if BlockAdapterRegistry.is_supported(pipe_or_adapter):
|
|
47
53
|
logger.info(
|
|
48
54
|
f"{pipe_or_adapter.__class__.__name__} is officially "
|
|
@@ -62,10 +68,12 @@ class CachedAdapter:
|
|
|
62
68
|
):
|
|
63
69
|
block_adapter.params_modifiers = params_modifiers
|
|
64
70
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
71
|
+
block_adapter = cls.cachify(block_adapter, **context_kwargs)
|
|
72
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
73
|
+
return block_adapter.pipe
|
|
74
|
+
|
|
75
|
+
return block_adapter.transformer
|
|
76
|
+
|
|
69
77
|
else:
|
|
70
78
|
raise ValueError(
|
|
71
79
|
f"{pipe_or_adapter.__class__.__name__} is not officially supported "
|
|
@@ -182,8 +190,6 @@ class CachedAdapter:
|
|
|
182
190
|
context_kwargs = cls.check_context_kwargs(
|
|
183
191
|
block_adapter, **context_kwargs
|
|
184
192
|
)
|
|
185
|
-
# Apply cache on pipeline: wrap cache context
|
|
186
|
-
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
187
193
|
|
|
188
194
|
# Each Pipeline should have it's own context manager instance.
|
|
189
195
|
# Different transformers (Wan2.2, etc) should shared the same
|
|
@@ -193,38 +199,58 @@ class CachedAdapter:
|
|
|
193
199
|
"cache_config", None
|
|
194
200
|
)
|
|
195
201
|
assert cache_config is not None, "cache_config can not be None."
|
|
202
|
+
# Apply cache on pipeline: wrap cache context
|
|
203
|
+
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
196
204
|
context_manager = ContextManager(
|
|
197
205
|
name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
|
|
198
206
|
cache_type=cache_config.cache_type,
|
|
207
|
+
# Force use persistent_context for FakeDiffusionPipeline
|
|
208
|
+
persistent_context=isinstance(
|
|
209
|
+
block_adapter.pipe, FakeDiffusionPipeline
|
|
210
|
+
),
|
|
199
211
|
)
|
|
200
|
-
block_adapter.pipe._context_manager = context_manager # instance level
|
|
201
|
-
|
|
202
212
|
flatten_contexts, contexts_kwargs = cls.modify_context_params(
|
|
203
213
|
block_adapter, **context_kwargs
|
|
204
214
|
)
|
|
205
|
-
original_call = block_adapter.pipe.__class__.__call__
|
|
206
215
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
216
|
+
block_adapter.pipe._context_manager = context_manager # instance level
|
|
217
|
+
|
|
218
|
+
if not context_manager.persistent_context:
|
|
219
|
+
|
|
220
|
+
original_call = block_adapter.pipe.__class__.__call__
|
|
221
|
+
|
|
222
|
+
@functools.wraps(original_call)
|
|
223
|
+
def new_call(self, *args, **kwargs):
|
|
224
|
+
with ExitStack() as stack:
|
|
225
|
+
# cache context will be reset for each pipe inference
|
|
226
|
+
for context_name, context_kwargs in zip(
|
|
227
|
+
flatten_contexts, contexts_kwargs
|
|
228
|
+
):
|
|
229
|
+
stack.enter_context(
|
|
230
|
+
context_manager.enter_context(
|
|
231
|
+
context_manager.reset_context(
|
|
232
|
+
context_name,
|
|
233
|
+
**context_kwargs,
|
|
234
|
+
),
|
|
235
|
+
)
|
|
220
236
|
)
|
|
221
|
-
)
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
237
|
+
outputs = original_call(self, *args, **kwargs)
|
|
238
|
+
cls.apply_stats_hooks(block_adapter)
|
|
239
|
+
return outputs
|
|
240
|
+
|
|
241
|
+
block_adapter.pipe.__class__.__call__ = new_call
|
|
242
|
+
block_adapter.pipe.__class__._original_call = original_call
|
|
243
|
+
|
|
244
|
+
else:
|
|
245
|
+
# Init persistent cache context for transformer
|
|
246
|
+
for context_name, context_kwargs in zip(
|
|
247
|
+
flatten_contexts, contexts_kwargs
|
|
248
|
+
):
|
|
249
|
+
context_manager.reset_context(
|
|
250
|
+
context_name,
|
|
251
|
+
**context_kwargs,
|
|
252
|
+
)
|
|
225
253
|
|
|
226
|
-
block_adapter.pipe.__class__.__call__ = new_call
|
|
227
|
-
block_adapter.pipe.__class__._original_call = original_call
|
|
228
254
|
block_adapter.pipe.__class__._is_cached = True
|
|
229
255
|
|
|
230
256
|
cls.apply_params_hooks(block_adapter, contexts_kwargs)
|
|
@@ -353,6 +379,7 @@ class CachedAdapter:
|
|
|
353
379
|
blocks_name,
|
|
354
380
|
unique_blocks_name,
|
|
355
381
|
dummy_blocks_names,
|
|
382
|
+
block_adapter,
|
|
356
383
|
)
|
|
357
384
|
|
|
358
385
|
return block_adapter.transformer
|
|
@@ -365,6 +392,7 @@ class CachedAdapter:
|
|
|
365
392
|
blocks_name: List[str],
|
|
366
393
|
unique_blocks_name: List[str],
|
|
367
394
|
dummy_blocks_names: List[str],
|
|
395
|
+
block_adapter: BlockAdapter,
|
|
368
396
|
) -> torch.nn.Module:
|
|
369
397
|
dummy_blocks = torch.nn.ModuleList()
|
|
370
398
|
|
|
@@ -391,6 +419,8 @@ class CachedAdapter:
|
|
|
391
419
|
# re-apply hooks to transformer after cache applied.
|
|
392
420
|
# from diffusers.hooks.hooks import HookFunctionReference, HookRegistry
|
|
393
421
|
# from diffusers.hooks.group_offloading import apply_group_offloading
|
|
422
|
+
context_manager: ContextManager = block_adapter.pipe._context_manager
|
|
423
|
+
assert isinstance(context_manager, ContextManager._supported_managers)
|
|
394
424
|
|
|
395
425
|
def new_forward(self, *args, **kwargs):
|
|
396
426
|
with ExitStack() as stack:
|
|
@@ -410,6 +440,13 @@ class CachedAdapter:
|
|
|
410
440
|
)
|
|
411
441
|
)
|
|
412
442
|
outputs = original_forward(*args, **kwargs)
|
|
443
|
+
|
|
444
|
+
if (
|
|
445
|
+
context_manager.persistent_context
|
|
446
|
+
and context_manager.is_pre_refreshed()
|
|
447
|
+
):
|
|
448
|
+
cls.apply_stats_hooks(block_adapter)
|
|
449
|
+
|
|
413
450
|
return outputs
|
|
414
451
|
|
|
415
452
|
def new_forward_with_hf_hook(self, *args, **kwargs):
|
|
@@ -513,6 +550,7 @@ class CachedAdapter:
|
|
|
513
550
|
params_shift += len(blocks)
|
|
514
551
|
|
|
515
552
|
@classmethod
|
|
553
|
+
@torch.compiler.disable
|
|
516
554
|
def apply_stats_hooks(
|
|
517
555
|
cls,
|
|
518
556
|
block_adapter: BlockAdapter,
|