cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,226 @@
1
+ import torch
2
+
3
+ from cache_dit.caching import ForwardPattern
4
+ from cache_dit.caching.cache_types import CacheType
5
+ from cache_dit.caching.cache_contexts.cache_context import CachedContext
6
+ from cache_dit.caching.cache_contexts.prune_context import PrunedContext
7
+ from cache_dit.caching.cache_contexts.cache_manager import (
8
+ CachedContextManager,
9
+ )
10
+ from cache_dit.caching.cache_contexts.prune_manager import (
11
+ PrunedContextManager,
12
+ )
13
+
14
+ from cache_dit.caching.cache_blocks.pattern_0_1_2 import (
15
+ CachedBlocks_Pattern_0_1_2,
16
+ PrunedBlocks_Pattern_0_1_2,
17
+ )
18
+ from cache_dit.caching.cache_blocks.pattern_3_4_5 import (
19
+ CachedBlocks_Pattern_3_4_5,
20
+ PrunedBlocks_Pattern_3_4_5,
21
+ )
22
+ from cache_dit.caching.cache_blocks.pattern_utils import apply_stats
23
+ from cache_dit.caching.cache_blocks.pattern_utils import remove_stats
24
+
25
+ from cache_dit.logger import init_logger
26
+
27
+ logger = init_logger(__name__)
28
+
29
+
30
+ class UnifiedBlocks:
31
+ def __new__(
32
+ cls,
33
+ # 0. Transformer blocks configuration
34
+ transformer_blocks: torch.nn.ModuleList,
35
+ transformer: torch.nn.Module = None,
36
+ forward_pattern: ForwardPattern = None,
37
+ check_forward_pattern: bool = True,
38
+ check_num_outputs: bool = True,
39
+ # 1. Cache context configuration
40
+ # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
41
+ # 'layers', 'single_stream_blocks', 'double_stream_blocks'
42
+ cache_prefix: str = None, # cache_prefix maybe un-need.
43
+ # Usually, blocks_name, etc.
44
+ cache_context: CachedContext | PrunedContext | str = None,
45
+ context_manager: CachedContextManager | PrunedContextManager = None,
46
+ cache_type: CacheType = CacheType.DBCache,
47
+ **kwargs,
48
+ ):
49
+ if cache_type == CacheType.DBCache:
50
+ return CachedBlocks(
51
+ # 0. Transformer blocks configuration
52
+ transformer_blocks,
53
+ transformer=transformer,
54
+ forward_pattern=forward_pattern,
55
+ check_forward_pattern=check_forward_pattern,
56
+ check_num_outputs=check_num_outputs,
57
+ # 1. Cache context configuration
58
+ cache_prefix=cache_prefix,
59
+ cache_context=cache_context,
60
+ context_manager=context_manager,
61
+ cache_type=cache_type,
62
+ **kwargs,
63
+ )
64
+ elif cache_type == CacheType.DBPrune:
65
+ return PrunedBlocks(
66
+ # 0. Transformer blocks configuration
67
+ transformer_blocks,
68
+ transformer=transformer,
69
+ forward_pattern=forward_pattern,
70
+ check_forward_pattern=check_forward_pattern,
71
+ check_num_outputs=check_num_outputs,
72
+ # 1. Cache context configuration
73
+ cache_prefix=cache_prefix,
74
+ cache_context=cache_context,
75
+ context_manager=context_manager,
76
+ cache_type=cache_type,
77
+ **kwargs,
78
+ )
79
+ else:
80
+ raise ValueError(f"Cache type {cache_type} is not supported now!")
81
+
82
+
83
+ class CachedBlocks:
84
+ def __new__(
85
+ cls,
86
+ # 0. Transformer blocks configuration
87
+ transformer_blocks: torch.nn.ModuleList,
88
+ transformer: torch.nn.Module = None,
89
+ forward_pattern: ForwardPattern = None,
90
+ check_forward_pattern: bool = True,
91
+ check_num_outputs: bool = True,
92
+ # 1. Cache context configuration
93
+ # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
94
+ # 'layers', 'single_stream_blocks', 'double_stream_blocks'
95
+ cache_prefix: str = None, # cache_prefix maybe un-need.
96
+ # Usually, blocks_name, etc.
97
+ cache_context: CachedContext | PrunedContext | str = None,
98
+ context_manager: CachedContextManager | PrunedContextManager = None,
99
+ cache_type: CacheType = CacheType.DBCache,
100
+ **kwargs,
101
+ ):
102
+ assert transformer is not None, "transformer can't be None."
103
+ assert forward_pattern is not None, "forward_pattern can't be None."
104
+ assert cache_context is not None, "cache_context can't be None."
105
+ assert context_manager is not None, "context_manager can't be None."
106
+ if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
107
+ if cache_type == CacheType.DBCache:
108
+ assert isinstance(
109
+ context_manager, CachedContextManager
110
+ ), "context_manager must be CachedContextManager for DBCache."
111
+ return CachedBlocks_Pattern_0_1_2(
112
+ # 0. Transformer blocks configuration
113
+ transformer_blocks,
114
+ transformer=transformer,
115
+ forward_pattern=forward_pattern,
116
+ check_forward_pattern=check_forward_pattern,
117
+ check_num_outputs=check_num_outputs,
118
+ # 1. Cache context configuration
119
+ cache_prefix=cache_prefix,
120
+ cache_context=cache_context,
121
+ context_manager=context_manager,
122
+ cache_type=cache_type,
123
+ **kwargs,
124
+ )
125
+ else:
126
+ raise ValueError(
127
+ f"Cache type {cache_type} is not supported now!"
128
+ )
129
+ elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
130
+ if cache_type == CacheType.DBCache:
131
+ assert isinstance(
132
+ context_manager, CachedContextManager
133
+ ), "context_manager must be CachedContextManager for DBCache."
134
+ return CachedBlocks_Pattern_3_4_5(
135
+ # 0. Transformer blocks configuration
136
+ transformer_blocks,
137
+ transformer=transformer,
138
+ forward_pattern=forward_pattern,
139
+ check_forward_pattern=check_forward_pattern,
140
+ check_num_outputs=check_num_outputs,
141
+ # 1. Cache context configuration
142
+ cache_prefix=cache_prefix,
143
+ cache_context=cache_context,
144
+ context_manager=context_manager,
145
+ cache_type=cache_type,
146
+ **kwargs,
147
+ )
148
+ else:
149
+ raise ValueError(
150
+ f"Cache type {cache_type} is not supported now!"
151
+ )
152
+ else:
153
+ raise ValueError(f"Pattern {forward_pattern} is not supported now!")
154
+
155
+
156
+ class PrunedBlocks:
157
+ def __new__(
158
+ cls,
159
+ # 0. Transformer blocks configuration
160
+ transformer_blocks: torch.nn.ModuleList,
161
+ transformer: torch.nn.Module = None,
162
+ forward_pattern: ForwardPattern = None,
163
+ check_forward_pattern: bool = True,
164
+ check_num_outputs: bool = True,
165
+ # 1. Cache context configuration
166
+ # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
167
+ # 'layers', 'single_stream_blocks', 'double_stream_blocks'
168
+ cache_prefix: str = None, # cache_prefix maybe un-need.
169
+ # Usually, blocks_name, etc.
170
+ cache_context: CachedContext | PrunedContext | str = None,
171
+ context_manager: CachedContextManager | PrunedContextManager = None,
172
+ cache_type: CacheType = CacheType.DBCache,
173
+ **kwargs,
174
+ ):
175
+ assert transformer is not None, "transformer can't be None."
176
+ assert forward_pattern is not None, "forward_pattern can't be None."
177
+ assert cache_context is not None, "cache_context can't be None."
178
+ assert context_manager is not None, "context_manager can't be None."
179
+ if forward_pattern in PrunedBlocks_Pattern_0_1_2._supported_patterns:
180
+ if cache_type == CacheType.DBPrune:
181
+ assert isinstance(
182
+ context_manager, PrunedContextManager
183
+ ), "context_manager must be PrunedContextManager for DBPrune."
184
+ return PrunedBlocks_Pattern_0_1_2(
185
+ # 0. Transformer blocks configuration
186
+ transformer_blocks,
187
+ transformer=transformer,
188
+ forward_pattern=forward_pattern,
189
+ check_forward_pattern=check_forward_pattern,
190
+ check_num_outputs=check_num_outputs,
191
+ # 1. Cache context configuration
192
+ cache_prefix=cache_prefix,
193
+ cache_context=cache_context,
194
+ context_manager=context_manager,
195
+ cache_type=cache_type,
196
+ **kwargs,
197
+ )
198
+ else:
199
+ raise ValueError(
200
+ f"Cache type {cache_type} is not supported now!"
201
+ )
202
+ elif forward_pattern in PrunedBlocks_Pattern_3_4_5._supported_patterns:
203
+ if cache_type == CacheType.DBPrune:
204
+ assert isinstance(
205
+ context_manager, PrunedContextManager
206
+ ), "context_manager must be PrunedContextManager for DBPrune."
207
+ return PrunedBlocks_Pattern_3_4_5(
208
+ # 0. Transformer blocks configuration
209
+ transformer_blocks,
210
+ transformer=transformer,
211
+ forward_pattern=forward_pattern,
212
+ check_forward_pattern=check_forward_pattern,
213
+ check_num_outputs=check_num_outputs,
214
+ # 1. Cache context configuration
215
+ cache_prefix=cache_prefix,
216
+ cache_context=cache_context,
217
+ context_manager=context_manager,
218
+ cache_type=cache_type,
219
+ **kwargs,
220
+ )
221
+ else:
222
+ raise ValueError(
223
+ f"Cache type {cache_type} is not supported now!"
224
+ )
225
+ else:
226
+ raise ValueError(f"Pattern {forward_pattern} is not supported now!")
@@ -0,0 +1,26 @@
1
+ from cache_dit.caching import ForwardPattern
2
+ from cache_dit.caching.cache_blocks.pattern_base import (
3
+ CachedBlocks_Pattern_Base,
4
+ PrunedBlocks_Pattern_Base,
5
+ )
6
+ from cache_dit.logger import init_logger
7
+
8
+ logger = init_logger(__name__)
9
+
10
+
11
+ class CachedBlocks_Pattern_0_1_2(CachedBlocks_Pattern_Base):
12
+ _supported_patterns = [
13
+ ForwardPattern.Pattern_0,
14
+ ForwardPattern.Pattern_1,
15
+ ForwardPattern.Pattern_2,
16
+ ]
17
+ ...
18
+
19
+
20
+ class PrunedBlocks_Pattern_0_1_2(PrunedBlocks_Pattern_Base):
21
+ _supported_patterns = [
22
+ ForwardPattern.Pattern_0,
23
+ ForwardPattern.Pattern_1,
24
+ ForwardPattern.Pattern_2,
25
+ ]
26
+ ...