cache-dit 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +9 -4
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +16 -3
- cache_dit/cache_factory/block_adapters/__init__.py +538 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
- cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
- cache_dit/cache_factory/cache_adapters.py +121 -563
- cache_dit/cache_factory/cache_blocks/__init__.py +18 -0
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +275 -0
- cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +100 -82
- cache_dit/cache_factory/cache_blocks/utils.py +23 -0
- cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
- cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +94 -56
- cache_dit/cache_factory/cache_interface.py +24 -16
- cache_dit/cache_factory/forward_pattern.py +45 -24
- cache_dit/cache_factory/patch_functors/__init__.py +5 -0
- cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +276 -0
- cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +49 -31
- cache_dit/quantize/quantize_ao.py +19 -4
- cache_dit/quantize/quantize_interface.py +2 -2
- cache_dit/utils.py +19 -15
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/METADATA +76 -19
- cache_dit-0.2.27.dist-info/RECORD +47 -0
- cache_dit-0.2.25.dist-info/RECORD +0 -36
- /cache_dit/cache_factory/{patch/__init__.py → cache_contexts/cache_manager.py} +0 -0
- /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import dataclasses
|
|
5
|
+
|
|
6
|
+
from typing import Any, Tuple, List, Optional
|
|
7
|
+
|
|
8
|
+
from diffusers import DiffusionPipeline
|
|
9
|
+
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
10
|
+
from cache_dit.cache_factory.patch_functors import PatchFunctor
|
|
11
|
+
|
|
12
|
+
from cache_dit.logger import init_logger
|
|
13
|
+
|
|
14
|
+
logger = init_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclasses.dataclass
|
|
18
|
+
class BlockAdapter:
|
|
19
|
+
# Transformer configurations.
|
|
20
|
+
pipe: DiffusionPipeline | Any = None
|
|
21
|
+
transformer: torch.nn.Module = None
|
|
22
|
+
|
|
23
|
+
# ------------ Block Level Flags ------------
|
|
24
|
+
blocks: torch.nn.ModuleList | List[torch.nn.ModuleList] = None
|
|
25
|
+
# transformer_blocks, blocks, etc.
|
|
26
|
+
blocks_name: str | List[str] = None
|
|
27
|
+
dummy_blocks_names: List[str] = dataclasses.field(default_factory=list)
|
|
28
|
+
forward_pattern: ForwardPattern | List[ForwardPattern] = None
|
|
29
|
+
check_num_outputs: bool = True
|
|
30
|
+
|
|
31
|
+
# Flags to control auto block adapter
|
|
32
|
+
auto: bool = False
|
|
33
|
+
allow_prefixes: List[str] = dataclasses.field(
|
|
34
|
+
default_factory=lambda: [
|
|
35
|
+
"transformer",
|
|
36
|
+
"single_transformer",
|
|
37
|
+
"blocks",
|
|
38
|
+
"layers",
|
|
39
|
+
"single_stream_blocks",
|
|
40
|
+
"double_stream_blocks",
|
|
41
|
+
]
|
|
42
|
+
)
|
|
43
|
+
check_prefixes: bool = True
|
|
44
|
+
allow_suffixes: List[str] = dataclasses.field(
|
|
45
|
+
default_factory=lambda: ["TransformerBlock"]
|
|
46
|
+
)
|
|
47
|
+
check_suffixes: bool = False
|
|
48
|
+
blocks_policy: str = dataclasses.field(
|
|
49
|
+
default="max", metadata={"allowed_values": ["max", "min"]}
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# NOTE: Other flags.
|
|
53
|
+
disable_patch: bool = False
|
|
54
|
+
|
|
55
|
+
# ------------ Pipeline Level Flags ------------
|
|
56
|
+
# Patch Functor: Flux, etc.
|
|
57
|
+
patch_functor: Optional[PatchFunctor] = None
|
|
58
|
+
# Flags for separate cfg
|
|
59
|
+
has_separate_cfg: bool = False
|
|
60
|
+
|
|
61
|
+
def __post_init__(self):
|
|
62
|
+
assert any((self.pipe is not None, self.transformer is not None))
|
|
63
|
+
self.patchify()
|
|
64
|
+
|
|
65
|
+
def patchify(self, *args, **kwargs):
|
|
66
|
+
# Process some specificial cases, specific for transformers
|
|
67
|
+
# that has different forward patterns between single_transformer_blocks
|
|
68
|
+
# and transformer_blocks , such as Flux (diffusers < 0.35.0).
|
|
69
|
+
if self.patch_functor is not None and not self.disable_patch:
|
|
70
|
+
if self.transformer is not None:
|
|
71
|
+
self.patch_functor.apply(self.transformer, *args, **kwargs)
|
|
72
|
+
else:
|
|
73
|
+
assert hasattr(self.pipe, "transformer")
|
|
74
|
+
self.patch_functor.apply(self.pipe.transformer, *args, **kwargs)
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def auto_block_adapter(
|
|
78
|
+
adapter: "BlockAdapter",
|
|
79
|
+
) -> "BlockAdapter":
|
|
80
|
+
assert adapter.auto, (
|
|
81
|
+
"Please manually set `auto` to True, or, manually "
|
|
82
|
+
"set all the transformer blocks configuration."
|
|
83
|
+
)
|
|
84
|
+
assert adapter.pipe is not None, "adapter.pipe can not be None."
|
|
85
|
+
assert (
|
|
86
|
+
adapter.forward_pattern is not None
|
|
87
|
+
), "adapter.forward_pattern can not be None."
|
|
88
|
+
pipe = adapter.pipe
|
|
89
|
+
|
|
90
|
+
assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
|
|
91
|
+
|
|
92
|
+
transformer = pipe.transformer
|
|
93
|
+
|
|
94
|
+
# "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
|
|
95
|
+
blocks, blocks_name = BlockAdapter.find_blocks(
|
|
96
|
+
transformer=transformer,
|
|
97
|
+
allow_prefixes=adapter.allow_prefixes,
|
|
98
|
+
allow_suffixes=adapter.allow_suffixes,
|
|
99
|
+
check_prefixes=adapter.check_prefixes,
|
|
100
|
+
check_suffixes=adapter.check_suffixes,
|
|
101
|
+
blocks_policy=adapter.blocks_policy,
|
|
102
|
+
forward_pattern=adapter.forward_pattern,
|
|
103
|
+
check_num_outputs=adapter.check_num_outputs,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return BlockAdapter(
|
|
107
|
+
pipe=pipe,
|
|
108
|
+
transformer=transformer,
|
|
109
|
+
blocks=blocks,
|
|
110
|
+
blocks_name=blocks_name,
|
|
111
|
+
forward_pattern=adapter.forward_pattern,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
def check_block_adapter(
|
|
116
|
+
adapter: "BlockAdapter",
|
|
117
|
+
) -> bool:
|
|
118
|
+
def _check_warning(attr: str):
|
|
119
|
+
if getattr(adapter, attr, None) is None:
|
|
120
|
+
logger.warning(f"{attr} is None!")
|
|
121
|
+
return False
|
|
122
|
+
return True
|
|
123
|
+
|
|
124
|
+
if not _check_warning("pipe"):
|
|
125
|
+
return False
|
|
126
|
+
|
|
127
|
+
if not _check_warning("transformer"):
|
|
128
|
+
return False
|
|
129
|
+
|
|
130
|
+
if not _check_warning("blocks"):
|
|
131
|
+
return False
|
|
132
|
+
|
|
133
|
+
if not _check_warning("blocks_name"):
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
if not _check_warning("forward_pattern"):
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
if isinstance(adapter.blocks, list):
|
|
140
|
+
for i, blocks in enumerate(adapter.blocks):
|
|
141
|
+
if not isinstance(blocks, torch.nn.ModuleList):
|
|
142
|
+
logger.warning(f"blocks[{i}] is not ModuleList.")
|
|
143
|
+
return False
|
|
144
|
+
else:
|
|
145
|
+
if not isinstance(adapter.blocks, torch.nn.ModuleList):
|
|
146
|
+
logger.warning("blocks is not ModuleList.")
|
|
147
|
+
return False
|
|
148
|
+
|
|
149
|
+
return True
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def find_blocks(
|
|
153
|
+
transformer: torch.nn.Module,
|
|
154
|
+
allow_prefixes: List[str] = [
|
|
155
|
+
"transformer",
|
|
156
|
+
"single_transformer",
|
|
157
|
+
"blocks",
|
|
158
|
+
"layers",
|
|
159
|
+
"single_stream_blocks",
|
|
160
|
+
"double_stream_blocks",
|
|
161
|
+
],
|
|
162
|
+
allow_suffixes: List[str] = [
|
|
163
|
+
"TransformerBlock",
|
|
164
|
+
],
|
|
165
|
+
check_prefixes: bool = True,
|
|
166
|
+
check_suffixes: bool = False,
|
|
167
|
+
**kwargs,
|
|
168
|
+
) -> Tuple[torch.nn.ModuleList, str]:
|
|
169
|
+
# Check prefixes
|
|
170
|
+
if check_prefixes:
|
|
171
|
+
blocks_names = []
|
|
172
|
+
for attr_name in dir(transformer):
|
|
173
|
+
for prefix in allow_prefixes:
|
|
174
|
+
if attr_name.startswith(prefix):
|
|
175
|
+
blocks_names.append(attr_name)
|
|
176
|
+
else:
|
|
177
|
+
blocks_names = dir(transformer)
|
|
178
|
+
|
|
179
|
+
# Check ModuleList
|
|
180
|
+
valid_names = []
|
|
181
|
+
valid_count = []
|
|
182
|
+
forward_pattern = kwargs.pop("forward_pattern", None)
|
|
183
|
+
for blocks_name in blocks_names:
|
|
184
|
+
if blocks := getattr(transformer, blocks_name, None):
|
|
185
|
+
if isinstance(blocks, torch.nn.ModuleList):
|
|
186
|
+
block = blocks[0]
|
|
187
|
+
block_cls_name = block.__class__.__name__
|
|
188
|
+
# Check suffixes
|
|
189
|
+
if isinstance(block, torch.nn.Module) and (
|
|
190
|
+
any(
|
|
191
|
+
(
|
|
192
|
+
block_cls_name.endswith(allow_suffix)
|
|
193
|
+
for allow_suffix in allow_suffixes
|
|
194
|
+
)
|
|
195
|
+
)
|
|
196
|
+
or (not check_suffixes)
|
|
197
|
+
):
|
|
198
|
+
# May check forward pattern
|
|
199
|
+
if forward_pattern is not None:
|
|
200
|
+
if BlockAdapter.match_blocks_pattern(
|
|
201
|
+
blocks,
|
|
202
|
+
forward_pattern,
|
|
203
|
+
logging=False,
|
|
204
|
+
**kwargs,
|
|
205
|
+
):
|
|
206
|
+
valid_names.append(blocks_name)
|
|
207
|
+
valid_count.append(len(blocks))
|
|
208
|
+
else:
|
|
209
|
+
valid_names.append(blocks_name)
|
|
210
|
+
valid_count.append(len(blocks))
|
|
211
|
+
|
|
212
|
+
if not valid_names:
|
|
213
|
+
raise ValueError(
|
|
214
|
+
"Auto selected transformer blocks failed, please set it manually."
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
final_name = valid_names[0]
|
|
218
|
+
final_count = valid_count[0]
|
|
219
|
+
block_policy = kwargs.get("blocks_policy", "max")
|
|
220
|
+
|
|
221
|
+
for blocks_name, count in zip(valid_names, valid_count):
|
|
222
|
+
blocks = getattr(transformer, blocks_name)
|
|
223
|
+
logger.info(
|
|
224
|
+
f"Auto selected transformer blocks: {blocks_name}, "
|
|
225
|
+
f"class: {blocks[0].__class__.__name__}, "
|
|
226
|
+
f"num blocks: {count}"
|
|
227
|
+
)
|
|
228
|
+
if block_policy == "max":
|
|
229
|
+
if final_count < count:
|
|
230
|
+
final_count = count
|
|
231
|
+
final_name = blocks_name
|
|
232
|
+
else:
|
|
233
|
+
if final_count > count:
|
|
234
|
+
final_count = count
|
|
235
|
+
final_name = blocks_name
|
|
236
|
+
|
|
237
|
+
final_blocks = getattr(transformer, final_name)
|
|
238
|
+
|
|
239
|
+
logger.info(
|
|
240
|
+
f"Final selected transformer blocks: {final_name}, "
|
|
241
|
+
f"class: {final_blocks[0].__class__.__name__}, "
|
|
242
|
+
f"num blocks: {final_count}, block_policy: {block_policy}."
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
return final_blocks, final_name
|
|
246
|
+
|
|
247
|
+
@staticmethod
|
|
248
|
+
def match_block_pattern(
|
|
249
|
+
block: torch.nn.Module,
|
|
250
|
+
forward_pattern: ForwardPattern,
|
|
251
|
+
**kwargs,
|
|
252
|
+
) -> bool:
|
|
253
|
+
assert (
|
|
254
|
+
forward_pattern.Supported
|
|
255
|
+
and forward_pattern in ForwardPattern.supported_patterns()
|
|
256
|
+
), f"Pattern {forward_pattern} is not support now!"
|
|
257
|
+
|
|
258
|
+
forward_parameters = set(
|
|
259
|
+
inspect.signature(block.forward).parameters.keys()
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
in_matched = True
|
|
263
|
+
out_matched = True
|
|
264
|
+
|
|
265
|
+
if kwargs.get("check_num_outputs", True):
|
|
266
|
+
num_outputs = str(
|
|
267
|
+
inspect.signature(block.forward).return_annotation
|
|
268
|
+
).count("torch.Tensor")
|
|
269
|
+
|
|
270
|
+
if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
|
|
271
|
+
# output pattern not match
|
|
272
|
+
out_matched = False
|
|
273
|
+
|
|
274
|
+
for required_param in forward_pattern.In:
|
|
275
|
+
if required_param not in forward_parameters:
|
|
276
|
+
in_matched = False
|
|
277
|
+
|
|
278
|
+
return in_matched and out_matched
|
|
279
|
+
|
|
280
|
+
@staticmethod
|
|
281
|
+
def match_blocks_pattern(
|
|
282
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
283
|
+
forward_pattern: ForwardPattern,
|
|
284
|
+
logging: bool = True,
|
|
285
|
+
**kwargs,
|
|
286
|
+
) -> bool:
|
|
287
|
+
assert (
|
|
288
|
+
forward_pattern.Supported
|
|
289
|
+
and forward_pattern in ForwardPattern.supported_patterns()
|
|
290
|
+
), f"Pattern {forward_pattern} is not support now!"
|
|
291
|
+
|
|
292
|
+
assert isinstance(transformer_blocks, torch.nn.ModuleList)
|
|
293
|
+
|
|
294
|
+
pattern_matched_states = []
|
|
295
|
+
for block in transformer_blocks:
|
|
296
|
+
pattern_matched_states.append(
|
|
297
|
+
BlockAdapter.match_block_pattern(
|
|
298
|
+
block,
|
|
299
|
+
forward_pattern,
|
|
300
|
+
**kwargs,
|
|
301
|
+
)
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
pattern_matched = all(pattern_matched_states) # all block match
|
|
305
|
+
if pattern_matched and logging:
|
|
306
|
+
block_cls_names = [
|
|
307
|
+
block.__class__.__name__ for block in transformer_blocks
|
|
308
|
+
]
|
|
309
|
+
block_cls_names = list(set(block_cls_names))
|
|
310
|
+
if len(block_cls_names) == 1:
|
|
311
|
+
block_cls_names = block_cls_names[0]
|
|
312
|
+
logger.info(
|
|
313
|
+
f"Match Block Forward Pattern: {block_cls_names}, {forward_pattern}"
|
|
314
|
+
f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
return pattern_matched
|
|
318
|
+
|
|
319
|
+
@staticmethod
|
|
320
|
+
def normalize(
|
|
321
|
+
adapter: "BlockAdapter",
|
|
322
|
+
) -> "BlockAdapter":
|
|
323
|
+
if not isinstance(adapter.blocks, list):
|
|
324
|
+
adapter.blocks = [adapter.blocks]
|
|
325
|
+
if not isinstance(adapter.blocks_name, list):
|
|
326
|
+
adapter.blocks_name = [adapter.blocks_name]
|
|
327
|
+
if not isinstance(adapter.forward_pattern, list):
|
|
328
|
+
adapter.forward_pattern = [adapter.forward_pattern]
|
|
329
|
+
|
|
330
|
+
assert len(adapter.blocks) == len(adapter.blocks_name)
|
|
331
|
+
assert len(adapter.blocks) == len(adapter.forward_pattern)
|
|
332
|
+
|
|
333
|
+
return adapter
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from typing import Any, Tuple, List, Dict
|
|
2
|
+
|
|
3
|
+
from diffusers import DiffusionPipeline
|
|
4
|
+
from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
|
|
5
|
+
|
|
6
|
+
from cache_dit.logger import init_logger
|
|
7
|
+
|
|
8
|
+
logger = init_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BlockAdapterRegistry:
|
|
12
|
+
_adapters: Dict[str, BlockAdapter] = {}
|
|
13
|
+
_predefined_adapters_has_spearate_cfg: List[str] = {
|
|
14
|
+
"QwenImage",
|
|
15
|
+
"Wan",
|
|
16
|
+
"CogView4",
|
|
17
|
+
"Cosmos",
|
|
18
|
+
"SkyReelsV2",
|
|
19
|
+
"Chroma",
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def register(cls, name):
|
|
24
|
+
def decorator(func):
|
|
25
|
+
cls._adapters[name] = func
|
|
26
|
+
return func
|
|
27
|
+
|
|
28
|
+
return decorator
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def get_adapter(
|
|
32
|
+
cls,
|
|
33
|
+
pipe: DiffusionPipeline | str | Any,
|
|
34
|
+
**kwargs,
|
|
35
|
+
) -> BlockAdapter:
|
|
36
|
+
if not isinstance(pipe, str):
|
|
37
|
+
pipe_cls_name: str = pipe.__class__.__name__
|
|
38
|
+
else:
|
|
39
|
+
pipe_cls_name = pipe
|
|
40
|
+
|
|
41
|
+
for name in cls._adapters:
|
|
42
|
+
if pipe_cls_name.startswith(name):
|
|
43
|
+
return cls._adapters[name](pipe, **kwargs)
|
|
44
|
+
|
|
45
|
+
return BlockAdapter()
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def has_separate_cfg(
|
|
49
|
+
cls,
|
|
50
|
+
pipe: DiffusionPipeline | str | Any,
|
|
51
|
+
) -> bool:
|
|
52
|
+
if cls.get_adapter(
|
|
53
|
+
pipe,
|
|
54
|
+
disable_patch=True,
|
|
55
|
+
).has_separate_cfg:
|
|
56
|
+
return True
|
|
57
|
+
|
|
58
|
+
pipe_cls_name = pipe.__class__.__name__
|
|
59
|
+
for name in cls._predefined_adapters_has_spearate_cfg:
|
|
60
|
+
if pipe_cls_name.startswith(name):
|
|
61
|
+
return True
|
|
62
|
+
|
|
63
|
+
return False
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def is_supported(cls, pipe) -> bool:
|
|
67
|
+
pipe_cls_name: str = pipe.__class__.__name__
|
|
68
|
+
|
|
69
|
+
for name in cls._adapters:
|
|
70
|
+
if pipe_cls_name.startswith(name):
|
|
71
|
+
return True
|
|
72
|
+
return False
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def supported_pipelines(cls, **kwargs) -> Tuple[int, List[str]]:
|
|
76
|
+
val_pipelines = cls._adapters.keys()
|
|
77
|
+
return len(val_pipelines), [p + "*" for p in val_pipelines]
|