cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -1,458 +0,0 @@
1
- import inspect
2
- import torch
3
- import torch.distributed as dist
4
-
5
- from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
6
- from cache_dit.cache_factory.cache_contexts.cache_manager import (
7
- CachedContextManager,
8
- CacheNotExistError,
9
- )
10
- from cache_dit.cache_factory import ForwardPattern
11
- from cache_dit.logger import init_logger
12
-
13
- logger = init_logger(__name__)
14
-
15
-
16
- class CachedBlocks_Pattern_Base(torch.nn.Module):
17
- _supported_patterns = [
18
- ForwardPattern.Pattern_0,
19
- ForwardPattern.Pattern_1,
20
- ForwardPattern.Pattern_2,
21
- ]
22
-
23
- def __init__(
24
- self,
25
- # 0. Transformer blocks configuration
26
- transformer_blocks: torch.nn.ModuleList,
27
- transformer: torch.nn.Module = None,
28
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
29
- check_forward_pattern: bool = True,
30
- check_num_outputs: bool = True,
31
- # 1. Cache context configuration
32
- cache_prefix: str = None, # maybe un-need.
33
- cache_context: CachedContext | str = None,
34
- cache_manager: CachedContextManager = None,
35
- **kwargs,
36
- ):
37
- super().__init__()
38
-
39
- # 0. Transformer blocks configuration
40
- self.transformer = transformer
41
- self.transformer_blocks = transformer_blocks
42
- self.forward_pattern = forward_pattern
43
- self.check_forward_pattern = check_forward_pattern
44
- self.check_num_outputs = check_num_outputs
45
- # 1. Cache context configuration
46
- self.cache_prefix = cache_prefix
47
- self.cache_context = cache_context
48
- self.cache_manager = cache_manager
49
-
50
- self._check_forward_pattern()
51
- logger.info(
52
- f"Match Cached Blocks: {self.__class__.__name__}, for "
53
- f"{self.cache_prefix}, cache_context: {self.cache_context}, "
54
- f"cache_manager: {self.cache_manager.name}."
55
- )
56
-
57
- def _check_forward_pattern(self):
58
- if not self.check_forward_pattern:
59
- logger.warning(
60
- f"Skipped Forward Pattern Check: {self.forward_pattern}"
61
- )
62
- return
63
-
64
- assert (
65
- self.forward_pattern.Supported
66
- and self.forward_pattern in self._supported_patterns
67
- ), f"Pattern {self.forward_pattern} is not supported now!"
68
-
69
- if self.transformer_blocks is not None:
70
- for block in self.transformer_blocks:
71
- # Special case for HiDreamBlock
72
- if hasattr(block, "block"):
73
- if isinstance(block.block, torch.nn.Module):
74
- block = block.block
75
-
76
- forward_parameters = set(
77
- inspect.signature(block.forward).parameters.keys()
78
- )
79
-
80
- if self.check_num_outputs:
81
- num_outputs = str(
82
- inspect.signature(block.forward).return_annotation
83
- ).count("torch.Tensor")
84
-
85
- if num_outputs > 0:
86
- assert len(self.forward_pattern.Out) == num_outputs, (
87
- f"The number of block's outputs is {num_outputs} don't not "
88
- f"match the number of the pattern: {self.forward_pattern}, "
89
- f"Out: {len(self.forward_pattern.Out)}."
90
- )
91
-
92
- for required_param in self.forward_pattern.In:
93
- assert (
94
- required_param in forward_parameters
95
- ), f"The input parameters must contains: {required_param}."
96
-
97
- @torch.compiler.disable
98
- def _check_cache_params(self):
99
- assert self.cache_manager.Fn_compute_blocks() <= len(
100
- self.transformer_blocks
101
- ), (
102
- f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
103
- f"the number of transformer blocks {len(self.transformer_blocks)}"
104
- )
105
- assert self.cache_manager.Bn_compute_blocks() <= len(
106
- self.transformer_blocks
107
- ), (
108
- f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
109
- f"the number of transformer blocks {len(self.transformer_blocks)}"
110
- )
111
-
112
- def call_blocks(
113
- self,
114
- hidden_states: torch.Tensor,
115
- encoder_hidden_states: torch.Tensor,
116
- *args,
117
- **kwargs,
118
- ):
119
- # Call all blocks to process the hidden states without cache.
120
- for block in self.transformer_blocks:
121
- hidden_states = block(
122
- hidden_states,
123
- encoder_hidden_states,
124
- *args,
125
- **kwargs,
126
- )
127
- if not isinstance(hidden_states, torch.Tensor):
128
- hidden_states, encoder_hidden_states = hidden_states
129
- if not self.forward_pattern.Return_H_First:
130
- hidden_states, encoder_hidden_states = (
131
- encoder_hidden_states,
132
- hidden_states,
133
- )
134
-
135
- return hidden_states, encoder_hidden_states
136
-
137
- @torch.compiler.disable
138
- def _process_block_outputs(
139
- self,
140
- hidden_states: torch.Tensor | tuple,
141
- encoder_hidden_states: torch.Tensor | None,
142
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
143
- if not isinstance(hidden_states, torch.Tensor):
144
- hidden_states, encoder_hidden_states = hidden_states
145
- if not self.forward_pattern.Return_H_First:
146
- hidden_states, encoder_hidden_states = (
147
- encoder_hidden_states,
148
- hidden_states,
149
- )
150
- return hidden_states, encoder_hidden_states
151
-
152
- @torch.compiler.disable
153
- def _process_forward_outputs(
154
- self,
155
- hidden_states: torch.Tensor,
156
- encoder_hidden_states: torch.Tensor | None,
157
- ) -> tuple[torch.Tensor, torch.Tensor | None] | torch.Tensor:
158
- return (
159
- hidden_states
160
- if self.forward_pattern.Return_H_Only
161
- else (
162
- (hidden_states, encoder_hidden_states)
163
- if self.forward_pattern.Return_H_First
164
- else (encoder_hidden_states, hidden_states)
165
- )
166
- )
167
-
168
- def forward(
169
- self,
170
- hidden_states: torch.Tensor,
171
- encoder_hidden_states: torch.Tensor,
172
- *args,
173
- **kwargs,
174
- ):
175
- # Use it's own cache context.
176
- try:
177
- self.cache_manager.set_context(self.cache_context)
178
- self._check_cache_params()
179
- except CacheNotExistError as e:
180
- logger.warning(f"Cache context not exist: {e}, skip cache.")
181
- # Call all blocks to process the hidden states.
182
- hidden_states, encoder_hidden_states = self.call_blocks(
183
- hidden_states,
184
- encoder_hidden_states,
185
- *args,
186
- **kwargs,
187
- )
188
- return self._process_forward_outputs(
189
- hidden_states,
190
- encoder_hidden_states,
191
- )
192
-
193
- original_hidden_states = hidden_states
194
- # Call first `n` blocks to process the hidden states for
195
- # more stable diff calculation.
196
- hidden_states, encoder_hidden_states = self.call_Fn_blocks(
197
- hidden_states,
198
- encoder_hidden_states,
199
- *args,
200
- **kwargs,
201
- )
202
-
203
- Fn_hidden_states_residual = hidden_states - original_hidden_states
204
- del original_hidden_states
205
-
206
- self.cache_manager.mark_step_begin()
207
- # Residual L1 diff or Hidden States L1 diff
208
- can_use_cache = self.cache_manager.can_cache(
209
- (
210
- Fn_hidden_states_residual
211
- if not self.cache_manager.is_l1_diff_enabled()
212
- else hidden_states
213
- ),
214
- parallelized=self._is_parallelized(),
215
- prefix=(
216
- f"{self.cache_prefix}_Fn_residual"
217
- if not self.cache_manager.is_l1_diff_enabled()
218
- else f"{self.cache_prefix}_Fn_hidden_states"
219
- ),
220
- )
221
-
222
- torch._dynamo.graph_break()
223
- if can_use_cache:
224
- self.cache_manager.add_cached_step()
225
- del Fn_hidden_states_residual
226
- hidden_states, encoder_hidden_states = (
227
- self.cache_manager.apply_cache(
228
- hidden_states,
229
- encoder_hidden_states,
230
- prefix=(
231
- f"{self.cache_prefix}_Bn_residual"
232
- if self.cache_manager.is_cache_residual()
233
- else f"{self.cache_prefix}_Bn_hidden_states"
234
- ),
235
- encoder_prefix=(
236
- f"{self.cache_prefix}_Bn_residual"
237
- if self.cache_manager.is_encoder_cache_residual()
238
- else f"{self.cache_prefix}_Bn_hidden_states"
239
- ),
240
- )
241
- )
242
- torch._dynamo.graph_break()
243
- # Call last `n` blocks to further process the hidden states
244
- # for higher precision.
245
- hidden_states, encoder_hidden_states = self.call_Bn_blocks(
246
- hidden_states,
247
- encoder_hidden_states,
248
- *args,
249
- **kwargs,
250
- )
251
- else:
252
- self.cache_manager.set_Fn_buffer(
253
- Fn_hidden_states_residual,
254
- prefix=f"{self.cache_prefix}_Fn_residual",
255
- )
256
- if self.cache_manager.is_l1_diff_enabled():
257
- # for hidden states L1 diff
258
- self.cache_manager.set_Fn_buffer(
259
- hidden_states,
260
- f"{self.cache_prefix}_Fn_hidden_states",
261
- )
262
- del Fn_hidden_states_residual
263
- torch._dynamo.graph_break()
264
- (
265
- hidden_states,
266
- encoder_hidden_states,
267
- hidden_states_residual,
268
- encoder_hidden_states_residual,
269
- ) = self.call_Mn_blocks( # middle
270
- hidden_states,
271
- encoder_hidden_states,
272
- *args,
273
- **kwargs,
274
- )
275
- torch._dynamo.graph_break()
276
- if self.cache_manager.is_cache_residual():
277
- self.cache_manager.set_Bn_buffer(
278
- hidden_states_residual,
279
- prefix=f"{self.cache_prefix}_Bn_residual",
280
- )
281
- else:
282
- self.cache_manager.set_Bn_buffer(
283
- hidden_states,
284
- prefix=f"{self.cache_prefix}_Bn_hidden_states",
285
- )
286
-
287
- if self.cache_manager.is_encoder_cache_residual():
288
- self.cache_manager.set_Bn_encoder_buffer(
289
- encoder_hidden_states_residual,
290
- prefix=f"{self.cache_prefix}_Bn_residual",
291
- )
292
- else:
293
- self.cache_manager.set_Bn_encoder_buffer(
294
- encoder_hidden_states,
295
- prefix=f"{self.cache_prefix}_Bn_hidden_states",
296
- )
297
- torch._dynamo.graph_break()
298
- # Call last `n` blocks to further process the hidden states
299
- # for higher precision.
300
- hidden_states, encoder_hidden_states = self.call_Bn_blocks(
301
- hidden_states,
302
- encoder_hidden_states,
303
- *args,
304
- **kwargs,
305
- )
306
-
307
- # patch cached stats for blocks or remove it.
308
- torch._dynamo.graph_break()
309
-
310
- return self._process_forward_outputs(
311
- hidden_states,
312
- encoder_hidden_states,
313
- )
314
-
315
- @torch.compiler.disable
316
- def _is_parallelized(self):
317
- # Compatible with distributed inference.
318
- return any(
319
- (
320
- all(
321
- (
322
- self.transformer is not None,
323
- getattr(self.transformer, "_is_parallelized", False),
324
- )
325
- ),
326
- (dist.is_initialized() and dist.get_world_size() > 1),
327
- )
328
- )
329
-
330
- @torch.compiler.disable
331
- def _is_in_cache_step(self):
332
- # Check if the current step is in cache steps.
333
- # If so, we can skip some Bn blocks and directly
334
- # use the cached values.
335
- return (
336
- self.cache_manager.get_current_step()
337
- in self.cache_manager.get_cached_steps()
338
- ) or (
339
- self.cache_manager.get_current_step()
340
- in self.cache_manager.get_cfg_cached_steps()
341
- )
342
-
343
- @torch.compiler.disable
344
- def _Fn_blocks(self):
345
- # Select first `n` blocks to process the hidden states for
346
- # more stable diff calculation.
347
- # Fn: [0,...,n-1]
348
- selected_Fn_blocks = self.transformer_blocks[
349
- : self.cache_manager.Fn_compute_blocks()
350
- ]
351
- return selected_Fn_blocks
352
-
353
- @torch.compiler.disable
354
- def _Mn_blocks(self): # middle blocks
355
- # M(N-2n): only transformer_blocks [n,...,N-n], middle
356
- if self.cache_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
357
- selected_Mn_blocks = self.transformer_blocks[
358
- self.cache_manager.Fn_compute_blocks() :
359
- ]
360
- else:
361
- selected_Mn_blocks = self.transformer_blocks[
362
- self.cache_manager.Fn_compute_blocks() : -self.cache_manager.Bn_compute_blocks()
363
- ]
364
- return selected_Mn_blocks
365
-
366
- @torch.compiler.disable
367
- def _Bn_blocks(self):
368
- # Bn: transformer_blocks [N-n+1,...,N-1]
369
- selected_Bn_blocks = self.transformer_blocks[
370
- -self.cache_manager.Bn_compute_blocks() :
371
- ]
372
- return selected_Bn_blocks
373
-
374
- def call_Fn_blocks(
375
- self,
376
- hidden_states: torch.Tensor,
377
- encoder_hidden_states: torch.Tensor,
378
- *args,
379
- **kwargs,
380
- ):
381
- for block in self._Fn_blocks():
382
- hidden_states = block(
383
- hidden_states,
384
- encoder_hidden_states,
385
- *args,
386
- **kwargs,
387
- )
388
- hidden_states, encoder_hidden_states = self._process_block_outputs(
389
- hidden_states, encoder_hidden_states
390
- )
391
-
392
- return hidden_states, encoder_hidden_states
393
-
394
- def call_Mn_blocks(
395
- self,
396
- hidden_states: torch.Tensor,
397
- encoder_hidden_states: torch.Tensor,
398
- *args,
399
- **kwargs,
400
- ):
401
- original_hidden_states = hidden_states
402
- original_encoder_hidden_states = encoder_hidden_states
403
- for block in self._Mn_blocks():
404
- hidden_states = block(
405
- hidden_states,
406
- encoder_hidden_states,
407
- *args,
408
- **kwargs,
409
- )
410
- hidden_states, encoder_hidden_states = self._process_block_outputs(
411
- hidden_states, encoder_hidden_states
412
- )
413
-
414
- # compute hidden_states residual
415
- hidden_states = hidden_states.contiguous()
416
-
417
- hidden_states_residual = hidden_states - original_hidden_states
418
-
419
- if (
420
- encoder_hidden_states is not None
421
- and original_encoder_hidden_states is not None
422
- ):
423
- encoder_hidden_states = encoder_hidden_states.contiguous()
424
- encoder_hidden_states_residual = (
425
- encoder_hidden_states - original_encoder_hidden_states
426
- )
427
- else:
428
- encoder_hidden_states_residual = None
429
-
430
- return (
431
- hidden_states,
432
- encoder_hidden_states,
433
- hidden_states_residual,
434
- encoder_hidden_states_residual,
435
- )
436
-
437
- def call_Bn_blocks(
438
- self,
439
- hidden_states: torch.Tensor,
440
- encoder_hidden_states: torch.Tensor,
441
- *args,
442
- **kwargs,
443
- ):
444
- if self.cache_manager.Bn_compute_blocks() == 0:
445
- return hidden_states, encoder_hidden_states
446
-
447
- for block in self._Bn_blocks():
448
- hidden_states = block(
449
- hidden_states,
450
- encoder_hidden_states,
451
- *args,
452
- **kwargs,
453
- )
454
- hidden_states, encoder_hidden_states = self._process_block_outputs(
455
- hidden_states, encoder_hidden_states
456
- )
457
-
458
- return hidden_states, encoder_hidden_states
@@ -1,41 +0,0 @@
1
- import torch
2
-
3
- from typing import Any
4
- from cache_dit.cache_factory import CachedContext
5
- from cache_dit.cache_factory import CachedContextManager
6
-
7
-
8
- def patch_cached_stats(
9
- module: torch.nn.Module | Any,
10
- cache_context: CachedContext | str = None,
11
- cache_manager: CachedContextManager = None,
12
- ):
13
- # Patch the cached stats to the module, the cached stats
14
- # will be reset for each calling of pipe.__call__(**kwargs).
15
- if module is None or cache_manager is None:
16
- return
17
-
18
- if cache_context is not None:
19
- cache_manager.set_context(cache_context)
20
-
21
- # TODO: Patch more cached stats to the module
22
- module._cached_steps = cache_manager.get_cached_steps()
23
- module._residual_diffs = cache_manager.get_residual_diffs()
24
- module._cfg_cached_steps = cache_manager.get_cfg_cached_steps()
25
- module._cfg_residual_diffs = cache_manager.get_cfg_residual_diffs()
26
-
27
-
28
- def remove_cached_stats(
29
- module: torch.nn.Module | Any,
30
- ):
31
- if module is None:
32
- return
33
-
34
- if hasattr(module, "_cached_steps"):
35
- del module._cached_steps
36
- if hasattr(module, "_residual_diffs"):
37
- del module._residual_diffs
38
- if hasattr(module, "_cfg_cached_steps"):
39
- del module._cfg_cached_steps
40
- if hasattr(module, "_cfg_residual_diffs"):
41
- del module._cfg_residual_diffs
@@ -1,15 +0,0 @@
1
- from cache_dit.cache_factory.cache_contexts.calibrators import (
2
- Calibrator,
3
- CalibratorBase,
4
- CalibratorConfig,
5
- TaylorSeerCalibratorConfig,
6
- FoCaCalibratorConfig,
7
- )
8
- from cache_dit.cache_factory.cache_contexts.cache_context import (
9
- CachedContext,
10
- BasicCacheConfig,
11
- )
12
- from cache_dit.cache_factory.cache_contexts.cache_manager import (
13
- CachedContextManager,
14
- CacheNotExistError,
15
- )
@@ -1,15 +0,0 @@
1
- from cache_dit.cache_factory.patch_functors.functor_base import PatchFunctor
2
- from cache_dit.cache_factory.patch_functors.functor_dit import DiTPatchFunctor
3
- from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
4
- from cache_dit.cache_factory.patch_functors.functor_chroma import (
5
- ChromaPatchFunctor,
6
- )
7
- from cache_dit.cache_factory.patch_functors.functor_hidream import (
8
- HiDreamPatchFunctor,
9
- )
10
- from cache_dit.cache_factory.patch_functors.functor_hunyuan_dit import (
11
- HunyuanDiTPatchFunctor,
12
- )
13
- from cache_dit.cache_factory.patch_functors.functor_qwen_image_controlnet import (
14
- QwenImageControlNetPatchFunctor,
15
- )
@@ -1,58 +0,0 @@
1
- cache_dit/__init__.py,sha256=sHRg0swXZZiw6lvSQ53fcVtN9JRayx0az2lXAz5OOGI,1510
2
- cache_dit/_version.py,sha256=l8k828IdTfzXAlmx4oT8GsiIf2eeMAlFDALjoYk-jrU,704
3
- cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
- cache_dit/utils.py,sha256=AyYRwi5XBxYBH4GaXxOxv9-X24Te_IYOYwh54t_1d3A,10674
5
- cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
6
- cache_dit/cache_factory/__init__.py,sha256=vy9I6Ofkj9jWeUoOvh-cY5a9QlDDKfj2FVPlVTf7BeA,1390
7
- cache_dit/cache_factory/cache_interface.py,sha256=fJgsOSR_lP0cvNDrR0zMLLoZBZC6tLAQaPQs_oo2R1o,12577
8
- cache_dit/cache_factory/cache_types.py,sha256=ooukxQRG55uTLmaZ0SKw6gIeY6SQHhMxkbv55uj2Sqk,991
9
- cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
10
- cache_dit/cache_factory/params_modifier.py,sha256=zYJJsInTYCaYHBZ7mZJOP-PZnkSg3iN1WPewNOayXos,3628
11
- cache_dit/cache_factory/utils.py,sha256=mm8JNu6XG_w6nMYvv53TmugSb-l3W7l3Y4rJ2xBgktY,1891
12
- cache_dit/cache_factory/block_adapters/__init__.py,sha256=vM3aDMzPY79Tw4L0hlV2PdA3MFYomnf0eo0BGBo9P78,18087
13
- cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=2TVK_KqiYXC7AKZ2s07fzdOzUoeUBc9P1SzQtLVzhf4,22249
14
- cache_dit/cache_factory/block_adapters/block_registers.py,sha256=2L7QeM4ygnaKQpC9PoJod0QRYyxidUKU2AYpysDCUwE,2572
15
- cache_dit/cache_factory/cache_adapters/__init__.py,sha256=py71WGD3JztQ1uk6qdLVbzYcQ1rvqFidNNaQYo7tqTo,79
16
- cache_dit/cache_factory/cache_adapters/cache_adapter.py,sha256=HTyZdspd34G6QiJ2qPNoLmGwcxmAnCwpAf91NTIQtl4,21442
17
- cache_dit/cache_factory/cache_blocks/__init__.py,sha256=mivvm8YOfqT7YHs8y_MzGOGztPw8LxAqKGXuSRXxCv0,3032
18
- cache_dit/cache_factory/cache_blocks/offload_utils.py,sha256=wusgcqaCrwEjvv7Guy-6VXhNOgPPUrBV2sSVuRmGuvo,3513
19
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=ElMps6_7uI74tSF9GDR_dEI0bZEhdzcepM29xFWnYo8,428
20
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=mzs1S2YFwNAPMMTisTKbU6GA5m60J_20CAVy9OIWoMQ,10652
21
- cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=UeBYEz3hamO3CyVMj1KI7GnxRVQGBjQ5EJi90obVZyI,16306
22
- cache_dit/cache_factory/cache_blocks/pattern_utils.py,sha256=dGOC1tMMOvcbvEgx44eTESKn_jsv-0RZ3tRHPa3wmQ4,1315
23
- cache_dit/cache_factory/cache_contexts/__init__.py,sha256=N3SxFnluXk5q09nhSqKIJCVzEGWzySJWm-vic6dH79E,412
24
- cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=FXvrR3XZr4iIsKSTBngzaRM6_WxiHkRNQ3wAJz40kbk,15798
25
- cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=X99XnmiY-Us8D2pqJGPKxWcXAhQQpk3xdEWOOOYXIZ4,30465
26
- cache_dit/cache_factory/cache_contexts/calibrators/__init__.py,sha256=mzYXO8tbytGpJJ9rpPu20kMoj1Iu_7Ym9tjfzV8rA98,5574
27
- cache_dit/cache_factory/cache_contexts/calibrators/base.py,sha256=mn6ZBkChGpGwN5csrHTUGMoX6BBPvqHXSLbIExiW-EU,748
28
- cache_dit/cache_factory/cache_contexts/calibrators/foca.py,sha256=nhHGs_hxwW1M942BQDMJb9-9IuHdnOxp774Jrna1bJI,891
29
- cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py,sha256=aGxr9SpytYznTepDWGPAxWDnuVMSuNyn6uNXnLh2acQ,4001
30
- cache_dit/cache_factory/patch_functors/__init__.py,sha256=IJZrvSkeHbR_xW-6IzY7sqEhApBsOfPyorQGJutvWH0,652
31
- cache_dit/cache_factory/patch_functors/functor_base.py,sha256=Ahk0fTfrHgNdEl-9JSkACvfyyv9G-Ei5OSz7XBIlX5o,357
32
- cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=xD0Q96VArp1vYBLQ0pcjRIyFB1i_Y7muZ2q07Hz8Oqs,13430
33
- cache_dit/cache_factory/patch_functors/functor_dit.py,sha256=SDjhzCWa6PoFNN4_upoQEf6DHvW1yJ7zuXMS2VvyJco,3904
34
- cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=UMkyuEYjO7UO_zmXi9Djd-nD-XMgCUgE-qkYA3plWSM,9559
35
- cache_dit/cache_factory/patch_functors/functor_hidream.py,sha256=inf4T5UcIa06zVsoLWCNJbb1bEDmGeBGSyC7OL1zpuc,15309
36
- cache_dit/cache_factory/patch_functors/functor_hunyuan_dit.py,sha256=iSo5dD5uKnjQQeysDUIkKt0wdnK5bzXTc_F_lfHG70w,6401
37
- cache_dit/cache_factory/patch_functors/functor_qwen_image_controlnet.py,sha256=D5i1Rrq1FQ49liupLcV2DW04moBqLnW9TICzfnMMzIU,10519
38
- cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
39
- cache_dit/compile/utils.py,sha256=nN2OIrSdwRR5zGxJinKDqb07pXpvTNTF3g_OgLkeeBU,3858
40
- cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
41
- cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
- cache_dit/metrics/__init__.py,sha256=UjPJ69DyyjZDfERTpKAjZKOxOTx58aWnkze7VfH3en8,673
43
- cache_dit/metrics/clip_score.py,sha256=ERNCFQFJKzJdbIX9OAg-1LiSPuXUVHLOFxbf2gcENpc,3938
44
- cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
45
- cache_dit/metrics/fid.py,sha256=ZM_FM0XERtpnkMUfphmw2aOdljrh1uba-pnYItu0q6M,18219
46
- cache_dit/metrics/image_reward.py,sha256=N8HalJo1T1js0dsNb2V1KRv4kIdcm3nhx7iOXJuqcns,5421
47
- cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
48
- cache_dit/metrics/lpips.py,sha256=hrHrmdM-f2B4TKDs0xLqJO5JFaYcCjq2qNIR8oCrVkc,811
49
- cache_dit/metrics/metrics.py,sha256=AZbQyoavE-djvyRUZ_EfCIrWSQbiWQFo7n2dhn7XptE,40466
50
- cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70A,59
51
- cache_dit/quantize/quantize_ao.py,sha256=Pr3u3Qr6qLvFkd8k-_rfcz4Mkjlg36U9BHG2t6Bl-6M,6301
52
- cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
53
- cache_dit-1.0.3.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
54
- cache_dit-1.0.3.dist-info/METADATA,sha256=gPY4pnvl4dvTTu7Twv6unzEesu1fXCDlGNMlSdFP3Lc,28103
55
- cache_dit-1.0.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
- cache_dit-1.0.3.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
57
- cache_dit-1.0.3.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
58
- cache_dit-1.0.3.dist-info/RECORD,,
@@ -1,53 +0,0 @@
1
- # License
2
-
3
- ## Acceptance
4
-
5
- By using the software, you agree to all of the terms and conditions below.
6
-
7
- ## Copyright License
8
-
9
- The licensor grants you a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable license to use, copy, distribute, make available, and prepare derivative works of the software, in each case subject to the limitations and conditions below.
10
-
11
- ## Limitations
12
-
13
- You may not provide the software to third parties as a hosted or managed service, where the service provides users with access to any substantial set of the features or functionality of the software.
14
-
15
- You may not move, change, disable, or circumvent the license key functionality in the software, and you may not remove or obscure any functionality in the software that is protected by the license key.
16
-
17
- You may not alter, remove, or obscure any licensing, copyright, or other notices of the licensor in the software. Any use of the licensor’s trademarks is subject to applicable law.
18
-
19
- ## Patents
20
-
21
- The licensor grants you a license, under any patent claims the licensor can license, or becomes able to license, to make, have made, use, sell, offer for sale, import and have imported the software, in each case subject to the limitations and conditions in this license. This license does not cover any patent claims that you cause to be infringed by modifications or additions to the software. If you or your company make any written claim that the software infringes or contributes to infringement of any patent, your patent license for the software granted under these terms ends immediately. If your company makes such a claim, your patent license ends immediately for work on behalf of your company.
22
-
23
- ## Notices
24
-
25
- You must ensure that anyone who gets a copy of any part of the software from you also gets a copy of these terms.
26
-
27
- If you modify the software, you must include in any modified copies of the software prominent notices stating that you have modified the software.
28
- No Other Rights
29
-
30
- These terms do not imply any licenses other than those expressly granted in these terms.
31
-
32
- ## Termination
33
-
34
- If you use the software in violation of these terms, such use is not licensed, and your licenses will automatically terminate. If the licensor provides you with a notice of your violation, and you cease all violation of this license no later than 30 days after you receive that notice, your licenses will be reinstated retroactively. However, if you violate these terms after such reinstatement, any additional violation of these terms will cause your licenses to terminate automatically and permanently.
35
-
36
- ## No Liability
37
-
38
- As far as the law allows, the software comes as is, without any warranty or condition, and the licensor will not be liable to you for any damages arising out of these terms or the use or nature of the software, under any kind of legal claim.
39
- Definitions
40
-
41
- The licensor is the entity offering these terms, and the software is the software the licensor makes available under these terms, including any portion of it.
42
-
43
- ## Definitions
44
-
45
- you refers to the individual or entity agreeing to these terms.
46
-
47
- your company is any legal entity, sole proprietorship, or other kind of organization that you work for, plus all organizations that have control over, are under the control of, or are under common control with that organization. control means ownership of substantially all the assets of an entity, or the power to direct its management and policies by vote, contract, or otherwise. Control can be direct or indirect.
48
-
49
- your licenses are all the licenses granted to you for the software under these terms.
50
-
51
- use means anything you do with the software requiring one of your licenses.
52
-
53
- trademark means trademarks, service marks, and similar rights.
File without changes
File without changes
File without changes