cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
- cache_dit/{cache_factory → caching}/utils.py +19 -8
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
- cache_dit/cache_factory/patch_functors/__init__.py +0 -15
- cache_dit-1.0.3.dist-info/RECORD +0 -58
- cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
4
|
+
from diffusers.models.transformers.dit_transformer_2d import (
|
|
5
|
+
DiTTransformer2DModel,
|
|
6
|
+
)
|
|
7
|
+
from diffusers.models.attention_processor import (
|
|
8
|
+
Attention,
|
|
9
|
+
AttnProcessor2_0,
|
|
10
|
+
) # sdpa
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from diffusers.models._modeling_parallel import (
|
|
14
|
+
ContextParallelInput,
|
|
15
|
+
ContextParallelOutput,
|
|
16
|
+
ContextParallelModelPlan,
|
|
17
|
+
)
|
|
18
|
+
except ImportError:
|
|
19
|
+
raise ImportError(
|
|
20
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
21
|
+
"Please install latest version of diffusers from source: \n"
|
|
22
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
23
|
+
)
|
|
24
|
+
from .cp_plan_registers import (
|
|
25
|
+
ContextParallelismPlanner,
|
|
26
|
+
ContextParallelismPlannerRegister,
|
|
27
|
+
)
|
|
28
|
+
from .cp_plan_pixart import (
|
|
29
|
+
__patch_AttnProcessor2_0__call__,
|
|
30
|
+
__patch_Attention_prepare_attention_mask__,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
from cache_dit.logger import init_logger
|
|
35
|
+
|
|
36
|
+
logger = init_logger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@ContextParallelismPlannerRegister.register("DiT")
|
|
40
|
+
class DiTContextParallelismPlanner(ContextParallelismPlanner):
|
|
41
|
+
def apply(
|
|
42
|
+
self,
|
|
43
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
44
|
+
**kwargs,
|
|
45
|
+
) -> ContextParallelModelPlan:
|
|
46
|
+
assert transformer is not None, "Transformer must be provided."
|
|
47
|
+
assert isinstance(
|
|
48
|
+
transformer, DiTTransformer2DModel
|
|
49
|
+
), "Transformer must be an instance of DiTTransformer2DModel"
|
|
50
|
+
|
|
51
|
+
self._cp_planner_preferred_native_diffusers = False
|
|
52
|
+
|
|
53
|
+
if (
|
|
54
|
+
transformer is not None
|
|
55
|
+
and self._cp_planner_preferred_native_diffusers
|
|
56
|
+
):
|
|
57
|
+
if hasattr(transformer, "_cp_plan"):
|
|
58
|
+
if transformer._cp_plan is not None:
|
|
59
|
+
return transformer._cp_plan
|
|
60
|
+
|
|
61
|
+
# Apply monkey patch to fix attention mask preparation at class level
|
|
62
|
+
Attention.prepare_attention_mask = (
|
|
63
|
+
__patch_Attention_prepare_attention_mask__
|
|
64
|
+
)
|
|
65
|
+
AttnProcessor2_0.__call__ = __patch_AttnProcessor2_0__call__
|
|
66
|
+
if not hasattr(AttnProcessor2_0, "_parallel_config"):
|
|
67
|
+
AttnProcessor2_0._parallel_config = None
|
|
68
|
+
if not hasattr(AttnProcessor2_0, "_attention_backend"):
|
|
69
|
+
AttnProcessor2_0._attention_backend = None
|
|
70
|
+
|
|
71
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
72
|
+
# a little different from the native diffusers implementation
|
|
73
|
+
# for some models.
|
|
74
|
+
|
|
75
|
+
_cp_plan = {
|
|
76
|
+
# Pattern of transformer_blocks.0, split_output=False:
|
|
77
|
+
# un-split input -> split -> to_qkv/...
|
|
78
|
+
# -> all2all
|
|
79
|
+
# -> attn (local head, full seqlen)
|
|
80
|
+
# -> all2all
|
|
81
|
+
# -> splited output
|
|
82
|
+
# (only split hidden_states, not encoder_hidden_states)
|
|
83
|
+
"transformer_blocks.0": {
|
|
84
|
+
"hidden_states": ContextParallelInput(
|
|
85
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
86
|
+
),
|
|
87
|
+
},
|
|
88
|
+
# Then, the final proj_out will gather the splited output.
|
|
89
|
+
# splited input (previous splited output)
|
|
90
|
+
# -> all gather
|
|
91
|
+
# -> un-split output
|
|
92
|
+
"proj_out_2": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
93
|
+
}
|
|
94
|
+
return _cp_plan
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
4
|
+
from diffusers import FluxTransformer2DModel
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from diffusers.models._modeling_parallel import (
|
|
8
|
+
ContextParallelInput,
|
|
9
|
+
ContextParallelOutput,
|
|
10
|
+
ContextParallelModelPlan,
|
|
11
|
+
)
|
|
12
|
+
except ImportError:
|
|
13
|
+
raise ImportError(
|
|
14
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
15
|
+
"Please install latest version of diffusers from source: \n"
|
|
16
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
17
|
+
)
|
|
18
|
+
from .cp_plan_registers import (
|
|
19
|
+
ContextParallelismPlanner,
|
|
20
|
+
ContextParallelismPlannerRegister,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
from cache_dit.logger import init_logger
|
|
24
|
+
|
|
25
|
+
logger = init_logger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@ContextParallelismPlannerRegister.register("Flux")
|
|
29
|
+
class FluxContextParallelismPlanner(ContextParallelismPlanner):
|
|
30
|
+
def apply(
|
|
31
|
+
self,
|
|
32
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
33
|
+
**kwargs,
|
|
34
|
+
) -> ContextParallelModelPlan:
|
|
35
|
+
if (
|
|
36
|
+
transformer is not None
|
|
37
|
+
and self._cp_planner_preferred_native_diffusers
|
|
38
|
+
):
|
|
39
|
+
assert isinstance(
|
|
40
|
+
transformer, FluxTransformer2DModel
|
|
41
|
+
), "Transformer must be an instance of FluxTransformer2DModel"
|
|
42
|
+
if hasattr(transformer, "_cp_plan"):
|
|
43
|
+
if transformer._cp_plan is not None:
|
|
44
|
+
return transformer._cp_plan
|
|
45
|
+
|
|
46
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
47
|
+
# a little different from the native diffusers implementation
|
|
48
|
+
# for some models.
|
|
49
|
+
_cp_plan = {
|
|
50
|
+
# Here is a Transformer level CP plan for Flux, which will
|
|
51
|
+
# only apply the only 1 split hook (pre_forward) on the forward
|
|
52
|
+
# of Transformer, and gather the output after Transformer forward.
|
|
53
|
+
# Pattern of transformer forward, split_output=False:
|
|
54
|
+
# un-split input -> splited input (inside transformer)
|
|
55
|
+
# Pattern of the transformer_blocks, single_transformer_blocks:
|
|
56
|
+
# splited input (previous splited output) -> to_qkv/...
|
|
57
|
+
# -> all2all
|
|
58
|
+
# -> attn (local head, full seqlen)
|
|
59
|
+
# -> all2all
|
|
60
|
+
# -> splited output
|
|
61
|
+
# The `hidden_states` and `encoder_hidden_states` will still keep
|
|
62
|
+
# itself splited after block forward (namely, automatic split by
|
|
63
|
+
# the all2all comm op after attn) for the all blocks.
|
|
64
|
+
# img_ids and txt_ids will only be splited once at the very beginning,
|
|
65
|
+
# and keep splited through the whole transformer forward. The all2all
|
|
66
|
+
# comm op only happens on the `out` tensor after local attn not on
|
|
67
|
+
# img_ids and txt_ids.
|
|
68
|
+
"": {
|
|
69
|
+
"hidden_states": ContextParallelInput(
|
|
70
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
71
|
+
),
|
|
72
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
73
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
74
|
+
),
|
|
75
|
+
"img_ids": ContextParallelInput(
|
|
76
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
77
|
+
),
|
|
78
|
+
"txt_ids": ContextParallelInput(
|
|
79
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
80
|
+
),
|
|
81
|
+
},
|
|
82
|
+
# Then, the final proj_out will gather the splited output.
|
|
83
|
+
# splited input (previous splited output)
|
|
84
|
+
# -> all gather
|
|
85
|
+
# -> un-split output
|
|
86
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
87
|
+
}
|
|
88
|
+
return _cp_plan
|