cache-dit 0.2.30__py3-none-any.whl → 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/__init__.py +14 -9
- cache_dit/cache_factory/block_adapters/block_adapters.py +19 -0
- cache_dit/cache_factory/cache_adapters.py +1 -0
- cache_dit/cache_factory/cache_blocks/__init__.py +3 -0
- cache_dit/cache_factory/cache_blocks/pattern_base.py +13 -0
- cache_dit/cache_factory/patch_functors/__init__.py +6 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -2
- cache_dit/cache_factory/patch_functors/functor_flux.py +3 -2
- cache_dit/cache_factory/patch_functors/functor_hidream.py +412 -0
- cache_dit/cache_factory/patch_functors/functor_hunyuan_dit.py +213 -0
- {cache_dit-0.2.30.dist-info → cache_dit-0.2.31.dist-info}/METADATA +15 -9
- {cache_dit-0.2.30.dist-info → cache_dit-0.2.31.dist-info}/RECORD +17 -15
- {cache_dit-0.2.30.dist-info → cache_dit-0.2.31.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.30.dist-info → cache_dit-0.2.31.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.30.dist-info → cache_dit-0.2.31.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.30.dist-info → cache_dit-0.2.31.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.31'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 31)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -501,7 +501,7 @@ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
501
501
|
)
|
|
502
502
|
|
|
503
503
|
|
|
504
|
-
@BlockAdapterRegistry.register("HiDream"
|
|
504
|
+
@BlockAdapterRegistry.register("HiDream")
|
|
505
505
|
def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
506
506
|
# NOTE: Need to patch Transformer forward to fully support
|
|
507
507
|
# double_stream_blocks and single_stream_blocks, namely, need
|
|
@@ -509,29 +509,32 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
509
509
|
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
|
|
510
510
|
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
|
|
511
511
|
from diffusers import HiDreamImageTransformer2DModel
|
|
512
|
+
from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
|
|
512
513
|
|
|
513
514
|
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
|
514
515
|
return BlockAdapter(
|
|
515
516
|
pipe=pipe,
|
|
516
517
|
transformer=pipe.transformer,
|
|
517
518
|
blocks=[
|
|
518
|
-
|
|
519
|
+
pipe.transformer.double_stream_blocks,
|
|
519
520
|
pipe.transformer.single_stream_blocks,
|
|
520
521
|
],
|
|
521
522
|
forward_pattern=[
|
|
522
|
-
|
|
523
|
+
ForwardPattern.Pattern_0,
|
|
523
524
|
ForwardPattern.Pattern_3,
|
|
524
525
|
],
|
|
525
|
-
|
|
526
|
-
|
|
526
|
+
patch_functor=HiDreamPatchFunctor(),
|
|
527
|
+
# NOTE: The type hint in diffusers is wrong
|
|
528
|
+
check_forward_pattern=True,
|
|
529
|
+
check_num_outputs=True,
|
|
527
530
|
**kwargs,
|
|
528
531
|
)
|
|
529
532
|
|
|
530
533
|
|
|
531
|
-
@BlockAdapterRegistry.register("HunyuanDiT"
|
|
534
|
+
@BlockAdapterRegistry.register("HunyuanDiT")
|
|
532
535
|
def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
533
|
-
# TODO: Patch Transformer forward
|
|
534
536
|
from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
|
|
537
|
+
from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
|
|
535
538
|
|
|
536
539
|
assert isinstance(
|
|
537
540
|
pipe.transformer,
|
|
@@ -542,14 +545,15 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
542
545
|
transformer=pipe.transformer,
|
|
543
546
|
blocks=pipe.transformer.blocks,
|
|
544
547
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
548
|
+
patch_functor=HunyuanDiTPatchFunctor(),
|
|
545
549
|
**kwargs,
|
|
546
550
|
)
|
|
547
551
|
|
|
548
552
|
|
|
549
|
-
@BlockAdapterRegistry.register("HunyuanDiTPAG"
|
|
553
|
+
@BlockAdapterRegistry.register("HunyuanDiTPAG")
|
|
550
554
|
def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
551
|
-
# TODO: Patch Transformer forward
|
|
552
555
|
from diffusers import HunyuanDiT2DModel
|
|
556
|
+
from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
|
|
553
557
|
|
|
554
558
|
assert isinstance(pipe.transformer, HunyuanDiT2DModel)
|
|
555
559
|
return BlockAdapter(
|
|
@@ -557,5 +561,6 @@ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
557
561
|
transformer=pipe.transformer,
|
|
558
562
|
blocks=pipe.transformer.blocks,
|
|
559
563
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
564
|
+
patch_functor=HunyuanDiTPatchFunctor(),
|
|
560
565
|
**kwargs,
|
|
561
566
|
)
|
|
@@ -75,6 +75,7 @@ class BlockAdapter:
|
|
|
75
75
|
List[List[ParamsModifier]],
|
|
76
76
|
] = None
|
|
77
77
|
|
|
78
|
+
check_forward_pattern: bool = True
|
|
78
79
|
check_num_outputs: bool = False
|
|
79
80
|
|
|
80
81
|
# Pipeline Level Flags
|
|
@@ -391,11 +392,20 @@ class BlockAdapter:
|
|
|
391
392
|
forward_pattern: ForwardPattern,
|
|
392
393
|
**kwargs,
|
|
393
394
|
) -> bool:
|
|
395
|
+
|
|
396
|
+
if not kwargs.get("check_forward_pattern", True):
|
|
397
|
+
return True
|
|
398
|
+
|
|
394
399
|
assert (
|
|
395
400
|
forward_pattern.Supported
|
|
396
401
|
and forward_pattern in ForwardPattern.supported_patterns()
|
|
397
402
|
), f"Pattern {forward_pattern} is not support now!"
|
|
398
403
|
|
|
404
|
+
# NOTE: Special case for HiDreamBlock
|
|
405
|
+
if hasattr(block, "block"):
|
|
406
|
+
if isinstance(block.block, torch.nn.Module):
|
|
407
|
+
block = block.block
|
|
408
|
+
|
|
399
409
|
forward_parameters = set(
|
|
400
410
|
inspect.signature(block.forward).parameters.keys()
|
|
401
411
|
)
|
|
@@ -425,6 +435,14 @@ class BlockAdapter:
|
|
|
425
435
|
logging: bool = True,
|
|
426
436
|
**kwargs,
|
|
427
437
|
) -> bool:
|
|
438
|
+
|
|
439
|
+
if not kwargs.get("check_forward_pattern", True):
|
|
440
|
+
if logging:
|
|
441
|
+
logger.warning(
|
|
442
|
+
f"Skipped Forward Pattern Check: {forward_pattern}"
|
|
443
|
+
)
|
|
444
|
+
return True
|
|
445
|
+
|
|
428
446
|
assert (
|
|
429
447
|
forward_pattern.Supported
|
|
430
448
|
and forward_pattern in ForwardPattern.supported_patterns()
|
|
@@ -531,6 +549,7 @@ class BlockAdapter:
|
|
|
531
549
|
blocks,
|
|
532
550
|
forward_pattern=forward_pattern,
|
|
533
551
|
check_num_outputs=adapter.check_num_outputs,
|
|
552
|
+
check_forward_pattern=adapter.check_forward_pattern,
|
|
534
553
|
), (
|
|
535
554
|
"No block forward pattern matched, "
|
|
536
555
|
f"supported lists: {ForwardPattern.supported_patterns()}"
|
|
@@ -345,6 +345,7 @@ class CachedAdapter:
|
|
|
345
345
|
block_adapter.blocks[i][j],
|
|
346
346
|
transformer=block_adapter.transformer[i],
|
|
347
347
|
forward_pattern=block_adapter.forward_pattern[i][j],
|
|
348
|
+
check_forward_pattern=block_adapter.check_forward_pattern,
|
|
348
349
|
check_num_outputs=block_adapter.check_num_outputs,
|
|
349
350
|
# 1. Cache context configuration
|
|
350
351
|
cache_prefix=block_adapter.blocks_name[i][j],
|
|
@@ -25,6 +25,7 @@ class CachedBlocks:
|
|
|
25
25
|
transformer_blocks: torch.nn.ModuleList,
|
|
26
26
|
transformer: torch.nn.Module = None,
|
|
27
27
|
forward_pattern: ForwardPattern = None,
|
|
28
|
+
check_forward_pattern: bool = True,
|
|
28
29
|
check_num_outputs: bool = True,
|
|
29
30
|
# 1. Cache context configuration
|
|
30
31
|
# 'transformer_blocks', 'blocks', 'single_transformer_blocks',
|
|
@@ -45,6 +46,7 @@ class CachedBlocks:
|
|
|
45
46
|
transformer_blocks,
|
|
46
47
|
transformer=transformer,
|
|
47
48
|
forward_pattern=forward_pattern,
|
|
49
|
+
check_forward_pattern=check_forward_pattern,
|
|
48
50
|
check_num_outputs=check_num_outputs,
|
|
49
51
|
# 1. Cache context configuration
|
|
50
52
|
cache_prefix=cache_prefix,
|
|
@@ -58,6 +60,7 @@ class CachedBlocks:
|
|
|
58
60
|
transformer_blocks,
|
|
59
61
|
transformer=transformer,
|
|
60
62
|
forward_pattern=forward_pattern,
|
|
63
|
+
check_forward_pattern=check_forward_pattern,
|
|
61
64
|
check_num_outputs=check_num_outputs,
|
|
62
65
|
# 1. Cache context configuration
|
|
63
66
|
cache_prefix=cache_prefix,
|
|
@@ -25,6 +25,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
25
25
|
transformer_blocks: torch.nn.ModuleList,
|
|
26
26
|
transformer: torch.nn.Module = None,
|
|
27
27
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
28
|
+
check_forward_pattern: bool = True,
|
|
28
29
|
check_num_outputs: bool = True,
|
|
29
30
|
# 1. Cache context configuration
|
|
30
31
|
cache_prefix: str = None, # maybe un-need.
|
|
@@ -38,6 +39,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
38
39
|
self.transformer = transformer
|
|
39
40
|
self.transformer_blocks = transformer_blocks
|
|
40
41
|
self.forward_pattern = forward_pattern
|
|
42
|
+
self.check_forward_pattern = check_forward_pattern
|
|
41
43
|
self.check_num_outputs = check_num_outputs
|
|
42
44
|
# 1. Cache context configuration
|
|
43
45
|
self.cache_prefix = cache_prefix
|
|
@@ -52,6 +54,12 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
52
54
|
)
|
|
53
55
|
|
|
54
56
|
def _check_forward_pattern(self):
|
|
57
|
+
if not self.check_forward_pattern:
|
|
58
|
+
logger.warning(
|
|
59
|
+
f"Skipped Forward Pattern Check: {self.forward_pattern}"
|
|
60
|
+
)
|
|
61
|
+
return
|
|
62
|
+
|
|
55
63
|
assert (
|
|
56
64
|
self.forward_pattern.Supported
|
|
57
65
|
and self.forward_pattern in self._supported_patterns
|
|
@@ -59,6 +67,11 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
59
67
|
|
|
60
68
|
if self.transformer_blocks is not None:
|
|
61
69
|
for block in self.transformer_blocks:
|
|
70
|
+
# Special case for HiDreamBlock
|
|
71
|
+
if hasattr(block, "block"):
|
|
72
|
+
if isinstance(block.block, torch.nn.Module):
|
|
73
|
+
block = block.block
|
|
74
|
+
|
|
62
75
|
forward_parameters = set(
|
|
63
76
|
inspect.signature(block.forward).parameters.keys()
|
|
64
77
|
)
|
|
@@ -3,3 +3,9 @@ from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
|
|
|
3
3
|
from cache_dit.cache_factory.patch_functors.functor_chroma import (
|
|
4
4
|
ChromaPatchFunctor,
|
|
5
5
|
)
|
|
6
|
+
from cache_dit.cache_factory.patch_functors.functor_hidream import (
|
|
7
|
+
HiDreamPatchFunctor,
|
|
8
|
+
)
|
|
9
|
+
from cache_dit.cache_factory.patch_functors.functor_hunyuan_dit import (
|
|
10
|
+
HunyuanDiTPatchFunctor,
|
|
11
|
+
)
|
|
@@ -46,8 +46,10 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
46
46
|
block.forward = __patch_single_forward__.__get__(block)
|
|
47
47
|
is_patched = True
|
|
48
48
|
|
|
49
|
+
cls_name = transformer.__class__.__name__
|
|
50
|
+
|
|
49
51
|
if is_patched:
|
|
50
|
-
logger.warning("Patched
|
|
52
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
51
53
|
assert not getattr(transformer, "_is_parallelized", False), (
|
|
52
54
|
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
53
55
|
"the __patch_transformer_forward__ will overwrite the "
|
|
@@ -59,7 +61,6 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
59
61
|
|
|
60
62
|
transformer._is_patched = is_patched # True or False
|
|
61
63
|
|
|
62
|
-
cls_name = transformer.__class__.__name__
|
|
63
64
|
logger.info(
|
|
64
65
|
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
65
66
|
f"Patch: {is_patched}."
|
|
@@ -47,8 +47,10 @@ class FluxPatchFunctor(PatchFunctor):
|
|
|
47
47
|
block.forward = __patch_single_forward__.__get__(block)
|
|
48
48
|
is_patched = True
|
|
49
49
|
|
|
50
|
+
cls_name = transformer.__class__.__name__
|
|
51
|
+
|
|
50
52
|
if is_patched:
|
|
51
|
-
logger.warning("Patched
|
|
53
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
52
54
|
assert not getattr(transformer, "_is_parallelized", False), (
|
|
53
55
|
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
54
56
|
"the __patch_transformer_forward__ will overwrite the "
|
|
@@ -60,7 +62,6 @@ class FluxPatchFunctor(PatchFunctor):
|
|
|
60
62
|
|
|
61
63
|
transformer._is_patched = is_patched # True or False
|
|
62
64
|
|
|
63
|
-
cls_name = transformer.__class__.__name__
|
|
64
65
|
logger.info(
|
|
65
66
|
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
66
67
|
f"Patch: {is_patched}."
|
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Tuple, Optional, Dict, Any, Union, List
|
|
3
|
+
from diffusers import HiDreamImageTransformer2DModel
|
|
4
|
+
from diffusers.models.transformers.transformer_hidream_image import (
|
|
5
|
+
HiDreamBlock,
|
|
6
|
+
HiDreamImageTransformerBlock,
|
|
7
|
+
HiDreamImageSingleTransformerBlock,
|
|
8
|
+
Transformer2DModelOutput,
|
|
9
|
+
)
|
|
10
|
+
from diffusers.utils import (
|
|
11
|
+
deprecate,
|
|
12
|
+
USE_PEFT_BACKEND,
|
|
13
|
+
scale_lora_layers,
|
|
14
|
+
unscale_lora_layers,
|
|
15
|
+
)
|
|
16
|
+
from cache_dit.cache_factory.patch_functors.functor_base import (
|
|
17
|
+
PatchFunctor,
|
|
18
|
+
)
|
|
19
|
+
from cache_dit.logger import init_logger
|
|
20
|
+
|
|
21
|
+
logger = init_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class HiDreamPatchFunctor(PatchFunctor):
|
|
25
|
+
|
|
26
|
+
def apply(
|
|
27
|
+
self,
|
|
28
|
+
transformer: HiDreamImageTransformer2DModel,
|
|
29
|
+
**kwargs,
|
|
30
|
+
) -> HiDreamImageTransformer2DModel:
|
|
31
|
+
if hasattr(transformer, "_is_patched"):
|
|
32
|
+
return transformer
|
|
33
|
+
|
|
34
|
+
is_patched = False
|
|
35
|
+
|
|
36
|
+
_block_id = 0
|
|
37
|
+
for block in transformer.double_stream_blocks:
|
|
38
|
+
assert isinstance(block, HiDreamBlock)
|
|
39
|
+
block.forward = __patch_block_forward__.__get__(block)
|
|
40
|
+
# NOTE: Patch Inner block and block_id
|
|
41
|
+
_block = block.block
|
|
42
|
+
assert isinstance(_block, HiDreamImageTransformerBlock)
|
|
43
|
+
_block._block_id = _block_id
|
|
44
|
+
_block.forward = __patch_double_forward__.__get__(_block)
|
|
45
|
+
_block_id += 1
|
|
46
|
+
|
|
47
|
+
for block in transformer.single_stream_blocks:
|
|
48
|
+
assert isinstance(block, HiDreamBlock)
|
|
49
|
+
block.forward = __patch_block_forward__.__get__(block)
|
|
50
|
+
# NOTE: Patch Inner block and block_id
|
|
51
|
+
_block = block.block
|
|
52
|
+
assert isinstance(_block, HiDreamImageSingleTransformerBlock)
|
|
53
|
+
_block._block_id = _block_id
|
|
54
|
+
_block.forward = __patch_single_forward__.__get__(_block)
|
|
55
|
+
_block_id += 1
|
|
56
|
+
|
|
57
|
+
is_patched = True
|
|
58
|
+
cls_name = transformer.__class__.__name__
|
|
59
|
+
|
|
60
|
+
if is_patched:
|
|
61
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
62
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
63
|
+
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
64
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
65
|
+
"parallized forward and cause a downgrade of performance."
|
|
66
|
+
)
|
|
67
|
+
transformer.forward = __patch_transformer_forward__.__get__(
|
|
68
|
+
transformer
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
transformer._is_patched = is_patched # True or False
|
|
72
|
+
|
|
73
|
+
logger.info(
|
|
74
|
+
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
75
|
+
f"Patch: {is_patched}."
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return transformer
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
|
|
82
|
+
def __patch_double_forward__(
|
|
83
|
+
self: HiDreamImageTransformerBlock,
|
|
84
|
+
hidden_states: torch.Tensor,
|
|
85
|
+
encoder_hidden_states: torch.Tensor, # initial_encoder_hidden_states
|
|
86
|
+
hidden_states_masks: Optional[torch.Tensor] = None,
|
|
87
|
+
temb: Optional[torch.Tensor] = None,
|
|
88
|
+
image_rotary_emb: torch.Tensor = None,
|
|
89
|
+
llama31_encoder_hidden_states: List[torch.Tensor] = None,
|
|
90
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
91
|
+
# Assume block_id was patched in transformer forward:
|
|
92
|
+
# for i, block in enumerate(blocks): block._block_id = i;
|
|
93
|
+
block_id = self._block_id
|
|
94
|
+
initial_encoder_hidden_states_seq_len = encoder_hidden_states.shape[1]
|
|
95
|
+
cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[block_id]
|
|
96
|
+
cur_encoder_hidden_states = torch.cat(
|
|
97
|
+
[encoder_hidden_states, cur_llama31_encoder_hidden_states],
|
|
98
|
+
dim=1,
|
|
99
|
+
)
|
|
100
|
+
encoder_hidden_states = cur_encoder_hidden_states
|
|
101
|
+
|
|
102
|
+
wtype = hidden_states.dtype
|
|
103
|
+
(
|
|
104
|
+
shift_msa_i,
|
|
105
|
+
scale_msa_i,
|
|
106
|
+
gate_msa_i,
|
|
107
|
+
shift_mlp_i,
|
|
108
|
+
scale_mlp_i,
|
|
109
|
+
gate_mlp_i,
|
|
110
|
+
shift_msa_t,
|
|
111
|
+
scale_msa_t,
|
|
112
|
+
gate_msa_t,
|
|
113
|
+
shift_mlp_t,
|
|
114
|
+
scale_mlp_t,
|
|
115
|
+
gate_mlp_t,
|
|
116
|
+
) = self.adaLN_modulation(temb)[:, None].chunk(12, dim=-1)
|
|
117
|
+
|
|
118
|
+
# 1. MM-Attention
|
|
119
|
+
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
|
|
120
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
|
|
121
|
+
norm_encoder_hidden_states = self.norm1_t(encoder_hidden_states).to(
|
|
122
|
+
dtype=wtype
|
|
123
|
+
)
|
|
124
|
+
norm_encoder_hidden_states = (
|
|
125
|
+
norm_encoder_hidden_states * (1 + scale_msa_t) + shift_msa_t
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
attn_output_i, attn_output_t = self.attn1(
|
|
129
|
+
norm_hidden_states,
|
|
130
|
+
hidden_states_masks,
|
|
131
|
+
norm_encoder_hidden_states,
|
|
132
|
+
image_rotary_emb=image_rotary_emb,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
hidden_states = gate_msa_i * attn_output_i + hidden_states
|
|
136
|
+
encoder_hidden_states = gate_msa_t * attn_output_t + encoder_hidden_states
|
|
137
|
+
|
|
138
|
+
# 2. Feed-forward
|
|
139
|
+
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
|
|
140
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
|
|
141
|
+
norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(
|
|
142
|
+
dtype=wtype
|
|
143
|
+
)
|
|
144
|
+
norm_encoder_hidden_states = (
|
|
145
|
+
norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
|
|
149
|
+
ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
|
|
150
|
+
hidden_states = ff_output_i + hidden_states
|
|
151
|
+
encoder_hidden_states = ff_output_t + encoder_hidden_states
|
|
152
|
+
|
|
153
|
+
initial_encoder_hidden_states = encoder_hidden_states[
|
|
154
|
+
:, :initial_encoder_hidden_states_seq_len
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
return hidden_states, initial_encoder_hidden_states
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
|
|
161
|
+
def __patch_single_forward__(
|
|
162
|
+
self: HiDreamImageSingleTransformerBlock,
|
|
163
|
+
hidden_states: torch.Tensor,
|
|
164
|
+
hidden_states_masks: Optional[torch.Tensor] = None,
|
|
165
|
+
temb: Optional[torch.Tensor] = None,
|
|
166
|
+
image_rotary_emb: torch.Tensor = None,
|
|
167
|
+
llama31_encoder_hidden_states: List[torch.Tensor] = None,
|
|
168
|
+
) -> torch.Tensor:
|
|
169
|
+
# Assume block_id was patched in transformer forward:
|
|
170
|
+
# for i, block in enumerate(blocks): block._block_id = i;
|
|
171
|
+
block_id = self._block_id
|
|
172
|
+
hidden_states_seq_len = hidden_states.shape[1]
|
|
173
|
+
cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[block_id]
|
|
174
|
+
hidden_states = torch.cat(
|
|
175
|
+
[hidden_states, cur_llama31_encoder_hidden_states], dim=1
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
wtype = hidden_states.dtype
|
|
179
|
+
(
|
|
180
|
+
shift_msa_i,
|
|
181
|
+
scale_msa_i,
|
|
182
|
+
gate_msa_i,
|
|
183
|
+
shift_mlp_i,
|
|
184
|
+
scale_mlp_i,
|
|
185
|
+
gate_mlp_i,
|
|
186
|
+
) = self.adaLN_modulation(temb)[:, None].chunk(6, dim=-1)
|
|
187
|
+
|
|
188
|
+
# 1. MM-Attention
|
|
189
|
+
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
|
|
190
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
|
|
191
|
+
attn_output_i = self.attn1(
|
|
192
|
+
norm_hidden_states,
|
|
193
|
+
hidden_states_masks,
|
|
194
|
+
image_rotary_emb=image_rotary_emb,
|
|
195
|
+
)
|
|
196
|
+
hidden_states = gate_msa_i * attn_output_i + hidden_states
|
|
197
|
+
|
|
198
|
+
# 2. Feed-forward
|
|
199
|
+
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
|
|
200
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
|
|
201
|
+
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
|
|
202
|
+
hidden_states = ff_output_i + hidden_states
|
|
203
|
+
|
|
204
|
+
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
|
205
|
+
|
|
206
|
+
return hidden_states
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
|
|
210
|
+
def __patch_block_forward__(
|
|
211
|
+
self: HiDreamBlock,
|
|
212
|
+
hidden_states: torch.Tensor,
|
|
213
|
+
*args,
|
|
214
|
+
**kwargs,
|
|
215
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
216
|
+
return self.block(hidden_states, *args, **kwargs)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
|
|
220
|
+
def __patch_transformer_forward__(
|
|
221
|
+
self: HiDreamImageTransformer2DModel,
|
|
222
|
+
hidden_states: torch.Tensor,
|
|
223
|
+
timesteps: torch.LongTensor = None,
|
|
224
|
+
encoder_hidden_states_t5: torch.Tensor = None,
|
|
225
|
+
encoder_hidden_states_llama3: torch.Tensor = None,
|
|
226
|
+
pooled_embeds: torch.Tensor = None,
|
|
227
|
+
img_ids: Optional[torch.Tensor] = None,
|
|
228
|
+
img_sizes: Optional[List[Tuple[int, int]]] = None,
|
|
229
|
+
hidden_states_masks: Optional[torch.Tensor] = None,
|
|
230
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
231
|
+
return_dict: bool = True,
|
|
232
|
+
**kwargs,
|
|
233
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
|
234
|
+
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
|
235
|
+
|
|
236
|
+
if encoder_hidden_states is not None:
|
|
237
|
+
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
|
|
238
|
+
deprecate("encoder_hidden_states", "0.35.0", deprecation_message)
|
|
239
|
+
encoder_hidden_states_t5 = encoder_hidden_states[0]
|
|
240
|
+
encoder_hidden_states_llama3 = encoder_hidden_states[1]
|
|
241
|
+
|
|
242
|
+
if (
|
|
243
|
+
img_ids is not None
|
|
244
|
+
and img_sizes is not None
|
|
245
|
+
and hidden_states_masks is None
|
|
246
|
+
):
|
|
247
|
+
deprecation_message = "Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
|
|
248
|
+
deprecate("img_ids", "0.35.0", deprecation_message)
|
|
249
|
+
|
|
250
|
+
if hidden_states_masks is not None and (
|
|
251
|
+
img_ids is None or img_sizes is None
|
|
252
|
+
):
|
|
253
|
+
raise ValueError(
|
|
254
|
+
"if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed."
|
|
255
|
+
)
|
|
256
|
+
elif hidden_states_masks is not None and hidden_states.ndim != 3:
|
|
257
|
+
raise ValueError(
|
|
258
|
+
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if attention_kwargs is not None:
|
|
262
|
+
attention_kwargs = attention_kwargs.copy()
|
|
263
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
|
264
|
+
else:
|
|
265
|
+
lora_scale = 1.0
|
|
266
|
+
|
|
267
|
+
if USE_PEFT_BACKEND:
|
|
268
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
269
|
+
scale_lora_layers(self, lora_scale)
|
|
270
|
+
else:
|
|
271
|
+
if (
|
|
272
|
+
attention_kwargs is not None
|
|
273
|
+
and attention_kwargs.get("scale", None) is not None
|
|
274
|
+
):
|
|
275
|
+
logger.warning(
|
|
276
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# spatial forward
|
|
280
|
+
batch_size = hidden_states.shape[0]
|
|
281
|
+
hidden_states_type = hidden_states.dtype
|
|
282
|
+
|
|
283
|
+
# Patchify the input
|
|
284
|
+
if hidden_states_masks is None:
|
|
285
|
+
hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(
|
|
286
|
+
hidden_states
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Embed the hidden states
|
|
290
|
+
hidden_states = self.x_embedder(hidden_states)
|
|
291
|
+
|
|
292
|
+
# 0. time
|
|
293
|
+
timesteps = self.t_embedder(timesteps, hidden_states_type)
|
|
294
|
+
p_embedder = self.p_embedder(pooled_embeds)
|
|
295
|
+
temb = timesteps + p_embedder
|
|
296
|
+
|
|
297
|
+
encoder_hidden_states = [
|
|
298
|
+
encoder_hidden_states_llama3[k] for k in self.config.llama_layers
|
|
299
|
+
]
|
|
300
|
+
|
|
301
|
+
if self.caption_projection is not None:
|
|
302
|
+
new_encoder_hidden_states = []
|
|
303
|
+
for i, enc_hidden_state in enumerate(encoder_hidden_states):
|
|
304
|
+
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
|
|
305
|
+
enc_hidden_state = enc_hidden_state.view(
|
|
306
|
+
batch_size, -1, hidden_states.shape[-1]
|
|
307
|
+
)
|
|
308
|
+
new_encoder_hidden_states.append(enc_hidden_state)
|
|
309
|
+
encoder_hidden_states = new_encoder_hidden_states
|
|
310
|
+
encoder_hidden_states_t5 = self.caption_projection[-1](
|
|
311
|
+
encoder_hidden_states_t5
|
|
312
|
+
)
|
|
313
|
+
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(
|
|
314
|
+
batch_size, -1, hidden_states.shape[-1]
|
|
315
|
+
)
|
|
316
|
+
encoder_hidden_states.append(encoder_hidden_states_t5)
|
|
317
|
+
|
|
318
|
+
txt_ids = torch.zeros(
|
|
319
|
+
batch_size,
|
|
320
|
+
encoder_hidden_states[-1].shape[1]
|
|
321
|
+
+ encoder_hidden_states[-2].shape[1]
|
|
322
|
+
+ encoder_hidden_states[0].shape[1],
|
|
323
|
+
3,
|
|
324
|
+
device=img_ids.device,
|
|
325
|
+
dtype=img_ids.dtype,
|
|
326
|
+
)
|
|
327
|
+
ids = torch.cat((img_ids, txt_ids), dim=1)
|
|
328
|
+
image_rotary_emb = self.pe_embedder(ids)
|
|
329
|
+
|
|
330
|
+
# 2. Blocks
|
|
331
|
+
# NOTE: block_id is no-need anymore.
|
|
332
|
+
initial_encoder_hidden_states = torch.cat(
|
|
333
|
+
[encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1
|
|
334
|
+
)
|
|
335
|
+
llama31_encoder_hidden_states = encoder_hidden_states
|
|
336
|
+
for bid, block in enumerate(self.double_stream_blocks):
|
|
337
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
338
|
+
hidden_states, initial_encoder_hidden_states = (
|
|
339
|
+
self._gradient_checkpointing_func(
|
|
340
|
+
block,
|
|
341
|
+
hidden_states,
|
|
342
|
+
initial_encoder_hidden_states,
|
|
343
|
+
hidden_states_masks,
|
|
344
|
+
temb,
|
|
345
|
+
image_rotary_emb,
|
|
346
|
+
llama31_encoder_hidden_states,
|
|
347
|
+
)
|
|
348
|
+
)
|
|
349
|
+
else:
|
|
350
|
+
hidden_states, initial_encoder_hidden_states = block(
|
|
351
|
+
hidden_states,
|
|
352
|
+
initial_encoder_hidden_states, # encoder_hidden_states
|
|
353
|
+
hidden_states_masks=hidden_states_masks,
|
|
354
|
+
temb=temb,
|
|
355
|
+
image_rotary_emb=image_rotary_emb,
|
|
356
|
+
llama31_encoder_hidden_states=llama31_encoder_hidden_states,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
image_tokens_seq_len = hidden_states.shape[1]
|
|
360
|
+
hidden_states = torch.cat(
|
|
361
|
+
[hidden_states, initial_encoder_hidden_states], dim=1
|
|
362
|
+
)
|
|
363
|
+
if hidden_states_masks is not None:
|
|
364
|
+
# NOTE: Patched
|
|
365
|
+
cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[
|
|
366
|
+
self.double_stream_blocks[-1].block._block_id
|
|
367
|
+
]
|
|
368
|
+
encoder_attention_mask_ones = torch.ones(
|
|
369
|
+
(
|
|
370
|
+
batch_size,
|
|
371
|
+
initial_encoder_hidden_states.shape[1]
|
|
372
|
+
+ cur_llama31_encoder_hidden_states.shape[1],
|
|
373
|
+
),
|
|
374
|
+
device=hidden_states_masks.device,
|
|
375
|
+
dtype=hidden_states_masks.dtype,
|
|
376
|
+
)
|
|
377
|
+
hidden_states_masks = torch.cat(
|
|
378
|
+
[hidden_states_masks, encoder_attention_mask_ones], dim=1
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
for bid, block in enumerate(self.single_stream_blocks):
|
|
382
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
383
|
+
hidden_states = self._gradient_checkpointing_func(
|
|
384
|
+
block,
|
|
385
|
+
hidden_states,
|
|
386
|
+
hidden_states_masks,
|
|
387
|
+
temb,
|
|
388
|
+
image_rotary_emb,
|
|
389
|
+
llama31_encoder_hidden_states,
|
|
390
|
+
)
|
|
391
|
+
else:
|
|
392
|
+
hidden_states = block(
|
|
393
|
+
hidden_states,
|
|
394
|
+
hidden_states_masks=hidden_states_masks,
|
|
395
|
+
temb=temb,
|
|
396
|
+
image_rotary_emb=image_rotary_emb,
|
|
397
|
+
llama31_encoder_hidden_states=llama31_encoder_hidden_states,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
|
401
|
+
output = self.final_layer(hidden_states, temb)
|
|
402
|
+
output = self.unpatchify(output, img_sizes, self.training)
|
|
403
|
+
if hidden_states_masks is not None:
|
|
404
|
+
hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
|
|
405
|
+
|
|
406
|
+
if USE_PEFT_BACKEND:
|
|
407
|
+
# remove `lora_scale` from each PEFT layer
|
|
408
|
+
unscale_lora_layers(self, lora_scale)
|
|
409
|
+
|
|
410
|
+
if not return_dict:
|
|
411
|
+
return (output,)
|
|
412
|
+
return Transformer2DModelOutput(sample=output)
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional, Union, List
|
|
3
|
+
from diffusers import HunyuanDiT2DModel
|
|
4
|
+
from diffusers.models.transformers.hunyuan_transformer_2d import (
|
|
5
|
+
HunyuanDiTBlock,
|
|
6
|
+
Transformer2DModelOutput,
|
|
7
|
+
)
|
|
8
|
+
from cache_dit.cache_factory.patch_functors.functor_base import (
|
|
9
|
+
PatchFunctor,
|
|
10
|
+
)
|
|
11
|
+
from cache_dit.logger import init_logger
|
|
12
|
+
|
|
13
|
+
logger = init_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HunyuanDiTPatchFunctor(PatchFunctor):
|
|
17
|
+
|
|
18
|
+
def apply(
|
|
19
|
+
self,
|
|
20
|
+
transformer: HunyuanDiT2DModel,
|
|
21
|
+
**kwargs,
|
|
22
|
+
) -> HunyuanDiT2DModel:
|
|
23
|
+
if hasattr(transformer, "_is_patched"):
|
|
24
|
+
return transformer
|
|
25
|
+
|
|
26
|
+
is_patched = False
|
|
27
|
+
|
|
28
|
+
num_layers = transformer.config.num_layers
|
|
29
|
+
layer_id = 0
|
|
30
|
+
for block in transformer.blocks:
|
|
31
|
+
assert isinstance(block, HunyuanDiTBlock)
|
|
32
|
+
block._num_layers = num_layers
|
|
33
|
+
block._layer_id = layer_id
|
|
34
|
+
block.forward = __patch_block_forward__.__get__(block)
|
|
35
|
+
layer_id += 1
|
|
36
|
+
|
|
37
|
+
is_patched = True
|
|
38
|
+
|
|
39
|
+
cls_name = transformer.__class__.__name__
|
|
40
|
+
|
|
41
|
+
if is_patched:
|
|
42
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
43
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
44
|
+
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
45
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
46
|
+
"parallized forward and cause a downgrade of performance."
|
|
47
|
+
)
|
|
48
|
+
transformer.forward = __patch_transformer_forward__.__get__(
|
|
49
|
+
transformer
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
transformer._is_patched = is_patched # True or False
|
|
53
|
+
|
|
54
|
+
logger.info(
|
|
55
|
+
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
56
|
+
f"Patch: {is_patched}."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
return transformer
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def __patch_block_forward__(
|
|
63
|
+
self: HunyuanDiTBlock,
|
|
64
|
+
hidden_states: torch.Tensor,
|
|
65
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
66
|
+
temb: Optional[torch.Tensor] = None,
|
|
67
|
+
image_rotary_emb: torch.Tensor = None,
|
|
68
|
+
controlnet_block_samples: torch.Tensor = None,
|
|
69
|
+
skips: List[torch.Tensor] = [],
|
|
70
|
+
) -> torch.Tensor:
|
|
71
|
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
|
72
|
+
# 0. Long Skip Connection
|
|
73
|
+
num_layers = self._num_layers
|
|
74
|
+
layer_id = self._layer_id
|
|
75
|
+
|
|
76
|
+
if layer_id > num_layers // 2:
|
|
77
|
+
if controlnet_block_samples is not None:
|
|
78
|
+
skip = skips.pop() + controlnet_block_samples.pop()
|
|
79
|
+
else:
|
|
80
|
+
skip = skips.pop()
|
|
81
|
+
else:
|
|
82
|
+
skip = None
|
|
83
|
+
|
|
84
|
+
if self.skip_linear is not None:
|
|
85
|
+
cat = torch.cat([hidden_states, skip], dim=-1)
|
|
86
|
+
cat = self.skip_norm(cat)
|
|
87
|
+
hidden_states = self.skip_linear(cat)
|
|
88
|
+
|
|
89
|
+
# 1. Self-Attention
|
|
90
|
+
norm_hidden_states = self.norm1(
|
|
91
|
+
hidden_states, temb
|
|
92
|
+
) # checked: self.norm1 is correct
|
|
93
|
+
attn_output = self.attn1(
|
|
94
|
+
norm_hidden_states,
|
|
95
|
+
image_rotary_emb=image_rotary_emb,
|
|
96
|
+
)
|
|
97
|
+
hidden_states = hidden_states + attn_output
|
|
98
|
+
|
|
99
|
+
# 2. Cross-Attention
|
|
100
|
+
hidden_states = hidden_states + self.attn2(
|
|
101
|
+
self.norm2(hidden_states),
|
|
102
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
103
|
+
image_rotary_emb=image_rotary_emb,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# FFN Layer
|
|
107
|
+
mlp_inputs = self.norm3(hidden_states)
|
|
108
|
+
hidden_states = hidden_states + self.ff(mlp_inputs)
|
|
109
|
+
|
|
110
|
+
if layer_id < (num_layers // 2 - 1):
|
|
111
|
+
skips.append(hidden_states)
|
|
112
|
+
|
|
113
|
+
return hidden_states
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def __patch_transformer_forward__(
|
|
117
|
+
self: HunyuanDiT2DModel,
|
|
118
|
+
hidden_states,
|
|
119
|
+
timestep,
|
|
120
|
+
encoder_hidden_states=None,
|
|
121
|
+
text_embedding_mask=None,
|
|
122
|
+
encoder_hidden_states_t5=None,
|
|
123
|
+
text_embedding_mask_t5=None,
|
|
124
|
+
image_meta_size=None,
|
|
125
|
+
style=None,
|
|
126
|
+
image_rotary_emb=None,
|
|
127
|
+
controlnet_block_samples=None,
|
|
128
|
+
return_dict=True,
|
|
129
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
|
130
|
+
height, width = hidden_states.shape[-2:]
|
|
131
|
+
|
|
132
|
+
hidden_states = self.pos_embed(hidden_states)
|
|
133
|
+
|
|
134
|
+
temb = self.time_extra_emb(
|
|
135
|
+
timestep,
|
|
136
|
+
encoder_hidden_states_t5,
|
|
137
|
+
image_meta_size,
|
|
138
|
+
style,
|
|
139
|
+
hidden_dtype=timestep.dtype,
|
|
140
|
+
) # [B, D]
|
|
141
|
+
|
|
142
|
+
# text projection
|
|
143
|
+
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
|
|
144
|
+
encoder_hidden_states_t5 = self.text_embedder(
|
|
145
|
+
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
|
|
146
|
+
)
|
|
147
|
+
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(
|
|
148
|
+
batch_size, sequence_length, -1
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
encoder_hidden_states = torch.cat(
|
|
152
|
+
[encoder_hidden_states, encoder_hidden_states_t5], dim=1
|
|
153
|
+
)
|
|
154
|
+
text_embedding_mask = torch.cat(
|
|
155
|
+
[text_embedding_mask, text_embedding_mask_t5], dim=-1
|
|
156
|
+
)
|
|
157
|
+
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
|
|
158
|
+
|
|
159
|
+
encoder_hidden_states = torch.where(
|
|
160
|
+
text_embedding_mask, encoder_hidden_states, self.text_embedding_padding
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
skips = []
|
|
164
|
+
for layer, block in enumerate(self.blocks):
|
|
165
|
+
hidden_states = block(
|
|
166
|
+
hidden_states,
|
|
167
|
+
temb=temb,
|
|
168
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
169
|
+
image_rotary_emb=image_rotary_emb,
|
|
170
|
+
controlnet_block_samples=controlnet_block_samples,
|
|
171
|
+
skips=skips,
|
|
172
|
+
) # (N, L, D)
|
|
173
|
+
|
|
174
|
+
if (
|
|
175
|
+
controlnet_block_samples is not None
|
|
176
|
+
and len(controlnet_block_samples) != 0
|
|
177
|
+
):
|
|
178
|
+
raise ValueError(
|
|
179
|
+
"The number of controls is not equal to the number of skip connections."
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# final layer
|
|
183
|
+
hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
|
|
184
|
+
hidden_states = self.proj_out(hidden_states)
|
|
185
|
+
# (N, L, patch_size ** 2 * out_channels)
|
|
186
|
+
|
|
187
|
+
# unpatchify: (N, out_channels, H, W)
|
|
188
|
+
patch_size = self.pos_embed.patch_size
|
|
189
|
+
height = height // patch_size
|
|
190
|
+
width = width // patch_size
|
|
191
|
+
|
|
192
|
+
hidden_states = hidden_states.reshape(
|
|
193
|
+
shape=(
|
|
194
|
+
hidden_states.shape[0],
|
|
195
|
+
height,
|
|
196
|
+
width,
|
|
197
|
+
patch_size,
|
|
198
|
+
patch_size,
|
|
199
|
+
self.out_channels,
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
|
203
|
+
output = hidden_states.reshape(
|
|
204
|
+
shape=(
|
|
205
|
+
hidden_states.shape[0],
|
|
206
|
+
self.out_channels,
|
|
207
|
+
height * patch_size,
|
|
208
|
+
width * patch_size,
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
if not return_dict:
|
|
212
|
+
return (output,)
|
|
213
|
+
return Transformer2DModelOutput(sample=output)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.31
|
|
4
4
|
Summary: 🤗 A Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -59,18 +59,21 @@ Dynamic: requires-python
|
|
|
59
59
|
🔥<b><a href="#unified">Unified Cache APIs</a> | <a href="#dbcache">DBCache</a> | <a href="#taylorseer">Hybrid TaylorSeer</a> | <a href="#cfg">Hybrid Cache CFG</a></b>🔥
|
|
60
60
|
</p>
|
|
61
61
|
<p align="center">
|
|
62
|
-
🎉Now, <b>cache-dit</b> covers <b>
|
|
63
|
-
🔥<
|
|
62
|
+
🎉Now, <b>cache-dit</b> covers <b>most</b> mainstream Diffusers' <b>DiT</b> Pipelines🎉<br>
|
|
63
|
+
🔥<a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1</a> | <a href="#supported"> Wan 2.2 </a> | <a href="#supported">HunyuanVideo</a>🔥<br>
|
|
64
|
+
🔥<a href="#supported">HunyuanDiT</a> | <a href="#supported">HiDream</a> | <a href="#supported">Mochi</a> | <a href="#supported"> CogVideoX </a> | <a href="#supported">CogVideoX1.5</a>🔥<br>
|
|
65
|
+
🔥<a href="#supported">CogView3Plus</a> | <a href="#supported">CogView4</a> | <a href="#supported">Chroma</a> | <a href="#supported"> LTXVideo </a> | <a href="#supported">PixArt</a>🔥<br>
|
|
66
|
+
🔥<a href="#supported">Cosmos</a> | <a href="#supported">SkyReelsV2</a> | <a href="#supported">VisualCloze</a> | <a href="#supported"> ... </a> | <a href="#supported">Lumina2</a>🔥
|
|
64
67
|
</p>
|
|
65
68
|
</div>
|
|
66
69
|
<div align='center'>
|
|
67
|
-
<img src
|
|
68
|
-
<img src
|
|
69
|
-
<img src
|
|
70
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C0_Q0_NONE.gif width=160px>
|
|
71
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q0_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
|
|
72
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q1_fp8_w8a8_dq_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
|
|
70
73
|
<p><b>🔥Wan2.2 MoE</b> Baseline | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:~2.0x↑🎉</b> | +FP8 DQ:<b>~2.4x↑🎉</b></p>
|
|
71
|
-
<img src
|
|
72
|
-
<img src
|
|
73
|
-
<img src
|
|
74
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C0_Q0_NONE.png width=160px>
|
|
75
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C1_Q0_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S23.png width=160px>
|
|
76
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C1_Q1_fp8_w8a8_dq_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S18.png width=160px>
|
|
74
77
|
<p><b>🔥Qwen-Image</b> Baseline | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:~1.8x↑🎉</b> | +FP8 DQ:<b>~2.2x↑🎉</b><br>♥️ Please consider to leave a <b>⭐️ Star</b> to support us ~ ♥️</p>
|
|
75
78
|
</p>
|
|
76
79
|
</div>
|
|
@@ -141,7 +144,10 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
|
|
|
141
144
|
- [🚀Wan2.2-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
142
145
|
- [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
143
146
|
- [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
147
|
+
- [🚀mochi-1-preview](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
144
148
|
- [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
149
|
+
- [🚀HunyuanDiT](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
150
|
+
- [🚀HiDream](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
145
151
|
|
|
146
152
|
</details>
|
|
147
153
|
|
|
@@ -1,30 +1,32 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=kX9V-FegZG4c8LMwI4PTmMqH794MEW0pzDArdhC0cJw,1241
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=cMx3p02rk8iaGjj6X7bw0aOcGW7d-iY_EBO9S_9o-b4,706
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
4
|
cache_dit/utils.py,sha256=WK7eqgH6gCYNHXNLmWyxBDU0XSHTPg7CfOcyXlGXBqE,10510
|
|
5
5
|
cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
|
|
6
6
|
cache_dit/cache_factory/__init__.py,sha256=Iw6-iJLFbdzCsIDZXXOw371L-HPmoeZO_P9a3sDjP5s,1103
|
|
7
|
-
cache_dit/cache_factory/cache_adapters.py,sha256=
|
|
7
|
+
cache_dit/cache_factory/cache_adapters.py,sha256=6YbBSfKEGdWi9oY1ceuxi-MpHcaDYoQ-t6NTaLZITR4,17938
|
|
8
8
|
cache_dit/cache_factory/cache_interface.py,sha256=y1nY6R3MucRmAnG2UJRI_tIKrRk27FktGWLbfckf3zE,8543
|
|
9
9
|
cache_dit/cache_factory/cache_types.py,sha256=ooukxQRG55uTLmaZ0SKw6gIeY6SQHhMxkbv55uj2Sqk,991
|
|
10
10
|
cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
|
|
11
11
|
cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
|
|
12
|
-
cache_dit/cache_factory/block_adapters/__init__.py,sha256=
|
|
13
|
-
cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=
|
|
12
|
+
cache_dit/cache_factory/block_adapters/__init__.py,sha256=x2ivShzOy2z3p1WUArzoChR4jaLHhNXkXMSk-RPzR3g,17534
|
|
13
|
+
cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=EQBiJYyoInKU1ND69wTm7M0n5Ja4I8QW01SgRpBjSn8,21671
|
|
14
14
|
cache_dit/cache_factory/block_adapters/block_registers.py,sha256=ZeN2wGPmuf2u3puSsBx8x-rl3wRo8-cWcuWNcrssVfA,2553
|
|
15
|
-
cache_dit/cache_factory/cache_blocks/__init__.py,sha256=
|
|
15
|
+
cache_dit/cache_factory/cache_blocks/__init__.py,sha256=08Ox7kD05lkRKCOsVTdEZeKAWBheqpxfrAT1Nz7eclI,2916
|
|
16
16
|
cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=ElMps6_7uI74tSF9GDR_dEI0bZEhdzcepM29xFWnYo8,428
|
|
17
17
|
cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=nf2f5wdxp6tfq9AhFyMyBeKiZfxh63WG1g8q-c2BBSg,10182
|
|
18
|
-
cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=
|
|
18
|
+
cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=_sajtb-Cz8yrCRBRSiJREzFG7h6265K9pXeAz5i1meY,20814
|
|
19
19
|
cache_dit/cache_factory/cache_blocks/utils.py,sha256=dGOC1tMMOvcbvEgx44eTESKn_jsv-0RZ3tRHPa3wmQ4,1315
|
|
20
20
|
cache_dit/cache_factory/cache_contexts/__init__.py,sha256=rqnJ5__zqnpVHK5A1OqWILpNh5Ss-0ZDTGgtxZMKGGo,250
|
|
21
21
|
cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=N88WLdd4KE9DuMWmpX8URcF55E2zWNwcKMxgVYkxMJY,13691
|
|
22
22
|
cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=_NUXcMYYEIVfDHpc4HJr9RUjU5RUEkZmAgFGE8bh5Wc,34883
|
|
23
23
|
cache_dit/cache_factory/cache_contexts/taylorseer.py,sha256=etSUIZzDvqW3ScKCbccTPcFaSmxV1T-xAXdk-p3e3wk,3802
|
|
24
|
-
cache_dit/cache_factory/patch_functors/__init__.py,sha256=
|
|
24
|
+
cache_dit/cache_factory/patch_functors/__init__.py,sha256=06zdddrjvSCgBzJ0a8niRHd3ucF2qsbzlbL00d4aCvk,451
|
|
25
25
|
cache_dit/cache_factory/patch_functors/functor_base.py,sha256=Ahk0fTfrHgNdEl-9JSkACvfyyv9G-Ei5OSz7XBIlX5o,357
|
|
26
|
-
cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=
|
|
27
|
-
cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=
|
|
26
|
+
cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=2iLxlsc-1dDHRveqCXaC07E9CeMNOuBNkvpJ1atpK7E,10048
|
|
27
|
+
cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=UMkyuEYjO7UO_zmXi9Djd-nD-XMgCUgE-qkYA3plWSM,9559
|
|
28
|
+
cache_dit/cache_factory/patch_functors/functor_hidream.py,sha256=pi_vvpDy1lsgQHxu3eK3v93rdJL7oNwkt3WakRP8pbw,15375
|
|
29
|
+
cache_dit/cache_factory/patch_functors/functor_hunyuan_dit.py,sha256=iSo5dD5uKnjQQeysDUIkKt0wdnK5bzXTc_F_lfHG70w,6401
|
|
28
30
|
cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
|
|
29
31
|
cache_dit/compile/utils.py,sha256=nN2OIrSdwRR5zGxJinKDqb07pXpvTNTF3g_OgLkeeBU,3858
|
|
30
32
|
cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -39,9 +41,9 @@ cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70
|
|
|
39
41
|
cache_dit/quantize/quantize_ao.py,sha256=mGspqYgQtenl3QnKPtsSYsSD7LbVX93f1M940bhXKLU,6066
|
|
40
42
|
cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
|
|
41
43
|
cache_dit/quantize/quantize_svdq.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
42
|
-
cache_dit-0.2.
|
|
43
|
-
cache_dit-0.2.
|
|
44
|
-
cache_dit-0.2.
|
|
45
|
-
cache_dit-0.2.
|
|
46
|
-
cache_dit-0.2.
|
|
47
|
-
cache_dit-0.2.
|
|
44
|
+
cache_dit-0.2.31.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
45
|
+
cache_dit-0.2.31.dist-info/METADATA,sha256=MrRvt7HL8pNm0ZsBxKO25pBcCJhHPG7HddwjT_euy_I,23198
|
|
46
|
+
cache_dit-0.2.31.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
47
|
+
cache_dit-0.2.31.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
48
|
+
cache_dit-0.2.31.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
49
|
+
cache_dit-0.2.31.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|