cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/offload_utils.py +115 -0
  11. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  12. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  13. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  14. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  15. cache_dit/caching/cache_contexts/__init__.py +28 -0
  16. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
  18. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  21. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  22. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  23. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  24. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  25. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  26. cache_dit/caching/cache_interface.py +358 -0
  27. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  28. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  29. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  30. cache_dit/caching/patch_functors/__init__.py +15 -0
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  36. cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
  37. cache_dit/caching/utils.py +68 -0
  38. cache_dit/metrics/__init__.py +11 -0
  39. cache_dit/metrics/metrics.py +3 -0
  40. cache_dit/parallelism/__init__.py +3 -0
  41. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  57. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  58. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  59. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  60. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  61. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  62. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  68. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  69. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  70. cache_dit/parallelism/parallel_backend.py +26 -0
  71. cache_dit/parallelism/parallel_config.py +88 -0
  72. cache_dit/parallelism/parallel_interface.py +77 -0
  73. cache_dit/quantize/__init__.py +7 -0
  74. cache_dit/quantize/backends/__init__.py +1 -0
  75. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  76. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  77. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
  78. cache_dit/quantize/quantize_backend.py +0 -0
  79. cache_dit/quantize/quantize_config.py +0 -0
  80. cache_dit/quantize/quantize_interface.py +3 -16
  81. cache_dit/summary.py +593 -0
  82. cache_dit/utils.py +46 -290
  83. cache_dit-1.0.14.dist-info/METADATA +301 -0
  84. cache_dit-1.0.14.dist-info/RECORD +102 -0
  85. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  86. cache_dit/cache_factory/__init__.py +0 -28
  87. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  88. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  89. cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
  90. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  91. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
  92. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
  93. cache_dit/cache_factory/cache_blocks/utils.py +0 -41
  94. cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
  95. cache_dit/cache_factory/cache_interface.py +0 -217
  96. cache_dit/cache_factory/patch_functors/__init__.py +0 -12
  97. cache_dit/cache_factory/utils.py +0 -57
  98. cache_dit-0.3.2.dist-info/METADATA +0 -753
  99. cache_dit-0.3.2.dist-info/RECORD +0 -56
  100. cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
  101. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  102. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  103. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  104. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  105. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  106. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  107. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  108. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,226 @@
1
+ import torch
2
+
3
+ from cache_dit.caching import ForwardPattern
4
+ from cache_dit.caching.cache_types import CacheType
5
+ from cache_dit.caching.cache_contexts.cache_context import CachedContext
6
+ from cache_dit.caching.cache_contexts.prune_context import PrunedContext
7
+ from cache_dit.caching.cache_contexts.cache_manager import (
8
+ CachedContextManager,
9
+ )
10
+ from cache_dit.caching.cache_contexts.prune_manager import (
11
+ PrunedContextManager,
12
+ )
13
+
14
+ from cache_dit.caching.cache_blocks.pattern_0_1_2 import (
15
+ CachedBlocks_Pattern_0_1_2,
16
+ PrunedBlocks_Pattern_0_1_2,
17
+ )
18
+ from cache_dit.caching.cache_blocks.pattern_3_4_5 import (
19
+ CachedBlocks_Pattern_3_4_5,
20
+ PrunedBlocks_Pattern_3_4_5,
21
+ )
22
+ from cache_dit.caching.cache_blocks.pattern_utils import apply_stats
23
+ from cache_dit.caching.cache_blocks.pattern_utils import remove_stats
24
+
25
+ from cache_dit.logger import init_logger
26
+
27
+ logger = init_logger(__name__)
28
+
29
+
30
+ class UnifiedBlocks:
31
+ def __new__(
32
+ cls,
33
+ # 0. Transformer blocks configuration
34
+ transformer_blocks: torch.nn.ModuleList,
35
+ transformer: torch.nn.Module = None,
36
+ forward_pattern: ForwardPattern = None,
37
+ check_forward_pattern: bool = True,
38
+ check_num_outputs: bool = True,
39
+ # 1. Cache context configuration
40
+ # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
41
+ # 'layers', 'single_stream_blocks', 'double_stream_blocks'
42
+ cache_prefix: str = None, # cache_prefix maybe un-need.
43
+ # Usually, blocks_name, etc.
44
+ cache_context: CachedContext | PrunedContext | str = None,
45
+ context_manager: CachedContextManager | PrunedContextManager = None,
46
+ cache_type: CacheType = CacheType.DBCache,
47
+ **kwargs,
48
+ ):
49
+ if cache_type == CacheType.DBCache:
50
+ return CachedBlocks(
51
+ # 0. Transformer blocks configuration
52
+ transformer_blocks,
53
+ transformer=transformer,
54
+ forward_pattern=forward_pattern,
55
+ check_forward_pattern=check_forward_pattern,
56
+ check_num_outputs=check_num_outputs,
57
+ # 1. Cache context configuration
58
+ cache_prefix=cache_prefix,
59
+ cache_context=cache_context,
60
+ context_manager=context_manager,
61
+ cache_type=cache_type,
62
+ **kwargs,
63
+ )
64
+ elif cache_type == CacheType.DBPrune:
65
+ return PrunedBlocks(
66
+ # 0. Transformer blocks configuration
67
+ transformer_blocks,
68
+ transformer=transformer,
69
+ forward_pattern=forward_pattern,
70
+ check_forward_pattern=check_forward_pattern,
71
+ check_num_outputs=check_num_outputs,
72
+ # 1. Cache context configuration
73
+ cache_prefix=cache_prefix,
74
+ cache_context=cache_context,
75
+ context_manager=context_manager,
76
+ cache_type=cache_type,
77
+ **kwargs,
78
+ )
79
+ else:
80
+ raise ValueError(f"Cache type {cache_type} is not supported now!")
81
+
82
+
83
+ class CachedBlocks:
84
+ def __new__(
85
+ cls,
86
+ # 0. Transformer blocks configuration
87
+ transformer_blocks: torch.nn.ModuleList,
88
+ transformer: torch.nn.Module = None,
89
+ forward_pattern: ForwardPattern = None,
90
+ check_forward_pattern: bool = True,
91
+ check_num_outputs: bool = True,
92
+ # 1. Cache context configuration
93
+ # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
94
+ # 'layers', 'single_stream_blocks', 'double_stream_blocks'
95
+ cache_prefix: str = None, # cache_prefix maybe un-need.
96
+ # Usually, blocks_name, etc.
97
+ cache_context: CachedContext | PrunedContext | str = None,
98
+ context_manager: CachedContextManager | PrunedContextManager = None,
99
+ cache_type: CacheType = CacheType.DBCache,
100
+ **kwargs,
101
+ ):
102
+ assert transformer is not None, "transformer can't be None."
103
+ assert forward_pattern is not None, "forward_pattern can't be None."
104
+ assert cache_context is not None, "cache_context can't be None."
105
+ assert context_manager is not None, "context_manager can't be None."
106
+ if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
107
+ if cache_type == CacheType.DBCache:
108
+ assert isinstance(
109
+ context_manager, CachedContextManager
110
+ ), "context_manager must be CachedContextManager for DBCache."
111
+ return CachedBlocks_Pattern_0_1_2(
112
+ # 0. Transformer blocks configuration
113
+ transformer_blocks,
114
+ transformer=transformer,
115
+ forward_pattern=forward_pattern,
116
+ check_forward_pattern=check_forward_pattern,
117
+ check_num_outputs=check_num_outputs,
118
+ # 1. Cache context configuration
119
+ cache_prefix=cache_prefix,
120
+ cache_context=cache_context,
121
+ context_manager=context_manager,
122
+ cache_type=cache_type,
123
+ **kwargs,
124
+ )
125
+ else:
126
+ raise ValueError(
127
+ f"Cache type {cache_type} is not supported now!"
128
+ )
129
+ elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
130
+ if cache_type == CacheType.DBCache:
131
+ assert isinstance(
132
+ context_manager, CachedContextManager
133
+ ), "context_manager must be CachedContextManager for DBCache."
134
+ return CachedBlocks_Pattern_3_4_5(
135
+ # 0. Transformer blocks configuration
136
+ transformer_blocks,
137
+ transformer=transformer,
138
+ forward_pattern=forward_pattern,
139
+ check_forward_pattern=check_forward_pattern,
140
+ check_num_outputs=check_num_outputs,
141
+ # 1. Cache context configuration
142
+ cache_prefix=cache_prefix,
143
+ cache_context=cache_context,
144
+ context_manager=context_manager,
145
+ cache_type=cache_type,
146
+ **kwargs,
147
+ )
148
+ else:
149
+ raise ValueError(
150
+ f"Cache type {cache_type} is not supported now!"
151
+ )
152
+ else:
153
+ raise ValueError(f"Pattern {forward_pattern} is not supported now!")
154
+
155
+
156
+ class PrunedBlocks:
157
+ def __new__(
158
+ cls,
159
+ # 0. Transformer blocks configuration
160
+ transformer_blocks: torch.nn.ModuleList,
161
+ transformer: torch.nn.Module = None,
162
+ forward_pattern: ForwardPattern = None,
163
+ check_forward_pattern: bool = True,
164
+ check_num_outputs: bool = True,
165
+ # 1. Cache context configuration
166
+ # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
167
+ # 'layers', 'single_stream_blocks', 'double_stream_blocks'
168
+ cache_prefix: str = None, # cache_prefix maybe un-need.
169
+ # Usually, blocks_name, etc.
170
+ cache_context: CachedContext | PrunedContext | str = None,
171
+ context_manager: CachedContextManager | PrunedContextManager = None,
172
+ cache_type: CacheType = CacheType.DBCache,
173
+ **kwargs,
174
+ ):
175
+ assert transformer is not None, "transformer can't be None."
176
+ assert forward_pattern is not None, "forward_pattern can't be None."
177
+ assert cache_context is not None, "cache_context can't be None."
178
+ assert context_manager is not None, "context_manager can't be None."
179
+ if forward_pattern in PrunedBlocks_Pattern_0_1_2._supported_patterns:
180
+ if cache_type == CacheType.DBPrune:
181
+ assert isinstance(
182
+ context_manager, PrunedContextManager
183
+ ), "context_manager must be PrunedContextManager for DBPrune."
184
+ return PrunedBlocks_Pattern_0_1_2(
185
+ # 0. Transformer blocks configuration
186
+ transformer_blocks,
187
+ transformer=transformer,
188
+ forward_pattern=forward_pattern,
189
+ check_forward_pattern=check_forward_pattern,
190
+ check_num_outputs=check_num_outputs,
191
+ # 1. Cache context configuration
192
+ cache_prefix=cache_prefix,
193
+ cache_context=cache_context,
194
+ context_manager=context_manager,
195
+ cache_type=cache_type,
196
+ **kwargs,
197
+ )
198
+ else:
199
+ raise ValueError(
200
+ f"Cache type {cache_type} is not supported now!"
201
+ )
202
+ elif forward_pattern in PrunedBlocks_Pattern_3_4_5._supported_patterns:
203
+ if cache_type == CacheType.DBPrune:
204
+ assert isinstance(
205
+ context_manager, PrunedContextManager
206
+ ), "context_manager must be PrunedContextManager for DBPrune."
207
+ return PrunedBlocks_Pattern_3_4_5(
208
+ # 0. Transformer blocks configuration
209
+ transformer_blocks,
210
+ transformer=transformer,
211
+ forward_pattern=forward_pattern,
212
+ check_forward_pattern=check_forward_pattern,
213
+ check_num_outputs=check_num_outputs,
214
+ # 1. Cache context configuration
215
+ cache_prefix=cache_prefix,
216
+ cache_context=cache_context,
217
+ context_manager=context_manager,
218
+ cache_type=cache_type,
219
+ **kwargs,
220
+ )
221
+ else:
222
+ raise ValueError(
223
+ f"Cache type {cache_type} is not supported now!"
224
+ )
225
+ else:
226
+ raise ValueError(f"Pattern {forward_pattern} is not supported now!")
@@ -0,0 +1,115 @@
1
+ import torch
2
+ import asyncio
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from typing import Generator, Optional, List
6
+ from diffusers.hooks.group_offloading import _is_group_offload_enabled
7
+ from cache_dit.logger import init_logger
8
+
9
+ logger = init_logger(__name__)
10
+
11
+
12
+ @torch.compiler.disable
13
+ @contextmanager
14
+ def maybe_onload(
15
+ block: torch.nn.Module,
16
+ reference_tensor: torch.Tensor,
17
+ pending_tasks: List[asyncio.Task] = [],
18
+ ) -> Generator:
19
+
20
+ if not _is_group_offload_enabled(block):
21
+ yield block
22
+ return
23
+
24
+ original_devices: Optional[List[torch.device]] = None
25
+ if hasattr(block, "parameters"):
26
+ params = list(block.parameters())
27
+ if params:
28
+ original_devices = [param.data.device for param in params]
29
+
30
+ target_device: torch.device = reference_tensor.device
31
+ move_task: Optional[asyncio.Task] = None
32
+ need_restore: bool = False
33
+
34
+ try:
35
+ if original_devices is not None:
36
+ unique_devices = list(set(original_devices))
37
+ if len(unique_devices) > 1 or unique_devices[0] != target_device:
38
+ if logger.isEnabledFor(logging.DEBUG):
39
+ logger.debug(
40
+ f"Onloading from {unique_devices} to {target_device}"
41
+ )
42
+
43
+ has_meta_params = any(
44
+ dev.type == "meta" for dev in original_devices
45
+ )
46
+ if has_meta_params: # compatible with sequential cpu offload
47
+ block = block.to_empty(device=target_device)
48
+ else:
49
+ block = block.to(target_device, non_blocking=False)
50
+ need_restore = True
51
+ yield block
52
+ finally:
53
+ if need_restore and original_devices:
54
+
55
+ async def restore_device():
56
+ for param, original_device in zip(
57
+ block.parameters(), original_devices
58
+ ):
59
+ param.data = await asyncio.to_thread(
60
+ lambda p, d: p.to(d, non_blocking=True),
61
+ param.data, # type: torch.Tensor
62
+ original_device, # type: torch.device
63
+ ) # type: ignore[assignment]
64
+
65
+ loop = get_event_loop()
66
+ move_task = loop.create_task(restore_device())
67
+ if move_task:
68
+ pending_tasks.append(move_task)
69
+
70
+
71
+ def get_event_loop() -> asyncio.AbstractEventLoop:
72
+ try:
73
+ loop = asyncio.get_running_loop()
74
+ except RuntimeError:
75
+ try:
76
+ loop = asyncio.get_event_loop()
77
+ except RuntimeError:
78
+ loop = asyncio.new_event_loop()
79
+ asyncio.set_event_loop(loop)
80
+
81
+ if not loop.is_running():
82
+
83
+ def run_loop() -> None:
84
+ asyncio.set_event_loop(loop)
85
+ loop.run_forever()
86
+
87
+ import threading
88
+
89
+ if not any(t.name == "_my_loop" for t in threading.enumerate()):
90
+ threading.Thread(
91
+ target=run_loop, name="_my_loop", daemon=True
92
+ ).start()
93
+
94
+ return loop
95
+
96
+
97
+ @torch.compiler.disable
98
+ def maybe_offload(
99
+ pending_tasks: List[asyncio.Task],
100
+ ) -> None:
101
+ if not pending_tasks:
102
+ return
103
+
104
+ loop = get_event_loop()
105
+
106
+ async def gather_tasks():
107
+ return await asyncio.gather(*pending_tasks)
108
+
109
+ future = asyncio.run_coroutine_threadsafe(gather_tasks(), loop)
110
+ try:
111
+ future.result(timeout=30.0)
112
+ except Exception as e:
113
+ logger.error(f"May Offload Error: {e}")
114
+
115
+ pending_tasks.clear()
@@ -0,0 +1,26 @@
1
+ from cache_dit.caching import ForwardPattern
2
+ from cache_dit.caching.cache_blocks.pattern_base import (
3
+ CachedBlocks_Pattern_Base,
4
+ PrunedBlocks_Pattern_Base,
5
+ )
6
+ from cache_dit.logger import init_logger
7
+
8
+ logger = init_logger(__name__)
9
+
10
+
11
+ class CachedBlocks_Pattern_0_1_2(CachedBlocks_Pattern_Base):
12
+ _supported_patterns = [
13
+ ForwardPattern.Pattern_0,
14
+ ForwardPattern.Pattern_1,
15
+ ForwardPattern.Pattern_2,
16
+ ]
17
+ ...
18
+
19
+
20
+ class PrunedBlocks_Pattern_0_1_2(PrunedBlocks_Pattern_Base):
21
+ _supported_patterns = [
22
+ ForwardPattern.Pattern_0,
23
+ ForwardPattern.Pattern_1,
24
+ ForwardPattern.Pattern_2,
25
+ ]
26
+ ...