cache-dit 0.2.26__py3-none-any.whl → 0.2.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (28) hide show
  1. cache_dit/__init__.py +7 -6
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +15 -4
  4. cache_dit/cache_factory/block_adapters/__init__.py +538 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +120 -911
  8. cache_dit/cache_factory/cache_blocks/__init__.py +7 -9
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +46 -41
  11. cache_dit/cache_factory/cache_blocks/pattern_base.py +98 -79
  12. cache_dit/cache_factory/cache_blocks/utils.py +13 -9
  13. cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
  14. cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +89 -55
  15. cache_dit/cache_factory/cache_contexts/cache_manager.py +0 -0
  16. cache_dit/cache_factory/cache_interface.py +21 -18
  17. cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
  18. cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
  19. cache_dit/quantize/quantize_ao.py +1 -0
  20. cache_dit/utils.py +19 -16
  21. {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/METADATA +42 -12
  22. cache_dit-0.2.27.dist-info/RECORD +47 -0
  23. cache_dit-0.2.26.dist-info/RECORD +0 -42
  24. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  25. {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
  26. {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
  27. {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
  28. {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,333 @@
1
+ import torch
2
+
3
+ import inspect
4
+ import dataclasses
5
+
6
+ from typing import Any, Tuple, List, Optional
7
+
8
+ from diffusers import DiffusionPipeline
9
+ from cache_dit.cache_factory.forward_pattern import ForwardPattern
10
+ from cache_dit.cache_factory.patch_functors import PatchFunctor
11
+
12
+ from cache_dit.logger import init_logger
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ @dataclasses.dataclass
18
+ class BlockAdapter:
19
+ # Transformer configurations.
20
+ pipe: DiffusionPipeline | Any = None
21
+ transformer: torch.nn.Module = None
22
+
23
+ # ------------ Block Level Flags ------------
24
+ blocks: torch.nn.ModuleList | List[torch.nn.ModuleList] = None
25
+ # transformer_blocks, blocks, etc.
26
+ blocks_name: str | List[str] = None
27
+ dummy_blocks_names: List[str] = dataclasses.field(default_factory=list)
28
+ forward_pattern: ForwardPattern | List[ForwardPattern] = None
29
+ check_num_outputs: bool = True
30
+
31
+ # Flags to control auto block adapter
32
+ auto: bool = False
33
+ allow_prefixes: List[str] = dataclasses.field(
34
+ default_factory=lambda: [
35
+ "transformer",
36
+ "single_transformer",
37
+ "blocks",
38
+ "layers",
39
+ "single_stream_blocks",
40
+ "double_stream_blocks",
41
+ ]
42
+ )
43
+ check_prefixes: bool = True
44
+ allow_suffixes: List[str] = dataclasses.field(
45
+ default_factory=lambda: ["TransformerBlock"]
46
+ )
47
+ check_suffixes: bool = False
48
+ blocks_policy: str = dataclasses.field(
49
+ default="max", metadata={"allowed_values": ["max", "min"]}
50
+ )
51
+
52
+ # NOTE: Other flags.
53
+ disable_patch: bool = False
54
+
55
+ # ------------ Pipeline Level Flags ------------
56
+ # Patch Functor: Flux, etc.
57
+ patch_functor: Optional[PatchFunctor] = None
58
+ # Flags for separate cfg
59
+ has_separate_cfg: bool = False
60
+
61
+ def __post_init__(self):
62
+ assert any((self.pipe is not None, self.transformer is not None))
63
+ self.patchify()
64
+
65
+ def patchify(self, *args, **kwargs):
66
+ # Process some specificial cases, specific for transformers
67
+ # that has different forward patterns between single_transformer_blocks
68
+ # and transformer_blocks , such as Flux (diffusers < 0.35.0).
69
+ if self.patch_functor is not None and not self.disable_patch:
70
+ if self.transformer is not None:
71
+ self.patch_functor.apply(self.transformer, *args, **kwargs)
72
+ else:
73
+ assert hasattr(self.pipe, "transformer")
74
+ self.patch_functor.apply(self.pipe.transformer, *args, **kwargs)
75
+
76
+ @staticmethod
77
+ def auto_block_adapter(
78
+ adapter: "BlockAdapter",
79
+ ) -> "BlockAdapter":
80
+ assert adapter.auto, (
81
+ "Please manually set `auto` to True, or, manually "
82
+ "set all the transformer blocks configuration."
83
+ )
84
+ assert adapter.pipe is not None, "adapter.pipe can not be None."
85
+ assert (
86
+ adapter.forward_pattern is not None
87
+ ), "adapter.forward_pattern can not be None."
88
+ pipe = adapter.pipe
89
+
90
+ assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
91
+
92
+ transformer = pipe.transformer
93
+
94
+ # "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
95
+ blocks, blocks_name = BlockAdapter.find_blocks(
96
+ transformer=transformer,
97
+ allow_prefixes=adapter.allow_prefixes,
98
+ allow_suffixes=adapter.allow_suffixes,
99
+ check_prefixes=adapter.check_prefixes,
100
+ check_suffixes=adapter.check_suffixes,
101
+ blocks_policy=adapter.blocks_policy,
102
+ forward_pattern=adapter.forward_pattern,
103
+ check_num_outputs=adapter.check_num_outputs,
104
+ )
105
+
106
+ return BlockAdapter(
107
+ pipe=pipe,
108
+ transformer=transformer,
109
+ blocks=blocks,
110
+ blocks_name=blocks_name,
111
+ forward_pattern=adapter.forward_pattern,
112
+ )
113
+
114
+ @staticmethod
115
+ def check_block_adapter(
116
+ adapter: "BlockAdapter",
117
+ ) -> bool:
118
+ def _check_warning(attr: str):
119
+ if getattr(adapter, attr, None) is None:
120
+ logger.warning(f"{attr} is None!")
121
+ return False
122
+ return True
123
+
124
+ if not _check_warning("pipe"):
125
+ return False
126
+
127
+ if not _check_warning("transformer"):
128
+ return False
129
+
130
+ if not _check_warning("blocks"):
131
+ return False
132
+
133
+ if not _check_warning("blocks_name"):
134
+ return False
135
+
136
+ if not _check_warning("forward_pattern"):
137
+ return False
138
+
139
+ if isinstance(adapter.blocks, list):
140
+ for i, blocks in enumerate(adapter.blocks):
141
+ if not isinstance(blocks, torch.nn.ModuleList):
142
+ logger.warning(f"blocks[{i}] is not ModuleList.")
143
+ return False
144
+ else:
145
+ if not isinstance(adapter.blocks, torch.nn.ModuleList):
146
+ logger.warning("blocks is not ModuleList.")
147
+ return False
148
+
149
+ return True
150
+
151
+ @staticmethod
152
+ def find_blocks(
153
+ transformer: torch.nn.Module,
154
+ allow_prefixes: List[str] = [
155
+ "transformer",
156
+ "single_transformer",
157
+ "blocks",
158
+ "layers",
159
+ "single_stream_blocks",
160
+ "double_stream_blocks",
161
+ ],
162
+ allow_suffixes: List[str] = [
163
+ "TransformerBlock",
164
+ ],
165
+ check_prefixes: bool = True,
166
+ check_suffixes: bool = False,
167
+ **kwargs,
168
+ ) -> Tuple[torch.nn.ModuleList, str]:
169
+ # Check prefixes
170
+ if check_prefixes:
171
+ blocks_names = []
172
+ for attr_name in dir(transformer):
173
+ for prefix in allow_prefixes:
174
+ if attr_name.startswith(prefix):
175
+ blocks_names.append(attr_name)
176
+ else:
177
+ blocks_names = dir(transformer)
178
+
179
+ # Check ModuleList
180
+ valid_names = []
181
+ valid_count = []
182
+ forward_pattern = kwargs.pop("forward_pattern", None)
183
+ for blocks_name in blocks_names:
184
+ if blocks := getattr(transformer, blocks_name, None):
185
+ if isinstance(blocks, torch.nn.ModuleList):
186
+ block = blocks[0]
187
+ block_cls_name = block.__class__.__name__
188
+ # Check suffixes
189
+ if isinstance(block, torch.nn.Module) and (
190
+ any(
191
+ (
192
+ block_cls_name.endswith(allow_suffix)
193
+ for allow_suffix in allow_suffixes
194
+ )
195
+ )
196
+ or (not check_suffixes)
197
+ ):
198
+ # May check forward pattern
199
+ if forward_pattern is not None:
200
+ if BlockAdapter.match_blocks_pattern(
201
+ blocks,
202
+ forward_pattern,
203
+ logging=False,
204
+ **kwargs,
205
+ ):
206
+ valid_names.append(blocks_name)
207
+ valid_count.append(len(blocks))
208
+ else:
209
+ valid_names.append(blocks_name)
210
+ valid_count.append(len(blocks))
211
+
212
+ if not valid_names:
213
+ raise ValueError(
214
+ "Auto selected transformer blocks failed, please set it manually."
215
+ )
216
+
217
+ final_name = valid_names[0]
218
+ final_count = valid_count[0]
219
+ block_policy = kwargs.get("blocks_policy", "max")
220
+
221
+ for blocks_name, count in zip(valid_names, valid_count):
222
+ blocks = getattr(transformer, blocks_name)
223
+ logger.info(
224
+ f"Auto selected transformer blocks: {blocks_name}, "
225
+ f"class: {blocks[0].__class__.__name__}, "
226
+ f"num blocks: {count}"
227
+ )
228
+ if block_policy == "max":
229
+ if final_count < count:
230
+ final_count = count
231
+ final_name = blocks_name
232
+ else:
233
+ if final_count > count:
234
+ final_count = count
235
+ final_name = blocks_name
236
+
237
+ final_blocks = getattr(transformer, final_name)
238
+
239
+ logger.info(
240
+ f"Final selected transformer blocks: {final_name}, "
241
+ f"class: {final_blocks[0].__class__.__name__}, "
242
+ f"num blocks: {final_count}, block_policy: {block_policy}."
243
+ )
244
+
245
+ return final_blocks, final_name
246
+
247
+ @staticmethod
248
+ def match_block_pattern(
249
+ block: torch.nn.Module,
250
+ forward_pattern: ForwardPattern,
251
+ **kwargs,
252
+ ) -> bool:
253
+ assert (
254
+ forward_pattern.Supported
255
+ and forward_pattern in ForwardPattern.supported_patterns()
256
+ ), f"Pattern {forward_pattern} is not support now!"
257
+
258
+ forward_parameters = set(
259
+ inspect.signature(block.forward).parameters.keys()
260
+ )
261
+
262
+ in_matched = True
263
+ out_matched = True
264
+
265
+ if kwargs.get("check_num_outputs", True):
266
+ num_outputs = str(
267
+ inspect.signature(block.forward).return_annotation
268
+ ).count("torch.Tensor")
269
+
270
+ if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
271
+ # output pattern not match
272
+ out_matched = False
273
+
274
+ for required_param in forward_pattern.In:
275
+ if required_param not in forward_parameters:
276
+ in_matched = False
277
+
278
+ return in_matched and out_matched
279
+
280
+ @staticmethod
281
+ def match_blocks_pattern(
282
+ transformer_blocks: torch.nn.ModuleList,
283
+ forward_pattern: ForwardPattern,
284
+ logging: bool = True,
285
+ **kwargs,
286
+ ) -> bool:
287
+ assert (
288
+ forward_pattern.Supported
289
+ and forward_pattern in ForwardPattern.supported_patterns()
290
+ ), f"Pattern {forward_pattern} is not support now!"
291
+
292
+ assert isinstance(transformer_blocks, torch.nn.ModuleList)
293
+
294
+ pattern_matched_states = []
295
+ for block in transformer_blocks:
296
+ pattern_matched_states.append(
297
+ BlockAdapter.match_block_pattern(
298
+ block,
299
+ forward_pattern,
300
+ **kwargs,
301
+ )
302
+ )
303
+
304
+ pattern_matched = all(pattern_matched_states) # all block match
305
+ if pattern_matched and logging:
306
+ block_cls_names = [
307
+ block.__class__.__name__ for block in transformer_blocks
308
+ ]
309
+ block_cls_names = list(set(block_cls_names))
310
+ if len(block_cls_names) == 1:
311
+ block_cls_names = block_cls_names[0]
312
+ logger.info(
313
+ f"Match Block Forward Pattern: {block_cls_names}, {forward_pattern}"
314
+ f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
315
+ )
316
+
317
+ return pattern_matched
318
+
319
+ @staticmethod
320
+ def normalize(
321
+ adapter: "BlockAdapter",
322
+ ) -> "BlockAdapter":
323
+ if not isinstance(adapter.blocks, list):
324
+ adapter.blocks = [adapter.blocks]
325
+ if not isinstance(adapter.blocks_name, list):
326
+ adapter.blocks_name = [adapter.blocks_name]
327
+ if not isinstance(adapter.forward_pattern, list):
328
+ adapter.forward_pattern = [adapter.forward_pattern]
329
+
330
+ assert len(adapter.blocks) == len(adapter.blocks_name)
331
+ assert len(adapter.blocks) == len(adapter.forward_pattern)
332
+
333
+ return adapter
@@ -0,0 +1,77 @@
1
+ from typing import Any, Tuple, List, Dict
2
+
3
+ from diffusers import DiffusionPipeline
4
+ from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
5
+
6
+ from cache_dit.logger import init_logger
7
+
8
+ logger = init_logger(__name__)
9
+
10
+
11
+ class BlockAdapterRegistry:
12
+ _adapters: Dict[str, BlockAdapter] = {}
13
+ _predefined_adapters_has_spearate_cfg: List[str] = {
14
+ "QwenImage",
15
+ "Wan",
16
+ "CogView4",
17
+ "Cosmos",
18
+ "SkyReelsV2",
19
+ "Chroma",
20
+ }
21
+
22
+ @classmethod
23
+ def register(cls, name):
24
+ def decorator(func):
25
+ cls._adapters[name] = func
26
+ return func
27
+
28
+ return decorator
29
+
30
+ @classmethod
31
+ def get_adapter(
32
+ cls,
33
+ pipe: DiffusionPipeline | str | Any,
34
+ **kwargs,
35
+ ) -> BlockAdapter:
36
+ if not isinstance(pipe, str):
37
+ pipe_cls_name: str = pipe.__class__.__name__
38
+ else:
39
+ pipe_cls_name = pipe
40
+
41
+ for name in cls._adapters:
42
+ if pipe_cls_name.startswith(name):
43
+ return cls._adapters[name](pipe, **kwargs)
44
+
45
+ return BlockAdapter()
46
+
47
+ @classmethod
48
+ def has_separate_cfg(
49
+ cls,
50
+ pipe: DiffusionPipeline | str | Any,
51
+ ) -> bool:
52
+ if cls.get_adapter(
53
+ pipe,
54
+ disable_patch=True,
55
+ ).has_separate_cfg:
56
+ return True
57
+
58
+ pipe_cls_name = pipe.__class__.__name__
59
+ for name in cls._predefined_adapters_has_spearate_cfg:
60
+ if pipe_cls_name.startswith(name):
61
+ return True
62
+
63
+ return False
64
+
65
+ @classmethod
66
+ def is_supported(cls, pipe) -> bool:
67
+ pipe_cls_name: str = pipe.__class__.__name__
68
+
69
+ for name in cls._adapters:
70
+ if pipe_cls_name.startswith(name):
71
+ return True
72
+ return False
73
+
74
+ @classmethod
75
+ def supported_pipelines(cls, **kwargs) -> Tuple[int, List[str]]:
76
+ val_pipelines = cls._adapters.keys()
77
+ return len(val_pipelines), [p + "*" for p in val_pipelines]