cache-dit 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. cache_dit/__init__.py +1 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +3 -6
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +8 -64
  5. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  6. cache_dit/cache_factory/cache_adapters/cache_adapter.py +47 -14
  7. cache_dit/cache_factory/cache_contexts/__init__.py +10 -8
  8. cache_dit/cache_factory/cache_contexts/cache_context.py +186 -117
  9. cache_dit/cache_factory/cache_contexts/cache_manager.py +63 -131
  10. cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +132 -0
  11. cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/foca.py +1 -1
  12. cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/taylorseer.py +7 -2
  13. cache_dit/cache_factory/cache_interface.py +128 -111
  14. cache_dit/cache_factory/params_modifier.py +87 -0
  15. cache_dit/metrics/__init__.py +3 -1
  16. cache_dit/utils.py +12 -21
  17. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/METADATA +78 -64
  18. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/RECORD +23 -28
  19. cache_dit/cache_factory/cache_adapters/v2/__init__.py +0 -3
  20. cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +0 -524
  21. cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -102
  22. cache_dit/cache_factory/cache_contexts/v2/__init__.py +0 -13
  23. cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +0 -288
  24. cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +0 -799
  25. cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +0 -81
  26. /cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/base.py +0 -0
  27. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/WHEEL +0 -0
  28. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/entry_points.txt +0 -0
  29. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/licenses/LICENSE +0 -0
  30. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/top_level.txt +0 -0
@@ -1,524 +0,0 @@
1
- import torch
2
-
3
- import unittest
4
- import functools
5
-
6
- from contextlib import ExitStack
7
- from typing import Dict, List, Tuple, Any, Union, Callable
8
-
9
- from diffusers import DiffusionPipeline
10
-
11
- from cache_dit.cache_factory.cache_types import CacheType
12
- from cache_dit.cache_factory.block_adapters import BlockAdapter
13
- from cache_dit.cache_factory.block_adapters import ParamsModifier
14
- from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
15
- from cache_dit.cache_factory.cache_contexts import CachedContextManagerV2
16
- from cache_dit.cache_factory.cache_blocks import CachedBlocks
17
- from cache_dit.cache_factory.cache_blocks.utils import (
18
- patch_cached_stats,
19
- remove_cached_stats,
20
- )
21
- from cache_dit.logger import init_logger
22
-
23
- logger = init_logger(__name__)
24
-
25
-
26
- # Unified Cached Adapter
27
- class CachedAdapterV2:
28
-
29
- def __call__(self, *args, **kwargs):
30
- return self.apply(*args, **kwargs)
31
-
32
- @classmethod
33
- def apply(
34
- cls,
35
- pipe_or_adapter: Union[
36
- DiffusionPipeline,
37
- BlockAdapter,
38
- ],
39
- **cache_context_kwargs,
40
- ) -> Union[
41
- DiffusionPipeline,
42
- BlockAdapter,
43
- ]:
44
- assert (
45
- pipe_or_adapter is not None
46
- ), "pipe or block_adapter can not both None!"
47
-
48
- if isinstance(pipe_or_adapter, DiffusionPipeline):
49
- if BlockAdapterRegistry.is_supported(pipe_or_adapter):
50
- logger.info(
51
- f"{pipe_or_adapter.__class__.__name__} is officially "
52
- "supported by cache-dit. Use it's pre-defined BlockAdapter "
53
- "directly!"
54
- )
55
- block_adapter = BlockAdapterRegistry.get_adapter(
56
- pipe_or_adapter
57
- )
58
- return cls.cachify(
59
- block_adapter,
60
- **cache_context_kwargs,
61
- ).pipe
62
- else:
63
- raise ValueError(
64
- f"{pipe_or_adapter.__class__.__name__} is not officially supported "
65
- "by cache-dit, please set BlockAdapter instead!"
66
- )
67
- else:
68
- assert isinstance(pipe_or_adapter, BlockAdapter)
69
- logger.info(
70
- "Adapting Cache Acceleration using custom BlockAdapter!"
71
- )
72
- return cls.cachify(
73
- pipe_or_adapter,
74
- **cache_context_kwargs,
75
- )
76
-
77
- @classmethod
78
- def cachify(
79
- cls,
80
- block_adapter: BlockAdapter,
81
- **cache_context_kwargs,
82
- ) -> BlockAdapter:
83
-
84
- if block_adapter.auto:
85
- block_adapter = BlockAdapter.auto_block_adapter(
86
- block_adapter,
87
- )
88
-
89
- if BlockAdapter.check_block_adapter(block_adapter):
90
-
91
- # 0. Must normalize block_adapter before apply cache
92
- block_adapter = BlockAdapter.normalize(block_adapter)
93
- if BlockAdapter.is_cached(block_adapter):
94
- return block_adapter
95
-
96
- # 1. Apply cache on pipeline: wrap cache context, must
97
- # call create_context before mock_blocks.
98
- cls.create_context(
99
- block_adapter,
100
- **cache_context_kwargs,
101
- )
102
-
103
- # 2. Apply cache on transformer: mock cached blocks
104
- cls.mock_blocks(
105
- block_adapter,
106
- )
107
-
108
- return block_adapter
109
-
110
- @classmethod
111
- def check_context_kwargs(
112
- cls,
113
- block_adapter: BlockAdapter,
114
- **cache_context_kwargs,
115
- ):
116
- # Check cache_context_kwargs
117
- if cache_context_kwargs["enable_separate_cfg"] is None:
118
- # Check cfg for some specific case if users don't set it as True
119
- if BlockAdapterRegistry.has_separate_cfg(block_adapter):
120
- cache_context_kwargs["enable_separate_cfg"] = True
121
- logger.info(
122
- f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
123
- f"Pipeline: {block_adapter.pipe.__class__.__name__}."
124
- )
125
- else:
126
- cache_context_kwargs["enable_separate_cfg"] = (
127
- BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
128
- )
129
- logger.info(
130
- f"Use default 'enable_separate_cfg' from block adapter "
131
- f"register: {cache_context_kwargs['enable_separate_cfg']}, "
132
- f"Pipeline: {block_adapter.pipe.__class__.__name__}."
133
- )
134
- else:
135
- logger.info(
136
- f"Use custom 'enable_separate_cfg' from cache context "
137
- f"kwargs: {cache_context_kwargs['enable_separate_cfg']}. "
138
- f"Pipeline: {block_adapter.pipe.__class__.__name__}."
139
- )
140
-
141
- if (
142
- cache_type := cache_context_kwargs.pop("cache_type", None)
143
- ) is not None:
144
- assert (
145
- cache_type == CacheType.DBCache
146
- ), "Custom cache setting only support for DBCache now!"
147
-
148
- return cache_context_kwargs
149
-
150
- @classmethod
151
- def create_context(
152
- cls,
153
- block_adapter: BlockAdapter,
154
- **cache_context_kwargs,
155
- ) -> DiffusionPipeline:
156
-
157
- BlockAdapter.assert_normalized(block_adapter)
158
-
159
- if BlockAdapter.is_cached(block_adapter.pipe):
160
- return block_adapter.pipe
161
-
162
- # Check cache_context_kwargs
163
- cache_context_kwargs = cls.check_context_kwargs(
164
- block_adapter, **cache_context_kwargs
165
- )
166
- # Apply cache on pipeline: wrap cache context
167
- pipe_cls_name = block_adapter.pipe.__class__.__name__
168
-
169
- # Each Pipeline should have it's own context manager instance.
170
- # Different transformers (Wan2.2, etc) should shared the same
171
- # cache manager but with different cache context (according
172
- # to their unique instance id).
173
- cache_manager = CachedContextManagerV2(
174
- name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
175
- )
176
- block_adapter.pipe._cache_manager = cache_manager # instance level
177
-
178
- flatten_contexts, contexts_kwargs = cls.modify_context_params(
179
- block_adapter, cache_manager, **cache_context_kwargs
180
- )
181
-
182
- original_call = block_adapter.pipe.__class__.__call__
183
-
184
- @functools.wraps(original_call)
185
- def new_call(self, *args, **kwargs):
186
- with ExitStack() as stack:
187
- # cache context will be reset for each pipe inference
188
- for context_name, context_kwargs in zip(
189
- flatten_contexts, contexts_kwargs
190
- ):
191
- stack.enter_context(
192
- cache_manager.enter_context(
193
- cache_manager.reset_context(
194
- context_name,
195
- **context_kwargs,
196
- ),
197
- )
198
- )
199
- outputs = original_call(self, *args, **kwargs)
200
- cls.apply_stats_hooks(block_adapter)
201
- return outputs
202
-
203
- block_adapter.pipe.__class__.__call__ = new_call
204
- block_adapter.pipe.__class__._original_call = original_call
205
- block_adapter.pipe.__class__._is_cached = True
206
-
207
- cls.apply_params_hooks(block_adapter, contexts_kwargs)
208
-
209
- return block_adapter.pipe
210
-
211
- @classmethod
212
- def modify_context_params(
213
- cls,
214
- block_adapter: BlockAdapter,
215
- cache_manager: CachedContextManagerV2,
216
- **cache_context_kwargs,
217
- ) -> Tuple[List[str], List[Dict[str, Any]]]:
218
-
219
- flatten_contexts = BlockAdapter.flatten(
220
- block_adapter.unique_blocks_name
221
- )
222
- contexts_kwargs = [
223
- cache_context_kwargs.copy()
224
- for _ in range(
225
- len(flatten_contexts),
226
- )
227
- ]
228
-
229
- for i in range(len(contexts_kwargs)):
230
- contexts_kwargs[i]["name"] = flatten_contexts[i]
231
-
232
- if block_adapter.params_modifiers is None:
233
- return flatten_contexts, contexts_kwargs
234
-
235
- flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
236
- block_adapter.params_modifiers,
237
- )
238
-
239
- for i in range(
240
- min(len(contexts_kwargs), len(flatten_modifiers)),
241
- ):
242
- contexts_kwargs[i].update(
243
- flatten_modifiers[i]._context_kwargs,
244
- )
245
- contexts_kwargs[i], _ = cache_manager.collect_cache_kwargs(
246
- default_attrs={}, **contexts_kwargs[i]
247
- )
248
-
249
- return flatten_contexts, contexts_kwargs
250
-
251
- @classmethod
252
- def mock_blocks(
253
- cls,
254
- block_adapter: BlockAdapter,
255
- ) -> List[torch.nn.Module]:
256
-
257
- BlockAdapter.assert_normalized(block_adapter)
258
-
259
- if BlockAdapter.is_cached(block_adapter.transformer):
260
- return block_adapter.transformer
261
-
262
- # Apply cache on transformer: mock cached transformer blocks
263
- for (
264
- cached_blocks,
265
- transformer,
266
- blocks_name,
267
- unique_blocks_name,
268
- dummy_blocks_names,
269
- ) in zip(
270
- cls.collect_cached_blocks(block_adapter),
271
- block_adapter.transformer,
272
- block_adapter.blocks_name,
273
- block_adapter.unique_blocks_name,
274
- block_adapter.dummy_blocks_names,
275
- ):
276
- cls.mock_transformer(
277
- cached_blocks,
278
- transformer,
279
- blocks_name,
280
- unique_blocks_name,
281
- dummy_blocks_names,
282
- )
283
-
284
- return block_adapter.transformer
285
-
286
- @classmethod
287
- def mock_transformer(
288
- cls,
289
- cached_blocks: Dict[str, torch.nn.ModuleList],
290
- transformer: torch.nn.Module,
291
- blocks_name: List[str],
292
- unique_blocks_name: List[str],
293
- dummy_blocks_names: List[str],
294
- ) -> torch.nn.Module:
295
- dummy_blocks = torch.nn.ModuleList()
296
-
297
- original_forward = transformer.forward
298
-
299
- assert isinstance(dummy_blocks_names, list)
300
-
301
- @functools.wraps(original_forward)
302
- def new_forward(self, *args, **kwargs):
303
- with ExitStack() as stack:
304
- for name, context_name in zip(
305
- blocks_name,
306
- unique_blocks_name,
307
- ):
308
- stack.enter_context(
309
- unittest.mock.patch.object(
310
- self, name, cached_blocks[context_name]
311
- )
312
- )
313
- for dummy_name in dummy_blocks_names:
314
- stack.enter_context(
315
- unittest.mock.patch.object(
316
- self, dummy_name, dummy_blocks
317
- )
318
- )
319
- return original_forward(*args, **kwargs)
320
-
321
- transformer.forward = new_forward.__get__(transformer)
322
- transformer._original_forward = original_forward
323
- transformer._is_cached = True
324
-
325
- return transformer
326
-
327
- @classmethod
328
- def collect_cached_blocks(
329
- cls,
330
- block_adapter: BlockAdapter,
331
- ) -> List[Dict[str, torch.nn.ModuleList]]:
332
-
333
- BlockAdapter.assert_normalized(block_adapter)
334
-
335
- total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
336
- assert hasattr(block_adapter.pipe, "_cache_manager")
337
- assert isinstance(
338
- block_adapter.pipe._cache_manager, CachedContextManagerV2
339
- )
340
-
341
- for i in range(len(block_adapter.transformer)):
342
-
343
- cached_blocks_bind_context = {}
344
- for j in range(len(block_adapter.blocks[i])):
345
- cached_blocks_bind_context[
346
- block_adapter.unique_blocks_name[i][j]
347
- ] = torch.nn.ModuleList(
348
- [
349
- CachedBlocks(
350
- # 0. Transformer blocks configuration
351
- block_adapter.blocks[i][j],
352
- transformer=block_adapter.transformer[i],
353
- forward_pattern=block_adapter.forward_pattern[i][j],
354
- check_forward_pattern=block_adapter.check_forward_pattern,
355
- check_num_outputs=block_adapter.check_num_outputs,
356
- # 1. Cache context configuration
357
- cache_prefix=block_adapter.blocks_name[i][j],
358
- cache_context=block_adapter.unique_blocks_name[i][
359
- j
360
- ],
361
- cache_manager=block_adapter.pipe._cache_manager,
362
- )
363
- ]
364
- )
365
-
366
- total_cached_blocks.append(cached_blocks_bind_context)
367
-
368
- return total_cached_blocks
369
-
370
- @classmethod
371
- def apply_params_hooks(
372
- cls,
373
- block_adapter: BlockAdapter,
374
- contexts_kwargs: List[Dict],
375
- ):
376
- block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
377
-
378
- params_shift = 0
379
- for i in range(len(block_adapter.transformer)):
380
-
381
- block_adapter.transformer[i]._forward_pattern = (
382
- block_adapter.forward_pattern
383
- )
384
- block_adapter.transformer[i]._has_separate_cfg = (
385
- block_adapter.has_separate_cfg
386
- )
387
- block_adapter.transformer[i]._cache_context_kwargs = (
388
- contexts_kwargs[params_shift]
389
- )
390
-
391
- blocks = block_adapter.blocks[i]
392
- for j in range(len(blocks)):
393
- blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
394
- blocks[j]._cache_context_kwargs = contexts_kwargs[
395
- params_shift + j
396
- ]
397
-
398
- params_shift += len(blocks)
399
-
400
- @classmethod
401
- def apply_stats_hooks(
402
- cls,
403
- block_adapter: BlockAdapter,
404
- ):
405
- cache_manager = block_adapter.pipe._cache_manager
406
-
407
- for i in range(len(block_adapter.transformer)):
408
- patch_cached_stats(
409
- block_adapter.transformer[i],
410
- cache_context=block_adapter.unique_blocks_name[i][-1],
411
- cache_manager=cache_manager,
412
- )
413
- for blocks, unique_name in zip(
414
- block_adapter.blocks[i],
415
- block_adapter.unique_blocks_name[i],
416
- ):
417
- patch_cached_stats(
418
- blocks,
419
- cache_context=unique_name,
420
- cache_manager=cache_manager,
421
- )
422
-
423
- @classmethod
424
- def maybe_release_hooks(
425
- cls,
426
- pipe_or_adapter: Union[
427
- DiffusionPipeline,
428
- BlockAdapter,
429
- ],
430
- ):
431
- # release model hooks
432
- def _release_blocks_hooks(blocks):
433
- return
434
-
435
- def _release_transformer_hooks(transformer):
436
- if hasattr(transformer, "_original_forward"):
437
- original_forward = transformer._original_forward
438
- transformer.forward = original_forward.__get__(transformer)
439
- del transformer._original_forward
440
- if hasattr(transformer, "_is_cached"):
441
- del transformer._is_cached
442
-
443
- def _release_pipeline_hooks(pipe):
444
- if hasattr(pipe, "_original_call"):
445
- original_call = pipe.__class__._original_call
446
- pipe.__class__.__call__ = original_call
447
- del pipe.__class__._original_call
448
- if hasattr(pipe, "_cache_manager"):
449
- cache_manager = pipe._cache_manager
450
- if isinstance(cache_manager, CachedContextManagerV2):
451
- cache_manager.clear_contexts()
452
- del pipe._cache_manager
453
- if hasattr(pipe, "_is_cached"):
454
- del pipe.__class__._is_cached
455
-
456
- cls.release_hooks(
457
- pipe_or_adapter,
458
- _release_blocks_hooks,
459
- _release_transformer_hooks,
460
- _release_pipeline_hooks,
461
- )
462
-
463
- # release params hooks
464
- def _release_blocks_params(blocks):
465
- if hasattr(blocks, "_forward_pattern"):
466
- del blocks._forward_pattern
467
- if hasattr(blocks, "_cache_context_kwargs"):
468
- del blocks._cache_context_kwargs
469
-
470
- def _release_transformer_params(transformer):
471
- if hasattr(transformer, "_forward_pattern"):
472
- del transformer._forward_pattern
473
- if hasattr(transformer, "_has_separate_cfg"):
474
- del transformer._has_separate_cfg
475
- if hasattr(transformer, "_cache_context_kwargs"):
476
- del transformer._cache_context_kwargs
477
- for blocks in BlockAdapter.find_blocks(transformer):
478
- _release_blocks_params(blocks)
479
-
480
- def _release_pipeline_params(pipe):
481
- if hasattr(pipe, "_cache_context_kwargs"):
482
- del pipe._cache_context_kwargs
483
-
484
- cls.release_hooks(
485
- pipe_or_adapter,
486
- _release_blocks_params,
487
- _release_transformer_params,
488
- _release_pipeline_params,
489
- )
490
-
491
- # release stats hooks
492
- cls.release_hooks(
493
- pipe_or_adapter,
494
- remove_cached_stats,
495
- remove_cached_stats,
496
- remove_cached_stats,
497
- )
498
-
499
- @classmethod
500
- def release_hooks(
501
- cls,
502
- pipe_or_adapter: Union[
503
- DiffusionPipeline,
504
- BlockAdapter,
505
- ],
506
- _release_blocks: Callable,
507
- _release_transformer: Callable,
508
- _release_pipeline: Callable,
509
- ):
510
- if isinstance(pipe_or_adapter, DiffusionPipeline):
511
- pipe = pipe_or_adapter
512
- _release_pipeline(pipe)
513
- if hasattr(pipe, "transformer"):
514
- _release_transformer(pipe.transformer)
515
- if hasattr(pipe, "transformer_2"): # Wan 2.2
516
- _release_transformer(pipe.transformer_2)
517
- elif isinstance(pipe_or_adapter, BlockAdapter):
518
- adapter = pipe_or_adapter
519
- BlockAdapter.assert_normalized(adapter)
520
- _release_pipeline(adapter.pipe)
521
- for transformer in BlockAdapter.flatten(adapter.transformer):
522
- _release_transformer(transformer)
523
- for blocks in BlockAdapter.flatten(adapter.blocks):
524
- _release_blocks(blocks)
@@ -1,102 +0,0 @@
1
- import math
2
- import torch
3
- from typing import List, Dict
4
-
5
-
6
- class TaylorSeer:
7
- def __init__(
8
- self,
9
- n_derivatives=2,
10
- max_warmup_steps=1,
11
- skip_interval_steps=1,
12
- compute_step_map=None,
13
- ):
14
- self.n_derivatives = n_derivatives
15
- self.ORDER = n_derivatives + 1
16
- self.max_warmup_steps = max_warmup_steps
17
- self.skip_interval_steps = skip_interval_steps
18
- self.compute_step_map = compute_step_map
19
- self.reset_cache()
20
-
21
- def reset_cache(self):
22
- self.state: Dict[str, List[torch.Tensor]] = {
23
- "dY_prev": [None] * self.ORDER,
24
- "dY_current": [None] * self.ORDER,
25
- }
26
- self.current_step = -1
27
- self.last_non_approximated_step = -1
28
-
29
- def should_compute_full(self, step=None):
30
- step = self.current_step if step is None else step
31
- if self.compute_step_map is not None:
32
- return self.compute_step_map[step]
33
- if (
34
- step < self.max_warmup_steps
35
- or (step - self.max_warmup_steps + 1) % self.skip_interval_steps
36
- == 0
37
- ):
38
- return True
39
- return False
40
-
41
- def approximate_derivative(self, Y: torch.Tensor) -> List[torch.Tensor]:
42
- # n-th order Taylor expansion:
43
- # Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
44
- # + ... + d^nY(0)/dt^n * t^n / n!
45
- # TODO: Custom Triton/CUDA kernel for better performance,
46
- # especially for large n_derivatives.
47
- dY_current: List[torch.Tensor] = [None] * self.ORDER
48
- dY_current[0] = Y
49
- window = self.current_step - self.last_non_approximated_step
50
- if self.state["dY_prev"][0] is not None:
51
- if dY_current[0].shape != self.state["dY_prev"][0].shape:
52
- self.reset_cache()
53
-
54
- for i in range(self.n_derivatives):
55
- if self.state["dY_prev"][i] is not None and self.current_step > 1:
56
- dY_current[i + 1] = (
57
- dY_current[i] - self.state["dY_prev"][i]
58
- ) / window
59
- else:
60
- break
61
- return dY_current
62
-
63
- def approximate_value(self) -> torch.Tensor:
64
- # TODO: Custom Triton/CUDA kernel for better performance,
65
- # especially for large n_derivatives.
66
- elapsed = self.current_step - self.last_non_approximated_step
67
- output = 0
68
- for i, derivative in enumerate(self.state["dY_current"]):
69
- if derivative is not None:
70
- output += (1 / math.factorial(i)) * derivative * (elapsed**i)
71
- else:
72
- break
73
- return output
74
-
75
- def mark_step_begin(self):
76
- self.current_step += 1
77
-
78
- def update(self, Y: torch.Tensor):
79
- # Directly call this method will ingnore the warmup
80
- # policy and force full computation.
81
- # Assume warmup steps is 3, and n_derivatives is 3.
82
- # step 0: dY_prev = [None, None, None, None ]
83
- # dY_current = [Y0, None, None, None ]
84
- # step 1: dY_prev = [Y0, None, None, None ]
85
- # dY_current = [Y1, dY1, None, None ]
86
- # step 2: dY_prev = [Y1, dY1, None, None ]
87
- # dY_current = [Y2, dY2/Y1, dY2/dY1, None ]
88
- # step 3: dY_prev = [Y2, dY2/Y1, dY2/dY1, None ],
89
- # dY_current = [Y3, dY3/Y2, dY3/dY2, dY3/dY1]
90
- # step 4: dY_prev = [Y3, dY3/Y2, dY3/dY2, dY3/dY1]
91
- # dY_current = [Y4, dY4/Y3, dY4/dY3, dY4/dY2]
92
- self.state["dY_prev"] = self.state["dY_current"]
93
- self.state["dY_current"] = self.approximate_derivative(Y)
94
- self.last_non_approximated_step = self.current_step
95
-
96
- def step(self, Y: torch.Tensor):
97
- self.mark_step_begin()
98
- if self.should_compute_full():
99
- self.update(Y)
100
- return Y
101
- else:
102
- return self.approximate_value()
@@ -1,13 +0,0 @@
1
- from cache_dit.cache_factory.cache_contexts.v2.calibrators import (
2
- Calibrator,
3
- CalibratorBase,
4
- CalibratorConfig,
5
- TaylorSeerCalibratorConfig,
6
- FoCaCalibratorConfig,
7
- )
8
- from cache_dit.cache_factory.cache_contexts.v2.cache_context_v2 import (
9
- CachedContextV2,
10
- )
11
- from cache_dit.cache_factory.cache_contexts.v2.cache_manager_v2 import (
12
- CachedContextManagerV2,
13
- )