cache-dit 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +3 -6
- cache_dit/cache_factory/block_adapters/block_adapters.py +21 -64
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +82 -21
- cache_dit/cache_factory/cache_blocks/__init__.py +4 -0
- cache_dit/cache_factory/cache_blocks/offload_utils.py +115 -0
- cache_dit/cache_factory/cache_blocks/pattern_base.py +3 -0
- cache_dit/cache_factory/cache_contexts/__init__.py +10 -8
- cache_dit/cache_factory/cache_contexts/cache_context.py +186 -117
- cache_dit/cache_factory/cache_contexts/cache_manager.py +63 -131
- cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +132 -0
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/foca.py +1 -1
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/taylorseer.py +7 -2
- cache_dit/cache_factory/cache_interface.py +128 -111
- cache_dit/cache_factory/params_modifier.py +87 -0
- cache_dit/metrics/__init__.py +3 -1
- cache_dit/utils.py +12 -21
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/METADATA +200 -434
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/RECORD +27 -31
- cache_dit/cache_factory/cache_adapters/v2/__init__.py +0 -3
- cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +0 -524
- cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -102
- cache_dit/cache_factory/cache_contexts/v2/__init__.py +0 -13
- cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +0 -288
- cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +0 -799
- cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +0 -81
- /cache_dit/cache_factory/cache_blocks/{utils.py → pattern_utils.py} +0 -0
- /cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/base.py +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/top_level.txt +0 -0
|
@@ -1,102 +0,0 @@
|
|
|
1
|
-
import math
|
|
2
|
-
import torch
|
|
3
|
-
from typing import List, Dict
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class TaylorSeer:
|
|
7
|
-
def __init__(
|
|
8
|
-
self,
|
|
9
|
-
n_derivatives=2,
|
|
10
|
-
max_warmup_steps=1,
|
|
11
|
-
skip_interval_steps=1,
|
|
12
|
-
compute_step_map=None,
|
|
13
|
-
):
|
|
14
|
-
self.n_derivatives = n_derivatives
|
|
15
|
-
self.ORDER = n_derivatives + 1
|
|
16
|
-
self.max_warmup_steps = max_warmup_steps
|
|
17
|
-
self.skip_interval_steps = skip_interval_steps
|
|
18
|
-
self.compute_step_map = compute_step_map
|
|
19
|
-
self.reset_cache()
|
|
20
|
-
|
|
21
|
-
def reset_cache(self):
|
|
22
|
-
self.state: Dict[str, List[torch.Tensor]] = {
|
|
23
|
-
"dY_prev": [None] * self.ORDER,
|
|
24
|
-
"dY_current": [None] * self.ORDER,
|
|
25
|
-
}
|
|
26
|
-
self.current_step = -1
|
|
27
|
-
self.last_non_approximated_step = -1
|
|
28
|
-
|
|
29
|
-
def should_compute_full(self, step=None):
|
|
30
|
-
step = self.current_step if step is None else step
|
|
31
|
-
if self.compute_step_map is not None:
|
|
32
|
-
return self.compute_step_map[step]
|
|
33
|
-
if (
|
|
34
|
-
step < self.max_warmup_steps
|
|
35
|
-
or (step - self.max_warmup_steps + 1) % self.skip_interval_steps
|
|
36
|
-
== 0
|
|
37
|
-
):
|
|
38
|
-
return True
|
|
39
|
-
return False
|
|
40
|
-
|
|
41
|
-
def approximate_derivative(self, Y: torch.Tensor) -> List[torch.Tensor]:
|
|
42
|
-
# n-th order Taylor expansion:
|
|
43
|
-
# Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
|
|
44
|
-
# + ... + d^nY(0)/dt^n * t^n / n!
|
|
45
|
-
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
46
|
-
# especially for large n_derivatives.
|
|
47
|
-
dY_current: List[torch.Tensor] = [None] * self.ORDER
|
|
48
|
-
dY_current[0] = Y
|
|
49
|
-
window = self.current_step - self.last_non_approximated_step
|
|
50
|
-
if self.state["dY_prev"][0] is not None:
|
|
51
|
-
if dY_current[0].shape != self.state["dY_prev"][0].shape:
|
|
52
|
-
self.reset_cache()
|
|
53
|
-
|
|
54
|
-
for i in range(self.n_derivatives):
|
|
55
|
-
if self.state["dY_prev"][i] is not None and self.current_step > 1:
|
|
56
|
-
dY_current[i + 1] = (
|
|
57
|
-
dY_current[i] - self.state["dY_prev"][i]
|
|
58
|
-
) / window
|
|
59
|
-
else:
|
|
60
|
-
break
|
|
61
|
-
return dY_current
|
|
62
|
-
|
|
63
|
-
def approximate_value(self) -> torch.Tensor:
|
|
64
|
-
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
65
|
-
# especially for large n_derivatives.
|
|
66
|
-
elapsed = self.current_step - self.last_non_approximated_step
|
|
67
|
-
output = 0
|
|
68
|
-
for i, derivative in enumerate(self.state["dY_current"]):
|
|
69
|
-
if derivative is not None:
|
|
70
|
-
output += (1 / math.factorial(i)) * derivative * (elapsed**i)
|
|
71
|
-
else:
|
|
72
|
-
break
|
|
73
|
-
return output
|
|
74
|
-
|
|
75
|
-
def mark_step_begin(self):
|
|
76
|
-
self.current_step += 1
|
|
77
|
-
|
|
78
|
-
def update(self, Y: torch.Tensor):
|
|
79
|
-
# Directly call this method will ingnore the warmup
|
|
80
|
-
# policy and force full computation.
|
|
81
|
-
# Assume warmup steps is 3, and n_derivatives is 3.
|
|
82
|
-
# step 0: dY_prev = [None, None, None, None ]
|
|
83
|
-
# dY_current = [Y0, None, None, None ]
|
|
84
|
-
# step 1: dY_prev = [Y0, None, None, None ]
|
|
85
|
-
# dY_current = [Y1, dY1, None, None ]
|
|
86
|
-
# step 2: dY_prev = [Y1, dY1, None, None ]
|
|
87
|
-
# dY_current = [Y2, dY2/Y1, dY2/dY1, None ]
|
|
88
|
-
# step 3: dY_prev = [Y2, dY2/Y1, dY2/dY1, None ],
|
|
89
|
-
# dY_current = [Y3, dY3/Y2, dY3/dY2, dY3/dY1]
|
|
90
|
-
# step 4: dY_prev = [Y3, dY3/Y2, dY3/dY2, dY3/dY1]
|
|
91
|
-
# dY_current = [Y4, dY4/Y3, dY4/dY3, dY4/dY2]
|
|
92
|
-
self.state["dY_prev"] = self.state["dY_current"]
|
|
93
|
-
self.state["dY_current"] = self.approximate_derivative(Y)
|
|
94
|
-
self.last_non_approximated_step = self.current_step
|
|
95
|
-
|
|
96
|
-
def step(self, Y: torch.Tensor):
|
|
97
|
-
self.mark_step_begin()
|
|
98
|
-
if self.should_compute_full():
|
|
99
|
-
self.update(Y)
|
|
100
|
-
return Y
|
|
101
|
-
else:
|
|
102
|
-
return self.approximate_value()
|
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
from cache_dit.cache_factory.cache_contexts.v2.calibrators import (
|
|
2
|
-
Calibrator,
|
|
3
|
-
CalibratorBase,
|
|
4
|
-
CalibratorConfig,
|
|
5
|
-
TaylorSeerCalibratorConfig,
|
|
6
|
-
FoCaCalibratorConfig,
|
|
7
|
-
)
|
|
8
|
-
from cache_dit.cache_factory.cache_contexts.v2.cache_context_v2 import (
|
|
9
|
-
CachedContextV2,
|
|
10
|
-
)
|
|
11
|
-
from cache_dit.cache_factory.cache_contexts.v2.cache_manager_v2 import (
|
|
12
|
-
CachedContextManagerV2,
|
|
13
|
-
)
|
|
@@ -1,288 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import dataclasses
|
|
3
|
-
from collections import defaultdict
|
|
4
|
-
from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
|
|
8
|
-
from cache_dit.cache_factory.cache_contexts.v2.calibrators import (
|
|
9
|
-
Calibrator,
|
|
10
|
-
CalibratorBase,
|
|
11
|
-
CalibratorConfig,
|
|
12
|
-
)
|
|
13
|
-
from cache_dit.logger import init_logger
|
|
14
|
-
|
|
15
|
-
logger = init_logger(__name__)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
@dataclasses.dataclass
|
|
19
|
-
class CachedContextV2: # Internal CachedContext Impl class
|
|
20
|
-
name: str = "default"
|
|
21
|
-
# Dual Block Cache with flexible FnBn configuration.
|
|
22
|
-
Fn_compute_blocks: int = 1
|
|
23
|
-
Bn_compute_blocks: int = 0
|
|
24
|
-
# non compute blocks diff threshold, we don't skip the non
|
|
25
|
-
# compute blocks if the diff >= threshold
|
|
26
|
-
non_compute_blocks_diff_threshold: float = 0.08
|
|
27
|
-
max_Fn_compute_blocks: int = -1
|
|
28
|
-
max_Bn_compute_blocks: int = -1
|
|
29
|
-
# L1 hidden states or residual diff threshold for Fn
|
|
30
|
-
residual_diff_threshold: Union[torch.Tensor, float] = 0.05
|
|
31
|
-
l1_hidden_states_diff_threshold: float = None
|
|
32
|
-
important_condition_threshold: float = 0.0
|
|
33
|
-
|
|
34
|
-
# Buffer for storing the residuals and other tensors
|
|
35
|
-
buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
36
|
-
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
|
|
37
|
-
default_factory=lambda: defaultdict(int),
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
# Other settings
|
|
41
|
-
downsample_factor: int = 1
|
|
42
|
-
num_inference_steps: int = -1 # for future use
|
|
43
|
-
max_warmup_steps: int = 0 # DON'T Cache in warmup steps
|
|
44
|
-
# DON'T Cache if the number of cached steps >= max_cached_steps
|
|
45
|
-
max_cached_steps: int = -1 # for both CFG and non-CFG
|
|
46
|
-
max_continuous_cached_steps: int = -1 # the max continuous cached steps
|
|
47
|
-
|
|
48
|
-
# Record the steps that have been cached, both cached and non-cache
|
|
49
|
-
executed_steps: int = 0 # cache + non-cache steps pippeline
|
|
50
|
-
# steps for transformer, for CFG, transformer_executed_steps will
|
|
51
|
-
# be double of executed_steps.
|
|
52
|
-
transformer_executed_steps: int = 0
|
|
53
|
-
|
|
54
|
-
# Support calibrators in Dual Block Cache: TaylorSeer, FoCa, etc.
|
|
55
|
-
calibrator_config: Optional[CalibratorConfig] = None
|
|
56
|
-
calibrator: Optional[CalibratorBase] = None
|
|
57
|
-
encoder_calibrator: Optional[CalibratorBase] = None
|
|
58
|
-
|
|
59
|
-
# Support enable_separate_cfg, such as Wan 2.1,
|
|
60
|
-
# Qwen-Image. For model that fused CFG and non-CFG into single
|
|
61
|
-
# forward step, should set enable_separate_cfg as False.
|
|
62
|
-
# For example: CogVideoX, HunyuanVideo, Mochi.
|
|
63
|
-
enable_separate_cfg: bool = False
|
|
64
|
-
# Compute cfg forward first or not, default False, namely,
|
|
65
|
-
# 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
|
|
66
|
-
cfg_compute_first: bool = False
|
|
67
|
-
# Compute separate diff values for CFG and non-CFG step,
|
|
68
|
-
# default True. If False, we will use the computed diff from
|
|
69
|
-
# current non-CFG transformer step for current CFG step.
|
|
70
|
-
cfg_diff_compute_separate: bool = True
|
|
71
|
-
cfg_calibrator: Optional[CalibratorBase] = None
|
|
72
|
-
cfg_encoder_calibrator: Optional[CalibratorBase] = None
|
|
73
|
-
|
|
74
|
-
# CFG & non-CFG cached steps
|
|
75
|
-
cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
76
|
-
residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
77
|
-
default_factory=lambda: defaultdict(float),
|
|
78
|
-
)
|
|
79
|
-
continuous_cached_steps: int = 0
|
|
80
|
-
cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
81
|
-
cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
82
|
-
default_factory=lambda: defaultdict(float),
|
|
83
|
-
)
|
|
84
|
-
cfg_continuous_cached_steps: int = 0
|
|
85
|
-
|
|
86
|
-
def __post_init__(self):
|
|
87
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
88
|
-
logger.info(f"Created _CacheContextV2: {self.name}")
|
|
89
|
-
# Some checks for settings
|
|
90
|
-
if self.enable_separate_cfg:
|
|
91
|
-
if self.cfg_diff_compute_separate:
|
|
92
|
-
assert self.cfg_compute_first is False, (
|
|
93
|
-
"cfg_compute_first must set as False if "
|
|
94
|
-
"cfg_diff_compute_separate is enabled."
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
if self.calibrator_config.enable_calibrator:
|
|
98
|
-
self.calibrator = Calibrator(self.calibrator_config)
|
|
99
|
-
if self.enable_separate_cfg:
|
|
100
|
-
self.cfg_calibrator = Calibrator(self.calibrator_config)
|
|
101
|
-
|
|
102
|
-
if self.calibrator_config.enable_encoder_calibrator:
|
|
103
|
-
self.encoder_calibrator = Calibrator(self.calibrator_config)
|
|
104
|
-
if self.enable_separate_cfg:
|
|
105
|
-
self.cfg_encoder_calibrator = Calibrator(self.calibrator_config)
|
|
106
|
-
|
|
107
|
-
def enable_calibrator(self):
|
|
108
|
-
if self.calibrator_config is not None:
|
|
109
|
-
return self.calibrator_config.enable_calibrator
|
|
110
|
-
return False
|
|
111
|
-
|
|
112
|
-
def enable_encoder_calibrator(self):
|
|
113
|
-
if self.calibrator_config is not None:
|
|
114
|
-
return self.calibrator_config.enable_encoder_calibrator
|
|
115
|
-
return False
|
|
116
|
-
|
|
117
|
-
def calibrator_cache_type(self):
|
|
118
|
-
if self.calibrator_config is not None:
|
|
119
|
-
return self.calibrator_config.calibrator_cache_type
|
|
120
|
-
return "residual"
|
|
121
|
-
|
|
122
|
-
def get_residual_diff_threshold(self):
|
|
123
|
-
residual_diff_threshold = self.residual_diff_threshold
|
|
124
|
-
if self.l1_hidden_states_diff_threshold is not None:
|
|
125
|
-
# Use the L1 hidden states diff threshold if set
|
|
126
|
-
residual_diff_threshold = self.l1_hidden_states_diff_threshold
|
|
127
|
-
if isinstance(residual_diff_threshold, torch.Tensor):
|
|
128
|
-
residual_diff_threshold = residual_diff_threshold.item()
|
|
129
|
-
return residual_diff_threshold
|
|
130
|
-
|
|
131
|
-
def get_buffer(self, name):
|
|
132
|
-
return self.buffers.get(name)
|
|
133
|
-
|
|
134
|
-
def set_buffer(self, name, buffer):
|
|
135
|
-
self.buffers[name] = buffer
|
|
136
|
-
|
|
137
|
-
def remove_buffer(self, name):
|
|
138
|
-
if name in self.buffers:
|
|
139
|
-
del self.buffers[name]
|
|
140
|
-
|
|
141
|
-
def clear_buffers(self):
|
|
142
|
-
self.buffers.clear()
|
|
143
|
-
|
|
144
|
-
def mark_step_begin(self):
|
|
145
|
-
# Always increase transformer executed steps
|
|
146
|
-
# incr step: prev 0 -> 1; prev 1 -> 2
|
|
147
|
-
# current step: incr step - 1
|
|
148
|
-
self.transformer_executed_steps += 1
|
|
149
|
-
if not self.enable_separate_cfg:
|
|
150
|
-
self.executed_steps += 1
|
|
151
|
-
else:
|
|
152
|
-
# 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
|
|
153
|
-
if not self.cfg_compute_first:
|
|
154
|
-
if not self.is_separate_cfg_step():
|
|
155
|
-
# transformer step: 0,2,4,...
|
|
156
|
-
self.executed_steps += 1
|
|
157
|
-
else:
|
|
158
|
-
if self.is_separate_cfg_step():
|
|
159
|
-
# transformer step: 0,2,4,...
|
|
160
|
-
self.executed_steps += 1
|
|
161
|
-
|
|
162
|
-
# Reset the cached steps and residual diffs at the beginning
|
|
163
|
-
# of each inference.
|
|
164
|
-
if self.get_current_transformer_step() == 0:
|
|
165
|
-
self.cached_steps.clear()
|
|
166
|
-
self.residual_diffs.clear()
|
|
167
|
-
self.cfg_cached_steps.clear()
|
|
168
|
-
self.cfg_residual_diffs.clear()
|
|
169
|
-
# Reset the calibrators cache at the beginning of each inference.
|
|
170
|
-
# reset_cache will set the current step to -1 for calibrator,
|
|
171
|
-
if (
|
|
172
|
-
self.calibrator_config.enable_calibrator
|
|
173
|
-
or self.calibrator_config.enable_encoder_calibrator
|
|
174
|
-
):
|
|
175
|
-
calibrator, encoder_calibrator = self.get_calibrators()
|
|
176
|
-
if calibrator is not None:
|
|
177
|
-
calibrator.reset_cache()
|
|
178
|
-
if encoder_calibrator is not None:
|
|
179
|
-
encoder_calibrator.reset_cache()
|
|
180
|
-
cfg_calibrator, cfg_encoder_calibrator = (
|
|
181
|
-
self.get_cfg_calibrators()
|
|
182
|
-
)
|
|
183
|
-
if cfg_calibrator is not None:
|
|
184
|
-
cfg_calibrator.reset_cache()
|
|
185
|
-
if cfg_encoder_calibrator is not None:
|
|
186
|
-
cfg_encoder_calibrator.reset_cache()
|
|
187
|
-
|
|
188
|
-
# mark_step_begin of calibrator must be called after the cache is reset.
|
|
189
|
-
if (
|
|
190
|
-
self.calibrator_config.enable_calibrator
|
|
191
|
-
or self.calibrator_config.enable_encoder_calibrator
|
|
192
|
-
):
|
|
193
|
-
if self.enable_separate_cfg:
|
|
194
|
-
# Assume non-CFG steps: 0, 2, 4, 6, ...
|
|
195
|
-
if not self.is_separate_cfg_step():
|
|
196
|
-
calibrator, encoder_calibrator = self.get_calibrators()
|
|
197
|
-
if calibrator is not None:
|
|
198
|
-
calibrator.mark_step_begin()
|
|
199
|
-
if encoder_calibrator is not None:
|
|
200
|
-
encoder_calibrator.mark_step_begin()
|
|
201
|
-
else:
|
|
202
|
-
cfg_calibrator, cfg_encoder_calibrator = (
|
|
203
|
-
self.get_cfg_calibrators()
|
|
204
|
-
)
|
|
205
|
-
if cfg_calibrator is not None:
|
|
206
|
-
cfg_calibrator.mark_step_begin()
|
|
207
|
-
if cfg_encoder_calibrator is not None:
|
|
208
|
-
cfg_encoder_calibrator.mark_step_begin()
|
|
209
|
-
else:
|
|
210
|
-
calibrator, encoder_calibrator = self.get_calibrators()
|
|
211
|
-
if calibrator is not None:
|
|
212
|
-
calibrator.mark_step_begin()
|
|
213
|
-
if encoder_calibrator is not None:
|
|
214
|
-
encoder_calibrator.mark_step_begin()
|
|
215
|
-
|
|
216
|
-
def get_calibrators(self) -> Tuple[CalibratorBase, CalibratorBase]:
|
|
217
|
-
return self.calibrator, self.encoder_calibrator
|
|
218
|
-
|
|
219
|
-
def get_cfg_calibrators(self) -> Tuple[CalibratorBase, CalibratorBase]:
|
|
220
|
-
return self.cfg_calibrator, self.cfg_encoder_calibrator
|
|
221
|
-
|
|
222
|
-
def add_residual_diff(self, diff):
|
|
223
|
-
# step: executed_steps - 1, not transformer_steps - 1
|
|
224
|
-
step = str(self.get_current_step())
|
|
225
|
-
# Only add the diff if it is not already recorded for this step
|
|
226
|
-
if not self.is_separate_cfg_step():
|
|
227
|
-
if step not in self.residual_diffs:
|
|
228
|
-
self.residual_diffs[step] = diff
|
|
229
|
-
else:
|
|
230
|
-
if step not in self.cfg_residual_diffs:
|
|
231
|
-
self.cfg_residual_diffs[step] = diff
|
|
232
|
-
|
|
233
|
-
def get_residual_diffs(self):
|
|
234
|
-
return self.residual_diffs.copy()
|
|
235
|
-
|
|
236
|
-
def get_cfg_residual_diffs(self):
|
|
237
|
-
return self.cfg_residual_diffs.copy()
|
|
238
|
-
|
|
239
|
-
def add_cached_step(self):
|
|
240
|
-
curr_cached_step = self.get_current_step()
|
|
241
|
-
if not self.is_separate_cfg_step():
|
|
242
|
-
if self.cached_steps:
|
|
243
|
-
prev_cached_step = self.cached_steps[-1]
|
|
244
|
-
if curr_cached_step - prev_cached_step == 1:
|
|
245
|
-
if self.continuous_cached_steps == 0:
|
|
246
|
-
self.continuous_cached_steps += 2
|
|
247
|
-
else:
|
|
248
|
-
self.continuous_cached_steps += 1
|
|
249
|
-
else:
|
|
250
|
-
self.continuous_cached_steps += 1
|
|
251
|
-
|
|
252
|
-
self.cached_steps.append(curr_cached_step)
|
|
253
|
-
else:
|
|
254
|
-
if self.cfg_cached_steps:
|
|
255
|
-
prev_cfg_cached_step = self.cfg_cached_steps[-1]
|
|
256
|
-
if curr_cached_step - prev_cfg_cached_step == 1:
|
|
257
|
-
if self.cfg_continuous_cached_steps == 0:
|
|
258
|
-
self.cfg_continuous_cached_steps += 2
|
|
259
|
-
else:
|
|
260
|
-
self.cfg_continuous_cached_steps += 1
|
|
261
|
-
else:
|
|
262
|
-
self.cfg_continuous_cached_steps += 1
|
|
263
|
-
|
|
264
|
-
self.cfg_cached_steps.append(curr_cached_step)
|
|
265
|
-
|
|
266
|
-
def get_cached_steps(self):
|
|
267
|
-
return self.cached_steps.copy()
|
|
268
|
-
|
|
269
|
-
def get_cfg_cached_steps(self):
|
|
270
|
-
return self.cfg_cached_steps.copy()
|
|
271
|
-
|
|
272
|
-
def get_current_step(self):
|
|
273
|
-
return self.executed_steps - 1
|
|
274
|
-
|
|
275
|
-
def get_current_transformer_step(self):
|
|
276
|
-
return self.transformer_executed_steps - 1
|
|
277
|
-
|
|
278
|
-
def is_separate_cfg_step(self):
|
|
279
|
-
if not self.enable_separate_cfg:
|
|
280
|
-
return False
|
|
281
|
-
if self.cfg_compute_first:
|
|
282
|
-
# CFG steps: 0, 2, 4, 6, ...
|
|
283
|
-
return self.get_current_transformer_step() % 2 == 0
|
|
284
|
-
# CFG steps: 1, 3, 5, 7, ...
|
|
285
|
-
return self.get_current_transformer_step() % 2 != 0
|
|
286
|
-
|
|
287
|
-
def is_in_warmup(self):
|
|
288
|
-
return self.get_current_step() < self.max_warmup_steps
|