cache-dit 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +3 -6
- cache_dit/cache_factory/block_adapters/block_adapters.py +21 -64
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +82 -21
- cache_dit/cache_factory/cache_blocks/__init__.py +4 -0
- cache_dit/cache_factory/cache_blocks/offload_utils.py +115 -0
- cache_dit/cache_factory/cache_blocks/pattern_base.py +3 -0
- cache_dit/cache_factory/cache_contexts/__init__.py +10 -8
- cache_dit/cache_factory/cache_contexts/cache_context.py +186 -117
- cache_dit/cache_factory/cache_contexts/cache_manager.py +63 -131
- cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +132 -0
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/foca.py +1 -1
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/taylorseer.py +7 -2
- cache_dit/cache_factory/cache_interface.py +128 -111
- cache_dit/cache_factory/params_modifier.py +87 -0
- cache_dit/metrics/__init__.py +3 -1
- cache_dit/utils.py +12 -21
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/METADATA +200 -434
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/RECORD +27 -31
- cache_dit/cache_factory/cache_adapters/v2/__init__.py +0 -3
- cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +0 -524
- cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -102
- cache_dit/cache_factory/cache_contexts/v2/__init__.py +0 -13
- cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +0 -288
- cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +0 -799
- cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +0 -81
- /cache_dit/cache_factory/cache_blocks/{utils.py → pattern_utils.py} +0 -0
- /cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/base.py +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/top_level.txt +0 -0
|
@@ -5,41 +5,120 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from cache_dit.cache_factory.cache_contexts.
|
|
8
|
+
from cache_dit.cache_factory.cache_contexts.calibrators import (
|
|
9
|
+
Calibrator,
|
|
10
|
+
CalibratorBase,
|
|
11
|
+
CalibratorConfig,
|
|
12
|
+
)
|
|
9
13
|
from cache_dit.logger import init_logger
|
|
10
14
|
|
|
11
15
|
logger = init_logger(__name__)
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
@dataclasses.dataclass
|
|
15
|
-
class
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
Fn_compute_blocks: int
|
|
19
|
+
class BasicCacheConfig:
|
|
20
|
+
# Dual Block Cache with Flexible FnBn configuration.
|
|
21
|
+
|
|
22
|
+
# Fn_compute_blocks: (`int`, *required*, defaults to 8):
|
|
23
|
+
# Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
|
|
24
|
+
# at time step t, enabling the calculation of a more stable L1 diff and delivering more
|
|
25
|
+
# accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
|
|
26
|
+
# for more details of DBCache.
|
|
27
|
+
Fn_compute_blocks: int = 8
|
|
28
|
+
# Bn_compute_blocks: (`int`, *required*, defaults to 0):
|
|
29
|
+
# Further fuses approximate information in the **last n** Transformer blocks to enhance
|
|
30
|
+
# prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
|
|
31
|
+
# that use residual cache.
|
|
19
32
|
Bn_compute_blocks: int = 0
|
|
20
|
-
#
|
|
21
|
-
#
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
#
|
|
26
|
-
|
|
33
|
+
# residual_diff_threshold (`float`, *required*, defaults to 0.08):
|
|
34
|
+
# the value of residual diff threshold, a higher value leads to faster performance at the
|
|
35
|
+
# cost of lower precision.
|
|
36
|
+
residual_diff_threshold: Union[torch.Tensor, float] = 0.08
|
|
37
|
+
# max_warmup_steps (`int`, *required*, defaults to 8):
|
|
38
|
+
# DBCache does not apply the caching strategy when the number of running steps is less than
|
|
39
|
+
# or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
40
|
+
max_warmup_steps: int = 8 # DON'T Cache in warmup steps
|
|
41
|
+
# max_cached_steps (`int`, *required*, defaults to -1):
|
|
42
|
+
# DBCache disables the caching strategy when the previous cached steps exceed this value to
|
|
43
|
+
# prevent precision degradation.
|
|
44
|
+
max_cached_steps: int = -1 # for both CFG and non-CFG
|
|
45
|
+
# max_continuous_cached_steps (`int`, *required*, defaults to -1):
|
|
46
|
+
# DBCache disables the caching strategy when the previous continous cached steps exceed this value to
|
|
47
|
+
# prevent precision degradation.
|
|
48
|
+
max_continuous_cached_steps: int = -1 # the max continuous cached steps
|
|
49
|
+
# enable_separate_cfg (`bool`, *required*, defaults to None):
|
|
50
|
+
# Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
|
|
51
|
+
# and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
|
|
52
|
+
# CogVideoX, HunyuanVideo, Mochi, etc.
|
|
53
|
+
enable_separate_cfg: Optional[bool] = None
|
|
54
|
+
# cfg_compute_first (`bool`, *required*, defaults to False):
|
|
55
|
+
# Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
|
|
56
|
+
# 1, 3, 5, ... -> CFG step.
|
|
57
|
+
cfg_compute_first: bool = False
|
|
58
|
+
# cfg_diff_compute_separate (`bool`, *required*, defaults to True):
|
|
59
|
+
# Compute separate diff values for CFG and non-CFG step, default True. If False, we will
|
|
60
|
+
# use the computed diff from current non-CFG transformer step for current CFG step.
|
|
61
|
+
cfg_diff_compute_separate: bool = True
|
|
62
|
+
|
|
63
|
+
def update(self, **kwargs) -> "BasicCacheConfig":
|
|
64
|
+
for key, value in kwargs.items():
|
|
65
|
+
if hasattr(self, key):
|
|
66
|
+
setattr(self, key, value)
|
|
67
|
+
return self
|
|
68
|
+
|
|
69
|
+
def strify(self) -> str:
|
|
70
|
+
return (
|
|
71
|
+
f"DBCACHE_F{self.Fn_compute_blocks}"
|
|
72
|
+
f"B{self.Bn_compute_blocks}_"
|
|
73
|
+
f"W{self.max_warmup_steps}"
|
|
74
|
+
f"M{max(0, self.max_cached_steps)}"
|
|
75
|
+
f"MC{max(0, self.max_continuous_cached_steps)}_"
|
|
76
|
+
f"R{self.residual_diff_threshold}"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclasses.dataclass
|
|
81
|
+
class ExtraCacheConfig:
|
|
82
|
+
# Some other not very important settings for Dual Block Cache.
|
|
83
|
+
# NOTE: These flags maybe deprecated in the future and users
|
|
84
|
+
# should never use these extra configurations in their cases.
|
|
85
|
+
|
|
86
|
+
# l1_hidden_states_diff_threshold (`float`, *optional*, defaults to None):
|
|
87
|
+
# The hidden states diff threshold for DBCache if use hidden_states as
|
|
88
|
+
# cache (not residual).
|
|
27
89
|
l1_hidden_states_diff_threshold: float = None
|
|
90
|
+
# important_condition_threshold (`float`, *optional*, defaults to 0.0):
|
|
91
|
+
# Only select the most important tokens while calculating the l1 diff.
|
|
28
92
|
important_condition_threshold: float = 0.0
|
|
93
|
+
# downsample_factor (`int`, *optional*, defaults to 1):
|
|
94
|
+
# Downsample factor for Fn buffer, in order the save GPU memory.
|
|
95
|
+
downsample_factor: int = 1
|
|
96
|
+
# num_inference_steps (`int`, *optional*, defaults to -1):
|
|
97
|
+
# num_inference_steps for DiffusionPipeline, for future use.
|
|
98
|
+
num_inference_steps: int = -1
|
|
29
99
|
|
|
100
|
+
|
|
101
|
+
@dataclasses.dataclass
|
|
102
|
+
class CachedContext:
|
|
103
|
+
name: str = "default"
|
|
30
104
|
# Buffer for storing the residuals and other tensors
|
|
31
105
|
buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
32
|
-
|
|
33
|
-
|
|
106
|
+
# Basic Dual Block Cache Config
|
|
107
|
+
cache_config: BasicCacheConfig = dataclasses.field(
|
|
108
|
+
default_factory=BasicCacheConfig,
|
|
34
109
|
)
|
|
110
|
+
# NOTE: Users should never use these extra configurations.
|
|
111
|
+
extra_cache_config: ExtraCacheConfig = dataclasses.field(
|
|
112
|
+
default_factory=ExtraCacheConfig,
|
|
113
|
+
)
|
|
114
|
+
# Calibrator config for Dual Block Cache: TaylorSeer, FoCa, etc.
|
|
115
|
+
calibrator_config: Optional[CalibratorConfig] = None
|
|
35
116
|
|
|
36
|
-
#
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
max_cached_steps: int = -1 # for both CFG and non-CFG
|
|
42
|
-
max_continuous_cached_steps: int = -1 # the max continuous cached steps
|
|
117
|
+
# Calibrators for both CFG and non-CFG
|
|
118
|
+
calibrator: Optional[CalibratorBase] = None
|
|
119
|
+
encoder_calibrator: Optional[CalibratorBase] = None
|
|
120
|
+
cfg_calibrator: Optional[CalibratorBase] = None
|
|
121
|
+
cfg_encoder_calibrator: Optional[CalibratorBase] = None
|
|
43
122
|
|
|
44
123
|
# Record the steps that have been cached, both cached and non-cache
|
|
45
124
|
executed_steps: int = 0 # cache + non-cache steps pippeline
|
|
@@ -47,32 +126,6 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
47
126
|
# be double of executed_steps.
|
|
48
127
|
transformer_executed_steps: int = 0
|
|
49
128
|
|
|
50
|
-
# Support TaylorSeers in Dual Block Cache
|
|
51
|
-
# Title: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
|
|
52
|
-
# Url: https://arxiv.org/pdf/2503.06923
|
|
53
|
-
enable_taylorseer: bool = False
|
|
54
|
-
enable_encoder_taylorseer: bool = False
|
|
55
|
-
taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
|
|
56
|
-
taylorseer_order: int = 1 # The order for TaylorSeer
|
|
57
|
-
taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
58
|
-
taylorseer: Optional[TaylorSeer] = None
|
|
59
|
-
encoder_tarlorseer: Optional[TaylorSeer] = None
|
|
60
|
-
|
|
61
|
-
# Support enable_separate_cfg, such as Wan 2.1,
|
|
62
|
-
# Qwen-Image. For model that fused CFG and non-CFG into single
|
|
63
|
-
# forward step, should set enable_separate_cfg as False.
|
|
64
|
-
# For example: CogVideoX, HunyuanVideo, Mochi.
|
|
65
|
-
enable_separate_cfg: bool = False
|
|
66
|
-
# Compute cfg forward first or not, default False, namely,
|
|
67
|
-
# 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
|
|
68
|
-
cfg_compute_first: bool = False
|
|
69
|
-
# Compute separate diff values for CFG and non-CFG step,
|
|
70
|
-
# default True. If False, we will use the computed diff from
|
|
71
|
-
# current non-CFG transformer step for current CFG step.
|
|
72
|
-
cfg_diff_compute_separate: bool = True
|
|
73
|
-
cfg_taylorseer: Optional[TaylorSeer] = None
|
|
74
|
-
cfg_encoder_taylorseer: Optional[TaylorSeer] = None
|
|
75
|
-
|
|
76
129
|
# CFG & non-CFG cached steps
|
|
77
130
|
cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
78
131
|
residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
@@ -87,42 +140,58 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
87
140
|
|
|
88
141
|
def __post_init__(self):
|
|
89
142
|
if logger.isEnabledFor(logging.DEBUG):
|
|
90
|
-
logger.info(f"Created
|
|
143
|
+
logger.info(f"Created CachedContext: {self.name}")
|
|
91
144
|
# Some checks for settings
|
|
92
|
-
if self.enable_separate_cfg:
|
|
93
|
-
if self.cfg_diff_compute_separate:
|
|
94
|
-
assert self.cfg_compute_first is False, (
|
|
145
|
+
if self.cache_config.enable_separate_cfg:
|
|
146
|
+
if self.cache_config.cfg_diff_compute_separate:
|
|
147
|
+
assert self.cache_config.cfg_compute_first is False, (
|
|
95
148
|
"cfg_compute_first must set as False if "
|
|
96
149
|
"cfg_diff_compute_separate is enabled."
|
|
97
150
|
)
|
|
98
151
|
|
|
99
|
-
if
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
if self.enable_separate_cfg:
|
|
112
|
-
self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
152
|
+
if self.calibrator_config is not None:
|
|
153
|
+
if self.calibrator_config.enable_calibrator:
|
|
154
|
+
self.calibrator = Calibrator(self.calibrator_config)
|
|
155
|
+
if self.cache_config.enable_separate_cfg:
|
|
156
|
+
self.cfg_calibrator = Calibrator(self.calibrator_config)
|
|
157
|
+
|
|
158
|
+
if self.calibrator_config.enable_encoder_calibrator:
|
|
159
|
+
self.encoder_calibrator = Calibrator(self.calibrator_config)
|
|
160
|
+
if self.cache_config.enable_separate_cfg:
|
|
161
|
+
self.cfg_encoder_calibrator = Calibrator(
|
|
162
|
+
self.calibrator_config
|
|
163
|
+
)
|
|
113
164
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
165
|
+
def enable_calibrator(self):
|
|
166
|
+
if self.calibrator_config is not None:
|
|
167
|
+
return self.calibrator_config.enable_calibrator
|
|
168
|
+
return False
|
|
169
|
+
|
|
170
|
+
def enable_encoder_calibrator(self):
|
|
171
|
+
if self.calibrator_config is not None:
|
|
172
|
+
return self.calibrator_config.enable_encoder_calibrator
|
|
173
|
+
return False
|
|
174
|
+
|
|
175
|
+
def calibrator_cache_type(self):
|
|
176
|
+
if self.calibrator_config is not None:
|
|
177
|
+
return self.calibrator_config.calibrator_cache_type
|
|
178
|
+
return "residual"
|
|
179
|
+
|
|
180
|
+
def has_calibrators(self) -> bool:
|
|
181
|
+
if self.calibrator_config is not None:
|
|
182
|
+
return (
|
|
183
|
+
self.calibrator_config.enable_calibrator
|
|
184
|
+
or self.calibrator_config.enable_encoder_calibrator
|
|
185
|
+
)
|
|
186
|
+
return False
|
|
120
187
|
|
|
121
188
|
def get_residual_diff_threshold(self):
|
|
122
|
-
residual_diff_threshold = self.residual_diff_threshold
|
|
123
|
-
if self.l1_hidden_states_diff_threshold is not None:
|
|
189
|
+
residual_diff_threshold = self.cache_config.residual_diff_threshold
|
|
190
|
+
if self.extra_cache_config.l1_hidden_states_diff_threshold is not None:
|
|
124
191
|
# Use the L1 hidden states diff threshold if set
|
|
125
|
-
residual_diff_threshold =
|
|
192
|
+
residual_diff_threshold = (
|
|
193
|
+
self.extra_cache_config.l1_hidden_states_diff_threshold
|
|
194
|
+
)
|
|
126
195
|
if isinstance(residual_diff_threshold, torch.Tensor):
|
|
127
196
|
residual_diff_threshold = residual_diff_threshold.item()
|
|
128
197
|
return residual_diff_threshold
|
|
@@ -145,11 +214,11 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
145
214
|
# incr step: prev 0 -> 1; prev 1 -> 2
|
|
146
215
|
# current step: incr step - 1
|
|
147
216
|
self.transformer_executed_steps += 1
|
|
148
|
-
if not self.enable_separate_cfg:
|
|
217
|
+
if not self.cache_config.enable_separate_cfg:
|
|
149
218
|
self.executed_steps += 1
|
|
150
219
|
else:
|
|
151
220
|
# 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
|
|
152
|
-
if not self.cfg_compute_first:
|
|
221
|
+
if not self.cache_config.cfg_compute_first:
|
|
153
222
|
if not self.is_separate_cfg_step():
|
|
154
223
|
# transformer step: 0,2,4,...
|
|
155
224
|
self.executed_steps += 1
|
|
@@ -165,52 +234,52 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
165
234
|
self.residual_diffs.clear()
|
|
166
235
|
self.cfg_cached_steps.clear()
|
|
167
236
|
self.cfg_residual_diffs.clear()
|
|
168
|
-
# Reset the
|
|
169
|
-
# reset_cache will set the current step to -1 for
|
|
170
|
-
if self.
|
|
171
|
-
|
|
172
|
-
if
|
|
173
|
-
|
|
174
|
-
if
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
self.
|
|
237
|
+
# Reset the calibrators cache at the beginning of each inference.
|
|
238
|
+
# reset_cache will set the current step to -1 for calibrator,
|
|
239
|
+
if self.has_calibrators():
|
|
240
|
+
calibrator, encoder_calibrator = self.get_calibrators()
|
|
241
|
+
if calibrator is not None:
|
|
242
|
+
calibrator.reset_cache()
|
|
243
|
+
if encoder_calibrator is not None:
|
|
244
|
+
encoder_calibrator.reset_cache()
|
|
245
|
+
cfg_calibrator, cfg_encoder_calibrator = (
|
|
246
|
+
self.get_cfg_calibrators()
|
|
178
247
|
)
|
|
179
|
-
if
|
|
180
|
-
|
|
181
|
-
if
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
# mark_step_begin of
|
|
185
|
-
if self.
|
|
186
|
-
if self.enable_separate_cfg:
|
|
248
|
+
if cfg_calibrator is not None:
|
|
249
|
+
cfg_calibrator.reset_cache()
|
|
250
|
+
if cfg_encoder_calibrator is not None:
|
|
251
|
+
cfg_encoder_calibrator.reset_cache()
|
|
252
|
+
|
|
253
|
+
# mark_step_begin of calibrator must be called after the cache is reset.
|
|
254
|
+
if self.has_calibrators():
|
|
255
|
+
if self.cache_config.enable_separate_cfg:
|
|
187
256
|
# Assume non-CFG steps: 0, 2, 4, 6, ...
|
|
188
257
|
if not self.is_separate_cfg_step():
|
|
189
|
-
|
|
190
|
-
if
|
|
191
|
-
|
|
192
|
-
if
|
|
193
|
-
|
|
258
|
+
calibrator, encoder_calibrator = self.get_calibrators()
|
|
259
|
+
if calibrator is not None:
|
|
260
|
+
calibrator.mark_step_begin()
|
|
261
|
+
if encoder_calibrator is not None:
|
|
262
|
+
encoder_calibrator.mark_step_begin()
|
|
194
263
|
else:
|
|
195
|
-
|
|
196
|
-
self.
|
|
264
|
+
cfg_calibrator, cfg_encoder_calibrator = (
|
|
265
|
+
self.get_cfg_calibrators()
|
|
197
266
|
)
|
|
198
|
-
if
|
|
199
|
-
|
|
200
|
-
if
|
|
201
|
-
|
|
267
|
+
if cfg_calibrator is not None:
|
|
268
|
+
cfg_calibrator.mark_step_begin()
|
|
269
|
+
if cfg_encoder_calibrator is not None:
|
|
270
|
+
cfg_encoder_calibrator.mark_step_begin()
|
|
202
271
|
else:
|
|
203
|
-
|
|
204
|
-
if
|
|
205
|
-
|
|
206
|
-
if
|
|
207
|
-
|
|
272
|
+
calibrator, encoder_calibrator = self.get_calibrators()
|
|
273
|
+
if calibrator is not None:
|
|
274
|
+
calibrator.mark_step_begin()
|
|
275
|
+
if encoder_calibrator is not None:
|
|
276
|
+
encoder_calibrator.mark_step_begin()
|
|
208
277
|
|
|
209
|
-
def
|
|
210
|
-
return self.
|
|
278
|
+
def get_calibrators(self) -> Tuple[CalibratorBase, CalibratorBase]:
|
|
279
|
+
return self.calibrator, self.encoder_calibrator
|
|
211
280
|
|
|
212
|
-
def
|
|
213
|
-
return self.
|
|
281
|
+
def get_cfg_calibrators(self) -> Tuple[CalibratorBase, CalibratorBase]:
|
|
282
|
+
return self.cfg_calibrator, self.cfg_encoder_calibrator
|
|
214
283
|
|
|
215
284
|
def add_residual_diff(self, diff):
|
|
216
285
|
# step: executed_steps - 1, not transformer_steps - 1
|
|
@@ -269,13 +338,13 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
269
338
|
return self.transformer_executed_steps - 1
|
|
270
339
|
|
|
271
340
|
def is_separate_cfg_step(self):
|
|
272
|
-
if not self.enable_separate_cfg:
|
|
341
|
+
if not self.cache_config.enable_separate_cfg:
|
|
273
342
|
return False
|
|
274
|
-
if self.cfg_compute_first:
|
|
343
|
+
if self.cache_config.cfg_compute_first:
|
|
275
344
|
# CFG steps: 0, 2, 4, 6, ...
|
|
276
345
|
return self.get_current_transformer_step() % 2 == 0
|
|
277
346
|
# CFG steps: 1, 3, 5, 7, ...
|
|
278
347
|
return self.get_current_transformer_step() % 2 != 0
|
|
279
348
|
|
|
280
349
|
def is_in_warmup(self):
|
|
281
|
-
return self.get_current_step() < self.max_warmup_steps
|
|
350
|
+
return self.get_current_step() < self.cache_config.max_warmup_steps
|