cache-dit 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. cache_dit/__init__.py +1 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +3 -6
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +8 -64
  5. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  6. cache_dit/cache_factory/cache_adapters/cache_adapter.py +47 -14
  7. cache_dit/cache_factory/cache_contexts/__init__.py +10 -8
  8. cache_dit/cache_factory/cache_contexts/cache_context.py +186 -117
  9. cache_dit/cache_factory/cache_contexts/cache_manager.py +63 -131
  10. cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +132 -0
  11. cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/foca.py +1 -1
  12. cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/taylorseer.py +7 -2
  13. cache_dit/cache_factory/cache_interface.py +128 -111
  14. cache_dit/cache_factory/params_modifier.py +87 -0
  15. cache_dit/metrics/__init__.py +3 -1
  16. cache_dit/utils.py +12 -21
  17. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/METADATA +78 -64
  18. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/RECORD +23 -28
  19. cache_dit/cache_factory/cache_adapters/v2/__init__.py +0 -3
  20. cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +0 -524
  21. cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -102
  22. cache_dit/cache_factory/cache_contexts/v2/__init__.py +0 -13
  23. cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +0 -288
  24. cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +0 -799
  25. cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +0 -81
  26. /cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/base.py +0 -0
  27. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/WHEEL +0 -0
  28. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/entry_points.txt +0 -0
  29. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/licenses/LICENSE +0 -0
  30. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/top_level.txt +0 -0
@@ -1,288 +0,0 @@
1
- import logging
2
- import dataclasses
3
- from collections import defaultdict
4
- from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
5
-
6
- import torch
7
-
8
- from cache_dit.cache_factory.cache_contexts.v2.calibrators import (
9
- Calibrator,
10
- CalibratorBase,
11
- CalibratorConfig,
12
- )
13
- from cache_dit.logger import init_logger
14
-
15
- logger = init_logger(__name__)
16
-
17
-
18
- @dataclasses.dataclass
19
- class CachedContextV2: # Internal CachedContext Impl class
20
- name: str = "default"
21
- # Dual Block Cache with flexible FnBn configuration.
22
- Fn_compute_blocks: int = 1
23
- Bn_compute_blocks: int = 0
24
- # non compute blocks diff threshold, we don't skip the non
25
- # compute blocks if the diff >= threshold
26
- non_compute_blocks_diff_threshold: float = 0.08
27
- max_Fn_compute_blocks: int = -1
28
- max_Bn_compute_blocks: int = -1
29
- # L1 hidden states or residual diff threshold for Fn
30
- residual_diff_threshold: Union[torch.Tensor, float] = 0.05
31
- l1_hidden_states_diff_threshold: float = None
32
- important_condition_threshold: float = 0.0
33
-
34
- # Buffer for storing the residuals and other tensors
35
- buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
36
- incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
37
- default_factory=lambda: defaultdict(int),
38
- )
39
-
40
- # Other settings
41
- downsample_factor: int = 1
42
- num_inference_steps: int = -1 # for future use
43
- max_warmup_steps: int = 0 # DON'T Cache in warmup steps
44
- # DON'T Cache if the number of cached steps >= max_cached_steps
45
- max_cached_steps: int = -1 # for both CFG and non-CFG
46
- max_continuous_cached_steps: int = -1 # the max continuous cached steps
47
-
48
- # Record the steps that have been cached, both cached and non-cache
49
- executed_steps: int = 0 # cache + non-cache steps pippeline
50
- # steps for transformer, for CFG, transformer_executed_steps will
51
- # be double of executed_steps.
52
- transformer_executed_steps: int = 0
53
-
54
- # Support calibrators in Dual Block Cache: TaylorSeer, FoCa, etc.
55
- calibrator_config: Optional[CalibratorConfig] = None
56
- calibrator: Optional[CalibratorBase] = None
57
- encoder_calibrator: Optional[CalibratorBase] = None
58
-
59
- # Support enable_separate_cfg, such as Wan 2.1,
60
- # Qwen-Image. For model that fused CFG and non-CFG into single
61
- # forward step, should set enable_separate_cfg as False.
62
- # For example: CogVideoX, HunyuanVideo, Mochi.
63
- enable_separate_cfg: bool = False
64
- # Compute cfg forward first or not, default False, namely,
65
- # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
66
- cfg_compute_first: bool = False
67
- # Compute separate diff values for CFG and non-CFG step,
68
- # default True. If False, we will use the computed diff from
69
- # current non-CFG transformer step for current CFG step.
70
- cfg_diff_compute_separate: bool = True
71
- cfg_calibrator: Optional[CalibratorBase] = None
72
- cfg_encoder_calibrator: Optional[CalibratorBase] = None
73
-
74
- # CFG & non-CFG cached steps
75
- cached_steps: List[int] = dataclasses.field(default_factory=list)
76
- residual_diffs: DefaultDict[str, float] = dataclasses.field(
77
- default_factory=lambda: defaultdict(float),
78
- )
79
- continuous_cached_steps: int = 0
80
- cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
81
- cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
82
- default_factory=lambda: defaultdict(float),
83
- )
84
- cfg_continuous_cached_steps: int = 0
85
-
86
- def __post_init__(self):
87
- if logger.isEnabledFor(logging.DEBUG):
88
- logger.info(f"Created _CacheContextV2: {self.name}")
89
- # Some checks for settings
90
- if self.enable_separate_cfg:
91
- if self.cfg_diff_compute_separate:
92
- assert self.cfg_compute_first is False, (
93
- "cfg_compute_first must set as False if "
94
- "cfg_diff_compute_separate is enabled."
95
- )
96
-
97
- if self.calibrator_config.enable_calibrator:
98
- self.calibrator = Calibrator(self.calibrator_config)
99
- if self.enable_separate_cfg:
100
- self.cfg_calibrator = Calibrator(self.calibrator_config)
101
-
102
- if self.calibrator_config.enable_encoder_calibrator:
103
- self.encoder_calibrator = Calibrator(self.calibrator_config)
104
- if self.enable_separate_cfg:
105
- self.cfg_encoder_calibrator = Calibrator(self.calibrator_config)
106
-
107
- def enable_calibrator(self):
108
- if self.calibrator_config is not None:
109
- return self.calibrator_config.enable_calibrator
110
- return False
111
-
112
- def enable_encoder_calibrator(self):
113
- if self.calibrator_config is not None:
114
- return self.calibrator_config.enable_encoder_calibrator
115
- return False
116
-
117
- def calibrator_cache_type(self):
118
- if self.calibrator_config is not None:
119
- return self.calibrator_config.calibrator_cache_type
120
- return "residual"
121
-
122
- def get_residual_diff_threshold(self):
123
- residual_diff_threshold = self.residual_diff_threshold
124
- if self.l1_hidden_states_diff_threshold is not None:
125
- # Use the L1 hidden states diff threshold if set
126
- residual_diff_threshold = self.l1_hidden_states_diff_threshold
127
- if isinstance(residual_diff_threshold, torch.Tensor):
128
- residual_diff_threshold = residual_diff_threshold.item()
129
- return residual_diff_threshold
130
-
131
- def get_buffer(self, name):
132
- return self.buffers.get(name)
133
-
134
- def set_buffer(self, name, buffer):
135
- self.buffers[name] = buffer
136
-
137
- def remove_buffer(self, name):
138
- if name in self.buffers:
139
- del self.buffers[name]
140
-
141
- def clear_buffers(self):
142
- self.buffers.clear()
143
-
144
- def mark_step_begin(self):
145
- # Always increase transformer executed steps
146
- # incr step: prev 0 -> 1; prev 1 -> 2
147
- # current step: incr step - 1
148
- self.transformer_executed_steps += 1
149
- if not self.enable_separate_cfg:
150
- self.executed_steps += 1
151
- else:
152
- # 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
153
- if not self.cfg_compute_first:
154
- if not self.is_separate_cfg_step():
155
- # transformer step: 0,2,4,...
156
- self.executed_steps += 1
157
- else:
158
- if self.is_separate_cfg_step():
159
- # transformer step: 0,2,4,...
160
- self.executed_steps += 1
161
-
162
- # Reset the cached steps and residual diffs at the beginning
163
- # of each inference.
164
- if self.get_current_transformer_step() == 0:
165
- self.cached_steps.clear()
166
- self.residual_diffs.clear()
167
- self.cfg_cached_steps.clear()
168
- self.cfg_residual_diffs.clear()
169
- # Reset the calibrators cache at the beginning of each inference.
170
- # reset_cache will set the current step to -1 for calibrator,
171
- if (
172
- self.calibrator_config.enable_calibrator
173
- or self.calibrator_config.enable_encoder_calibrator
174
- ):
175
- calibrator, encoder_calibrator = self.get_calibrators()
176
- if calibrator is not None:
177
- calibrator.reset_cache()
178
- if encoder_calibrator is not None:
179
- encoder_calibrator.reset_cache()
180
- cfg_calibrator, cfg_encoder_calibrator = (
181
- self.get_cfg_calibrators()
182
- )
183
- if cfg_calibrator is not None:
184
- cfg_calibrator.reset_cache()
185
- if cfg_encoder_calibrator is not None:
186
- cfg_encoder_calibrator.reset_cache()
187
-
188
- # mark_step_begin of calibrator must be called after the cache is reset.
189
- if (
190
- self.calibrator_config.enable_calibrator
191
- or self.calibrator_config.enable_encoder_calibrator
192
- ):
193
- if self.enable_separate_cfg:
194
- # Assume non-CFG steps: 0, 2, 4, 6, ...
195
- if not self.is_separate_cfg_step():
196
- calibrator, encoder_calibrator = self.get_calibrators()
197
- if calibrator is not None:
198
- calibrator.mark_step_begin()
199
- if encoder_calibrator is not None:
200
- encoder_calibrator.mark_step_begin()
201
- else:
202
- cfg_calibrator, cfg_encoder_calibrator = (
203
- self.get_cfg_calibrators()
204
- )
205
- if cfg_calibrator is not None:
206
- cfg_calibrator.mark_step_begin()
207
- if cfg_encoder_calibrator is not None:
208
- cfg_encoder_calibrator.mark_step_begin()
209
- else:
210
- calibrator, encoder_calibrator = self.get_calibrators()
211
- if calibrator is not None:
212
- calibrator.mark_step_begin()
213
- if encoder_calibrator is not None:
214
- encoder_calibrator.mark_step_begin()
215
-
216
- def get_calibrators(self) -> Tuple[CalibratorBase, CalibratorBase]:
217
- return self.calibrator, self.encoder_calibrator
218
-
219
- def get_cfg_calibrators(self) -> Tuple[CalibratorBase, CalibratorBase]:
220
- return self.cfg_calibrator, self.cfg_encoder_calibrator
221
-
222
- def add_residual_diff(self, diff):
223
- # step: executed_steps - 1, not transformer_steps - 1
224
- step = str(self.get_current_step())
225
- # Only add the diff if it is not already recorded for this step
226
- if not self.is_separate_cfg_step():
227
- if step not in self.residual_diffs:
228
- self.residual_diffs[step] = diff
229
- else:
230
- if step not in self.cfg_residual_diffs:
231
- self.cfg_residual_diffs[step] = diff
232
-
233
- def get_residual_diffs(self):
234
- return self.residual_diffs.copy()
235
-
236
- def get_cfg_residual_diffs(self):
237
- return self.cfg_residual_diffs.copy()
238
-
239
- def add_cached_step(self):
240
- curr_cached_step = self.get_current_step()
241
- if not self.is_separate_cfg_step():
242
- if self.cached_steps:
243
- prev_cached_step = self.cached_steps[-1]
244
- if curr_cached_step - prev_cached_step == 1:
245
- if self.continuous_cached_steps == 0:
246
- self.continuous_cached_steps += 2
247
- else:
248
- self.continuous_cached_steps += 1
249
- else:
250
- self.continuous_cached_steps += 1
251
-
252
- self.cached_steps.append(curr_cached_step)
253
- else:
254
- if self.cfg_cached_steps:
255
- prev_cfg_cached_step = self.cfg_cached_steps[-1]
256
- if curr_cached_step - prev_cfg_cached_step == 1:
257
- if self.cfg_continuous_cached_steps == 0:
258
- self.cfg_continuous_cached_steps += 2
259
- else:
260
- self.cfg_continuous_cached_steps += 1
261
- else:
262
- self.cfg_continuous_cached_steps += 1
263
-
264
- self.cfg_cached_steps.append(curr_cached_step)
265
-
266
- def get_cached_steps(self):
267
- return self.cached_steps.copy()
268
-
269
- def get_cfg_cached_steps(self):
270
- return self.cfg_cached_steps.copy()
271
-
272
- def get_current_step(self):
273
- return self.executed_steps - 1
274
-
275
- def get_current_transformer_step(self):
276
- return self.transformer_executed_steps - 1
277
-
278
- def is_separate_cfg_step(self):
279
- if not self.enable_separate_cfg:
280
- return False
281
- if self.cfg_compute_first:
282
- # CFG steps: 0, 2, 4, 6, ...
283
- return self.get_current_transformer_step() % 2 == 0
284
- # CFG steps: 1, 3, 5, 7, ...
285
- return self.get_current_transformer_step() % 2 != 0
286
-
287
- def is_in_warmup(self):
288
- return self.get_current_step() < self.max_warmup_steps