cache-dit 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +3 -6
- cache_dit/cache_factory/block_adapters/block_adapters.py +8 -64
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +47 -14
- cache_dit/cache_factory/cache_contexts/__init__.py +10 -8
- cache_dit/cache_factory/cache_contexts/cache_context.py +186 -117
- cache_dit/cache_factory/cache_contexts/cache_manager.py +63 -131
- cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +132 -0
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/foca.py +1 -1
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/taylorseer.py +7 -2
- cache_dit/cache_factory/cache_interface.py +128 -111
- cache_dit/cache_factory/params_modifier.py +87 -0
- cache_dit/metrics/__init__.py +3 -1
- cache_dit/utils.py +12 -21
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/METADATA +78 -64
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/RECORD +23 -28
- cache_dit/cache_factory/cache_adapters/v2/__init__.py +0 -3
- cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +0 -524
- cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -102
- cache_dit/cache_factory/cache_contexts/v2/__init__.py +0 -13
- cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +0 -288
- cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +0 -799
- cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +0 -81
- /cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/base.py +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/top_level.txt +0 -0
|
@@ -1,524 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
import unittest
|
|
4
|
-
import functools
|
|
5
|
-
|
|
6
|
-
from contextlib import ExitStack
|
|
7
|
-
from typing import Dict, List, Tuple, Any, Union, Callable
|
|
8
|
-
|
|
9
|
-
from diffusers import DiffusionPipeline
|
|
10
|
-
|
|
11
|
-
from cache_dit.cache_factory.cache_types import CacheType
|
|
12
|
-
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
13
|
-
from cache_dit.cache_factory.block_adapters import ParamsModifier
|
|
14
|
-
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
15
|
-
from cache_dit.cache_factory.cache_contexts import CachedContextManagerV2
|
|
16
|
-
from cache_dit.cache_factory.cache_blocks import CachedBlocks
|
|
17
|
-
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
18
|
-
patch_cached_stats,
|
|
19
|
-
remove_cached_stats,
|
|
20
|
-
)
|
|
21
|
-
from cache_dit.logger import init_logger
|
|
22
|
-
|
|
23
|
-
logger = init_logger(__name__)
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
# Unified Cached Adapter
|
|
27
|
-
class CachedAdapterV2:
|
|
28
|
-
|
|
29
|
-
def __call__(self, *args, **kwargs):
|
|
30
|
-
return self.apply(*args, **kwargs)
|
|
31
|
-
|
|
32
|
-
@classmethod
|
|
33
|
-
def apply(
|
|
34
|
-
cls,
|
|
35
|
-
pipe_or_adapter: Union[
|
|
36
|
-
DiffusionPipeline,
|
|
37
|
-
BlockAdapter,
|
|
38
|
-
],
|
|
39
|
-
**cache_context_kwargs,
|
|
40
|
-
) -> Union[
|
|
41
|
-
DiffusionPipeline,
|
|
42
|
-
BlockAdapter,
|
|
43
|
-
]:
|
|
44
|
-
assert (
|
|
45
|
-
pipe_or_adapter is not None
|
|
46
|
-
), "pipe or block_adapter can not both None!"
|
|
47
|
-
|
|
48
|
-
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
49
|
-
if BlockAdapterRegistry.is_supported(pipe_or_adapter):
|
|
50
|
-
logger.info(
|
|
51
|
-
f"{pipe_or_adapter.__class__.__name__} is officially "
|
|
52
|
-
"supported by cache-dit. Use it's pre-defined BlockAdapter "
|
|
53
|
-
"directly!"
|
|
54
|
-
)
|
|
55
|
-
block_adapter = BlockAdapterRegistry.get_adapter(
|
|
56
|
-
pipe_or_adapter
|
|
57
|
-
)
|
|
58
|
-
return cls.cachify(
|
|
59
|
-
block_adapter,
|
|
60
|
-
**cache_context_kwargs,
|
|
61
|
-
).pipe
|
|
62
|
-
else:
|
|
63
|
-
raise ValueError(
|
|
64
|
-
f"{pipe_or_adapter.__class__.__name__} is not officially supported "
|
|
65
|
-
"by cache-dit, please set BlockAdapter instead!"
|
|
66
|
-
)
|
|
67
|
-
else:
|
|
68
|
-
assert isinstance(pipe_or_adapter, BlockAdapter)
|
|
69
|
-
logger.info(
|
|
70
|
-
"Adapting Cache Acceleration using custom BlockAdapter!"
|
|
71
|
-
)
|
|
72
|
-
return cls.cachify(
|
|
73
|
-
pipe_or_adapter,
|
|
74
|
-
**cache_context_kwargs,
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
@classmethod
|
|
78
|
-
def cachify(
|
|
79
|
-
cls,
|
|
80
|
-
block_adapter: BlockAdapter,
|
|
81
|
-
**cache_context_kwargs,
|
|
82
|
-
) -> BlockAdapter:
|
|
83
|
-
|
|
84
|
-
if block_adapter.auto:
|
|
85
|
-
block_adapter = BlockAdapter.auto_block_adapter(
|
|
86
|
-
block_adapter,
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
if BlockAdapter.check_block_adapter(block_adapter):
|
|
90
|
-
|
|
91
|
-
# 0. Must normalize block_adapter before apply cache
|
|
92
|
-
block_adapter = BlockAdapter.normalize(block_adapter)
|
|
93
|
-
if BlockAdapter.is_cached(block_adapter):
|
|
94
|
-
return block_adapter
|
|
95
|
-
|
|
96
|
-
# 1. Apply cache on pipeline: wrap cache context, must
|
|
97
|
-
# call create_context before mock_blocks.
|
|
98
|
-
cls.create_context(
|
|
99
|
-
block_adapter,
|
|
100
|
-
**cache_context_kwargs,
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
# 2. Apply cache on transformer: mock cached blocks
|
|
104
|
-
cls.mock_blocks(
|
|
105
|
-
block_adapter,
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
return block_adapter
|
|
109
|
-
|
|
110
|
-
@classmethod
|
|
111
|
-
def check_context_kwargs(
|
|
112
|
-
cls,
|
|
113
|
-
block_adapter: BlockAdapter,
|
|
114
|
-
**cache_context_kwargs,
|
|
115
|
-
):
|
|
116
|
-
# Check cache_context_kwargs
|
|
117
|
-
if cache_context_kwargs["enable_separate_cfg"] is None:
|
|
118
|
-
# Check cfg for some specific case if users don't set it as True
|
|
119
|
-
if BlockAdapterRegistry.has_separate_cfg(block_adapter):
|
|
120
|
-
cache_context_kwargs["enable_separate_cfg"] = True
|
|
121
|
-
logger.info(
|
|
122
|
-
f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
|
|
123
|
-
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
124
|
-
)
|
|
125
|
-
else:
|
|
126
|
-
cache_context_kwargs["enable_separate_cfg"] = (
|
|
127
|
-
BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
|
|
128
|
-
)
|
|
129
|
-
logger.info(
|
|
130
|
-
f"Use default 'enable_separate_cfg' from block adapter "
|
|
131
|
-
f"register: {cache_context_kwargs['enable_separate_cfg']}, "
|
|
132
|
-
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
133
|
-
)
|
|
134
|
-
else:
|
|
135
|
-
logger.info(
|
|
136
|
-
f"Use custom 'enable_separate_cfg' from cache context "
|
|
137
|
-
f"kwargs: {cache_context_kwargs['enable_separate_cfg']}. "
|
|
138
|
-
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
if (
|
|
142
|
-
cache_type := cache_context_kwargs.pop("cache_type", None)
|
|
143
|
-
) is not None:
|
|
144
|
-
assert (
|
|
145
|
-
cache_type == CacheType.DBCache
|
|
146
|
-
), "Custom cache setting only support for DBCache now!"
|
|
147
|
-
|
|
148
|
-
return cache_context_kwargs
|
|
149
|
-
|
|
150
|
-
@classmethod
|
|
151
|
-
def create_context(
|
|
152
|
-
cls,
|
|
153
|
-
block_adapter: BlockAdapter,
|
|
154
|
-
**cache_context_kwargs,
|
|
155
|
-
) -> DiffusionPipeline:
|
|
156
|
-
|
|
157
|
-
BlockAdapter.assert_normalized(block_adapter)
|
|
158
|
-
|
|
159
|
-
if BlockAdapter.is_cached(block_adapter.pipe):
|
|
160
|
-
return block_adapter.pipe
|
|
161
|
-
|
|
162
|
-
# Check cache_context_kwargs
|
|
163
|
-
cache_context_kwargs = cls.check_context_kwargs(
|
|
164
|
-
block_adapter, **cache_context_kwargs
|
|
165
|
-
)
|
|
166
|
-
# Apply cache on pipeline: wrap cache context
|
|
167
|
-
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
168
|
-
|
|
169
|
-
# Each Pipeline should have it's own context manager instance.
|
|
170
|
-
# Different transformers (Wan2.2, etc) should shared the same
|
|
171
|
-
# cache manager but with different cache context (according
|
|
172
|
-
# to their unique instance id).
|
|
173
|
-
cache_manager = CachedContextManagerV2(
|
|
174
|
-
name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
|
|
175
|
-
)
|
|
176
|
-
block_adapter.pipe._cache_manager = cache_manager # instance level
|
|
177
|
-
|
|
178
|
-
flatten_contexts, contexts_kwargs = cls.modify_context_params(
|
|
179
|
-
block_adapter, cache_manager, **cache_context_kwargs
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
original_call = block_adapter.pipe.__class__.__call__
|
|
183
|
-
|
|
184
|
-
@functools.wraps(original_call)
|
|
185
|
-
def new_call(self, *args, **kwargs):
|
|
186
|
-
with ExitStack() as stack:
|
|
187
|
-
# cache context will be reset for each pipe inference
|
|
188
|
-
for context_name, context_kwargs in zip(
|
|
189
|
-
flatten_contexts, contexts_kwargs
|
|
190
|
-
):
|
|
191
|
-
stack.enter_context(
|
|
192
|
-
cache_manager.enter_context(
|
|
193
|
-
cache_manager.reset_context(
|
|
194
|
-
context_name,
|
|
195
|
-
**context_kwargs,
|
|
196
|
-
),
|
|
197
|
-
)
|
|
198
|
-
)
|
|
199
|
-
outputs = original_call(self, *args, **kwargs)
|
|
200
|
-
cls.apply_stats_hooks(block_adapter)
|
|
201
|
-
return outputs
|
|
202
|
-
|
|
203
|
-
block_adapter.pipe.__class__.__call__ = new_call
|
|
204
|
-
block_adapter.pipe.__class__._original_call = original_call
|
|
205
|
-
block_adapter.pipe.__class__._is_cached = True
|
|
206
|
-
|
|
207
|
-
cls.apply_params_hooks(block_adapter, contexts_kwargs)
|
|
208
|
-
|
|
209
|
-
return block_adapter.pipe
|
|
210
|
-
|
|
211
|
-
@classmethod
|
|
212
|
-
def modify_context_params(
|
|
213
|
-
cls,
|
|
214
|
-
block_adapter: BlockAdapter,
|
|
215
|
-
cache_manager: CachedContextManagerV2,
|
|
216
|
-
**cache_context_kwargs,
|
|
217
|
-
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
218
|
-
|
|
219
|
-
flatten_contexts = BlockAdapter.flatten(
|
|
220
|
-
block_adapter.unique_blocks_name
|
|
221
|
-
)
|
|
222
|
-
contexts_kwargs = [
|
|
223
|
-
cache_context_kwargs.copy()
|
|
224
|
-
for _ in range(
|
|
225
|
-
len(flatten_contexts),
|
|
226
|
-
)
|
|
227
|
-
]
|
|
228
|
-
|
|
229
|
-
for i in range(len(contexts_kwargs)):
|
|
230
|
-
contexts_kwargs[i]["name"] = flatten_contexts[i]
|
|
231
|
-
|
|
232
|
-
if block_adapter.params_modifiers is None:
|
|
233
|
-
return flatten_contexts, contexts_kwargs
|
|
234
|
-
|
|
235
|
-
flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
|
|
236
|
-
block_adapter.params_modifiers,
|
|
237
|
-
)
|
|
238
|
-
|
|
239
|
-
for i in range(
|
|
240
|
-
min(len(contexts_kwargs), len(flatten_modifiers)),
|
|
241
|
-
):
|
|
242
|
-
contexts_kwargs[i].update(
|
|
243
|
-
flatten_modifiers[i]._context_kwargs,
|
|
244
|
-
)
|
|
245
|
-
contexts_kwargs[i], _ = cache_manager.collect_cache_kwargs(
|
|
246
|
-
default_attrs={}, **contexts_kwargs[i]
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
return flatten_contexts, contexts_kwargs
|
|
250
|
-
|
|
251
|
-
@classmethod
|
|
252
|
-
def mock_blocks(
|
|
253
|
-
cls,
|
|
254
|
-
block_adapter: BlockAdapter,
|
|
255
|
-
) -> List[torch.nn.Module]:
|
|
256
|
-
|
|
257
|
-
BlockAdapter.assert_normalized(block_adapter)
|
|
258
|
-
|
|
259
|
-
if BlockAdapter.is_cached(block_adapter.transformer):
|
|
260
|
-
return block_adapter.transformer
|
|
261
|
-
|
|
262
|
-
# Apply cache on transformer: mock cached transformer blocks
|
|
263
|
-
for (
|
|
264
|
-
cached_blocks,
|
|
265
|
-
transformer,
|
|
266
|
-
blocks_name,
|
|
267
|
-
unique_blocks_name,
|
|
268
|
-
dummy_blocks_names,
|
|
269
|
-
) in zip(
|
|
270
|
-
cls.collect_cached_blocks(block_adapter),
|
|
271
|
-
block_adapter.transformer,
|
|
272
|
-
block_adapter.blocks_name,
|
|
273
|
-
block_adapter.unique_blocks_name,
|
|
274
|
-
block_adapter.dummy_blocks_names,
|
|
275
|
-
):
|
|
276
|
-
cls.mock_transformer(
|
|
277
|
-
cached_blocks,
|
|
278
|
-
transformer,
|
|
279
|
-
blocks_name,
|
|
280
|
-
unique_blocks_name,
|
|
281
|
-
dummy_blocks_names,
|
|
282
|
-
)
|
|
283
|
-
|
|
284
|
-
return block_adapter.transformer
|
|
285
|
-
|
|
286
|
-
@classmethod
|
|
287
|
-
def mock_transformer(
|
|
288
|
-
cls,
|
|
289
|
-
cached_blocks: Dict[str, torch.nn.ModuleList],
|
|
290
|
-
transformer: torch.nn.Module,
|
|
291
|
-
blocks_name: List[str],
|
|
292
|
-
unique_blocks_name: List[str],
|
|
293
|
-
dummy_blocks_names: List[str],
|
|
294
|
-
) -> torch.nn.Module:
|
|
295
|
-
dummy_blocks = torch.nn.ModuleList()
|
|
296
|
-
|
|
297
|
-
original_forward = transformer.forward
|
|
298
|
-
|
|
299
|
-
assert isinstance(dummy_blocks_names, list)
|
|
300
|
-
|
|
301
|
-
@functools.wraps(original_forward)
|
|
302
|
-
def new_forward(self, *args, **kwargs):
|
|
303
|
-
with ExitStack() as stack:
|
|
304
|
-
for name, context_name in zip(
|
|
305
|
-
blocks_name,
|
|
306
|
-
unique_blocks_name,
|
|
307
|
-
):
|
|
308
|
-
stack.enter_context(
|
|
309
|
-
unittest.mock.patch.object(
|
|
310
|
-
self, name, cached_blocks[context_name]
|
|
311
|
-
)
|
|
312
|
-
)
|
|
313
|
-
for dummy_name in dummy_blocks_names:
|
|
314
|
-
stack.enter_context(
|
|
315
|
-
unittest.mock.patch.object(
|
|
316
|
-
self, dummy_name, dummy_blocks
|
|
317
|
-
)
|
|
318
|
-
)
|
|
319
|
-
return original_forward(*args, **kwargs)
|
|
320
|
-
|
|
321
|
-
transformer.forward = new_forward.__get__(transformer)
|
|
322
|
-
transformer._original_forward = original_forward
|
|
323
|
-
transformer._is_cached = True
|
|
324
|
-
|
|
325
|
-
return transformer
|
|
326
|
-
|
|
327
|
-
@classmethod
|
|
328
|
-
def collect_cached_blocks(
|
|
329
|
-
cls,
|
|
330
|
-
block_adapter: BlockAdapter,
|
|
331
|
-
) -> List[Dict[str, torch.nn.ModuleList]]:
|
|
332
|
-
|
|
333
|
-
BlockAdapter.assert_normalized(block_adapter)
|
|
334
|
-
|
|
335
|
-
total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
|
|
336
|
-
assert hasattr(block_adapter.pipe, "_cache_manager")
|
|
337
|
-
assert isinstance(
|
|
338
|
-
block_adapter.pipe._cache_manager, CachedContextManagerV2
|
|
339
|
-
)
|
|
340
|
-
|
|
341
|
-
for i in range(len(block_adapter.transformer)):
|
|
342
|
-
|
|
343
|
-
cached_blocks_bind_context = {}
|
|
344
|
-
for j in range(len(block_adapter.blocks[i])):
|
|
345
|
-
cached_blocks_bind_context[
|
|
346
|
-
block_adapter.unique_blocks_name[i][j]
|
|
347
|
-
] = torch.nn.ModuleList(
|
|
348
|
-
[
|
|
349
|
-
CachedBlocks(
|
|
350
|
-
# 0. Transformer blocks configuration
|
|
351
|
-
block_adapter.blocks[i][j],
|
|
352
|
-
transformer=block_adapter.transformer[i],
|
|
353
|
-
forward_pattern=block_adapter.forward_pattern[i][j],
|
|
354
|
-
check_forward_pattern=block_adapter.check_forward_pattern,
|
|
355
|
-
check_num_outputs=block_adapter.check_num_outputs,
|
|
356
|
-
# 1. Cache context configuration
|
|
357
|
-
cache_prefix=block_adapter.blocks_name[i][j],
|
|
358
|
-
cache_context=block_adapter.unique_blocks_name[i][
|
|
359
|
-
j
|
|
360
|
-
],
|
|
361
|
-
cache_manager=block_adapter.pipe._cache_manager,
|
|
362
|
-
)
|
|
363
|
-
]
|
|
364
|
-
)
|
|
365
|
-
|
|
366
|
-
total_cached_blocks.append(cached_blocks_bind_context)
|
|
367
|
-
|
|
368
|
-
return total_cached_blocks
|
|
369
|
-
|
|
370
|
-
@classmethod
|
|
371
|
-
def apply_params_hooks(
|
|
372
|
-
cls,
|
|
373
|
-
block_adapter: BlockAdapter,
|
|
374
|
-
contexts_kwargs: List[Dict],
|
|
375
|
-
):
|
|
376
|
-
block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
|
|
377
|
-
|
|
378
|
-
params_shift = 0
|
|
379
|
-
for i in range(len(block_adapter.transformer)):
|
|
380
|
-
|
|
381
|
-
block_adapter.transformer[i]._forward_pattern = (
|
|
382
|
-
block_adapter.forward_pattern
|
|
383
|
-
)
|
|
384
|
-
block_adapter.transformer[i]._has_separate_cfg = (
|
|
385
|
-
block_adapter.has_separate_cfg
|
|
386
|
-
)
|
|
387
|
-
block_adapter.transformer[i]._cache_context_kwargs = (
|
|
388
|
-
contexts_kwargs[params_shift]
|
|
389
|
-
)
|
|
390
|
-
|
|
391
|
-
blocks = block_adapter.blocks[i]
|
|
392
|
-
for j in range(len(blocks)):
|
|
393
|
-
blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
|
|
394
|
-
blocks[j]._cache_context_kwargs = contexts_kwargs[
|
|
395
|
-
params_shift + j
|
|
396
|
-
]
|
|
397
|
-
|
|
398
|
-
params_shift += len(blocks)
|
|
399
|
-
|
|
400
|
-
@classmethod
|
|
401
|
-
def apply_stats_hooks(
|
|
402
|
-
cls,
|
|
403
|
-
block_adapter: BlockAdapter,
|
|
404
|
-
):
|
|
405
|
-
cache_manager = block_adapter.pipe._cache_manager
|
|
406
|
-
|
|
407
|
-
for i in range(len(block_adapter.transformer)):
|
|
408
|
-
patch_cached_stats(
|
|
409
|
-
block_adapter.transformer[i],
|
|
410
|
-
cache_context=block_adapter.unique_blocks_name[i][-1],
|
|
411
|
-
cache_manager=cache_manager,
|
|
412
|
-
)
|
|
413
|
-
for blocks, unique_name in zip(
|
|
414
|
-
block_adapter.blocks[i],
|
|
415
|
-
block_adapter.unique_blocks_name[i],
|
|
416
|
-
):
|
|
417
|
-
patch_cached_stats(
|
|
418
|
-
blocks,
|
|
419
|
-
cache_context=unique_name,
|
|
420
|
-
cache_manager=cache_manager,
|
|
421
|
-
)
|
|
422
|
-
|
|
423
|
-
@classmethod
|
|
424
|
-
def maybe_release_hooks(
|
|
425
|
-
cls,
|
|
426
|
-
pipe_or_adapter: Union[
|
|
427
|
-
DiffusionPipeline,
|
|
428
|
-
BlockAdapter,
|
|
429
|
-
],
|
|
430
|
-
):
|
|
431
|
-
# release model hooks
|
|
432
|
-
def _release_blocks_hooks(blocks):
|
|
433
|
-
return
|
|
434
|
-
|
|
435
|
-
def _release_transformer_hooks(transformer):
|
|
436
|
-
if hasattr(transformer, "_original_forward"):
|
|
437
|
-
original_forward = transformer._original_forward
|
|
438
|
-
transformer.forward = original_forward.__get__(transformer)
|
|
439
|
-
del transformer._original_forward
|
|
440
|
-
if hasattr(transformer, "_is_cached"):
|
|
441
|
-
del transformer._is_cached
|
|
442
|
-
|
|
443
|
-
def _release_pipeline_hooks(pipe):
|
|
444
|
-
if hasattr(pipe, "_original_call"):
|
|
445
|
-
original_call = pipe.__class__._original_call
|
|
446
|
-
pipe.__class__.__call__ = original_call
|
|
447
|
-
del pipe.__class__._original_call
|
|
448
|
-
if hasattr(pipe, "_cache_manager"):
|
|
449
|
-
cache_manager = pipe._cache_manager
|
|
450
|
-
if isinstance(cache_manager, CachedContextManagerV2):
|
|
451
|
-
cache_manager.clear_contexts()
|
|
452
|
-
del pipe._cache_manager
|
|
453
|
-
if hasattr(pipe, "_is_cached"):
|
|
454
|
-
del pipe.__class__._is_cached
|
|
455
|
-
|
|
456
|
-
cls.release_hooks(
|
|
457
|
-
pipe_or_adapter,
|
|
458
|
-
_release_blocks_hooks,
|
|
459
|
-
_release_transformer_hooks,
|
|
460
|
-
_release_pipeline_hooks,
|
|
461
|
-
)
|
|
462
|
-
|
|
463
|
-
# release params hooks
|
|
464
|
-
def _release_blocks_params(blocks):
|
|
465
|
-
if hasattr(blocks, "_forward_pattern"):
|
|
466
|
-
del blocks._forward_pattern
|
|
467
|
-
if hasattr(blocks, "_cache_context_kwargs"):
|
|
468
|
-
del blocks._cache_context_kwargs
|
|
469
|
-
|
|
470
|
-
def _release_transformer_params(transformer):
|
|
471
|
-
if hasattr(transformer, "_forward_pattern"):
|
|
472
|
-
del transformer._forward_pattern
|
|
473
|
-
if hasattr(transformer, "_has_separate_cfg"):
|
|
474
|
-
del transformer._has_separate_cfg
|
|
475
|
-
if hasattr(transformer, "_cache_context_kwargs"):
|
|
476
|
-
del transformer._cache_context_kwargs
|
|
477
|
-
for blocks in BlockAdapter.find_blocks(transformer):
|
|
478
|
-
_release_blocks_params(blocks)
|
|
479
|
-
|
|
480
|
-
def _release_pipeline_params(pipe):
|
|
481
|
-
if hasattr(pipe, "_cache_context_kwargs"):
|
|
482
|
-
del pipe._cache_context_kwargs
|
|
483
|
-
|
|
484
|
-
cls.release_hooks(
|
|
485
|
-
pipe_or_adapter,
|
|
486
|
-
_release_blocks_params,
|
|
487
|
-
_release_transformer_params,
|
|
488
|
-
_release_pipeline_params,
|
|
489
|
-
)
|
|
490
|
-
|
|
491
|
-
# release stats hooks
|
|
492
|
-
cls.release_hooks(
|
|
493
|
-
pipe_or_adapter,
|
|
494
|
-
remove_cached_stats,
|
|
495
|
-
remove_cached_stats,
|
|
496
|
-
remove_cached_stats,
|
|
497
|
-
)
|
|
498
|
-
|
|
499
|
-
@classmethod
|
|
500
|
-
def release_hooks(
|
|
501
|
-
cls,
|
|
502
|
-
pipe_or_adapter: Union[
|
|
503
|
-
DiffusionPipeline,
|
|
504
|
-
BlockAdapter,
|
|
505
|
-
],
|
|
506
|
-
_release_blocks: Callable,
|
|
507
|
-
_release_transformer: Callable,
|
|
508
|
-
_release_pipeline: Callable,
|
|
509
|
-
):
|
|
510
|
-
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
511
|
-
pipe = pipe_or_adapter
|
|
512
|
-
_release_pipeline(pipe)
|
|
513
|
-
if hasattr(pipe, "transformer"):
|
|
514
|
-
_release_transformer(pipe.transformer)
|
|
515
|
-
if hasattr(pipe, "transformer_2"): # Wan 2.2
|
|
516
|
-
_release_transformer(pipe.transformer_2)
|
|
517
|
-
elif isinstance(pipe_or_adapter, BlockAdapter):
|
|
518
|
-
adapter = pipe_or_adapter
|
|
519
|
-
BlockAdapter.assert_normalized(adapter)
|
|
520
|
-
_release_pipeline(adapter.pipe)
|
|
521
|
-
for transformer in BlockAdapter.flatten(adapter.transformer):
|
|
522
|
-
_release_transformer(transformer)
|
|
523
|
-
for blocks in BlockAdapter.flatten(adapter.blocks):
|
|
524
|
-
_release_blocks(blocks)
|
|
@@ -1,102 +0,0 @@
|
|
|
1
|
-
import math
|
|
2
|
-
import torch
|
|
3
|
-
from typing import List, Dict
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class TaylorSeer:
|
|
7
|
-
def __init__(
|
|
8
|
-
self,
|
|
9
|
-
n_derivatives=2,
|
|
10
|
-
max_warmup_steps=1,
|
|
11
|
-
skip_interval_steps=1,
|
|
12
|
-
compute_step_map=None,
|
|
13
|
-
):
|
|
14
|
-
self.n_derivatives = n_derivatives
|
|
15
|
-
self.ORDER = n_derivatives + 1
|
|
16
|
-
self.max_warmup_steps = max_warmup_steps
|
|
17
|
-
self.skip_interval_steps = skip_interval_steps
|
|
18
|
-
self.compute_step_map = compute_step_map
|
|
19
|
-
self.reset_cache()
|
|
20
|
-
|
|
21
|
-
def reset_cache(self):
|
|
22
|
-
self.state: Dict[str, List[torch.Tensor]] = {
|
|
23
|
-
"dY_prev": [None] * self.ORDER,
|
|
24
|
-
"dY_current": [None] * self.ORDER,
|
|
25
|
-
}
|
|
26
|
-
self.current_step = -1
|
|
27
|
-
self.last_non_approximated_step = -1
|
|
28
|
-
|
|
29
|
-
def should_compute_full(self, step=None):
|
|
30
|
-
step = self.current_step if step is None else step
|
|
31
|
-
if self.compute_step_map is not None:
|
|
32
|
-
return self.compute_step_map[step]
|
|
33
|
-
if (
|
|
34
|
-
step < self.max_warmup_steps
|
|
35
|
-
or (step - self.max_warmup_steps + 1) % self.skip_interval_steps
|
|
36
|
-
== 0
|
|
37
|
-
):
|
|
38
|
-
return True
|
|
39
|
-
return False
|
|
40
|
-
|
|
41
|
-
def approximate_derivative(self, Y: torch.Tensor) -> List[torch.Tensor]:
|
|
42
|
-
# n-th order Taylor expansion:
|
|
43
|
-
# Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
|
|
44
|
-
# + ... + d^nY(0)/dt^n * t^n / n!
|
|
45
|
-
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
46
|
-
# especially for large n_derivatives.
|
|
47
|
-
dY_current: List[torch.Tensor] = [None] * self.ORDER
|
|
48
|
-
dY_current[0] = Y
|
|
49
|
-
window = self.current_step - self.last_non_approximated_step
|
|
50
|
-
if self.state["dY_prev"][0] is not None:
|
|
51
|
-
if dY_current[0].shape != self.state["dY_prev"][0].shape:
|
|
52
|
-
self.reset_cache()
|
|
53
|
-
|
|
54
|
-
for i in range(self.n_derivatives):
|
|
55
|
-
if self.state["dY_prev"][i] is not None and self.current_step > 1:
|
|
56
|
-
dY_current[i + 1] = (
|
|
57
|
-
dY_current[i] - self.state["dY_prev"][i]
|
|
58
|
-
) / window
|
|
59
|
-
else:
|
|
60
|
-
break
|
|
61
|
-
return dY_current
|
|
62
|
-
|
|
63
|
-
def approximate_value(self) -> torch.Tensor:
|
|
64
|
-
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
65
|
-
# especially for large n_derivatives.
|
|
66
|
-
elapsed = self.current_step - self.last_non_approximated_step
|
|
67
|
-
output = 0
|
|
68
|
-
for i, derivative in enumerate(self.state["dY_current"]):
|
|
69
|
-
if derivative is not None:
|
|
70
|
-
output += (1 / math.factorial(i)) * derivative * (elapsed**i)
|
|
71
|
-
else:
|
|
72
|
-
break
|
|
73
|
-
return output
|
|
74
|
-
|
|
75
|
-
def mark_step_begin(self):
|
|
76
|
-
self.current_step += 1
|
|
77
|
-
|
|
78
|
-
def update(self, Y: torch.Tensor):
|
|
79
|
-
# Directly call this method will ingnore the warmup
|
|
80
|
-
# policy and force full computation.
|
|
81
|
-
# Assume warmup steps is 3, and n_derivatives is 3.
|
|
82
|
-
# step 0: dY_prev = [None, None, None, None ]
|
|
83
|
-
# dY_current = [Y0, None, None, None ]
|
|
84
|
-
# step 1: dY_prev = [Y0, None, None, None ]
|
|
85
|
-
# dY_current = [Y1, dY1, None, None ]
|
|
86
|
-
# step 2: dY_prev = [Y1, dY1, None, None ]
|
|
87
|
-
# dY_current = [Y2, dY2/Y1, dY2/dY1, None ]
|
|
88
|
-
# step 3: dY_prev = [Y2, dY2/Y1, dY2/dY1, None ],
|
|
89
|
-
# dY_current = [Y3, dY3/Y2, dY3/dY2, dY3/dY1]
|
|
90
|
-
# step 4: dY_prev = [Y3, dY3/Y2, dY3/dY2, dY3/dY1]
|
|
91
|
-
# dY_current = [Y4, dY4/Y3, dY4/dY3, dY4/dY2]
|
|
92
|
-
self.state["dY_prev"] = self.state["dY_current"]
|
|
93
|
-
self.state["dY_current"] = self.approximate_derivative(Y)
|
|
94
|
-
self.last_non_approximated_step = self.current_step
|
|
95
|
-
|
|
96
|
-
def step(self, Y: torch.Tensor):
|
|
97
|
-
self.mark_step_begin()
|
|
98
|
-
if self.should_compute_full():
|
|
99
|
-
self.update(Y)
|
|
100
|
-
return Y
|
|
101
|
-
else:
|
|
102
|
-
return self.approximate_value()
|
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
from cache_dit.cache_factory.cache_contexts.v2.calibrators import (
|
|
2
|
-
Calibrator,
|
|
3
|
-
CalibratorBase,
|
|
4
|
-
CalibratorConfig,
|
|
5
|
-
TaylorSeerCalibratorConfig,
|
|
6
|
-
FoCaCalibratorConfig,
|
|
7
|
-
)
|
|
8
|
-
from cache_dit.cache_factory.cache_contexts.v2.cache_context_v2 import (
|
|
9
|
-
CachedContextV2,
|
|
10
|
-
)
|
|
11
|
-
from cache_dit.cache_factory.cache_contexts.v2.cache_manager_v2 import (
|
|
12
|
-
CachedContextManagerV2,
|
|
13
|
-
)
|