cache-dit 0.2.27__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cache_dit/__init__.py CHANGED
@@ -10,6 +10,7 @@ from cache_dit.cache_factory import cache_type
10
10
  from cache_dit.cache_factory import block_range
11
11
  from cache_dit.cache_factory import CacheType
12
12
  from cache_dit.cache_factory import BlockAdapter
13
+ from cache_dit.cache_factory import ParamsModifier
13
14
  from cache_dit.cache_factory import ForwardPattern
14
15
  from cache_dit.cache_factory import PatchFunctor
15
16
  from cache_dit.cache_factory import supported_pipelines
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.27'
32
- __version_tuple__ = version_tuple = (0, 2, 27)
31
+ __version__ = version = '0.2.28'
32
+ __version_tuple__ = version_tuple = (0, 2, 28)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -7,9 +7,11 @@ from cache_dit.cache_factory.forward_pattern import ForwardPattern
7
7
  from cache_dit.cache_factory.patch_functors import PatchFunctor
8
8
 
9
9
  from cache_dit.cache_factory.block_adapters import BlockAdapter
10
+ from cache_dit.cache_factory.block_adapters import ParamsModifier
10
11
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
11
12
 
12
13
  from cache_dit.cache_factory.cache_contexts import CachedContext
14
+ from cache_dit.cache_factory.cache_contexts import CachedContextManager
13
15
  from cache_dit.cache_factory.cache_blocks import CachedBlocks
14
16
 
15
17
  from cache_dit.cache_factory.cache_adapters import CachedAdapter
@@ -1,5 +1,6 @@
1
1
  from cache_dit.cache_factory.forward_pattern import ForwardPattern
2
2
  from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
3
+ from cache_dit.cache_factory.block_adapters.block_adapters import ParamsModifier
3
4
  from cache_dit.cache_factory.block_adapters.block_registers import (
4
5
  BlockAdapterRegistry,
5
6
  )
@@ -69,14 +70,30 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
69
70
  (WanTransformer3DModel, WanVACETransformer3DModel),
70
71
  )
71
72
  if getattr(pipe, "transformer_2", None):
72
- # Wan 2.2, cache for low-noise transformer
73
+ assert isinstance(
74
+ pipe.transformer_2,
75
+ (WanTransformer3DModel, WanVACETransformer3DModel),
76
+ )
77
+ # Wan 2.2 MoE
73
78
  return BlockAdapter(
74
79
  pipe=pipe,
75
- transformer=pipe.transformer_2,
76
- blocks=pipe.transformer_2.blocks,
77
- blocks_name="blocks",
80
+ transformer=[
81
+ pipe.transformer,
82
+ pipe.transformer_2,
83
+ ],
84
+ blocks=[
85
+ pipe.transformer.blocks,
86
+ pipe.transformer_2.blocks,
87
+ ],
88
+ blocks_name=[
89
+ "blocks",
90
+ "blocks",
91
+ ],
92
+ forward_pattern=[
93
+ ForwardPattern.Pattern_2,
94
+ ForwardPattern.Pattern_2,
95
+ ],
78
96
  dummy_blocks_names=[],
79
- forward_pattern=ForwardPattern.Pattern_2,
80
97
  has_separate_cfg=True,
81
98
  )
82
99
  else:
@@ -3,7 +3,7 @@ import torch
3
3
  import inspect
4
4
  import dataclasses
5
5
 
6
- from typing import Any, Tuple, List, Optional
6
+ from typing import Any, Tuple, List, Optional, Union
7
7
 
8
8
  from diffusers import DiffusionPipeline
9
9
  from cache_dit.cache_factory.forward_pattern import ForwardPattern
@@ -14,21 +14,79 @@ from cache_dit.logger import init_logger
14
14
  logger = init_logger(__name__)
15
15
 
16
16
 
17
+ class ParamsModifier:
18
+ def __init__(self, **kwargs):
19
+ self._context_kwargs = kwargs.copy()
20
+
21
+
17
22
  @dataclasses.dataclass
18
23
  class BlockAdapter:
24
+
19
25
  # Transformer configurations.
20
- pipe: DiffusionPipeline | Any = None
21
- transformer: torch.nn.Module = None
26
+ pipe: Union[
27
+ DiffusionPipeline,
28
+ Any,
29
+ ] = None
30
+
31
+ # single transformer (most cases) or list of transformers (Wan2.2, etc)
32
+ transformer: Union[
33
+ torch.nn.Module,
34
+ List[torch.nn.Module],
35
+ ] = None
36
+
37
+ # Block Level Flags
38
+ # Each transformer contains a list of blocks-list,
39
+ # blocks_name-list, dummy_blocks_names-list, etc.
40
+ blocks: Union[
41
+ torch.nn.ModuleList,
42
+ List[torch.nn.ModuleList],
43
+ List[List[torch.nn.ModuleList]],
44
+ ] = None
22
45
 
23
- # ------------ Block Level Flags ------------
24
- blocks: torch.nn.ModuleList | List[torch.nn.ModuleList] = None
25
46
  # transformer_blocks, blocks, etc.
26
- blocks_name: str | List[str] = None
27
- dummy_blocks_names: List[str] = dataclasses.field(default_factory=list)
28
- forward_pattern: ForwardPattern | List[ForwardPattern] = None
47
+ blocks_name: Union[
48
+ str,
49
+ List[str],
50
+ List[List[str]],
51
+ ] = None
52
+
53
+ unique_blocks_name: Union[
54
+ str,
55
+ List[str],
56
+ List[List[str]],
57
+ ] = dataclasses.field(default_factory=list)
58
+
59
+ dummy_blocks_names: Union[
60
+ List[str],
61
+ List[List[str]],
62
+ ] = dataclasses.field(default_factory=list)
63
+
64
+ forward_pattern: Union[
65
+ ForwardPattern,
66
+ List[ForwardPattern],
67
+ List[List[ForwardPattern]],
68
+ ] = None
69
+
70
+ # modify cache context params for specific blocks.
71
+ params_modifiers: Union[
72
+ ParamsModifier,
73
+ List[ParamsModifier],
74
+ List[List[ParamsModifier]],
75
+ ] = None
76
+
29
77
  check_num_outputs: bool = True
30
78
 
79
+ # Pipeline Level Flags
80
+ # Patch Functor: Flux, etc.
81
+ patch_functor: Optional[PatchFunctor] = None
82
+ # Flags for separate cfg
83
+ has_separate_cfg: bool = False
84
+
85
+ # Other Flags
86
+ disable_patch: bool = False
87
+
31
88
  # Flags to control auto block adapter
89
+ # NOTE: NOT support for multi-transformers.
32
90
  auto: bool = False
33
91
  allow_prefixes: List[str] = dataclasses.field(
34
92
  default_factory=lambda: [
@@ -49,15 +107,6 @@ class BlockAdapter:
49
107
  default="max", metadata={"allowed_values": ["max", "min"]}
50
108
  )
51
109
 
52
- # NOTE: Other flags.
53
- disable_patch: bool = False
54
-
55
- # ------------ Pipeline Level Flags ------------
56
- # Patch Functor: Flux, etc.
57
- patch_functor: Optional[PatchFunctor] = None
58
- # Flags for separate cfg
59
- has_separate_cfg: bool = False
60
-
61
110
  def __post_init__(self):
62
111
  assert any((self.pipe is not None, self.transformer is not None))
63
112
  self.patchify()
@@ -320,14 +369,170 @@ class BlockAdapter:
320
369
  def normalize(
321
370
  adapter: "BlockAdapter",
322
371
  ) -> "BlockAdapter":
323
- if not isinstance(adapter.blocks, list):
324
- adapter.blocks = [adapter.blocks]
325
- if not isinstance(adapter.blocks_name, list):
326
- adapter.blocks_name = [adapter.blocks_name]
327
- if not isinstance(adapter.forward_pattern, list):
328
- adapter.forward_pattern = [adapter.forward_pattern]
329
372
 
330
- assert len(adapter.blocks) == len(adapter.blocks_name)
331
- assert len(adapter.blocks) == len(adapter.forward_pattern)
373
+ if getattr(adapter, "_is_normalized", False):
374
+ return adapter
375
+
376
+ if not isinstance(adapter.transformer, list):
377
+ adapter.transformer = [adapter.transformer]
378
+
379
+ if isinstance(adapter.blocks, torch.nn.ModuleList):
380
+ # blocks_0 = [[blocks_0,],] -> match [TRN_0,]
381
+ adapter.blocks = [[adapter.blocks]]
382
+ elif isinstance(adapter.blocks, list):
383
+ if isinstance(adapter.blocks[0], torch.nn.ModuleList):
384
+ # [blocks_0, blocks_1] -> [[blocks_0, blocks_1],] -> match [TRN_0,]
385
+ if len(adapter.blocks) == len(adapter.transformer):
386
+ adapter.blocks = [[blocks] for blocks in adapter.blocks]
387
+ else:
388
+ adapter.blocks = [adapter.blocks]
389
+ elif isinstance(adapter.blocks[0], list):
390
+ # [[blocks_0, blocks_1],[blocks_2, blocks_3],] -> match [TRN_0, TRN_1,]
391
+ pass
392
+
393
+ if isinstance(adapter.blocks_name, str):
394
+ adapter.blocks_name = [[adapter.blocks_name]]
395
+ elif isinstance(adapter.blocks_name, list):
396
+ if isinstance(adapter.blocks_name[0], str):
397
+ if len(adapter.blocks_name) == len(adapter.transformer):
398
+ adapter.blocks_name = [
399
+ [blocks_name] for blocks_name in adapter.blocks_name
400
+ ]
401
+ else:
402
+ adapter.blocks_name = [adapter.blocks_name]
403
+ elif isinstance(adapter.blocks_name[0], list):
404
+ pass
405
+
406
+ if isinstance(adapter.forward_pattern, ForwardPattern):
407
+ adapter.forward_pattern = [[adapter.forward_pattern]]
408
+ elif isinstance(adapter.forward_pattern, list):
409
+ if isinstance(adapter.forward_pattern[0], ForwardPattern):
410
+ if len(adapter.forward_pattern) == len(adapter.transformer):
411
+ adapter.forward_pattern = [
412
+ [forward_pattern]
413
+ for forward_pattern in adapter.forward_pattern
414
+ ]
415
+ else:
416
+ adapter.forward_pattern = [adapter.forward_pattern]
417
+ elif isinstance(adapter.forward_pattern[0], list):
418
+ pass
419
+
420
+ if isinstance(adapter.dummy_blocks_names, list):
421
+ if len(adapter.dummy_blocks_names) > 0:
422
+ if isinstance(adapter.dummy_blocks_names[0], str):
423
+ if len(adapter.dummy_blocks_names) == len(
424
+ adapter.transformer
425
+ ):
426
+ adapter.dummy_blocks_names = [
427
+ [dummy_blocks_names]
428
+ for dummy_blocks_names in adapter.dummy_blocks_names
429
+ ]
430
+ else:
431
+ adapter.dummy_blocks_names = [
432
+ adapter.dummy_blocks_names
433
+ ]
434
+ elif isinstance(adapter.dummy_blocks_names[0], list):
435
+ pass
436
+ else:
437
+ # Empty dummy_blocks_names
438
+ adapter.dummy_blocks_names = [
439
+ [] for _ in range(len(adapter.transformer))
440
+ ]
441
+
442
+ if adapter.params_modifiers is not None:
443
+ if isinstance(adapter.params_modifiers, ParamsModifier):
444
+ adapter.params_modifiers = [[adapter.params_modifiers]]
445
+ elif isinstance(adapter.params_modifiers, list):
446
+ if isinstance(adapter.params_modifiers[0], ParamsModifier):
447
+ if len(adapter.params_modifiers) == len(
448
+ adapter.transformer
449
+ ):
450
+ adapter.params_modifiers = [
451
+ [params_modifiers]
452
+ for params_modifiers in adapter.params_modifiers
453
+ ]
454
+ else:
455
+ adapter.params_modifiers = [adapter.params_modifiers]
456
+ elif isinstance(adapter.params_modifiers[0], list):
457
+ pass
458
+
459
+ assert len(adapter.transformer) == len(adapter.blocks)
460
+ assert len(adapter.transformer) == len(adapter.blocks_name)
461
+ assert len(adapter.transformer) == len(adapter.forward_pattern)
462
+ assert len(adapter.transformer) == len(adapter.dummy_blocks_names)
463
+ if adapter.params_modifiers is not None:
464
+ assert len(adapter.transformer) == len(adapter.params_modifiers)
465
+
466
+ for i in range(len(adapter.blocks)):
467
+ assert len(adapter.blocks[i]) == len(adapter.blocks_name[i])
468
+ assert len(adapter.blocks[i]) == len(adapter.forward_pattern[i])
469
+
470
+ if len(adapter.unique_blocks_name) == 0:
471
+ for i in range(len(adapter.transformer)):
472
+ # Generate unique blocks names
473
+ adapter.unique_blocks_name.append(
474
+ [
475
+ f"{name}_{hash(id(blocks))}"
476
+ for blocks, name in zip(
477
+ adapter.blocks[i],
478
+ adapter.blocks_name[i],
479
+ )
480
+ ]
481
+ )
482
+
483
+ assert len(adapter.transformer) == len(adapter.unique_blocks_name)
484
+
485
+ # Match Forward Pattern
486
+ for i in range(len(adapter.transformer)):
487
+ for forward_pattern, blocks in zip(
488
+ adapter.forward_pattern[i], adapter.blocks[i]
489
+ ):
490
+ assert BlockAdapter.match_blocks_pattern(
491
+ blocks,
492
+ forward_pattern=forward_pattern,
493
+ check_num_outputs=adapter.check_num_outputs,
494
+ ), (
495
+ "No block forward pattern matched, "
496
+ f"supported lists: {ForwardPattern.supported_patterns()}"
497
+ )
498
+
499
+ adapter._is_normalized = True
332
500
 
333
501
  return adapter
502
+
503
+ @classmethod
504
+ def assert_normalized(cls, adapter: "BlockAdapter"):
505
+ if not getattr(adapter, "_is_normalized", False):
506
+ raise RuntimeError("block_adapter must be normailzed.")
507
+
508
+ @classmethod
509
+ def is_cached(cls, adapter: Any) -> bool:
510
+ if isinstance(adapter, cls):
511
+ cls.assert_normalized(adapter)
512
+ return all(
513
+ (
514
+ getattr(adapter.pipe, "_is_cached", False),
515
+ getattr(adapter.transformer[0], "_is_cached", False),
516
+ )
517
+ )
518
+ elif isinstance(
519
+ adapter,
520
+ (DiffusionPipeline, torch.nn.Module),
521
+ ):
522
+ return getattr(adapter, "_is_cached", False)
523
+ elif isinstance(adapter, list): # [TRN_0,...]
524
+ assert isinstance(adapter[0], torch.nn.Module)
525
+ return getattr(adapter[0], "_is_cached", False)
526
+ else:
527
+ raise TypeError(f"Can't check this type: {adapter}!")
528
+
529
+ @classmethod
530
+ def flatten(cls, attr: List[List[Any]]):
531
+ if isinstance(attr, list):
532
+ if not isinstance(attr[0], list):
533
+ return attr
534
+ flatten_attr = []
535
+ for i in range(len(attr)):
536
+ flatten_attr.extend(attr[i])
537
+ return flatten_attr
538
+ return attr