cache-dit 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (34) hide show
  1. cache_dit/__init__.py +1 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +3 -6
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +21 -64
  5. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  6. cache_dit/cache_factory/cache_adapters/cache_adapter.py +82 -21
  7. cache_dit/cache_factory/cache_blocks/__init__.py +4 -0
  8. cache_dit/cache_factory/cache_blocks/offload_utils.py +115 -0
  9. cache_dit/cache_factory/cache_blocks/pattern_base.py +3 -0
  10. cache_dit/cache_factory/cache_contexts/__init__.py +10 -8
  11. cache_dit/cache_factory/cache_contexts/cache_context.py +186 -117
  12. cache_dit/cache_factory/cache_contexts/cache_manager.py +63 -131
  13. cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +132 -0
  14. cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/foca.py +1 -1
  15. cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/taylorseer.py +7 -2
  16. cache_dit/cache_factory/cache_interface.py +128 -111
  17. cache_dit/cache_factory/params_modifier.py +87 -0
  18. cache_dit/metrics/__init__.py +3 -1
  19. cache_dit/utils.py +12 -21
  20. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/METADATA +200 -434
  21. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/RECORD +27 -31
  22. cache_dit/cache_factory/cache_adapters/v2/__init__.py +0 -3
  23. cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +0 -524
  24. cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -102
  25. cache_dit/cache_factory/cache_contexts/v2/__init__.py +0 -13
  26. cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +0 -288
  27. cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +0 -799
  28. cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +0 -81
  29. /cache_dit/cache_factory/cache_blocks/{utils.py → pattern_utils.py} +0 -0
  30. /cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/base.py +0 -0
  31. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/WHEEL +0 -0
  32. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/entry_points.txt +0 -0
  33. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/licenses/LICENSE +0 -0
  34. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/top_level.txt +0 -0
@@ -1,36 +1,32 @@
1
- cache_dit/__init__.py,sha256=Nd4a609z8PLFMSO8J0sUe2xRaFDIYK8778ff8yBU7uQ,1457
2
- cache_dit/_version.py,sha256=gGLpQUQx-ty9SEy9PYw9OgJWWzJLBnCpfJOfzL7SjlI,704
1
+ cache_dit/__init__.py,sha256=sHRg0swXZZiw6lvSQ53fcVtN9JRayx0az2lXAz5OOGI,1510
2
+ cache_dit/_version.py,sha256=lemL_4Kl75FgrO6lVuFrrtw6-Dcf9wtXBalKkXuzkO4,704
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
- cache_dit/utils.py,sha256=bERXpCaCpOPThXB8Rkk52yAjjLrvxbt12ntpzpWdfUQ,11131
4
+ cache_dit/utils.py,sha256=AyYRwi5XBxYBH4GaXxOxv9-X24Te_IYOYwh54t_1d3A,10674
5
5
  cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
6
- cache_dit/cache_factory/__init__.py,sha256=Jj_Op6ACV35XilFPax3HEEsf_hOomjmogmNyWWteq_4,1539
7
- cache_dit/cache_factory/cache_interface.py,sha256=xpC-CWZDBfMb5BfnXnVW25xJhV8cYMRns-LKcPDksPU,9846
6
+ cache_dit/cache_factory/__init__.py,sha256=vy9I6Ofkj9jWeUoOvh-cY5a9QlDDKfj2FVPlVTf7BeA,1390
7
+ cache_dit/cache_factory/cache_interface.py,sha256=A_8bBsLfGOE5wM3_rniQKPJ223_-fSpNIq65uv00sF0,10620
8
8
  cache_dit/cache_factory/cache_types.py,sha256=ooukxQRG55uTLmaZ0SKw6gIeY6SQHhMxkbv55uj2Sqk,991
9
9
  cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
10
+ cache_dit/cache_factory/params_modifier.py,sha256=zYJJsInTYCaYHBZ7mZJOP-PZnkSg3iN1WPewNOayXos,3628
10
11
  cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
11
12
  cache_dit/cache_factory/block_adapters/__init__.py,sha256=33geXMz56TxFWMp0c-H4__MY5SGRzKMKj3TXnUYOMlc,17512
12
- cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=HlmStNIny0rZiRBYw-xdYYViVk9AEt0XlquoacEGr1U,24203
13
+ cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=2TVK_KqiYXC7AKZ2s07fzdOzUoeUBc9P1SzQtLVzhf4,22249
13
14
  cache_dit/cache_factory/block_adapters/block_registers.py,sha256=2L7QeM4ygnaKQpC9PoJod0QRYyxidUKU2AYpysDCUwE,2572
14
- cache_dit/cache_factory/cache_adapters/__init__.py,sha256=qB4bu1m3LgotOeNKluIkbQIf72PXpZWQMaSn1MOFEmY,149
15
- cache_dit/cache_factory/cache_adapters/cache_adapter.py,sha256=6WArUrTmtkZg147_Qef5jfzMVRg2hfYwvSB9Cvpf_HA,18297
16
- cache_dit/cache_factory/cache_adapters/v2/__init__.py,sha256=9PAH5YwpG_m0feE5eFQ7d2450nQR_Ctq8cd9Xu1Ldtk,96
17
- cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py,sha256=ove_pDh2QC3vjXWIYtrb8anc-NOmPIrDZN7hu16fjwU,18309
18
- cache_dit/cache_factory/cache_blocks/__init__.py,sha256=08Ox7kD05lkRKCOsVTdEZeKAWBheqpxfrAT1Nz7eclI,2916
15
+ cache_dit/cache_factory/cache_adapters/__init__.py,sha256=py71WGD3JztQ1uk6qdLVbzYcQ1rvqFidNNaQYo7tqTo,79
16
+ cache_dit/cache_factory/cache_adapters/cache_adapter.py,sha256=PuNFO0t9510MhOOJy93cz0uiG8PeWKsjgUWshNj76LQ,20906
17
+ cache_dit/cache_factory/cache_blocks/__init__.py,sha256=mivvm8YOfqT7YHs8y_MzGOGztPw8LxAqKGXuSRXxCv0,3032
18
+ cache_dit/cache_factory/cache_blocks/offload_utils.py,sha256=wusgcqaCrwEjvv7Guy-6VXhNOgPPUrBV2sSVuRmGuvo,3513
19
19
  cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=ElMps6_7uI74tSF9GDR_dEI0bZEhdzcepM29xFWnYo8,428
20
20
  cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=Bv56qETXhsREvCrNvnZpSqDIIHsi6Ze3FJW4Yk2x3uI,8597
21
- cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=d4H9kEB0AgnVMT8aF0Y54SUMUQUxw5HQ8gRkoCuTQ_A,14577
22
- cache_dit/cache_factory/cache_blocks/utils.py,sha256=dGOC1tMMOvcbvEgx44eTESKn_jsv-0RZ3tRHPa3wmQ4,1315
23
- cache_dit/cache_factory/cache_contexts/__init__.py,sha256=MQRxis-5gMhdJ6ZXIVN2nZEGPZoRLy59gSLniTYrWGY,437
24
- cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=FWdgInClWY8VZBsZIevtYk--rX-RL8c3QfNOJtqR8a4,11855
25
- cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=Ig5VKoQ46iG3lKmsaMulYxd2vCm__2rY8NBvERwexwM,32719
26
- cache_dit/cache_factory/cache_contexts/taylorseer.py,sha256=4nxgSEZvDy-w-7XuJYzsyzdtF1_uFrDwlF06XBDFVKQ,3922
27
- cache_dit/cache_factory/cache_contexts/v2/__init__.py,sha256=GVafOd9BUa-Tyv7FZbTSkd4bGJPpMonb1AZv78qLeHU,385
28
- cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py,sha256=JkMJSm-zme9ayonSFq6Y6esCb6RMuGLvhVINM-LFj2Y,11776
29
- cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py,sha256=ZRTl0M7jIPTIBS9lXoSh_pY6-hNu3JJ94WShv2CPWkk,32788
30
- cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py,sha256=BLCV0EtOcu30iytErL_IK6J9ZwmpE6P9ffNt4OL-IaU,2343
31
- cache_dit/cache_factory/cache_contexts/v2/calibrators/base.py,sha256=mn6ZBkChGpGwN5csrHTUGMoX6BBPvqHXSLbIExiW-EU,748
32
- cache_dit/cache_factory/cache_contexts/v2/calibrators/foca.py,sha256=jrEkoiLgDR2fiX_scIpaLIDT0pTMc9stg6L9HBkgsZw,894
33
- cache_dit/cache_factory/cache_contexts/v2/calibrators/taylorseer.py,sha256=q5xBmT4EmpF_b3KPAjMIangTBvovE_c8ZfFjIN_E9tg,3834
21
+ cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=wdh0bbcpKO08AW2FTsj9X_tTbFCLkDmBjrstMxTf7MQ,14668
22
+ cache_dit/cache_factory/cache_blocks/pattern_utils.py,sha256=dGOC1tMMOvcbvEgx44eTESKn_jsv-0RZ3tRHPa3wmQ4,1315
23
+ cache_dit/cache_factory/cache_contexts/__init__.py,sha256=T6Vak3x7Rs0Oy15Tou49p-rPQRA2jiuYtJBsbv1lBBU,388
24
+ cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=3EhaMCz3VUQ_NF81VgYwWoSEGIvhScPxPYhjL1OcgxE,15240
25
+ cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=hSKAeP1CxmO3RFUxjFjAK1xdvVvTmeayh5jEHMaQXNE,30225
26
+ cache_dit/cache_factory/cache_contexts/calibrators/__init__.py,sha256=mzYXO8tbytGpJJ9rpPu20kMoj1Iu_7Ym9tjfzV8rA98,5574
27
+ cache_dit/cache_factory/cache_contexts/calibrators/base.py,sha256=mn6ZBkChGpGwN5csrHTUGMoX6BBPvqHXSLbIExiW-EU,748
28
+ cache_dit/cache_factory/cache_contexts/calibrators/foca.py,sha256=nhHGs_hxwW1M942BQDMJb9-9IuHdnOxp774Jrna1bJI,891
29
+ cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py,sha256=aGxr9SpytYznTepDWGPAxWDnuVMSuNyn6uNXnLh2acQ,4001
34
30
  cache_dit/cache_factory/patch_functors/__init__.py,sha256=oI6F3N9ezahRHaFUOZ1GfrAw1qFdKrxFXXmlwwehHj4,530
35
31
  cache_dit/cache_factory/patch_functors/functor_base.py,sha256=Ahk0fTfrHgNdEl-9JSkACvfyyv9G-Ei5OSz7XBIlX5o,357
36
32
  cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=xD0Q96VArp1vYBLQ0pcjRIyFB1i_Y7muZ2q07Hz8Oqs,13430
@@ -42,7 +38,7 @@ cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0
42
38
  cache_dit/compile/utils.py,sha256=nN2OIrSdwRR5zGxJinKDqb07pXpvTNTF3g_OgLkeeBU,3858
43
39
  cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
40
  cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
- cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
41
+ cache_dit/metrics/__init__.py,sha256=UjPJ69DyyjZDfERTpKAjZKOxOTx58aWnkze7VfH3en8,673
46
42
  cache_dit/metrics/clip_score.py,sha256=ERNCFQFJKzJdbIX9OAg-1LiSPuXUVHLOFxbf2gcENpc,3938
47
43
  cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
48
44
  cache_dit/metrics/fid.py,sha256=ZM_FM0XERtpnkMUfphmw2aOdljrh1uba-pnYItu0q6M,18219
@@ -53,9 +49,9 @@ cache_dit/metrics/metrics.py,sha256=7UV-H2NRbhfr6dvrXEzU97Zy-BSQ5zEfm9CKtaK4ldg,
53
49
  cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70A,59
54
50
  cache_dit/quantize/quantize_ao.py,sha256=Fx1KW4l3gdEkdrcAYtPoDW7WKBJWrs3glOHiEwW_TgE,6160
55
51
  cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
56
- cache_dit-0.3.1.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
57
- cache_dit-0.3.1.dist-info/METADATA,sha256=I3gHe9m40_Ja0VurS7CDBYx_x_4rpra8zN245gBKv-A,46536
58
- cache_dit-0.3.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
59
- cache_dit-0.3.1.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
60
- cache_dit-0.3.1.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
61
- cache_dit-0.3.1.dist-info/RECORD,,
52
+ cache_dit-0.3.3.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
53
+ cache_dit-0.3.3.dist-info/METADATA,sha256=2kUqLHOXsbb25iz6uO8Y3pzOVMSaRHs-st6o3imjX_o,34752
54
+ cache_dit-0.3.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
55
+ cache_dit-0.3.3.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
56
+ cache_dit-0.3.3.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
57
+ cache_dit-0.3.3.dist-info/RECORD,,
@@ -1,3 +0,0 @@
1
- from cache_dit.cache_factory.cache_adapters.v2.cache_adapter_v2 import (
2
- CachedAdapterV2,
3
- )
@@ -1,524 +0,0 @@
1
- import torch
2
-
3
- import unittest
4
- import functools
5
-
6
- from contextlib import ExitStack
7
- from typing import Dict, List, Tuple, Any, Union, Callable
8
-
9
- from diffusers import DiffusionPipeline
10
-
11
- from cache_dit.cache_factory.cache_types import CacheType
12
- from cache_dit.cache_factory.block_adapters import BlockAdapter
13
- from cache_dit.cache_factory.block_adapters import ParamsModifier
14
- from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
15
- from cache_dit.cache_factory.cache_contexts import CachedContextManagerV2
16
- from cache_dit.cache_factory.cache_blocks import CachedBlocks
17
- from cache_dit.cache_factory.cache_blocks.utils import (
18
- patch_cached_stats,
19
- remove_cached_stats,
20
- )
21
- from cache_dit.logger import init_logger
22
-
23
- logger = init_logger(__name__)
24
-
25
-
26
- # Unified Cached Adapter
27
- class CachedAdapterV2:
28
-
29
- def __call__(self, *args, **kwargs):
30
- return self.apply(*args, **kwargs)
31
-
32
- @classmethod
33
- def apply(
34
- cls,
35
- pipe_or_adapter: Union[
36
- DiffusionPipeline,
37
- BlockAdapter,
38
- ],
39
- **cache_context_kwargs,
40
- ) -> Union[
41
- DiffusionPipeline,
42
- BlockAdapter,
43
- ]:
44
- assert (
45
- pipe_or_adapter is not None
46
- ), "pipe or block_adapter can not both None!"
47
-
48
- if isinstance(pipe_or_adapter, DiffusionPipeline):
49
- if BlockAdapterRegistry.is_supported(pipe_or_adapter):
50
- logger.info(
51
- f"{pipe_or_adapter.__class__.__name__} is officially "
52
- "supported by cache-dit. Use it's pre-defined BlockAdapter "
53
- "directly!"
54
- )
55
- block_adapter = BlockAdapterRegistry.get_adapter(
56
- pipe_or_adapter
57
- )
58
- return cls.cachify(
59
- block_adapter,
60
- **cache_context_kwargs,
61
- ).pipe
62
- else:
63
- raise ValueError(
64
- f"{pipe_or_adapter.__class__.__name__} is not officially supported "
65
- "by cache-dit, please set BlockAdapter instead!"
66
- )
67
- else:
68
- assert isinstance(pipe_or_adapter, BlockAdapter)
69
- logger.info(
70
- "Adapting Cache Acceleration using custom BlockAdapter!"
71
- )
72
- return cls.cachify(
73
- pipe_or_adapter,
74
- **cache_context_kwargs,
75
- )
76
-
77
- @classmethod
78
- def cachify(
79
- cls,
80
- block_adapter: BlockAdapter,
81
- **cache_context_kwargs,
82
- ) -> BlockAdapter:
83
-
84
- if block_adapter.auto:
85
- block_adapter = BlockAdapter.auto_block_adapter(
86
- block_adapter,
87
- )
88
-
89
- if BlockAdapter.check_block_adapter(block_adapter):
90
-
91
- # 0. Must normalize block_adapter before apply cache
92
- block_adapter = BlockAdapter.normalize(block_adapter)
93
- if BlockAdapter.is_cached(block_adapter):
94
- return block_adapter
95
-
96
- # 1. Apply cache on pipeline: wrap cache context, must
97
- # call create_context before mock_blocks.
98
- cls.create_context(
99
- block_adapter,
100
- **cache_context_kwargs,
101
- )
102
-
103
- # 2. Apply cache on transformer: mock cached blocks
104
- cls.mock_blocks(
105
- block_adapter,
106
- )
107
-
108
- return block_adapter
109
-
110
- @classmethod
111
- def check_context_kwargs(
112
- cls,
113
- block_adapter: BlockAdapter,
114
- **cache_context_kwargs,
115
- ):
116
- # Check cache_context_kwargs
117
- if cache_context_kwargs["enable_separate_cfg"] is None:
118
- # Check cfg for some specific case if users don't set it as True
119
- if BlockAdapterRegistry.has_separate_cfg(block_adapter):
120
- cache_context_kwargs["enable_separate_cfg"] = True
121
- logger.info(
122
- f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
123
- f"Pipeline: {block_adapter.pipe.__class__.__name__}."
124
- )
125
- else:
126
- cache_context_kwargs["enable_separate_cfg"] = (
127
- BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
128
- )
129
- logger.info(
130
- f"Use default 'enable_separate_cfg' from block adapter "
131
- f"register: {cache_context_kwargs['enable_separate_cfg']}, "
132
- f"Pipeline: {block_adapter.pipe.__class__.__name__}."
133
- )
134
- else:
135
- logger.info(
136
- f"Use custom 'enable_separate_cfg' from cache context "
137
- f"kwargs: {cache_context_kwargs['enable_separate_cfg']}. "
138
- f"Pipeline: {block_adapter.pipe.__class__.__name__}."
139
- )
140
-
141
- if (
142
- cache_type := cache_context_kwargs.pop("cache_type", None)
143
- ) is not None:
144
- assert (
145
- cache_type == CacheType.DBCache
146
- ), "Custom cache setting only support for DBCache now!"
147
-
148
- return cache_context_kwargs
149
-
150
- @classmethod
151
- def create_context(
152
- cls,
153
- block_adapter: BlockAdapter,
154
- **cache_context_kwargs,
155
- ) -> DiffusionPipeline:
156
-
157
- BlockAdapter.assert_normalized(block_adapter)
158
-
159
- if BlockAdapter.is_cached(block_adapter.pipe):
160
- return block_adapter.pipe
161
-
162
- # Check cache_context_kwargs
163
- cache_context_kwargs = cls.check_context_kwargs(
164
- block_adapter, **cache_context_kwargs
165
- )
166
- # Apply cache on pipeline: wrap cache context
167
- pipe_cls_name = block_adapter.pipe.__class__.__name__
168
-
169
- # Each Pipeline should have it's own context manager instance.
170
- # Different transformers (Wan2.2, etc) should shared the same
171
- # cache manager but with different cache context (according
172
- # to their unique instance id).
173
- cache_manager = CachedContextManagerV2(
174
- name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
175
- )
176
- block_adapter.pipe._cache_manager = cache_manager # instance level
177
-
178
- flatten_contexts, contexts_kwargs = cls.modify_context_params(
179
- block_adapter, cache_manager, **cache_context_kwargs
180
- )
181
-
182
- original_call = block_adapter.pipe.__class__.__call__
183
-
184
- @functools.wraps(original_call)
185
- def new_call(self, *args, **kwargs):
186
- with ExitStack() as stack:
187
- # cache context will be reset for each pipe inference
188
- for context_name, context_kwargs in zip(
189
- flatten_contexts, contexts_kwargs
190
- ):
191
- stack.enter_context(
192
- cache_manager.enter_context(
193
- cache_manager.reset_context(
194
- context_name,
195
- **context_kwargs,
196
- ),
197
- )
198
- )
199
- outputs = original_call(self, *args, **kwargs)
200
- cls.apply_stats_hooks(block_adapter)
201
- return outputs
202
-
203
- block_adapter.pipe.__class__.__call__ = new_call
204
- block_adapter.pipe.__class__._original_call = original_call
205
- block_adapter.pipe.__class__._is_cached = True
206
-
207
- cls.apply_params_hooks(block_adapter, contexts_kwargs)
208
-
209
- return block_adapter.pipe
210
-
211
- @classmethod
212
- def modify_context_params(
213
- cls,
214
- block_adapter: BlockAdapter,
215
- cache_manager: CachedContextManagerV2,
216
- **cache_context_kwargs,
217
- ) -> Tuple[List[str], List[Dict[str, Any]]]:
218
-
219
- flatten_contexts = BlockAdapter.flatten(
220
- block_adapter.unique_blocks_name
221
- )
222
- contexts_kwargs = [
223
- cache_context_kwargs.copy()
224
- for _ in range(
225
- len(flatten_contexts),
226
- )
227
- ]
228
-
229
- for i in range(len(contexts_kwargs)):
230
- contexts_kwargs[i]["name"] = flatten_contexts[i]
231
-
232
- if block_adapter.params_modifiers is None:
233
- return flatten_contexts, contexts_kwargs
234
-
235
- flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
236
- block_adapter.params_modifiers,
237
- )
238
-
239
- for i in range(
240
- min(len(contexts_kwargs), len(flatten_modifiers)),
241
- ):
242
- contexts_kwargs[i].update(
243
- flatten_modifiers[i]._context_kwargs,
244
- )
245
- contexts_kwargs[i], _ = cache_manager.collect_cache_kwargs(
246
- default_attrs={}, **contexts_kwargs[i]
247
- )
248
-
249
- return flatten_contexts, contexts_kwargs
250
-
251
- @classmethod
252
- def mock_blocks(
253
- cls,
254
- block_adapter: BlockAdapter,
255
- ) -> List[torch.nn.Module]:
256
-
257
- BlockAdapter.assert_normalized(block_adapter)
258
-
259
- if BlockAdapter.is_cached(block_adapter.transformer):
260
- return block_adapter.transformer
261
-
262
- # Apply cache on transformer: mock cached transformer blocks
263
- for (
264
- cached_blocks,
265
- transformer,
266
- blocks_name,
267
- unique_blocks_name,
268
- dummy_blocks_names,
269
- ) in zip(
270
- cls.collect_cached_blocks(block_adapter),
271
- block_adapter.transformer,
272
- block_adapter.blocks_name,
273
- block_adapter.unique_blocks_name,
274
- block_adapter.dummy_blocks_names,
275
- ):
276
- cls.mock_transformer(
277
- cached_blocks,
278
- transformer,
279
- blocks_name,
280
- unique_blocks_name,
281
- dummy_blocks_names,
282
- )
283
-
284
- return block_adapter.transformer
285
-
286
- @classmethod
287
- def mock_transformer(
288
- cls,
289
- cached_blocks: Dict[str, torch.nn.ModuleList],
290
- transformer: torch.nn.Module,
291
- blocks_name: List[str],
292
- unique_blocks_name: List[str],
293
- dummy_blocks_names: List[str],
294
- ) -> torch.nn.Module:
295
- dummy_blocks = torch.nn.ModuleList()
296
-
297
- original_forward = transformer.forward
298
-
299
- assert isinstance(dummy_blocks_names, list)
300
-
301
- @functools.wraps(original_forward)
302
- def new_forward(self, *args, **kwargs):
303
- with ExitStack() as stack:
304
- for name, context_name in zip(
305
- blocks_name,
306
- unique_blocks_name,
307
- ):
308
- stack.enter_context(
309
- unittest.mock.patch.object(
310
- self, name, cached_blocks[context_name]
311
- )
312
- )
313
- for dummy_name in dummy_blocks_names:
314
- stack.enter_context(
315
- unittest.mock.patch.object(
316
- self, dummy_name, dummy_blocks
317
- )
318
- )
319
- return original_forward(*args, **kwargs)
320
-
321
- transformer.forward = new_forward.__get__(transformer)
322
- transformer._original_forward = original_forward
323
- transformer._is_cached = True
324
-
325
- return transformer
326
-
327
- @classmethod
328
- def collect_cached_blocks(
329
- cls,
330
- block_adapter: BlockAdapter,
331
- ) -> List[Dict[str, torch.nn.ModuleList]]:
332
-
333
- BlockAdapter.assert_normalized(block_adapter)
334
-
335
- total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
336
- assert hasattr(block_adapter.pipe, "_cache_manager")
337
- assert isinstance(
338
- block_adapter.pipe._cache_manager, CachedContextManagerV2
339
- )
340
-
341
- for i in range(len(block_adapter.transformer)):
342
-
343
- cached_blocks_bind_context = {}
344
- for j in range(len(block_adapter.blocks[i])):
345
- cached_blocks_bind_context[
346
- block_adapter.unique_blocks_name[i][j]
347
- ] = torch.nn.ModuleList(
348
- [
349
- CachedBlocks(
350
- # 0. Transformer blocks configuration
351
- block_adapter.blocks[i][j],
352
- transformer=block_adapter.transformer[i],
353
- forward_pattern=block_adapter.forward_pattern[i][j],
354
- check_forward_pattern=block_adapter.check_forward_pattern,
355
- check_num_outputs=block_adapter.check_num_outputs,
356
- # 1. Cache context configuration
357
- cache_prefix=block_adapter.blocks_name[i][j],
358
- cache_context=block_adapter.unique_blocks_name[i][
359
- j
360
- ],
361
- cache_manager=block_adapter.pipe._cache_manager,
362
- )
363
- ]
364
- )
365
-
366
- total_cached_blocks.append(cached_blocks_bind_context)
367
-
368
- return total_cached_blocks
369
-
370
- @classmethod
371
- def apply_params_hooks(
372
- cls,
373
- block_adapter: BlockAdapter,
374
- contexts_kwargs: List[Dict],
375
- ):
376
- block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
377
-
378
- params_shift = 0
379
- for i in range(len(block_adapter.transformer)):
380
-
381
- block_adapter.transformer[i]._forward_pattern = (
382
- block_adapter.forward_pattern
383
- )
384
- block_adapter.transformer[i]._has_separate_cfg = (
385
- block_adapter.has_separate_cfg
386
- )
387
- block_adapter.transformer[i]._cache_context_kwargs = (
388
- contexts_kwargs[params_shift]
389
- )
390
-
391
- blocks = block_adapter.blocks[i]
392
- for j in range(len(blocks)):
393
- blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
394
- blocks[j]._cache_context_kwargs = contexts_kwargs[
395
- params_shift + j
396
- ]
397
-
398
- params_shift += len(blocks)
399
-
400
- @classmethod
401
- def apply_stats_hooks(
402
- cls,
403
- block_adapter: BlockAdapter,
404
- ):
405
- cache_manager = block_adapter.pipe._cache_manager
406
-
407
- for i in range(len(block_adapter.transformer)):
408
- patch_cached_stats(
409
- block_adapter.transformer[i],
410
- cache_context=block_adapter.unique_blocks_name[i][-1],
411
- cache_manager=cache_manager,
412
- )
413
- for blocks, unique_name in zip(
414
- block_adapter.blocks[i],
415
- block_adapter.unique_blocks_name[i],
416
- ):
417
- patch_cached_stats(
418
- blocks,
419
- cache_context=unique_name,
420
- cache_manager=cache_manager,
421
- )
422
-
423
- @classmethod
424
- def maybe_release_hooks(
425
- cls,
426
- pipe_or_adapter: Union[
427
- DiffusionPipeline,
428
- BlockAdapter,
429
- ],
430
- ):
431
- # release model hooks
432
- def _release_blocks_hooks(blocks):
433
- return
434
-
435
- def _release_transformer_hooks(transformer):
436
- if hasattr(transformer, "_original_forward"):
437
- original_forward = transformer._original_forward
438
- transformer.forward = original_forward.__get__(transformer)
439
- del transformer._original_forward
440
- if hasattr(transformer, "_is_cached"):
441
- del transformer._is_cached
442
-
443
- def _release_pipeline_hooks(pipe):
444
- if hasattr(pipe, "_original_call"):
445
- original_call = pipe.__class__._original_call
446
- pipe.__class__.__call__ = original_call
447
- del pipe.__class__._original_call
448
- if hasattr(pipe, "_cache_manager"):
449
- cache_manager = pipe._cache_manager
450
- if isinstance(cache_manager, CachedContextManagerV2):
451
- cache_manager.clear_contexts()
452
- del pipe._cache_manager
453
- if hasattr(pipe, "_is_cached"):
454
- del pipe.__class__._is_cached
455
-
456
- cls.release_hooks(
457
- pipe_or_adapter,
458
- _release_blocks_hooks,
459
- _release_transformer_hooks,
460
- _release_pipeline_hooks,
461
- )
462
-
463
- # release params hooks
464
- def _release_blocks_params(blocks):
465
- if hasattr(blocks, "_forward_pattern"):
466
- del blocks._forward_pattern
467
- if hasattr(blocks, "_cache_context_kwargs"):
468
- del blocks._cache_context_kwargs
469
-
470
- def _release_transformer_params(transformer):
471
- if hasattr(transformer, "_forward_pattern"):
472
- del transformer._forward_pattern
473
- if hasattr(transformer, "_has_separate_cfg"):
474
- del transformer._has_separate_cfg
475
- if hasattr(transformer, "_cache_context_kwargs"):
476
- del transformer._cache_context_kwargs
477
- for blocks in BlockAdapter.find_blocks(transformer):
478
- _release_blocks_params(blocks)
479
-
480
- def _release_pipeline_params(pipe):
481
- if hasattr(pipe, "_cache_context_kwargs"):
482
- del pipe._cache_context_kwargs
483
-
484
- cls.release_hooks(
485
- pipe_or_adapter,
486
- _release_blocks_params,
487
- _release_transformer_params,
488
- _release_pipeline_params,
489
- )
490
-
491
- # release stats hooks
492
- cls.release_hooks(
493
- pipe_or_adapter,
494
- remove_cached_stats,
495
- remove_cached_stats,
496
- remove_cached_stats,
497
- )
498
-
499
- @classmethod
500
- def release_hooks(
501
- cls,
502
- pipe_or_adapter: Union[
503
- DiffusionPipeline,
504
- BlockAdapter,
505
- ],
506
- _release_blocks: Callable,
507
- _release_transformer: Callable,
508
- _release_pipeline: Callable,
509
- ):
510
- if isinstance(pipe_or_adapter, DiffusionPipeline):
511
- pipe = pipe_or_adapter
512
- _release_pipeline(pipe)
513
- if hasattr(pipe, "transformer"):
514
- _release_transformer(pipe.transformer)
515
- if hasattr(pipe, "transformer_2"): # Wan 2.2
516
- _release_transformer(pipe.transformer_2)
517
- elif isinstance(pipe_or_adapter, BlockAdapter):
518
- adapter = pipe_or_adapter
519
- BlockAdapter.assert_normalized(adapter)
520
- _release_pipeline(adapter.pipe)
521
- for transformer in BlockAdapter.flatten(adapter.transformer):
522
- _release_transformer(transformer)
523
- for blocks in BlockAdapter.flatten(adapter.blocks):
524
- _release_blocks(blocks)