cache-dit 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +3 -6
- cache_dit/cache_factory/block_adapters/block_adapters.py +21 -64
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +82 -21
- cache_dit/cache_factory/cache_blocks/__init__.py +4 -0
- cache_dit/cache_factory/cache_blocks/offload_utils.py +115 -0
- cache_dit/cache_factory/cache_blocks/pattern_base.py +3 -0
- cache_dit/cache_factory/cache_contexts/__init__.py +10 -8
- cache_dit/cache_factory/cache_contexts/cache_context.py +186 -117
- cache_dit/cache_factory/cache_contexts/cache_manager.py +63 -131
- cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +132 -0
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/foca.py +1 -1
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/taylorseer.py +7 -2
- cache_dit/cache_factory/cache_interface.py +128 -111
- cache_dit/cache_factory/params_modifier.py +87 -0
- cache_dit/metrics/__init__.py +3 -1
- cache_dit/utils.py +12 -21
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/METADATA +200 -434
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/RECORD +27 -31
- cache_dit/cache_factory/cache_adapters/v2/__init__.py +0 -3
- cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +0 -524
- cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -102
- cache_dit/cache_factory/cache_contexts/v2/__init__.py +0 -13
- cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +0 -288
- cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +0 -799
- cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +0 -81
- /cache_dit/cache_factory/cache_blocks/{utils.py → pattern_utils.py} +0 -0
- /cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/base.py +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import contextlib
|
|
3
|
-
import
|
|
4
|
-
from typing import Any, Dict, Optional, Tuple, Union, List
|
|
3
|
+
from typing import Dict, Optional, Tuple, Union, List
|
|
5
4
|
|
|
6
5
|
import torch
|
|
7
6
|
import torch.distributed as dist
|
|
8
7
|
|
|
9
|
-
from cache_dit.cache_factory.cache_contexts.
|
|
10
|
-
from cache_dit.cache_factory.cache_contexts.cache_context import
|
|
8
|
+
from cache_dit.cache_factory.cache_contexts.calibrators import CalibratorBase
|
|
9
|
+
from cache_dit.cache_factory.cache_contexts.cache_context import (
|
|
10
|
+
CachedContext,
|
|
11
|
+
)
|
|
11
12
|
from cache_dit.logger import init_logger
|
|
12
13
|
|
|
13
14
|
logger = init_logger(__name__)
|
|
@@ -89,51 +90,6 @@ class CachedContextManager:
|
|
|
89
90
|
finally:
|
|
90
91
|
self._current_context = old_cached_context
|
|
91
92
|
|
|
92
|
-
@staticmethod
|
|
93
|
-
def collect_cache_kwargs(
|
|
94
|
-
default_attrs: dict, **kwargs
|
|
95
|
-
) -> Tuple[Dict, Dict]:
|
|
96
|
-
# NOTE: This API will split kwargs into cache_kwargs and other_kwargs
|
|
97
|
-
# default_attrs: specific settings for different pipelines
|
|
98
|
-
cache_attrs = dataclasses.fields(CachedContext)
|
|
99
|
-
cache_attrs = [
|
|
100
|
-
attr
|
|
101
|
-
for attr in cache_attrs
|
|
102
|
-
if hasattr(
|
|
103
|
-
CachedContext,
|
|
104
|
-
attr.name,
|
|
105
|
-
)
|
|
106
|
-
]
|
|
107
|
-
cache_kwargs = {
|
|
108
|
-
attr.name: kwargs.pop(
|
|
109
|
-
attr.name,
|
|
110
|
-
getattr(CachedContext, attr.name),
|
|
111
|
-
)
|
|
112
|
-
for attr in cache_attrs
|
|
113
|
-
}
|
|
114
|
-
|
|
115
|
-
def _safe_set_sequence_field(
|
|
116
|
-
field_name: str,
|
|
117
|
-
default_value: Any = None,
|
|
118
|
-
):
|
|
119
|
-
if field_name not in cache_kwargs:
|
|
120
|
-
cache_kwargs[field_name] = kwargs.pop(
|
|
121
|
-
field_name,
|
|
122
|
-
default_value,
|
|
123
|
-
)
|
|
124
|
-
|
|
125
|
-
# Manually set sequence fields
|
|
126
|
-
_safe_set_sequence_field("taylorseer_kwargs", {})
|
|
127
|
-
|
|
128
|
-
for attr in cache_attrs:
|
|
129
|
-
if attr.name in default_attrs: # can be empty {}
|
|
130
|
-
cache_kwargs[attr.name] = default_attrs[attr.name]
|
|
131
|
-
|
|
132
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
133
|
-
logger.debug(f"Collected Cache kwargs: {cache_kwargs}")
|
|
134
|
-
|
|
135
|
-
return cache_kwargs, kwargs
|
|
136
|
-
|
|
137
93
|
@torch.compiler.disable
|
|
138
94
|
def get_residual_diff_threshold(self) -> float:
|
|
139
95
|
cached_context = self.get_context()
|
|
@@ -212,13 +168,13 @@ class CachedContextManager:
|
|
|
212
168
|
def get_max_cached_steps(self) -> int:
|
|
213
169
|
cached_context = self.get_context()
|
|
214
170
|
assert cached_context is not None, "cached_context must be set before"
|
|
215
|
-
return cached_context.max_cached_steps
|
|
171
|
+
return cached_context.cache_config.max_cached_steps
|
|
216
172
|
|
|
217
173
|
@torch.compiler.disable
|
|
218
174
|
def get_max_continuous_cached_steps(self) -> int:
|
|
219
175
|
cached_context = self.get_context()
|
|
220
176
|
assert cached_context is not None, "cached_context must be set before"
|
|
221
|
-
return cached_context.max_continuous_cached_steps
|
|
177
|
+
return cached_context.cache_config.max_continuous_cached_steps
|
|
222
178
|
|
|
223
179
|
@torch.compiler.disable
|
|
224
180
|
def get_continuous_cached_steps(self) -> int:
|
|
@@ -257,45 +213,45 @@ class CachedContextManager:
|
|
|
257
213
|
return cached_context.get_cfg_residual_diffs()
|
|
258
214
|
|
|
259
215
|
@torch.compiler.disable
|
|
260
|
-
def
|
|
216
|
+
def is_calibrator_enabled(self) -> bool:
|
|
261
217
|
cached_context = self.get_context()
|
|
262
218
|
assert cached_context is not None, "cached_context must be set before"
|
|
263
|
-
return cached_context.
|
|
219
|
+
return cached_context.enable_calibrator()
|
|
264
220
|
|
|
265
221
|
@torch.compiler.disable
|
|
266
|
-
def
|
|
222
|
+
def is_encoder_calibrator_enabled(self) -> bool:
|
|
267
223
|
cached_context = self.get_context()
|
|
268
224
|
assert cached_context is not None, "cached_context must be set before"
|
|
269
|
-
return cached_context.
|
|
225
|
+
return cached_context.enable_encoder_calibrator()
|
|
270
226
|
|
|
271
|
-
def
|
|
227
|
+
def get_calibrator(self) -> Tuple[CalibratorBase, CalibratorBase]:
|
|
272
228
|
cached_context = self.get_context()
|
|
273
229
|
assert cached_context is not None, "cached_context must be set before"
|
|
274
|
-
return cached_context.
|
|
230
|
+
return cached_context.get_calibrators()
|
|
275
231
|
|
|
276
|
-
def
|
|
232
|
+
def get_cfg_calibrator(self) -> Tuple[CalibratorBase, CalibratorBase]:
|
|
277
233
|
cached_context = self.get_context()
|
|
278
234
|
assert cached_context is not None, "cached_context must be set before"
|
|
279
|
-
return cached_context.
|
|
235
|
+
return cached_context.get_cfg_calibrators()
|
|
280
236
|
|
|
281
237
|
@torch.compiler.disable
|
|
282
|
-
def
|
|
238
|
+
def is_calibrator_cache_residual(self) -> bool:
|
|
283
239
|
cached_context = self.get_context()
|
|
284
240
|
assert cached_context is not None, "cached_context must be set before"
|
|
285
|
-
return cached_context.
|
|
241
|
+
return cached_context.calibrator_cache_type() == "residual"
|
|
286
242
|
|
|
287
243
|
@torch.compiler.disable
|
|
288
244
|
def is_cache_residual(self) -> bool:
|
|
289
|
-
if self.
|
|
245
|
+
if self.is_calibrator_enabled():
|
|
290
246
|
# residual or hidden_states
|
|
291
|
-
return self.
|
|
247
|
+
return self.is_calibrator_cache_residual()
|
|
292
248
|
return True
|
|
293
249
|
|
|
294
250
|
@torch.compiler.disable
|
|
295
251
|
def is_encoder_cache_residual(self) -> bool:
|
|
296
|
-
if self.
|
|
252
|
+
if self.is_encoder_calibrator_enabled():
|
|
297
253
|
# residual or hidden_states
|
|
298
|
-
return self.
|
|
254
|
+
return self.is_calibrator_cache_residual()
|
|
299
255
|
return True
|
|
300
256
|
|
|
301
257
|
@torch.compiler.disable
|
|
@@ -309,65 +265,41 @@ class CachedContextManager:
|
|
|
309
265
|
cached_context = self.get_context()
|
|
310
266
|
assert cached_context is not None, "cached_context must be set before"
|
|
311
267
|
return (
|
|
312
|
-
cached_context.l1_hidden_states_diff_threshold
|
|
313
|
-
|
|
268
|
+
cached_context.extra_cache_config.l1_hidden_states_diff_threshold
|
|
269
|
+
is not None
|
|
270
|
+
and cached_context.extra_cache_config.l1_hidden_states_diff_threshold
|
|
271
|
+
> 0.0
|
|
314
272
|
)
|
|
315
273
|
|
|
316
274
|
@torch.compiler.disable
|
|
317
275
|
def get_important_condition_threshold(self) -> float:
|
|
318
276
|
cached_context = self.get_context()
|
|
319
277
|
assert cached_context is not None, "cached_context must be set before"
|
|
320
|
-
return cached_context.important_condition_threshold
|
|
321
|
-
|
|
322
|
-
@torch.compiler.disable
|
|
323
|
-
def non_compute_blocks_diff_threshold(self) -> float:
|
|
324
|
-
cached_context = self.get_context()
|
|
325
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
326
|
-
return cached_context.non_compute_blocks_diff_threshold
|
|
278
|
+
return cached_context.extra_cache_config.important_condition_threshold
|
|
327
279
|
|
|
328
280
|
@torch.compiler.disable
|
|
329
281
|
def Fn_compute_blocks(self) -> int:
|
|
330
282
|
cached_context = self.get_context()
|
|
331
283
|
assert cached_context is not None, "cached_context must be set before"
|
|
332
284
|
assert (
|
|
333
|
-
cached_context.Fn_compute_blocks >= 1
|
|
285
|
+
cached_context.cache_config.Fn_compute_blocks >= 1
|
|
334
286
|
), "Fn_compute_blocks must be >= 1"
|
|
335
|
-
|
|
336
|
-
# NOTE: Fn_compute_blocks can be 1, which means FB Cache
|
|
337
|
-
# but it must be less than or equal to max_Fn_compute_blocks
|
|
338
|
-
assert (
|
|
339
|
-
cached_context.Fn_compute_blocks
|
|
340
|
-
<= cached_context.max_Fn_compute_blocks
|
|
341
|
-
), (
|
|
342
|
-
f"Fn_compute_blocks must be <= {cached_context.max_Fn_compute_blocks}, "
|
|
343
|
-
f"but got {cached_context.Fn_compute_blocks}"
|
|
344
|
-
)
|
|
345
|
-
return cached_context.Fn_compute_blocks
|
|
287
|
+
return cached_context.cache_config.Fn_compute_blocks
|
|
346
288
|
|
|
347
289
|
@torch.compiler.disable
|
|
348
290
|
def Bn_compute_blocks(self) -> int:
|
|
349
291
|
cached_context = self.get_context()
|
|
350
292
|
assert cached_context is not None, "cached_context must be set before"
|
|
351
293
|
assert (
|
|
352
|
-
cached_context.Bn_compute_blocks >= 0
|
|
294
|
+
cached_context.cache_config.Bn_compute_blocks >= 0
|
|
353
295
|
), "Bn_compute_blocks must be >= 0"
|
|
354
|
-
|
|
355
|
-
# NOTE: Bn_compute_blocks can be 0, which means FB Cache
|
|
356
|
-
# but it must be less than or equal to max_Bn_compute_blocks
|
|
357
|
-
assert (
|
|
358
|
-
cached_context.Bn_compute_blocks
|
|
359
|
-
<= cached_context.max_Bn_compute_blocks
|
|
360
|
-
), (
|
|
361
|
-
f"Bn_compute_blocks must be <= {cached_context.max_Bn_compute_blocks}, "
|
|
362
|
-
f"but got {cached_context.Bn_compute_blocks}"
|
|
363
|
-
)
|
|
364
|
-
return cached_context.Bn_compute_blocks
|
|
296
|
+
return cached_context.cache_config.Bn_compute_blocks
|
|
365
297
|
|
|
366
298
|
@torch.compiler.disable
|
|
367
299
|
def enable_separate_cfg(self) -> bool:
|
|
368
300
|
cached_context = self.get_context()
|
|
369
301
|
assert cached_context is not None, "cached_context must be set before"
|
|
370
|
-
return cached_context.enable_separate_cfg
|
|
302
|
+
return cached_context.cache_config.enable_separate_cfg
|
|
371
303
|
|
|
372
304
|
@torch.compiler.disable
|
|
373
305
|
def is_separate_cfg_step(self) -> bool:
|
|
@@ -379,7 +311,7 @@ class CachedContextManager:
|
|
|
379
311
|
def cfg_diff_compute_separate(self) -> bool:
|
|
380
312
|
cached_context = self.get_context()
|
|
381
313
|
assert cached_context is not None, "cached_context must be set before"
|
|
382
|
-
return cached_context.cfg_diff_compute_separate
|
|
314
|
+
return cached_context.cache_config.cfg_diff_compute_separate
|
|
383
315
|
|
|
384
316
|
@torch.compiler.disable
|
|
385
317
|
def similarity(
|
|
@@ -534,20 +466,20 @@ class CachedContextManager:
|
|
|
534
466
|
return
|
|
535
467
|
# Set hidden_states or residual for Bn blocks.
|
|
536
468
|
# This buffer is use for hidden states approximation.
|
|
537
|
-
if self.
|
|
538
|
-
#
|
|
469
|
+
if self.is_calibrator_enabled():
|
|
470
|
+
# calibrator, encoder_calibrator
|
|
539
471
|
if self.is_separate_cfg_step():
|
|
540
|
-
|
|
472
|
+
calibrator, _ = self.get_cfg_calibrator()
|
|
541
473
|
else:
|
|
542
|
-
|
|
474
|
+
calibrator, _ = self.get_calibrator()
|
|
543
475
|
|
|
544
|
-
if
|
|
545
|
-
# Use
|
|
546
|
-
|
|
476
|
+
if calibrator is not None:
|
|
477
|
+
# Use calibrator to update the buffer
|
|
478
|
+
calibrator.update(buffer)
|
|
547
479
|
else:
|
|
548
480
|
if logger.isEnabledFor(logging.DEBUG):
|
|
549
481
|
logger.debug(
|
|
550
|
-
"
|
|
482
|
+
"calibrator is enabled but not set in the cache context. "
|
|
551
483
|
"Falling back to default buffer retrieval."
|
|
552
484
|
)
|
|
553
485
|
if self.is_separate_cfg_step():
|
|
@@ -566,19 +498,19 @@ class CachedContextManager:
|
|
|
566
498
|
|
|
567
499
|
@torch.compiler.disable
|
|
568
500
|
def get_Bn_buffer(self, prefix: str = "Bn") -> torch.Tensor:
|
|
569
|
-
if self.
|
|
570
|
-
#
|
|
501
|
+
if self.is_calibrator_enabled():
|
|
502
|
+
# calibrator, encoder_calibrator
|
|
571
503
|
if self.is_separate_cfg_step():
|
|
572
|
-
|
|
504
|
+
calibrator, _ = self.get_cfg_calibrator()
|
|
573
505
|
else:
|
|
574
|
-
|
|
506
|
+
calibrator, _ = self.get_calibrator()
|
|
575
507
|
|
|
576
|
-
if
|
|
577
|
-
return
|
|
508
|
+
if calibrator is not None:
|
|
509
|
+
return calibrator.approximate()
|
|
578
510
|
else:
|
|
579
511
|
if logger.isEnabledFor(logging.DEBUG):
|
|
580
512
|
logger.debug(
|
|
581
|
-
"
|
|
513
|
+
"calibrator is enabled but not set in the cache context. "
|
|
582
514
|
"Falling back to default buffer retrieval."
|
|
583
515
|
)
|
|
584
516
|
# Fallback to default buffer retrieval
|
|
@@ -603,20 +535,20 @@ class CachedContextManager:
|
|
|
603
535
|
return
|
|
604
536
|
|
|
605
537
|
# This buffer is use for encoder hidden states approximation.
|
|
606
|
-
if self.
|
|
607
|
-
#
|
|
538
|
+
if self.is_encoder_calibrator_enabled():
|
|
539
|
+
# calibrator, encoder_calibrator
|
|
608
540
|
if self.is_separate_cfg_step():
|
|
609
|
-
_,
|
|
541
|
+
_, encoder_calibrator = self.get_cfg_calibrator()
|
|
610
542
|
else:
|
|
611
|
-
_,
|
|
543
|
+
_, encoder_calibrator = self.get_calibrator()
|
|
612
544
|
|
|
613
|
-
if
|
|
614
|
-
# Use
|
|
615
|
-
|
|
545
|
+
if encoder_calibrator is not None:
|
|
546
|
+
# Use CalibratorBase to update the buffer
|
|
547
|
+
encoder_calibrator.update(buffer)
|
|
616
548
|
else:
|
|
617
549
|
if logger.isEnabledFor(logging.DEBUG):
|
|
618
550
|
logger.debug(
|
|
619
|
-
"
|
|
551
|
+
"CalibratorBase is enabled but not set in the cache context. "
|
|
620
552
|
"Falling back to default buffer retrieval."
|
|
621
553
|
)
|
|
622
554
|
if self.is_separate_cfg_step():
|
|
@@ -635,19 +567,19 @@ class CachedContextManager:
|
|
|
635
567
|
|
|
636
568
|
@torch.compiler.disable
|
|
637
569
|
def get_Bn_encoder_buffer(self, prefix: str = "Bn") -> torch.Tensor:
|
|
638
|
-
if self.
|
|
570
|
+
if self.is_encoder_calibrator_enabled():
|
|
639
571
|
if self.is_separate_cfg_step():
|
|
640
|
-
_,
|
|
572
|
+
_, encoder_calibrator = self.get_cfg_calibrator()
|
|
641
573
|
else:
|
|
642
|
-
_,
|
|
574
|
+
_, encoder_calibrator = self.get_calibrator()
|
|
643
575
|
|
|
644
|
-
if
|
|
645
|
-
# Use
|
|
646
|
-
return
|
|
576
|
+
if encoder_calibrator is not None:
|
|
577
|
+
# Use calibrator to approximate the value
|
|
578
|
+
return encoder_calibrator.approximate()
|
|
647
579
|
else:
|
|
648
580
|
if logger.isEnabledFor(logging.DEBUG):
|
|
649
581
|
logger.debug(
|
|
650
|
-
"
|
|
582
|
+
"calibrator is enabled but not set in the cache context. "
|
|
651
583
|
"Falling back to default buffer retrieval."
|
|
652
584
|
)
|
|
653
585
|
# Fallback to default buffer retrieval
|
|
@@ -717,7 +649,7 @@ class CachedContextManager:
|
|
|
717
649
|
def get_downsample_factor(self) -> float:
|
|
718
650
|
cached_context = self.get_context()
|
|
719
651
|
assert cached_context is not None, "cached_context must be set before"
|
|
720
|
-
return cached_context.downsample_factor
|
|
652
|
+
return cached_context.extra_cache_config.downsample_factor
|
|
721
653
|
|
|
722
654
|
@torch.compiler.disable
|
|
723
655
|
def can_cache(
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from cache_dit.cache_factory.cache_contexts.calibrators.base import (
|
|
2
|
+
CalibratorBase,
|
|
3
|
+
)
|
|
4
|
+
from cache_dit.cache_factory.cache_contexts.calibrators.taylorseer import (
|
|
5
|
+
TaylorSeerCalibrator,
|
|
6
|
+
)
|
|
7
|
+
from cache_dit.cache_factory.cache_contexts.calibrators.foca import (
|
|
8
|
+
FoCaCalibrator,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
import dataclasses
|
|
12
|
+
from typing import Any, Dict
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from cache_dit.logger import init_logger
|
|
16
|
+
|
|
17
|
+
logger = init_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclasses.dataclass
|
|
21
|
+
class CalibratorConfig:
|
|
22
|
+
# enable_calibrator (`bool`, *required*, defaults to False):
|
|
23
|
+
# Whether to enable calibrator, if True. means that user want to use DBCache
|
|
24
|
+
# with specific calibrator for hidden_states (or hidden_states redisual),
|
|
25
|
+
# such as taylorseer, foca, and so on.
|
|
26
|
+
enable_calibrator: bool = False
|
|
27
|
+
# enable_encoder_calibrator (`bool`, *required*, defaults to False):
|
|
28
|
+
# Whether to enable calibrator, if True. means that user want to use DBCache
|
|
29
|
+
# with specific calibrator for encoder_hidden_states (or encoder_hidden_states
|
|
30
|
+
# redisual), such as taylorseer, foca, and so on.
|
|
31
|
+
enable_encoder_calibrator: bool = False
|
|
32
|
+
# calibrator_type (`str`, *required*, defaults to 'taylorseer'):
|
|
33
|
+
# The specific type for calibrator, taylorseer or foca, etc.
|
|
34
|
+
calibrator_type: str = "taylorseer"
|
|
35
|
+
# calibrator_cache_type (`str`, *required*, defaults to 'residual'):
|
|
36
|
+
# The specific cache type for calibrator, residual or hidden_states.
|
|
37
|
+
calibrator_cache_type: str = "residual"
|
|
38
|
+
# calibrator_kwargs (`dict`, *optional*, defaults to {}):
|
|
39
|
+
# Init kwargs for specific calibrator, taylorseer or foca, etc.
|
|
40
|
+
calibrator_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
41
|
+
|
|
42
|
+
def strify(self, **kwargs) -> str:
|
|
43
|
+
return "CalibratorBase"
|
|
44
|
+
|
|
45
|
+
def to_kwargs(self) -> Dict:
|
|
46
|
+
return self.calibrator_kwargs.copy()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclasses.dataclass
|
|
50
|
+
class TaylorSeerCalibratorConfig(CalibratorConfig):
|
|
51
|
+
# TaylorSeers: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
|
|
52
|
+
# link: https://arxiv.org/pdf/2503.06923
|
|
53
|
+
|
|
54
|
+
# enable_calibrator (`bool`, *required*, defaults to True):
|
|
55
|
+
# Whether to enable calibrator, if True. means that user want to use DBCache
|
|
56
|
+
# with specific calibrator for hidden_states (or hidden_states redisual),
|
|
57
|
+
# such as taylorseer, foca, and so on.
|
|
58
|
+
enable_calibrator: bool = True
|
|
59
|
+
# enable_encoder_calibrator (`bool`, *required*, defaults to True):
|
|
60
|
+
# Whether to enable calibrator, if True. means that user want to use DBCache
|
|
61
|
+
# with specific calibrator for encoder_hidden_states (or encoder_hidden_states
|
|
62
|
+
# redisual), such as taylorseer, foca, and so on.
|
|
63
|
+
enable_encoder_calibrator: bool = True
|
|
64
|
+
# calibrator_type (`str`, *required*, defaults to 'taylorseer'):
|
|
65
|
+
# The specific type for calibrator, taylorseer or foca, etc.
|
|
66
|
+
calibrator_type: str = "taylorseer"
|
|
67
|
+
# taylorseer_order (`int`, *required*, defaults to 1):
|
|
68
|
+
# The order of taylorseer, higher values of n_derivatives will lead to longer computation time,
|
|
69
|
+
# the recommended value is 1 or 2. Please check [TaylorSeers: From Reusing to Forecasting:
|
|
70
|
+
# Accelerating Diffusion Models with TaylorSeers](https://arxiv.org/pdf/2503.06923) for
|
|
71
|
+
# more details.
|
|
72
|
+
taylorseer_order: int = 1
|
|
73
|
+
|
|
74
|
+
def strify(self, **kwargs) -> str:
|
|
75
|
+
if kwargs.get("details", False):
|
|
76
|
+
if self.taylorseer_order:
|
|
77
|
+
return f"TaylorSeer_O({self.taylorseer_order})"
|
|
78
|
+
return "NONE"
|
|
79
|
+
|
|
80
|
+
if self.taylorseer_order:
|
|
81
|
+
return f"T1O{self.taylorseer_order}"
|
|
82
|
+
return "NONE"
|
|
83
|
+
|
|
84
|
+
def to_kwargs(self) -> Dict:
|
|
85
|
+
kwargs = self.calibrator_kwargs.copy()
|
|
86
|
+
kwargs["n_derivatives"] = self.taylorseer_order
|
|
87
|
+
return kwargs
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclasses.dataclass
|
|
91
|
+
class FoCaCalibratorConfig(CalibratorConfig):
|
|
92
|
+
# FoCa: Forecast then Calibrate: Feature Caching as ODE for Efficient Diffusion Transformers
|
|
93
|
+
# link: https://arxiv.org/pdf/2508.16211
|
|
94
|
+
|
|
95
|
+
# enable_calibrator (`bool`, *required*, defaults to True):
|
|
96
|
+
# Whether to enable calibrator, if True. means that user want to use DBCache
|
|
97
|
+
# with specific calibrator for hidden_states (or hidden_states redisual),
|
|
98
|
+
# such as taylorseer, foca, and so on.
|
|
99
|
+
enable_calibrator: bool = True
|
|
100
|
+
# enable_encoder_calibrator (`bool`, *required*, defaults to True):
|
|
101
|
+
# Whether to enable calibrator, if True. means that user want to use DBCache
|
|
102
|
+
# with specific calibrator for encoder_hidden_states (or encoder_hidden_states
|
|
103
|
+
# redisual), such as taylorseer, foca, and so on.
|
|
104
|
+
enable_encoder_calibrator: bool = True
|
|
105
|
+
# calibrator_type (`str`, *required*, defaults to 'taylorseer'):
|
|
106
|
+
# The specific type for calibrator, taylorseer or foca, etc.
|
|
107
|
+
calibrator_type: str = "foca"
|
|
108
|
+
|
|
109
|
+
def strify(self, **kwargs) -> str:
|
|
110
|
+
return "FoCa"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class Calibrator:
|
|
114
|
+
_supported_calibrators = [
|
|
115
|
+
"taylorseer",
|
|
116
|
+
# TODO: FoCa
|
|
117
|
+
]
|
|
118
|
+
|
|
119
|
+
def __new__(
|
|
120
|
+
cls,
|
|
121
|
+
calibrator_config: CalibratorConfig,
|
|
122
|
+
) -> CalibratorBase:
|
|
123
|
+
assert (
|
|
124
|
+
calibrator_config.calibrator_type in cls._supported_calibrators
|
|
125
|
+
), f"Calibrator {calibrator_config.calibrator_type} is not supported now!"
|
|
126
|
+
|
|
127
|
+
if calibrator_config.calibrator_type.lower() == "taylorseer":
|
|
128
|
+
return TaylorSeerCalibrator(**calibrator_config.to_kwargs())
|
|
129
|
+
else:
|
|
130
|
+
raise ValueError(
|
|
131
|
+
f"Calibrator {calibrator_config.calibrator_type} is not supported now!"
|
|
132
|
+
)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import torch
|
|
3
3
|
from typing import List, Dict
|
|
4
|
-
from cache_dit.cache_factory.cache_contexts.
|
|
4
|
+
from cache_dit.cache_factory.cache_contexts.calibrators.base import (
|
|
5
5
|
CalibratorBase,
|
|
6
6
|
)
|
|
7
7
|
|
|
@@ -22,8 +22,13 @@ class TaylorSeerCalibrator(CalibratorBase):
|
|
|
22
22
|
self.order = n_derivatives + 1
|
|
23
23
|
self.max_warmup_steps = max_warmup_steps
|
|
24
24
|
self.skip_interval_steps = skip_interval_steps
|
|
25
|
+
self.current_step = -1
|
|
26
|
+
self.last_non_approximated_step = -1
|
|
27
|
+
self.state: Dict[str, List[torch.Tensor]] = {
|
|
28
|
+
"dY_prev": [None] * self.order,
|
|
29
|
+
"dY_current": [None] * self.order,
|
|
30
|
+
}
|
|
25
31
|
self.reset_cache()
|
|
26
|
-
logger.info(f"Created {self.__repr__()}_{id(self)}")
|
|
27
32
|
|
|
28
33
|
def reset_cache(self): # NEED
|
|
29
34
|
self.state: Dict[str, List[torch.Tensor]] = {
|