cache-dit 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (34) hide show
  1. cache_dit/__init__.py +1 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +3 -6
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +21 -64
  5. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  6. cache_dit/cache_factory/cache_adapters/cache_adapter.py +82 -21
  7. cache_dit/cache_factory/cache_blocks/__init__.py +4 -0
  8. cache_dit/cache_factory/cache_blocks/offload_utils.py +115 -0
  9. cache_dit/cache_factory/cache_blocks/pattern_base.py +3 -0
  10. cache_dit/cache_factory/cache_contexts/__init__.py +10 -8
  11. cache_dit/cache_factory/cache_contexts/cache_context.py +186 -117
  12. cache_dit/cache_factory/cache_contexts/cache_manager.py +63 -131
  13. cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +132 -0
  14. cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/foca.py +1 -1
  15. cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/taylorseer.py +7 -2
  16. cache_dit/cache_factory/cache_interface.py +128 -111
  17. cache_dit/cache_factory/params_modifier.py +87 -0
  18. cache_dit/metrics/__init__.py +3 -1
  19. cache_dit/utils.py +12 -21
  20. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/METADATA +200 -434
  21. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/RECORD +27 -31
  22. cache_dit/cache_factory/cache_adapters/v2/__init__.py +0 -3
  23. cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +0 -524
  24. cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -102
  25. cache_dit/cache_factory/cache_contexts/v2/__init__.py +0 -13
  26. cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +0 -288
  27. cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +0 -799
  28. cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +0 -81
  29. /cache_dit/cache_factory/cache_blocks/{utils.py → pattern_utils.py} +0 -0
  30. /cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/base.py +0 -0
  31. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/WHEEL +0 -0
  32. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/entry_points.txt +0 -0
  33. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/licenses/LICENSE +0 -0
  34. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,14 @@
1
1
  import logging
2
2
  import contextlib
3
- import dataclasses
4
- from typing import Any, Dict, Optional, Tuple, Union, List
3
+ from typing import Dict, Optional, Tuple, Union, List
5
4
 
6
5
  import torch
7
6
  import torch.distributed as dist
8
7
 
9
- from cache_dit.cache_factory.cache_contexts.taylorseer import TaylorSeer
10
- from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
8
+ from cache_dit.cache_factory.cache_contexts.calibrators import CalibratorBase
9
+ from cache_dit.cache_factory.cache_contexts.cache_context import (
10
+ CachedContext,
11
+ )
11
12
  from cache_dit.logger import init_logger
12
13
 
13
14
  logger = init_logger(__name__)
@@ -89,51 +90,6 @@ class CachedContextManager:
89
90
  finally:
90
91
  self._current_context = old_cached_context
91
92
 
92
- @staticmethod
93
- def collect_cache_kwargs(
94
- default_attrs: dict, **kwargs
95
- ) -> Tuple[Dict, Dict]:
96
- # NOTE: This API will split kwargs into cache_kwargs and other_kwargs
97
- # default_attrs: specific settings for different pipelines
98
- cache_attrs = dataclasses.fields(CachedContext)
99
- cache_attrs = [
100
- attr
101
- for attr in cache_attrs
102
- if hasattr(
103
- CachedContext,
104
- attr.name,
105
- )
106
- ]
107
- cache_kwargs = {
108
- attr.name: kwargs.pop(
109
- attr.name,
110
- getattr(CachedContext, attr.name),
111
- )
112
- for attr in cache_attrs
113
- }
114
-
115
- def _safe_set_sequence_field(
116
- field_name: str,
117
- default_value: Any = None,
118
- ):
119
- if field_name not in cache_kwargs:
120
- cache_kwargs[field_name] = kwargs.pop(
121
- field_name,
122
- default_value,
123
- )
124
-
125
- # Manually set sequence fields
126
- _safe_set_sequence_field("taylorseer_kwargs", {})
127
-
128
- for attr in cache_attrs:
129
- if attr.name in default_attrs: # can be empty {}
130
- cache_kwargs[attr.name] = default_attrs[attr.name]
131
-
132
- if logger.isEnabledFor(logging.DEBUG):
133
- logger.debug(f"Collected Cache kwargs: {cache_kwargs}")
134
-
135
- return cache_kwargs, kwargs
136
-
137
93
  @torch.compiler.disable
138
94
  def get_residual_diff_threshold(self) -> float:
139
95
  cached_context = self.get_context()
@@ -212,13 +168,13 @@ class CachedContextManager:
212
168
  def get_max_cached_steps(self) -> int:
213
169
  cached_context = self.get_context()
214
170
  assert cached_context is not None, "cached_context must be set before"
215
- return cached_context.max_cached_steps
171
+ return cached_context.cache_config.max_cached_steps
216
172
 
217
173
  @torch.compiler.disable
218
174
  def get_max_continuous_cached_steps(self) -> int:
219
175
  cached_context = self.get_context()
220
176
  assert cached_context is not None, "cached_context must be set before"
221
- return cached_context.max_continuous_cached_steps
177
+ return cached_context.cache_config.max_continuous_cached_steps
222
178
 
223
179
  @torch.compiler.disable
224
180
  def get_continuous_cached_steps(self) -> int:
@@ -257,45 +213,45 @@ class CachedContextManager:
257
213
  return cached_context.get_cfg_residual_diffs()
258
214
 
259
215
  @torch.compiler.disable
260
- def is_taylorseer_enabled(self) -> bool:
216
+ def is_calibrator_enabled(self) -> bool:
261
217
  cached_context = self.get_context()
262
218
  assert cached_context is not None, "cached_context must be set before"
263
- return cached_context.enable_taylorseer
219
+ return cached_context.enable_calibrator()
264
220
 
265
221
  @torch.compiler.disable
266
- def is_encoder_taylorseer_enabled(self) -> bool:
222
+ def is_encoder_calibrator_enabled(self) -> bool:
267
223
  cached_context = self.get_context()
268
224
  assert cached_context is not None, "cached_context must be set before"
269
- return cached_context.enable_encoder_taylorseer
225
+ return cached_context.enable_encoder_calibrator()
270
226
 
271
- def get_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
227
+ def get_calibrator(self) -> Tuple[CalibratorBase, CalibratorBase]:
272
228
  cached_context = self.get_context()
273
229
  assert cached_context is not None, "cached_context must be set before"
274
- return cached_context.get_taylorseers()
230
+ return cached_context.get_calibrators()
275
231
 
276
- def get_cfg_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
232
+ def get_cfg_calibrator(self) -> Tuple[CalibratorBase, CalibratorBase]:
277
233
  cached_context = self.get_context()
278
234
  assert cached_context is not None, "cached_context must be set before"
279
- return cached_context.get_cfg_taylorseers()
235
+ return cached_context.get_cfg_calibrators()
280
236
 
281
237
  @torch.compiler.disable
282
- def is_taylorseer_cache_residual(self) -> bool:
238
+ def is_calibrator_cache_residual(self) -> bool:
283
239
  cached_context = self.get_context()
284
240
  assert cached_context is not None, "cached_context must be set before"
285
- return cached_context.taylorseer_cache_type == "residual"
241
+ return cached_context.calibrator_cache_type() == "residual"
286
242
 
287
243
  @torch.compiler.disable
288
244
  def is_cache_residual(self) -> bool:
289
- if self.is_taylorseer_enabled():
245
+ if self.is_calibrator_enabled():
290
246
  # residual or hidden_states
291
- return self.is_taylorseer_cache_residual()
247
+ return self.is_calibrator_cache_residual()
292
248
  return True
293
249
 
294
250
  @torch.compiler.disable
295
251
  def is_encoder_cache_residual(self) -> bool:
296
- if self.is_encoder_taylorseer_enabled():
252
+ if self.is_encoder_calibrator_enabled():
297
253
  # residual or hidden_states
298
- return self.is_taylorseer_cache_residual()
254
+ return self.is_calibrator_cache_residual()
299
255
  return True
300
256
 
301
257
  @torch.compiler.disable
@@ -309,65 +265,41 @@ class CachedContextManager:
309
265
  cached_context = self.get_context()
310
266
  assert cached_context is not None, "cached_context must be set before"
311
267
  return (
312
- cached_context.l1_hidden_states_diff_threshold is not None
313
- and cached_context.l1_hidden_states_diff_threshold > 0.0
268
+ cached_context.extra_cache_config.l1_hidden_states_diff_threshold
269
+ is not None
270
+ and cached_context.extra_cache_config.l1_hidden_states_diff_threshold
271
+ > 0.0
314
272
  )
315
273
 
316
274
  @torch.compiler.disable
317
275
  def get_important_condition_threshold(self) -> float:
318
276
  cached_context = self.get_context()
319
277
  assert cached_context is not None, "cached_context must be set before"
320
- return cached_context.important_condition_threshold
321
-
322
- @torch.compiler.disable
323
- def non_compute_blocks_diff_threshold(self) -> float:
324
- cached_context = self.get_context()
325
- assert cached_context is not None, "cached_context must be set before"
326
- return cached_context.non_compute_blocks_diff_threshold
278
+ return cached_context.extra_cache_config.important_condition_threshold
327
279
 
328
280
  @torch.compiler.disable
329
281
  def Fn_compute_blocks(self) -> int:
330
282
  cached_context = self.get_context()
331
283
  assert cached_context is not None, "cached_context must be set before"
332
284
  assert (
333
- cached_context.Fn_compute_blocks >= 1
285
+ cached_context.cache_config.Fn_compute_blocks >= 1
334
286
  ), "Fn_compute_blocks must be >= 1"
335
- if cached_context.max_Fn_compute_blocks > 0:
336
- # NOTE: Fn_compute_blocks can be 1, which means FB Cache
337
- # but it must be less than or equal to max_Fn_compute_blocks
338
- assert (
339
- cached_context.Fn_compute_blocks
340
- <= cached_context.max_Fn_compute_blocks
341
- ), (
342
- f"Fn_compute_blocks must be <= {cached_context.max_Fn_compute_blocks}, "
343
- f"but got {cached_context.Fn_compute_blocks}"
344
- )
345
- return cached_context.Fn_compute_blocks
287
+ return cached_context.cache_config.Fn_compute_blocks
346
288
 
347
289
  @torch.compiler.disable
348
290
  def Bn_compute_blocks(self) -> int:
349
291
  cached_context = self.get_context()
350
292
  assert cached_context is not None, "cached_context must be set before"
351
293
  assert (
352
- cached_context.Bn_compute_blocks >= 0
294
+ cached_context.cache_config.Bn_compute_blocks >= 0
353
295
  ), "Bn_compute_blocks must be >= 0"
354
- if cached_context.max_Bn_compute_blocks > 0:
355
- # NOTE: Bn_compute_blocks can be 0, which means FB Cache
356
- # but it must be less than or equal to max_Bn_compute_blocks
357
- assert (
358
- cached_context.Bn_compute_blocks
359
- <= cached_context.max_Bn_compute_blocks
360
- ), (
361
- f"Bn_compute_blocks must be <= {cached_context.max_Bn_compute_blocks}, "
362
- f"but got {cached_context.Bn_compute_blocks}"
363
- )
364
- return cached_context.Bn_compute_blocks
296
+ return cached_context.cache_config.Bn_compute_blocks
365
297
 
366
298
  @torch.compiler.disable
367
299
  def enable_separate_cfg(self) -> bool:
368
300
  cached_context = self.get_context()
369
301
  assert cached_context is not None, "cached_context must be set before"
370
- return cached_context.enable_separate_cfg
302
+ return cached_context.cache_config.enable_separate_cfg
371
303
 
372
304
  @torch.compiler.disable
373
305
  def is_separate_cfg_step(self) -> bool:
@@ -379,7 +311,7 @@ class CachedContextManager:
379
311
  def cfg_diff_compute_separate(self) -> bool:
380
312
  cached_context = self.get_context()
381
313
  assert cached_context is not None, "cached_context must be set before"
382
- return cached_context.cfg_diff_compute_separate
314
+ return cached_context.cache_config.cfg_diff_compute_separate
383
315
 
384
316
  @torch.compiler.disable
385
317
  def similarity(
@@ -534,20 +466,20 @@ class CachedContextManager:
534
466
  return
535
467
  # Set hidden_states or residual for Bn blocks.
536
468
  # This buffer is use for hidden states approximation.
537
- if self.is_taylorseer_enabled():
538
- # taylorseer, encoder_taylorseer
469
+ if self.is_calibrator_enabled():
470
+ # calibrator, encoder_calibrator
539
471
  if self.is_separate_cfg_step():
540
- taylorseer, _ = self.get_cfg_taylorseers()
472
+ calibrator, _ = self.get_cfg_calibrator()
541
473
  else:
542
- taylorseer, _ = self.get_taylorseers()
474
+ calibrator, _ = self.get_calibrator()
543
475
 
544
- if taylorseer is not None:
545
- # Use TaylorSeer to update the buffer
546
- taylorseer.update(buffer)
476
+ if calibrator is not None:
477
+ # Use calibrator to update the buffer
478
+ calibrator.update(buffer)
547
479
  else:
548
480
  if logger.isEnabledFor(logging.DEBUG):
549
481
  logger.debug(
550
- "TaylorSeer is enabled but not set in the cache context. "
482
+ "calibrator is enabled but not set in the cache context. "
551
483
  "Falling back to default buffer retrieval."
552
484
  )
553
485
  if self.is_separate_cfg_step():
@@ -566,19 +498,19 @@ class CachedContextManager:
566
498
 
567
499
  @torch.compiler.disable
568
500
  def get_Bn_buffer(self, prefix: str = "Bn") -> torch.Tensor:
569
- if self.is_taylorseer_enabled():
570
- # taylorseer, encoder_taylorseer
501
+ if self.is_calibrator_enabled():
502
+ # calibrator, encoder_calibrator
571
503
  if self.is_separate_cfg_step():
572
- taylorseer, _ = self.get_cfg_taylorseers()
504
+ calibrator, _ = self.get_cfg_calibrator()
573
505
  else:
574
- taylorseer, _ = self.get_taylorseers()
506
+ calibrator, _ = self.get_calibrator()
575
507
 
576
- if taylorseer is not None:
577
- return taylorseer.approximate_value()
508
+ if calibrator is not None:
509
+ return calibrator.approximate()
578
510
  else:
579
511
  if logger.isEnabledFor(logging.DEBUG):
580
512
  logger.debug(
581
- "TaylorSeer is enabled but not set in the cache context. "
513
+ "calibrator is enabled but not set in the cache context. "
582
514
  "Falling back to default buffer retrieval."
583
515
  )
584
516
  # Fallback to default buffer retrieval
@@ -603,20 +535,20 @@ class CachedContextManager:
603
535
  return
604
536
 
605
537
  # This buffer is use for encoder hidden states approximation.
606
- if self.is_encoder_taylorseer_enabled():
607
- # taylorseer, encoder_taylorseer
538
+ if self.is_encoder_calibrator_enabled():
539
+ # calibrator, encoder_calibrator
608
540
  if self.is_separate_cfg_step():
609
- _, encoder_taylorseer = self.get_cfg_taylorseers()
541
+ _, encoder_calibrator = self.get_cfg_calibrator()
610
542
  else:
611
- _, encoder_taylorseer = self.get_taylorseers()
543
+ _, encoder_calibrator = self.get_calibrator()
612
544
 
613
- if encoder_taylorseer is not None:
614
- # Use TaylorSeer to update the buffer
615
- encoder_taylorseer.update(buffer)
545
+ if encoder_calibrator is not None:
546
+ # Use CalibratorBase to update the buffer
547
+ encoder_calibrator.update(buffer)
616
548
  else:
617
549
  if logger.isEnabledFor(logging.DEBUG):
618
550
  logger.debug(
619
- "TaylorSeer is enabled but not set in the cache context. "
551
+ "CalibratorBase is enabled but not set in the cache context. "
620
552
  "Falling back to default buffer retrieval."
621
553
  )
622
554
  if self.is_separate_cfg_step():
@@ -635,19 +567,19 @@ class CachedContextManager:
635
567
 
636
568
  @torch.compiler.disable
637
569
  def get_Bn_encoder_buffer(self, prefix: str = "Bn") -> torch.Tensor:
638
- if self.is_encoder_taylorseer_enabled():
570
+ if self.is_encoder_calibrator_enabled():
639
571
  if self.is_separate_cfg_step():
640
- _, encoder_taylorseer = self.get_cfg_taylorseers()
572
+ _, encoder_calibrator = self.get_cfg_calibrator()
641
573
  else:
642
- _, encoder_taylorseer = self.get_taylorseers()
574
+ _, encoder_calibrator = self.get_calibrator()
643
575
 
644
- if encoder_taylorseer is not None:
645
- # Use TaylorSeer to approximate the value
646
- return encoder_taylorseer.approximate_value()
576
+ if encoder_calibrator is not None:
577
+ # Use calibrator to approximate the value
578
+ return encoder_calibrator.approximate()
647
579
  else:
648
580
  if logger.isEnabledFor(logging.DEBUG):
649
581
  logger.debug(
650
- "TaylorSeer is enabled but not set in the cache context. "
582
+ "calibrator is enabled but not set in the cache context. "
651
583
  "Falling back to default buffer retrieval."
652
584
  )
653
585
  # Fallback to default buffer retrieval
@@ -717,7 +649,7 @@ class CachedContextManager:
717
649
  def get_downsample_factor(self) -> float:
718
650
  cached_context = self.get_context()
719
651
  assert cached_context is not None, "cached_context must be set before"
720
- return cached_context.downsample_factor
652
+ return cached_context.extra_cache_config.downsample_factor
721
653
 
722
654
  @torch.compiler.disable
723
655
  def can_cache(
@@ -0,0 +1,132 @@
1
+ from cache_dit.cache_factory.cache_contexts.calibrators.base import (
2
+ CalibratorBase,
3
+ )
4
+ from cache_dit.cache_factory.cache_contexts.calibrators.taylorseer import (
5
+ TaylorSeerCalibrator,
6
+ )
7
+ from cache_dit.cache_factory.cache_contexts.calibrators.foca import (
8
+ FoCaCalibrator,
9
+ )
10
+
11
+ import dataclasses
12
+ from typing import Any, Dict
13
+
14
+
15
+ from cache_dit.logger import init_logger
16
+
17
+ logger = init_logger(__name__)
18
+
19
+
20
+ @dataclasses.dataclass
21
+ class CalibratorConfig:
22
+ # enable_calibrator (`bool`, *required*, defaults to False):
23
+ # Whether to enable calibrator, if True. means that user want to use DBCache
24
+ # with specific calibrator for hidden_states (or hidden_states redisual),
25
+ # such as taylorseer, foca, and so on.
26
+ enable_calibrator: bool = False
27
+ # enable_encoder_calibrator (`bool`, *required*, defaults to False):
28
+ # Whether to enable calibrator, if True. means that user want to use DBCache
29
+ # with specific calibrator for encoder_hidden_states (or encoder_hidden_states
30
+ # redisual), such as taylorseer, foca, and so on.
31
+ enable_encoder_calibrator: bool = False
32
+ # calibrator_type (`str`, *required*, defaults to 'taylorseer'):
33
+ # The specific type for calibrator, taylorseer or foca, etc.
34
+ calibrator_type: str = "taylorseer"
35
+ # calibrator_cache_type (`str`, *required*, defaults to 'residual'):
36
+ # The specific cache type for calibrator, residual or hidden_states.
37
+ calibrator_cache_type: str = "residual"
38
+ # calibrator_kwargs (`dict`, *optional*, defaults to {}):
39
+ # Init kwargs for specific calibrator, taylorseer or foca, etc.
40
+ calibrator_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
41
+
42
+ def strify(self, **kwargs) -> str:
43
+ return "CalibratorBase"
44
+
45
+ def to_kwargs(self) -> Dict:
46
+ return self.calibrator_kwargs.copy()
47
+
48
+
49
+ @dataclasses.dataclass
50
+ class TaylorSeerCalibratorConfig(CalibratorConfig):
51
+ # TaylorSeers: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
52
+ # link: https://arxiv.org/pdf/2503.06923
53
+
54
+ # enable_calibrator (`bool`, *required*, defaults to True):
55
+ # Whether to enable calibrator, if True. means that user want to use DBCache
56
+ # with specific calibrator for hidden_states (or hidden_states redisual),
57
+ # such as taylorseer, foca, and so on.
58
+ enable_calibrator: bool = True
59
+ # enable_encoder_calibrator (`bool`, *required*, defaults to True):
60
+ # Whether to enable calibrator, if True. means that user want to use DBCache
61
+ # with specific calibrator for encoder_hidden_states (or encoder_hidden_states
62
+ # redisual), such as taylorseer, foca, and so on.
63
+ enable_encoder_calibrator: bool = True
64
+ # calibrator_type (`str`, *required*, defaults to 'taylorseer'):
65
+ # The specific type for calibrator, taylorseer or foca, etc.
66
+ calibrator_type: str = "taylorseer"
67
+ # taylorseer_order (`int`, *required*, defaults to 1):
68
+ # The order of taylorseer, higher values of n_derivatives will lead to longer computation time,
69
+ # the recommended value is 1 or 2. Please check [TaylorSeers: From Reusing to Forecasting:
70
+ # Accelerating Diffusion Models with TaylorSeers](https://arxiv.org/pdf/2503.06923) for
71
+ # more details.
72
+ taylorseer_order: int = 1
73
+
74
+ def strify(self, **kwargs) -> str:
75
+ if kwargs.get("details", False):
76
+ if self.taylorseer_order:
77
+ return f"TaylorSeer_O({self.taylorseer_order})"
78
+ return "NONE"
79
+
80
+ if self.taylorseer_order:
81
+ return f"T1O{self.taylorseer_order}"
82
+ return "NONE"
83
+
84
+ def to_kwargs(self) -> Dict:
85
+ kwargs = self.calibrator_kwargs.copy()
86
+ kwargs["n_derivatives"] = self.taylorseer_order
87
+ return kwargs
88
+
89
+
90
+ @dataclasses.dataclass
91
+ class FoCaCalibratorConfig(CalibratorConfig):
92
+ # FoCa: Forecast then Calibrate: Feature Caching as ODE for Efficient Diffusion Transformers
93
+ # link: https://arxiv.org/pdf/2508.16211
94
+
95
+ # enable_calibrator (`bool`, *required*, defaults to True):
96
+ # Whether to enable calibrator, if True. means that user want to use DBCache
97
+ # with specific calibrator for hidden_states (or hidden_states redisual),
98
+ # such as taylorseer, foca, and so on.
99
+ enable_calibrator: bool = True
100
+ # enable_encoder_calibrator (`bool`, *required*, defaults to True):
101
+ # Whether to enable calibrator, if True. means that user want to use DBCache
102
+ # with specific calibrator for encoder_hidden_states (or encoder_hidden_states
103
+ # redisual), such as taylorseer, foca, and so on.
104
+ enable_encoder_calibrator: bool = True
105
+ # calibrator_type (`str`, *required*, defaults to 'taylorseer'):
106
+ # The specific type for calibrator, taylorseer or foca, etc.
107
+ calibrator_type: str = "foca"
108
+
109
+ def strify(self, **kwargs) -> str:
110
+ return "FoCa"
111
+
112
+
113
+ class Calibrator:
114
+ _supported_calibrators = [
115
+ "taylorseer",
116
+ # TODO: FoCa
117
+ ]
118
+
119
+ def __new__(
120
+ cls,
121
+ calibrator_config: CalibratorConfig,
122
+ ) -> CalibratorBase:
123
+ assert (
124
+ calibrator_config.calibrator_type in cls._supported_calibrators
125
+ ), f"Calibrator {calibrator_config.calibrator_type} is not supported now!"
126
+
127
+ if calibrator_config.calibrator_type.lower() == "taylorseer":
128
+ return TaylorSeerCalibrator(**calibrator_config.to_kwargs())
129
+ else:
130
+ raise ValueError(
131
+ f"Calibrator {calibrator_config.calibrator_type} is not supported now!"
132
+ )
@@ -1,4 +1,4 @@
1
- from cache_dit.cache_factory.cache_contexts.v2.calibrators.base import (
1
+ from cache_dit.cache_factory.cache_contexts.calibrators.base import (
2
2
  CalibratorBase,
3
3
  )
4
4
 
@@ -1,7 +1,7 @@
1
1
  import math
2
2
  import torch
3
3
  from typing import List, Dict
4
- from cache_dit.cache_factory.cache_contexts.v2.calibrators.base import (
4
+ from cache_dit.cache_factory.cache_contexts.calibrators.base import (
5
5
  CalibratorBase,
6
6
  )
7
7
 
@@ -22,8 +22,13 @@ class TaylorSeerCalibrator(CalibratorBase):
22
22
  self.order = n_derivatives + 1
23
23
  self.max_warmup_steps = max_warmup_steps
24
24
  self.skip_interval_steps = skip_interval_steps
25
+ self.current_step = -1
26
+ self.last_non_approximated_step = -1
27
+ self.state: Dict[str, List[torch.Tensor]] = {
28
+ "dY_prev": [None] * self.order,
29
+ "dY_current": [None] * self.order,
30
+ }
25
31
  self.reset_cache()
26
- logger.info(f"Created {self.__repr__()}_{id(self)}")
27
32
 
28
33
  def reset_cache(self): # NEED
29
34
  self.state: Dict[str, List[torch.Tensor]] = {