cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/offload_utils.py +115 -0
  11. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  12. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  13. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  14. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  15. cache_dit/caching/cache_contexts/__init__.py +28 -0
  16. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
  18. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  21. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  22. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  23. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  24. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  25. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  26. cache_dit/caching/cache_interface.py +358 -0
  27. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  28. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  29. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  30. cache_dit/caching/patch_functors/__init__.py +15 -0
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  36. cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
  37. cache_dit/caching/utils.py +68 -0
  38. cache_dit/metrics/__init__.py +11 -0
  39. cache_dit/metrics/metrics.py +3 -0
  40. cache_dit/parallelism/__init__.py +3 -0
  41. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  57. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  58. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  59. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  60. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  61. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  62. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  68. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  69. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  70. cache_dit/parallelism/parallel_backend.py +26 -0
  71. cache_dit/parallelism/parallel_config.py +88 -0
  72. cache_dit/parallelism/parallel_interface.py +77 -0
  73. cache_dit/quantize/__init__.py +7 -0
  74. cache_dit/quantize/backends/__init__.py +1 -0
  75. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  76. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  77. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
  78. cache_dit/quantize/quantize_backend.py +0 -0
  79. cache_dit/quantize/quantize_config.py +0 -0
  80. cache_dit/quantize/quantize_interface.py +3 -16
  81. cache_dit/summary.py +593 -0
  82. cache_dit/utils.py +46 -290
  83. cache_dit-1.0.14.dist-info/METADATA +301 -0
  84. cache_dit-1.0.14.dist-info/RECORD +102 -0
  85. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  86. cache_dit/cache_factory/__init__.py +0 -28
  87. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  88. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  89. cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
  90. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  91. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
  92. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
  93. cache_dit/cache_factory/cache_blocks/utils.py +0 -41
  94. cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
  95. cache_dit/cache_factory/cache_interface.py +0 -217
  96. cache_dit/cache_factory/patch_functors/__init__.py +0 -12
  97. cache_dit/cache_factory/utils.py +0 -57
  98. cache_dit-0.3.2.dist-info/METADATA +0 -753
  99. cache_dit-0.3.2.dist-info/RECORD +0 -56
  100. cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
  101. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  102. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  103. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  104. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  105. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  106. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  107. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  108. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -6,22 +6,32 @@ from collections.abc import Iterable
6
6
 
7
7
  from typing import Any, Tuple, List, Optional, Union
8
8
 
9
- from diffusers import DiffusionPipeline
10
- from cache_dit.cache_factory.patch_functors import PatchFunctor
11
- from cache_dit.cache_factory.forward_pattern import ForwardPattern
12
- from cache_dit.cache_factory.params_modifier import ParamsModifier
9
+ from diffusers import DiffusionPipeline, ModelMixin
10
+ from cache_dit.caching.patch_functors import PatchFunctor
11
+ from cache_dit.caching.forward_pattern import ForwardPattern
12
+ from cache_dit.caching.params_modifier import ParamsModifier
13
13
 
14
14
  from cache_dit.logger import init_logger
15
15
 
16
16
  logger = init_logger(__name__)
17
17
 
18
18
 
19
+ class FakeDiffusionPipeline:
20
+ # A placeholder for pipelines when pipe is None.
21
+ def __init__(
22
+ self,
23
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
24
+ ):
25
+ self.transformer = transformer # Reference only
26
+
27
+
19
28
  @dataclasses.dataclass
20
29
  class BlockAdapter:
21
30
 
22
31
  # Transformer configurations.
23
32
  pipe: Union[
24
33
  DiffusionPipeline,
34
+ FakeDiffusionPipeline,
25
35
  Any,
26
36
  ] = None
27
37
 
@@ -73,7 +83,7 @@ class BlockAdapter:
73
83
  ]
74
84
  ] = None
75
85
 
76
- check_forward_pattern: bool = True
86
+ check_forward_pattern: Optional[bool] = None
77
87
  check_num_outputs: bool = False
78
88
 
79
89
  # Pipeline Level Flags
@@ -110,9 +120,53 @@ class BlockAdapter:
110
120
  def __post_init__(self):
111
121
  if self.skip_post_init:
112
122
  return
123
+
124
+ self.maybe_fake_pipe()
113
125
  if any((self.pipe is not None, self.transformer is not None)):
114
126
  self.maybe_fill_attrs()
115
127
  self.maybe_patchify()
128
+ self.maybe_skip_checks()
129
+
130
+ def maybe_fake_pipe(self):
131
+ if self.pipe is None:
132
+ self.pipe = FakeDiffusionPipeline()
133
+ logger.warning("pipe is None, use FakeDiffusionPipeline instead.")
134
+
135
+ def maybe_skip_checks(self):
136
+ if self.check_forward_pattern is None:
137
+ if self.transformer is not None:
138
+ if self.nested_depth(self.transformer) == 0:
139
+ transformer = self.transformer
140
+ elif self.nested_depth(self.transformer) == 1:
141
+ transformer = self.transformer[0]
142
+ else:
143
+ raise ValueError(
144
+ "transformer nested depth can't more than 1, "
145
+ f"current is: {self.nested_depth(self.transformer)}"
146
+ )
147
+ if transformer.__module__.startswith("diffusers"):
148
+ self.check_forward_pattern = True
149
+ logger.info(
150
+ f"Found transformer from diffusers: {transformer.__module__} "
151
+ "enable check_forward_pattern by default."
152
+ )
153
+ else:
154
+ self.check_forward_pattern = False
155
+ logger.info(
156
+ f"Found transformer NOT from diffusers: {transformer.__module__} "
157
+ "disable check_forward_pattern by default."
158
+ )
159
+
160
+ if getattr(self.transformer, "_hf_hook", None) is not None:
161
+ logger.warning("_hf_hook is not None, force skip pattern check!")
162
+ self.check_forward_pattern = False
163
+ self.check_num_outputs = False
164
+ elif getattr(self.transformer, "_diffusers_hook", None) is not None:
165
+ logger.warning(
166
+ "_diffusers_hook is not None, force skip pattern check!"
167
+ )
168
+ self.check_forward_pattern = False
169
+ self.check_num_outputs = False
116
170
 
117
171
  def maybe_fill_attrs(self):
118
172
  # NOTE: This func should be call before normalize.
@@ -195,7 +249,10 @@ class BlockAdapter:
195
249
  if self.transformer is not None:
196
250
  self.patch_functor.apply(self.transformer, *args, **kwargs)
197
251
  else:
198
- assert hasattr(self.pipe, "transformer")
252
+ assert hasattr(self.pipe, "transformer"), (
253
+ "pipe.transformer can not be None when patch_functor "
254
+ "is provided and transformer is None."
255
+ )
199
256
  self.patch_functor.apply(self.pipe.transformer, *args, **kwargs)
200
257
 
201
258
  @staticmethod
@@ -211,6 +268,10 @@ class BlockAdapter:
211
268
  adapter.forward_pattern is not None
212
269
  ), "adapter.forward_pattern can not be None."
213
270
  pipe = adapter.pipe
271
+ if isinstance(pipe, FakeDiffusionPipeline):
272
+ raise ValueError(
273
+ "Can not auto block adapter for FakeDiffusionPipeline."
274
+ )
214
275
 
215
276
  assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
216
277
 
@@ -476,6 +537,7 @@ class BlockAdapter:
476
537
  @staticmethod
477
538
  def normalize(
478
539
  adapter: "BlockAdapter",
540
+ unique: bool = True,
479
541
  ) -> "BlockAdapter":
480
542
 
481
543
  if getattr(adapter, "_is_normalized", False):
@@ -510,7 +572,10 @@ class BlockAdapter:
510
572
  adapter.forward_pattern = _normalize_attr(adapter.forward_pattern)
511
573
  adapter.dummy_blocks_names = _normalize_attr(adapter.dummy_blocks_names)
512
574
  adapter.params_modifiers = _normalize_attr(adapter.params_modifiers)
513
- BlockAdapter.unique(adapter)
575
+ # Some times, the cache_config will be None.
576
+ # So we do not perform unique check here.
577
+ if unique:
578
+ BlockAdapter.unique(adapter)
514
579
 
515
580
  adapter._is_normalized = True
516
581
 
@@ -558,6 +623,10 @@ class BlockAdapter:
558
623
  if not getattr(adapter, "_is_normalized", False):
559
624
  raise RuntimeError("block_adapter must be normailzed.")
560
625
 
626
+ @classmethod
627
+ def is_normalized(cls, adapter: "BlockAdapter") -> bool:
628
+ return getattr(adapter, "_is_normalized", False)
629
+
561
630
  @classmethod
562
631
  def is_cached(cls, adapter: Any) -> bool:
563
632
  if isinstance(adapter, cls):
@@ -579,6 +648,21 @@ class BlockAdapter:
579
648
  else:
580
649
  return getattr(adapter, "_is_cached", False)
581
650
 
651
+ @classmethod
652
+ def is_parallelized(cls, adapter: Any) -> bool:
653
+ if isinstance(adapter, cls):
654
+ cls.assert_normalized(adapter)
655
+ return getattr(adapter.transformer[0], "_is_parallelized", False)
656
+ elif isinstance(adapter, DiffusionPipeline):
657
+ return getattr(adapter.transformer, "_is_parallelized", False)
658
+ elif isinstance(adapter, torch.nn.Module):
659
+ return getattr(adapter, "_is_parallelized", False)
660
+ elif isinstance(adapter, list): # [TRN_0,...]
661
+ assert isinstance(adapter[0], torch.nn.Module)
662
+ return getattr(adapter[0], "_is_parallelized", False)
663
+ else:
664
+ return getattr(adapter, "_is_parallelized", False)
665
+
582
666
  @classmethod
583
667
  def nested_depth(cls, obj: Any):
584
668
  # str: 0; List[str]: 1; List[List[str]]: 2
@@ -0,0 +1,118 @@
1
+ import torch
2
+ from typing import Any, Tuple, List, Dict, Callable, Union
3
+
4
+ from diffusers import DiffusionPipeline
5
+ from cache_dit.caching.block_adapters.block_adapters import (
6
+ BlockAdapter,
7
+ FakeDiffusionPipeline,
8
+ )
9
+
10
+ from cache_dit.logger import init_logger
11
+
12
+ logger = init_logger(__name__)
13
+
14
+
15
+ class BlockAdapterRegistry:
16
+ _adapters: Dict[str, Callable[..., BlockAdapter]] = {}
17
+ _predefined_adapters_has_separate_cfg: List[str] = [
18
+ "QwenImage",
19
+ "Wan",
20
+ "CogView4",
21
+ "Cosmos",
22
+ "SkyReelsV2",
23
+ "Chroma",
24
+ "Lumina2",
25
+ "Kandinsky5",
26
+ ]
27
+
28
+ @classmethod
29
+ def register(cls, name: str, supported: bool = True):
30
+ def decorator(
31
+ func: Callable[..., BlockAdapter]
32
+ ) -> Callable[..., BlockAdapter]:
33
+ if supported:
34
+ cls._adapters[name] = func
35
+ return func
36
+
37
+ return decorator
38
+
39
+ @classmethod
40
+ def get_adapter(
41
+ cls,
42
+ pipe_or_module: DiffusionPipeline | torch.nn.Module | str | Any,
43
+ **kwargs,
44
+ ) -> BlockAdapter | None:
45
+ if not isinstance(pipe_or_module, str):
46
+ cls_name: str = pipe_or_module.__class__.__name__
47
+ else:
48
+ cls_name = pipe_or_module
49
+
50
+ for name in cls._adapters:
51
+ if cls_name.startswith(name):
52
+ if not isinstance(pipe_or_module, DiffusionPipeline):
53
+ assert isinstance(pipe_or_module, torch.nn.Module)
54
+ # NOTE: Make pre-registered adapters support Transformer-only case.
55
+ # WARN: This branch is not officially supported and only for testing
56
+ # purpose. We construct a fake diffusion pipeline that contains the
57
+ # given transformer module. Currently, only works for DiT models which
58
+ # only have one transformer module. Case like multiple transformers
59
+ # is not supported, e.g, Wan2.2. Please use BlockAdapter directly for
60
+ # such cases.
61
+ return cls._adapters[name](
62
+ FakeDiffusionPipeline(pipe_or_module), **kwargs
63
+ )
64
+ else:
65
+ return cls._adapters[name](pipe_or_module, **kwargs)
66
+
67
+ return None
68
+
69
+ @classmethod
70
+ def has_separate_cfg(
71
+ cls,
72
+ pipe_or_adapter: Union[
73
+ DiffusionPipeline,
74
+ FakeDiffusionPipeline,
75
+ BlockAdapter,
76
+ Any,
77
+ ],
78
+ ) -> bool:
79
+
80
+ # Prefer custom setting from block adapter.
81
+ if isinstance(pipe_or_adapter, BlockAdapter):
82
+ return pipe_or_adapter.has_separate_cfg
83
+
84
+ has_separate_cfg = False
85
+ if isinstance(pipe_or_adapter, FakeDiffusionPipeline):
86
+ return False
87
+
88
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
89
+ adapter = cls.get_adapter(
90
+ pipe_or_adapter,
91
+ skip_post_init=True, # check cfg setting only
92
+ )
93
+ if adapter is not None:
94
+ has_separate_cfg = adapter.has_separate_cfg
95
+
96
+ if has_separate_cfg:
97
+ return True
98
+
99
+ pipe_cls_name = pipe_or_adapter.__class__.__name__
100
+ for name in cls._predefined_adapters_has_separate_cfg:
101
+ if pipe_cls_name.startswith(name):
102
+ return True
103
+
104
+ return False
105
+
106
+ @classmethod
107
+ def is_supported(cls, pipe_or_module) -> bool:
108
+ cls_name: str = pipe_or_module.__class__.__name__
109
+
110
+ for name in cls._adapters:
111
+ if cls_name.startswith(name):
112
+ return True
113
+ return False
114
+
115
+ @classmethod
116
+ def supported_pipelines(cls, **kwargs) -> Tuple[int, List[str]]:
117
+ val_pipelines = cls._adapters.keys()
118
+ return len(val_pipelines), [p for p in val_pipelines]
@@ -0,0 +1 @@
1
+ from cache_dit.caching.cache_adapters.cache_adapter import CachedAdapter