cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,94 @@
1
+ import torch
2
+ from typing import Optional
3
+ from diffusers.models.modeling_utils import ModelMixin
4
+ from diffusers.models.transformers.dit_transformer_2d import (
5
+ DiTTransformer2DModel,
6
+ )
7
+ from diffusers.models.attention_processor import (
8
+ Attention,
9
+ AttnProcessor2_0,
10
+ ) # sdpa
11
+
12
+ try:
13
+ from diffusers.models._modeling_parallel import (
14
+ ContextParallelInput,
15
+ ContextParallelOutput,
16
+ ContextParallelModelPlan,
17
+ )
18
+ except ImportError:
19
+ raise ImportError(
20
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
21
+ "Please install latest version of diffusers from source: \n"
22
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
23
+ )
24
+ from .cp_plan_registers import (
25
+ ContextParallelismPlanner,
26
+ ContextParallelismPlannerRegister,
27
+ )
28
+ from .cp_plan_pixart import (
29
+ __patch_AttnProcessor2_0__call__,
30
+ __patch_Attention_prepare_attention_mask__,
31
+ )
32
+
33
+
34
+ from cache_dit.logger import init_logger
35
+
36
+ logger = init_logger(__name__)
37
+
38
+
39
+ @ContextParallelismPlannerRegister.register("DiT")
40
+ class DiTContextParallelismPlanner(ContextParallelismPlanner):
41
+ def apply(
42
+ self,
43
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
44
+ **kwargs,
45
+ ) -> ContextParallelModelPlan:
46
+ assert transformer is not None, "Transformer must be provided."
47
+ assert isinstance(
48
+ transformer, DiTTransformer2DModel
49
+ ), "Transformer must be an instance of DiTTransformer2DModel"
50
+
51
+ self._cp_planner_preferred_native_diffusers = False
52
+
53
+ if (
54
+ transformer is not None
55
+ and self._cp_planner_preferred_native_diffusers
56
+ ):
57
+ if hasattr(transformer, "_cp_plan"):
58
+ if transformer._cp_plan is not None:
59
+ return transformer._cp_plan
60
+
61
+ # Apply monkey patch to fix attention mask preparation at class level
62
+ Attention.prepare_attention_mask = (
63
+ __patch_Attention_prepare_attention_mask__
64
+ )
65
+ AttnProcessor2_0.__call__ = __patch_AttnProcessor2_0__call__
66
+ if not hasattr(AttnProcessor2_0, "_parallel_config"):
67
+ AttnProcessor2_0._parallel_config = None
68
+ if not hasattr(AttnProcessor2_0, "_attention_backend"):
69
+ AttnProcessor2_0._attention_backend = None
70
+
71
+ # Otherwise, use the custom CP plan defined here, this maybe
72
+ # a little different from the native diffusers implementation
73
+ # for some models.
74
+
75
+ _cp_plan = {
76
+ # Pattern of transformer_blocks.0, split_output=False:
77
+ # un-split input -> split -> to_qkv/...
78
+ # -> all2all
79
+ # -> attn (local head, full seqlen)
80
+ # -> all2all
81
+ # -> splited output
82
+ # (only split hidden_states, not encoder_hidden_states)
83
+ "transformer_blocks.0": {
84
+ "hidden_states": ContextParallelInput(
85
+ split_dim=1, expected_dims=3, split_output=False
86
+ ),
87
+ },
88
+ # Then, the final proj_out will gather the splited output.
89
+ # splited input (previous splited output)
90
+ # -> all gather
91
+ # -> un-split output
92
+ "proj_out_2": ContextParallelOutput(gather_dim=1, expected_dims=3),
93
+ }
94
+ return _cp_plan
@@ -0,0 +1,88 @@
1
+ import torch
2
+ from typing import Optional
3
+ from diffusers.models.modeling_utils import ModelMixin
4
+ from diffusers import FluxTransformer2DModel
5
+
6
+ try:
7
+ from diffusers.models._modeling_parallel import (
8
+ ContextParallelInput,
9
+ ContextParallelOutput,
10
+ ContextParallelModelPlan,
11
+ )
12
+ except ImportError:
13
+ raise ImportError(
14
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
15
+ "Please install latest version of diffusers from source: \n"
16
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
17
+ )
18
+ from .cp_plan_registers import (
19
+ ContextParallelismPlanner,
20
+ ContextParallelismPlannerRegister,
21
+ )
22
+
23
+ from cache_dit.logger import init_logger
24
+
25
+ logger = init_logger(__name__)
26
+
27
+
28
+ @ContextParallelismPlannerRegister.register("Flux")
29
+ class FluxContextParallelismPlanner(ContextParallelismPlanner):
30
+ def apply(
31
+ self,
32
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
33
+ **kwargs,
34
+ ) -> ContextParallelModelPlan:
35
+ if (
36
+ transformer is not None
37
+ and self._cp_planner_preferred_native_diffusers
38
+ ):
39
+ assert isinstance(
40
+ transformer, FluxTransformer2DModel
41
+ ), "Transformer must be an instance of FluxTransformer2DModel"
42
+ if hasattr(transformer, "_cp_plan"):
43
+ if transformer._cp_plan is not None:
44
+ return transformer._cp_plan
45
+
46
+ # Otherwise, use the custom CP plan defined here, this maybe
47
+ # a little different from the native diffusers implementation
48
+ # for some models.
49
+ _cp_plan = {
50
+ # Here is a Transformer level CP plan for Flux, which will
51
+ # only apply the only 1 split hook (pre_forward) on the forward
52
+ # of Transformer, and gather the output after Transformer forward.
53
+ # Pattern of transformer forward, split_output=False:
54
+ # un-split input -> splited input (inside transformer)
55
+ # Pattern of the transformer_blocks, single_transformer_blocks:
56
+ # splited input (previous splited output) -> to_qkv/...
57
+ # -> all2all
58
+ # -> attn (local head, full seqlen)
59
+ # -> all2all
60
+ # -> splited output
61
+ # The `hidden_states` and `encoder_hidden_states` will still keep
62
+ # itself splited after block forward (namely, automatic split by
63
+ # the all2all comm op after attn) for the all blocks.
64
+ # img_ids and txt_ids will only be splited once at the very beginning,
65
+ # and keep splited through the whole transformer forward. The all2all
66
+ # comm op only happens on the `out` tensor after local attn not on
67
+ # img_ids and txt_ids.
68
+ "": {
69
+ "hidden_states": ContextParallelInput(
70
+ split_dim=1, expected_dims=3, split_output=False
71
+ ),
72
+ "encoder_hidden_states": ContextParallelInput(
73
+ split_dim=1, expected_dims=3, split_output=False
74
+ ),
75
+ "img_ids": ContextParallelInput(
76
+ split_dim=0, expected_dims=2, split_output=False
77
+ ),
78
+ "txt_ids": ContextParallelInput(
79
+ split_dim=0, expected_dims=2, split_output=False
80
+ ),
81
+ },
82
+ # Then, the final proj_out will gather the splited output.
83
+ # splited input (previous splited output)
84
+ # -> all gather
85
+ # -> un-split output
86
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
87
+ }
88
+ return _cp_plan