cache-dit 0.2.27__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -3,14 +3,16 @@ import torch
3
3
  import unittest
4
4
  import functools
5
5
 
6
- from typing import Dict
7
6
  from contextlib import ExitStack
7
+ from typing import Dict, List, Tuple, Any
8
+
8
9
  from diffusers import DiffusionPipeline
10
+
9
11
  from cache_dit.cache_factory import CacheType
10
- from cache_dit.cache_factory import CachedContext
11
- from cache_dit.cache_factory import ForwardPattern
12
12
  from cache_dit.cache_factory import BlockAdapter
13
+ from cache_dit.cache_factory import ParamsModifier
13
14
  from cache_dit.cache_factory import BlockAdapterRegistry
15
+ from cache_dit.cache_factory import CachedContextManager
14
16
  from cache_dit.cache_factory import CachedBlocks
15
17
 
16
18
  from cache_dit.logger import init_logger
@@ -27,37 +29,39 @@ class CachedAdapter:
27
29
  @classmethod
28
30
  def apply(
29
31
  cls,
30
- pipe: DiffusionPipeline = None,
31
- block_adapter: BlockAdapter = None,
32
- # forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
32
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter,
33
33
  **cache_context_kwargs,
34
- ) -> DiffusionPipeline:
34
+ ) -> BlockAdapter:
35
35
  assert (
36
- pipe is not None or block_adapter is not None
36
+ pipe_or_adapter is not None
37
37
  ), "pipe or block_adapter can not both None!"
38
38
 
39
- if pipe is not None:
40
- if BlockAdapterRegistry.is_supported(pipe):
39
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
40
+ if BlockAdapterRegistry.is_supported(pipe_or_adapter):
41
41
  logger.info(
42
- f"{pipe.__class__.__name__} is officially supported by cache-dit. "
43
- "Use it's pre-defined BlockAdapter directly!"
42
+ f"{pipe_or_adapter.__class__.__name__} is officially "
43
+ "supported by cache-dit. Use it's pre-defined BlockAdapter "
44
+ "directly!"
45
+ )
46
+ block_adapter = BlockAdapterRegistry.get_adapter(
47
+ pipe_or_adapter
44
48
  )
45
- block_adapter = BlockAdapterRegistry.get_adapter(pipe)
46
49
  return cls.cachify(
47
50
  block_adapter,
48
51
  **cache_context_kwargs,
49
52
  )
50
53
  else:
51
54
  raise ValueError(
52
- f"{pipe.__class__.__name__} is not officially supported "
55
+ f"{pipe_or_adapter.__class__.__name__} is not officially supported "
53
56
  "by cache-dit, please set BlockAdapter instead!"
54
57
  )
55
58
  else:
59
+ assert isinstance(pipe_or_adapter, BlockAdapter)
56
60
  logger.info(
57
- "Adapting cache acceleration using custom BlockAdapter!"
61
+ "Adapting Cache Acceleration using custom BlockAdapter!"
58
62
  )
59
63
  return cls.cachify(
60
- block_adapter,
64
+ pipe_or_adapter,
61
65
  **cache_context_kwargs,
62
66
  )
63
67
 
@@ -66,7 +70,7 @@ class CachedAdapter:
66
70
  cls,
67
71
  block_adapter: BlockAdapter,
68
72
  **cache_context_kwargs,
69
- ) -> DiffusionPipeline:
73
+ ) -> BlockAdapter:
70
74
 
71
75
  if block_adapter.auto:
72
76
  block_adapter = BlockAdapter.auto_block_adapter(
@@ -74,56 +78,80 @@ class CachedAdapter:
74
78
  )
75
79
 
76
80
  if BlockAdapter.check_block_adapter(block_adapter):
81
+
82
+ # 0. Must normalize block_adapter before apply cache
77
83
  block_adapter = BlockAdapter.normalize(block_adapter)
78
- # 0. Apply cache on pipeline: wrap cache context
84
+ if BlockAdapter.is_cached(block_adapter):
85
+ return block_adapter.pipe
86
+
87
+ # 1. Apply cache on pipeline: wrap cache context, must
88
+ # call create_context before mock_blocks.
79
89
  cls.create_context(
80
90
  block_adapter,
81
91
  **cache_context_kwargs,
82
92
  )
83
- # 1. Apply cache on transformer: mock cached transformer blocks
93
+
94
+ # 2. Apply cache on transformer: mock cached blocks
84
95
  cls.mock_blocks(
85
96
  block_adapter,
86
97
  )
87
- cls.patch_params(
88
- block_adapter,
89
- **cache_context_kwargs,
90
- )
91
- return block_adapter.pipe
98
+
99
+ return block_adapter
92
100
 
93
101
  @classmethod
94
102
  def patch_params(
95
103
  cls,
96
104
  block_adapter: BlockAdapter,
97
- **cache_context_kwargs,
105
+ contexts_kwargs: List[Dict],
98
106
  ):
99
- block_adapter.transformer._forward_pattern = (
100
- block_adapter.forward_pattern
101
- )
102
- block_adapter.transformer._has_separate_cfg = (
103
- block_adapter.has_separate_cfg
104
- )
105
- block_adapter.transformer._cache_context_kwargs = cache_context_kwargs
106
- block_adapter.pipe.__class__._cache_context_kwargs = (
107
- cache_context_kwargs
108
- )
109
- for blocks, forward_pattern in zip(
110
- block_adapter.blocks, block_adapter.forward_pattern
111
- ):
112
- blocks._forward_pattern = forward_pattern
113
- blocks._cache_context_kwargs = cache_context_kwargs
107
+ block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
108
+
109
+ params_shift = 0
110
+ for i in range(len(block_adapter.transformer)):
111
+
112
+ block_adapter.transformer[i]._forward_pattern = (
113
+ block_adapter.forward_pattern
114
+ )
115
+ block_adapter.transformer[i]._has_separate_cfg = (
116
+ block_adapter.has_separate_cfg
117
+ )
118
+ block_adapter.transformer[i]._cache_context_kwargs = (
119
+ contexts_kwargs[params_shift]
120
+ )
121
+
122
+ blocks = block_adapter.blocks[i]
123
+ for j in range(len(blocks)):
124
+ blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
125
+ blocks[j]._cache_context_kwargs = contexts_kwargs[
126
+ params_shift + j
127
+ ]
128
+
129
+ params_shift += len(blocks)
114
130
 
115
131
  @classmethod
116
- def check_context_kwargs(cls, pipe, **cache_context_kwargs):
132
+ def check_context_kwargs(
133
+ cls,
134
+ block_adapter: BlockAdapter,
135
+ **cache_context_kwargs,
136
+ ):
117
137
  # Check cache_context_kwargs
118
- if not cache_context_kwargs["do_separate_cfg"]:
138
+ if not cache_context_kwargs["enable_spearate_cfg"]:
119
139
  # Check cfg for some specific case if users don't set it as True
120
- cache_context_kwargs["do_separate_cfg"] = (
121
- BlockAdapterRegistry.has_separate_cfg(pipe)
122
- )
123
- logger.info(
124
- f"Use default 'do_separate_cfg': {cache_context_kwargs['do_separate_cfg']}, "
125
- f"Pipeline: {pipe.__class__.__name__}."
126
- )
140
+ if BlockAdapterRegistry.has_separate_cfg(block_adapter):
141
+ cache_context_kwargs["enable_spearate_cfg"] = True
142
+ logger.info(
143
+ f"Use custom 'enable_spearate_cfg' from BlockAdapter: True. "
144
+ f"Pipeline: {block_adapter.pipe.__class__.__name__}."
145
+ )
146
+ else:
147
+ cache_context_kwargs["enable_spearate_cfg"] = (
148
+ BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
149
+ )
150
+ logger.info(
151
+ f"Use default 'enable_spearate_cfg' from block adapter "
152
+ f"register: {cache_context_kwargs['enable_spearate_cfg']}, "
153
+ f"Pipeline: {block_adapter.pipe.__class__.__name__}."
154
+ )
127
155
 
128
156
  if cache_type := cache_context_kwargs.pop("cache_type", None):
129
157
  assert (
@@ -138,31 +166,46 @@ class CachedAdapter:
138
166
  block_adapter: BlockAdapter,
139
167
  **cache_context_kwargs,
140
168
  ) -> DiffusionPipeline:
141
- if getattr(block_adapter.pipe, "_is_cached", False):
169
+
170
+ BlockAdapter.assert_normalized(block_adapter)
171
+
172
+ if BlockAdapter.is_cached(block_adapter.pipe):
142
173
  return block_adapter.pipe
143
174
 
144
175
  # Check cache_context_kwargs
145
176
  cache_context_kwargs = cls.check_context_kwargs(
146
- block_adapter.pipe,
147
- **cache_context_kwargs,
177
+ block_adapter, **cache_context_kwargs
148
178
  )
149
179
  # Apply cache on pipeline: wrap cache context
150
- cache_kwargs, _ = CachedContext.collect_cache_kwargs(
151
- default_attrs={},
152
- **cache_context_kwargs,
180
+ pipe_cls_name = block_adapter.pipe.__class__.__name__
181
+
182
+ # Each Pipeline should have it's own context manager instance.
183
+ # Different transformers (Wan2.2, etc) should shared the same
184
+ # cache manager but with different cache context (according
185
+ # to their unique instance id).
186
+ cache_manager = CachedContextManager(
187
+ name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
153
188
  )
189
+ block_adapter.pipe._cache_manager = cache_manager # instance level
190
+
191
+ flatten_contexts, contexts_kwargs = cls.modify_context_params(
192
+ block_adapter, cache_manager, **cache_context_kwargs
193
+ )
194
+
154
195
  original_call = block_adapter.pipe.__class__.__call__
155
196
 
156
197
  @functools.wraps(original_call)
157
198
  def new_call(self, *args, **kwargs):
158
199
  with ExitStack() as stack:
159
- # cache context will reset for each pipe inference
160
- for blocks_name in block_adapter.blocks_name:
200
+ # cache context will be reset for each pipe inference
201
+ for context_name, context_kwargs in zip(
202
+ flatten_contexts, contexts_kwargs
203
+ ):
161
204
  stack.enter_context(
162
- CachedContext.cache_context(
163
- CachedContext.reset_cache_context(
164
- blocks_name,
165
- **cache_kwargs,
205
+ cache_manager.enter_context(
206
+ cache_manager.reset_context(
207
+ context_name,
208
+ **context_kwargs,
166
209
  ),
167
210
  )
168
211
  )
@@ -171,109 +214,194 @@ class CachedAdapter:
171
214
  return outputs
172
215
 
173
216
  block_adapter.pipe.__class__.__call__ = new_call
217
+ block_adapter.pipe.__class__._original_call = original_call
174
218
  block_adapter.pipe.__class__._is_cached = True
219
+
220
+ cls.patch_params(block_adapter, contexts_kwargs)
221
+
175
222
  return block_adapter.pipe
176
223
 
177
224
  @classmethod
178
- def patch_stats(cls, block_adapter: BlockAdapter):
225
+ def modify_context_params(
226
+ cls,
227
+ block_adapter: BlockAdapter,
228
+ cache_manager: CachedContextManager,
229
+ **cache_context_kwargs,
230
+ ) -> Tuple[List[str], List[Dict[str, Any]]]:
231
+
232
+ flatten_contexts = BlockAdapter.flatten(
233
+ block_adapter.unique_blocks_name
234
+ )
235
+ contexts_kwargs = [
236
+ cache_context_kwargs.copy()
237
+ for _ in range(
238
+ len(flatten_contexts),
239
+ )
240
+ ]
241
+
242
+ for i in range(len(contexts_kwargs)):
243
+ contexts_kwargs[i]["name"] = flatten_contexts[i]
244
+
245
+ if block_adapter.params_modifiers is None:
246
+ return flatten_contexts, contexts_kwargs
247
+
248
+ flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
249
+ block_adapter.params_modifiers,
250
+ )
251
+
252
+ for i in range(
253
+ min(len(contexts_kwargs), len(flatten_modifiers)),
254
+ ):
255
+ contexts_kwargs[i].update(
256
+ flatten_modifiers[i]._context_kwargs,
257
+ )
258
+ contexts_kwargs[i], _ = cache_manager.collect_cache_kwargs(
259
+ default_attrs={}, **contexts_kwargs[i]
260
+ )
261
+
262
+ return flatten_contexts, contexts_kwargs
263
+
264
+ @classmethod
265
+ def patch_stats(
266
+ cls,
267
+ block_adapter: BlockAdapter,
268
+ ):
179
269
  from cache_dit.cache_factory.cache_blocks.utils import (
180
270
  patch_cached_stats,
181
271
  )
182
272
 
183
- patch_cached_stats(block_adapter.transformer)
184
- for blocks, blocks_name in zip(
185
- block_adapter.blocks, block_adapter.blocks_name
186
- ):
187
- patch_cached_stats(blocks, blocks_name)
273
+ cache_manager = block_adapter.pipe._cache_manager
274
+
275
+ for i in range(len(block_adapter.transformer)):
276
+ patch_cached_stats(
277
+ block_adapter.transformer[i],
278
+ cache_context=block_adapter.unique_blocks_name[i][-1],
279
+ cache_manager=cache_manager,
280
+ )
281
+ for blocks, unique_name in zip(
282
+ block_adapter.blocks[i],
283
+ block_adapter.unique_blocks_name[i],
284
+ ):
285
+ patch_cached_stats(
286
+ blocks,
287
+ cache_context=unique_name,
288
+ cache_manager=cache_manager,
289
+ )
188
290
 
189
291
  @classmethod
190
292
  def mock_blocks(
191
293
  cls,
192
294
  block_adapter: BlockAdapter,
193
- ) -> torch.nn.Module:
295
+ ) -> List[torch.nn.Module]:
296
+
297
+ BlockAdapter.assert_normalized(block_adapter)
194
298
 
195
- if getattr(block_adapter.transformer, "_is_cached", False):
299
+ if BlockAdapter.is_cached(block_adapter.transformer):
196
300
  return block_adapter.transformer
197
301
 
198
- # Check block forward pattern matching
199
- block_adapter = BlockAdapter.normalize(block_adapter)
200
- for forward_pattern, blocks in zip(
201
- block_adapter.forward_pattern, block_adapter.blocks
302
+ # Apply cache on transformer: mock cached transformer blocks
303
+ for (
304
+ cached_blocks,
305
+ transformer,
306
+ blocks_name,
307
+ unique_blocks_name,
308
+ dummy_blocks_names,
309
+ ) in zip(
310
+ cls.collect_cached_blocks(block_adapter),
311
+ block_adapter.transformer,
312
+ block_adapter.blocks_name,
313
+ block_adapter.unique_blocks_name,
314
+ block_adapter.dummy_blocks_names,
202
315
  ):
203
- assert BlockAdapter.match_blocks_pattern(
204
- blocks,
205
- forward_pattern=forward_pattern,
206
- check_num_outputs=block_adapter.check_num_outputs,
207
- ), (
208
- "No block forward pattern matched, "
209
- f"supported lists: {ForwardPattern.supported_patterns()}"
316
+ cls.mock_transformer(
317
+ cached_blocks,
318
+ transformer,
319
+ blocks_name,
320
+ unique_blocks_name,
321
+ dummy_blocks_names,
210
322
  )
211
323
 
212
- # Apply cache on transformer: mock cached transformer blocks
213
- # TODO: Use blocks_name to spearate cached context for different
214
- # blocks list. For example, single_transformer_blocks and
215
- # transformer_blocks should have different cached context and
216
- # forward pattern.
217
- cached_blocks = cls.collect_cached_blocks(
218
- block_adapter=block_adapter,
219
- )
324
+ return block_adapter.transformer
325
+
326
+ @classmethod
327
+ def mock_transformer(
328
+ cls,
329
+ cached_blocks: Dict[str, torch.nn.ModuleList],
330
+ transformer: torch.nn.Module,
331
+ blocks_name: List[str],
332
+ unique_blocks_name: List[str],
333
+ dummy_blocks_names: List[str],
334
+ ) -> torch.nn.Module:
220
335
  dummy_blocks = torch.nn.ModuleList()
221
336
 
222
- original_forward = block_adapter.transformer.forward
337
+ original_forward = transformer.forward
223
338
 
224
- assert isinstance(block_adapter.dummy_blocks_names, list)
339
+ assert isinstance(dummy_blocks_names, list)
225
340
 
226
341
  @functools.wraps(original_forward)
227
342
  def new_forward(self, *args, **kwargs):
228
343
  with ExitStack() as stack:
229
- for blocks_name in block_adapter.blocks_name:
344
+ for name, context_name in zip(
345
+ blocks_name,
346
+ unique_blocks_name,
347
+ ):
230
348
  stack.enter_context(
231
349
  unittest.mock.patch.object(
232
- self,
233
- blocks_name,
234
- cached_blocks[blocks_name],
350
+ self, name, cached_blocks[context_name]
235
351
  )
236
352
  )
237
- for dummy_name in block_adapter.dummy_blocks_names:
353
+ for dummy_name in dummy_blocks_names:
238
354
  stack.enter_context(
239
355
  unittest.mock.patch.object(
240
- self,
241
- dummy_name,
242
- dummy_blocks,
356
+ self, dummy_name, dummy_blocks
243
357
  )
244
358
  )
245
359
  return original_forward(*args, **kwargs)
246
360
 
247
- block_adapter.transformer.forward = new_forward.__get__(
248
- block_adapter.transformer
249
- )
250
- block_adapter.transformer._is_cached = True
361
+ transformer.forward = new_forward.__get__(transformer)
362
+ transformer._original_forward = original_forward
363
+ transformer._is_cached = True
251
364
 
252
- return block_adapter.transformer
365
+ return transformer
253
366
 
254
367
  @classmethod
255
368
  def collect_cached_blocks(
256
369
  cls,
257
370
  block_adapter: BlockAdapter,
258
- ) -> Dict[str, torch.nn.ModuleList]:
259
- block_adapter = BlockAdapter.normalize(block_adapter)
371
+ ) -> List[Dict[str, torch.nn.ModuleList]]:
260
372
 
261
- cached_blocks_bind_context = {}
373
+ BlockAdapter.assert_normalized(block_adapter)
262
374
 
263
- for i in range(len(block_adapter.blocks)):
264
- cached_blocks_bind_context[block_adapter.blocks_name[i]] = (
265
- torch.nn.ModuleList(
375
+ total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
376
+ assert hasattr(block_adapter.pipe, "_cache_manager")
377
+ assert isinstance(
378
+ block_adapter.pipe._cache_manager, CachedContextManager
379
+ )
380
+
381
+ for i in range(len(block_adapter.transformer)):
382
+
383
+ cached_blocks_bind_context = {}
384
+ for j in range(len(block_adapter.blocks[i])):
385
+ cached_blocks_bind_context[
386
+ block_adapter.unique_blocks_name[i][j]
387
+ ] = torch.nn.ModuleList(
266
388
  [
267
389
  CachedBlocks(
268
- block_adapter.blocks[i],
269
- block_adapter.blocks_name[i],
270
- block_adapter.blocks_name[i], # context name
271
- transformer=block_adapter.transformer,
272
- forward_pattern=block_adapter.forward_pattern[i],
390
+ # 0. Transformer blocks configuration
391
+ block_adapter.blocks[i][j],
392
+ transformer=block_adapter.transformer[i],
393
+ forward_pattern=block_adapter.forward_pattern[i][j],
273
394
  check_num_outputs=block_adapter.check_num_outputs,
395
+ # 1. Cache context configuration
396
+ cache_prefix=block_adapter.blocks_name[i][j],
397
+ cache_context=block_adapter.unique_blocks_name[i][
398
+ j
399
+ ],
400
+ cache_manager=block_adapter.pipe._cache_manager,
274
401
  )
275
402
  ]
276
403
  )
277
- )
278
404
 
279
- return cached_blocks_bind_context
405
+ total_cached_blocks.append(cached_blocks_bind_context)
406
+
407
+ return total_cached_blocks
@@ -1,3 +1,11 @@
1
+ import torch
2
+
3
+ from cache_dit.cache_factory import ForwardPattern
4
+ from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
5
+ from cache_dit.cache_factory.cache_contexts.cache_manager import (
6
+ CachedContextManager,
7
+ )
8
+
1
9
  from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
2
10
  CachedBlocks_Pattern_0_1_2,
3
11
  )
@@ -5,14 +13,57 @@ from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
5
13
  CachedBlocks_Pattern_3_4_5,
6
14
  )
7
15
 
16
+ from cache_dit.logger import init_logger
17
+
18
+ logger = init_logger(__name__)
19
+
8
20
 
9
21
  class CachedBlocks:
10
- def __new__(cls, *args, **kwargs):
11
- forward_pattern = kwargs.get("forward_pattern", None)
22
+ def __new__(
23
+ cls,
24
+ # 0. Transformer blocks configuration
25
+ transformer_blocks: torch.nn.ModuleList,
26
+ transformer: torch.nn.Module = None,
27
+ forward_pattern: ForwardPattern = None,
28
+ check_num_outputs: bool = True,
29
+ # 1. Cache context configuration
30
+ # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
31
+ # 'layers', 'single_stream_blocks', 'double_stream_blocks'
32
+ cache_prefix: str = None, # cache_prefix maybe un-need.
33
+ # Usually, blocks_name, etc.
34
+ cache_context: CachedContext | str = None,
35
+ cache_manager: CachedContextManager = None,
36
+ **kwargs,
37
+ ):
38
+ assert transformer is not None, "transformer can't be None."
12
39
  assert forward_pattern is not None, "forward_pattern can't be None."
40
+ assert cache_context is not None, "cache_context can't be None."
41
+ assert cache_manager is not None, "cache_manager can't be None."
13
42
  if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
14
- return CachedBlocks_Pattern_0_1_2(*args, **kwargs)
43
+ return CachedBlocks_Pattern_0_1_2(
44
+ # 0. Transformer blocks configuration
45
+ transformer_blocks,
46
+ transformer=transformer,
47
+ forward_pattern=forward_pattern,
48
+ check_num_outputs=check_num_outputs,
49
+ # 1. Cache context configuration
50
+ cache_prefix=cache_prefix,
51
+ cache_context=cache_context,
52
+ cache_manager=cache_manager,
53
+ **kwargs,
54
+ )
15
55
  elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
16
- return CachedBlocks_Pattern_3_4_5(*args, **kwargs)
56
+ return CachedBlocks_Pattern_3_4_5(
57
+ # 0. Transformer blocks configuration
58
+ transformer_blocks,
59
+ transformer=transformer,
60
+ forward_pattern=forward_pattern,
61
+ check_num_outputs=check_num_outputs,
62
+ # 1. Cache context configuration
63
+ cache_prefix=cache_prefix,
64
+ cache_context=cache_context,
65
+ cache_manager=cache_manager,
66
+ **kwargs,
67
+ )
17
68
  else:
18
69
  raise ValueError(f"Pattern {forward_pattern} is not supported now!")