cache-dit 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (37) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/__init__.py +1 -0
  3. cache_dit/cache_factory/adapters.py +47 -5
  4. cache_dit/cache_factory/dual_block_cache/__init__.py +4 -0
  5. cache_dit/cache_factory/dual_block_cache/cache_blocks.py +487 -0
  6. cache_dit/cache_factory/dual_block_cache/cache_context.py +10 -860
  7. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +4 -0
  8. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +5 -2
  9. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +14 -4
  10. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +5 -2
  11. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +5 -2
  12. cache_dit/cache_factory/{first_block_cache/diffusers_adapters/mochi.py → dual_block_cache/diffusers_adapters/qwen_image.py} +14 -12
  13. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +7 -4
  14. cache_dit/cache_factory/dynamic_block_prune/__init__.py +4 -0
  15. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +4 -0
  16. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +5 -2
  17. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +10 -4
  18. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +5 -2
  19. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +5 -2
  20. cache_dit/cache_factory/{first_block_cache/diffusers_adapters/cogvideox.py → dynamic_block_prune/diffusers_adapters/qwen_image.py} +28 -23
  21. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +5 -2
  22. cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +276 -0
  23. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +228 -516
  24. cache_dit/cache_factory/patch/flux.py +241 -0
  25. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/METADATA +22 -80
  26. cache_dit-0.2.16.dist-info/RECORD +47 -0
  27. cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
  28. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
  29. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
  30. cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
  31. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
  32. cache_dit-0.2.14.dist-info/RECORD +0 -49
  33. /cache_dit/cache_factory/{first_block_cache → patch}/__init__.py +0 -0
  34. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/WHEEL +0 -0
  35. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/entry_points.txt +0 -0
  36. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/licenses/LICENSE +0 -0
  37. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,8 @@ def apply_db_cache_on_transformer(transformer, *args, **kwargs):
15
15
  adapter_name = "wan"
16
16
  elif transformer_cls_name.startswith("HunyuanVideo"):
17
17
  adapter_name = "hunyuan_video"
18
+ elif transformer_cls_name.startswith("QwenImage"):
19
+ adapter_name = "qwen_image"
18
20
  else:
19
21
  raise ValueError(
20
22
  f"Unknown transformer class name: {transformer_cls_name}"
@@ -41,6 +43,8 @@ def apply_db_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
41
43
  adapter_name = "wan"
42
44
  elif pipe_cls_name.startswith("HunyuanVideo"):
43
45
  adapter_name = "hunyuan_video"
46
+ elif pipe_cls_name.startswith("QwenImage"):
47
+ adapter_name = "qwen_image"
44
48
  else:
45
49
  raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
46
50
 
@@ -4,7 +4,10 @@ import unittest
4
4
  import torch
5
5
  from diffusers import CogVideoXTransformer3DModel, DiffusionPipeline
6
6
 
7
- from cache_dit.cache_factory.dual_block_cache import cache_context
7
+ from cache_dit.cache_factory.dual_block_cache import (
8
+ cache_context,
9
+ DBCachedTransformerBlocks,
10
+ )
8
11
 
9
12
 
10
13
  def apply_db_cache_on_transformer(
@@ -15,7 +18,7 @@ def apply_db_cache_on_transformer(
15
18
 
16
19
  cached_transformer_blocks = torch.nn.ModuleList(
17
20
  [
18
- cache_context.DBCachedTransformerBlocks(
21
+ DBCachedTransformerBlocks(
19
22
  transformer.transformer_blocks,
20
23
  transformer=transformer,
21
24
  )
@@ -4,7 +4,15 @@ import unittest
4
4
  import torch
5
5
  from diffusers import DiffusionPipeline, FluxTransformer2DModel
6
6
 
7
- from cache_dit.cache_factory.dual_block_cache import cache_context
7
+ from cache_dit.cache_factory.patch.flux import maybe_patch_flux_transformer
8
+ from cache_dit.cache_factory.dual_block_cache import (
9
+ cache_context,
10
+ DBCachedTransformerBlocks,
11
+ )
12
+
13
+ from cache_dit.logger import init_logger
14
+
15
+ logger = init_logger(__name__)
8
16
 
9
17
 
10
18
  def apply_db_cache_on_transformer(
@@ -13,11 +21,13 @@ def apply_db_cache_on_transformer(
13
21
  if getattr(transformer, "_is_cached", False):
14
22
  return transformer
15
23
 
24
+ transformer = maybe_patch_flux_transformer(transformer)
25
+
16
26
  cached_transformer_blocks = torch.nn.ModuleList(
17
27
  [
18
- cache_context.DBCachedTransformerBlocks(
19
- transformer.transformer_blocks,
20
- transformer.single_transformer_blocks,
28
+ DBCachedTransformerBlocks(
29
+ transformer.transformer_blocks
30
+ + transformer.single_transformer_blocks,
21
31
  transformer=transformer,
22
32
  return_hidden_states_first=False,
23
33
  )
@@ -11,7 +11,10 @@ from diffusers.utils import (
11
11
  USE_PEFT_BACKEND,
12
12
  )
13
13
 
14
- from cache_dit.cache_factory.dual_block_cache import cache_context
14
+ from cache_dit.cache_factory.dual_block_cache import (
15
+ cache_context,
16
+ DBCachedTransformerBlocks,
17
+ )
15
18
  from cache_dit.logger import init_logger
16
19
 
17
20
  try:
@@ -44,7 +47,7 @@ def apply_db_cache_on_transformer(
44
47
 
45
48
  cached_transformer_blocks = torch.nn.ModuleList(
46
49
  [
47
- cache_context.DBCachedTransformerBlocks(
50
+ DBCachedTransformerBlocks(
48
51
  transformer.transformer_blocks
49
52
  + transformer.single_transformer_blocks,
50
53
  transformer=transformer,
@@ -4,7 +4,10 @@ import unittest
4
4
  import torch
5
5
  from diffusers import DiffusionPipeline, MochiTransformer3DModel
6
6
 
7
- from cache_dit.cache_factory.dual_block_cache import cache_context
7
+ from cache_dit.cache_factory.dual_block_cache import (
8
+ cache_context,
9
+ DBCachedTransformerBlocks,
10
+ )
8
11
 
9
12
 
10
13
  def apply_db_cache_on_transformer(
@@ -15,7 +18,7 @@ def apply_db_cache_on_transformer(
15
18
 
16
19
  cached_transformer_blocks = torch.nn.ModuleList(
17
20
  [
18
- cache_context.DBCachedTransformerBlocks(
21
+ DBCachedTransformerBlocks(
19
22
  transformer.transformer_blocks,
20
23
  transformer=transformer,
21
24
  )
@@ -1,25 +1,27 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/mochi.py
2
-
3
1
  import functools
4
2
  import unittest
5
3
 
6
4
  import torch
7
- from diffusers import DiffusionPipeline, MochiTransformer3DModel
5
+ from diffusers import QwenImagePipeline, QwenImageTransformer2DModel
8
6
 
9
- from cache_dit.cache_factory.first_block_cache import cache_context
7
+ from cache_dit.cache_factory.dual_block_cache import (
8
+ cache_context,
9
+ DBCachedTransformerBlocks,
10
+ )
10
11
 
11
12
 
12
- def apply_cache_on_transformer(
13
- transformer: MochiTransformer3DModel,
13
+ def apply_db_cache_on_transformer(
14
+ transformer: QwenImageTransformer2DModel,
14
15
  ):
15
16
  if getattr(transformer, "_is_cached", False):
16
17
  return transformer
17
18
 
18
- cached_transformer_blocks = torch.nn.ModuleList(
19
+ transformer_blocks = torch.nn.ModuleList(
19
20
  [
20
- cache_context.CachedTransformerBlocks(
21
+ DBCachedTransformerBlocks(
21
22
  transformer.transformer_blocks,
22
23
  transformer=transformer,
24
+ return_hidden_states_first=False,
23
25
  )
24
26
  ]
25
27
  )
@@ -35,7 +37,7 @@ def apply_cache_on_transformer(
35
37
  with unittest.mock.patch.object(
36
38
  self,
37
39
  "transformer_blocks",
38
- cached_transformer_blocks,
40
+ transformer_blocks,
39
41
  ):
40
42
  return original_forward(
41
43
  *args,
@@ -49,8 +51,8 @@ def apply_cache_on_transformer(
49
51
  return transformer
50
52
 
51
53
 
52
- def apply_cache_on_pipe(
53
- pipe: DiffusionPipeline,
54
+ def apply_db_cache_on_pipe(
55
+ pipe: QwenImagePipeline,
54
56
  *,
55
57
  shallow_patch: bool = False,
56
58
  residual_diff_threshold=0.06,
@@ -84,6 +86,6 @@ def apply_cache_on_pipe(
84
86
  pipe.__class__._is_cached = True
85
87
 
86
88
  if not shallow_patch:
87
- apply_cache_on_transformer(pipe.transformer, **kwargs)
89
+ apply_db_cache_on_transformer(pipe.transformer)
88
90
 
89
91
  return pipe
@@ -2,9 +2,12 @@ import functools
2
2
  import unittest
3
3
 
4
4
  import torch
5
- from diffusers import DiffusionPipeline, WanTransformer3DModel
5
+ from diffusers import WanPipeline, WanTransformer3DModel
6
6
 
7
- from cache_dit.cache_factory.dual_block_cache import cache_context
7
+ from cache_dit.cache_factory.dual_block_cache import (
8
+ cache_context,
9
+ DBCachedTransformerBlocks,
10
+ )
8
11
 
9
12
 
10
13
  def apply_db_cache_on_transformer(
@@ -15,7 +18,7 @@ def apply_db_cache_on_transformer(
15
18
 
16
19
  blocks = torch.nn.ModuleList(
17
20
  [
18
- cache_context.DBCachedTransformerBlocks(
21
+ DBCachedTransformerBlocks(
19
22
  transformer.blocks,
20
23
  transformer=transformer,
21
24
  return_hidden_states_only=True,
@@ -49,7 +52,7 @@ def apply_db_cache_on_transformer(
49
52
 
50
53
 
51
54
  def apply_db_cache_on_pipe(
52
- pipe: DiffusionPipeline,
55
+ pipe: WanPipeline,
53
56
  *,
54
57
  shallow_patch: bool = False,
55
58
  residual_diff_threshold=0.03,
@@ -0,0 +1,4 @@
1
+ from cache_dit.cache_factory.dynamic_block_prune.prune_blocks import (
2
+ prune_context,
3
+ DBPrunedTransformerBlocks,
4
+ )
@@ -15,6 +15,8 @@ def apply_db_prune_on_transformer(transformer, *args, **kwargs):
15
15
  adapter_name = "wan"
16
16
  elif transformer_cls_name.startswith("HunyuanVideo"):
17
17
  adapter_name = "hunyuan_video"
18
+ elif transformer_cls_name.startswith("QwenImage"):
19
+ adapter_name = "qwen_image"
18
20
  else:
19
21
  raise ValueError(
20
22
  f"Unknown transformer class name: {transformer_cls_name}"
@@ -41,6 +43,8 @@ def apply_db_prune_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
41
43
  adapter_name = "wan"
42
44
  elif pipe_cls_name.startswith("HunyuanVideo"):
43
45
  adapter_name = "hunyuan_video"
46
+ elif pipe_cls_name.startswith("QwenImage"):
47
+ adapter_name = "qwen_image"
44
48
  else:
45
49
  raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
46
50
 
@@ -4,7 +4,10 @@ import unittest
4
4
  import torch
5
5
  from diffusers import CogVideoXTransformer3DModel, DiffusionPipeline
6
6
 
7
- from cache_dit.cache_factory.dynamic_block_prune import prune_context
7
+ from cache_dit.cache_factory.dynamic_block_prune import (
8
+ prune_context,
9
+ DBPrunedTransformerBlocks,
10
+ )
8
11
 
9
12
 
10
13
  def apply_db_prune_on_transformer(
@@ -15,7 +18,7 @@ def apply_db_prune_on_transformer(
15
18
 
16
19
  cached_transformer_blocks = torch.nn.ModuleList(
17
20
  [
18
- prune_context.DBPrunedTransformerBlocks(
21
+ DBPrunedTransformerBlocks(
19
22
  transformer.transformer_blocks,
20
23
  transformer=transformer,
21
24
  )
@@ -4,7 +4,11 @@ import unittest
4
4
  import torch
5
5
  from diffusers import DiffusionPipeline, FluxTransformer2DModel
6
6
 
7
- from cache_dit.cache_factory.dynamic_block_prune import prune_context
7
+ from cache_dit.cache_factory.patch.flux import maybe_patch_flux_transformer
8
+ from cache_dit.cache_factory.dynamic_block_prune import (
9
+ prune_context,
10
+ DBPrunedTransformerBlocks,
11
+ )
8
12
 
9
13
 
10
14
  def apply_db_prune_on_transformer(
@@ -13,11 +17,13 @@ def apply_db_prune_on_transformer(
13
17
  if getattr(transformer, "_is_pruned", False):
14
18
  return transformer
15
19
 
20
+ transformer = maybe_patch_flux_transformer(transformer)
21
+
16
22
  cached_transformer_blocks = torch.nn.ModuleList(
17
23
  [
18
- prune_context.DBPrunedTransformerBlocks(
19
- transformer.transformer_blocks,
20
- transformer.single_transformer_blocks,
24
+ DBPrunedTransformerBlocks(
25
+ transformer.transformer_blocks
26
+ + transformer.single_transformer_blocks,
21
27
  transformer=transformer,
22
28
  return_hidden_states_first=False,
23
29
  )
@@ -11,7 +11,10 @@ from diffusers.utils import (
11
11
  USE_PEFT_BACKEND,
12
12
  )
13
13
 
14
- from cache_dit.cache_factory.dynamic_block_prune import prune_context
14
+ from cache_dit.cache_factory.dynamic_block_prune import (
15
+ prune_context,
16
+ DBPrunedTransformerBlocks,
17
+ )
15
18
  from cache_dit.logger import init_logger
16
19
 
17
20
  try:
@@ -44,7 +47,7 @@ def apply_db_prune_on_transformer(
44
47
 
45
48
  cached_transformer_blocks = torch.nn.ModuleList(
46
49
  [
47
- prune_context.DBPrunedTransformerBlocks(
50
+ DBPrunedTransformerBlocks(
48
51
  transformer.transformer_blocks
49
52
  + transformer.single_transformer_blocks,
50
53
  transformer=transformer,
@@ -4,7 +4,10 @@ import unittest
4
4
  import torch
5
5
  from diffusers import DiffusionPipeline, MochiTransformer3DModel
6
6
 
7
- from cache_dit.cache_factory.dynamic_block_prune import prune_context
7
+ from cache_dit.cache_factory.dynamic_block_prune import (
8
+ prune_context,
9
+ DBPrunedTransformerBlocks,
10
+ )
8
11
 
9
12
 
10
13
  def apply_db_prune_on_transformer(
@@ -15,7 +18,7 @@ def apply_db_prune_on_transformer(
15
18
 
16
19
  cached_transformer_blocks = torch.nn.ModuleList(
17
20
  [
18
- prune_context.DBPrunedTransformerBlocks(
21
+ DBPrunedTransformerBlocks(
19
22
  transformer.transformer_blocks,
20
23
  transformer=transformer,
21
24
  )
@@ -1,25 +1,30 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/cogvideox.py
2
-
3
1
  import functools
4
2
  import unittest
5
3
 
6
4
  import torch
7
- from diffusers import CogVideoXTransformer3DModel, DiffusionPipeline
5
+ from diffusers import QwenImagePipeline, QwenImageTransformer2DModel
6
+
7
+ from cache_dit.cache_factory.dynamic_block_prune import (
8
+ prune_context,
9
+ DBPrunedTransformerBlocks,
10
+ )
11
+ from cache_dit.logger import init_logger
8
12
 
9
- from cache_dit.cache_factory.first_block_cache import cache_context
13
+ logger = init_logger(__name__)
10
14
 
11
15
 
12
- def apply_cache_on_transformer(
13
- transformer: CogVideoXTransformer3DModel,
16
+ def apply_db_prune_on_transformer(
17
+ transformer: QwenImageTransformer2DModel,
14
18
  ):
15
- if getattr(transformer, "_is_cached", False):
19
+ if getattr(transformer, "_is_pruned", False):
16
20
  return transformer
17
21
 
18
- cached_transformer_blocks = torch.nn.ModuleList(
22
+ transformer_blocks = torch.nn.ModuleList(
19
23
  [
20
- cache_context.CachedTransformerBlocks(
24
+ DBPrunedTransformerBlocks(
21
25
  transformer.transformer_blocks,
22
26
  transformer=transformer,
27
+ return_hidden_states_first=False,
23
28
  )
24
29
  ]
25
30
  )
@@ -35,7 +40,7 @@ def apply_cache_on_transformer(
35
40
  with unittest.mock.patch.object(
36
41
  self,
37
42
  "transformer_blocks",
38
- cached_transformer_blocks,
43
+ transformer_blocks,
39
44
  ):
40
45
  return original_forward(
41
46
  *args,
@@ -44,46 +49,46 @@ def apply_cache_on_transformer(
44
49
 
45
50
  transformer.forward = new_forward.__get__(transformer)
46
51
 
47
- transformer._is_cached = True
52
+ transformer._is_pruned = True
48
53
 
49
54
  return transformer
50
55
 
51
56
 
52
- def apply_cache_on_pipe(
53
- pipe: DiffusionPipeline,
57
+ def apply_db_prune_on_pipe(
58
+ pipe: QwenImagePipeline,
54
59
  *,
55
60
  shallow_patch: bool = False,
56
- residual_diff_threshold=0.04,
61
+ residual_diff_threshold=0.06,
57
62
  downsample_factor=1,
58
63
  warmup_steps=0,
59
- max_cached_steps=-1,
64
+ max_pruned_steps=-1,
60
65
  **kwargs,
61
66
  ):
62
- cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
67
+ prune_kwargs, kwargs = prune_context.collect_prune_kwargs(
63
68
  default_attrs={
64
69
  "residual_diff_threshold": residual_diff_threshold,
65
70
  "downsample_factor": downsample_factor,
66
71
  "warmup_steps": warmup_steps,
67
- "max_cached_steps": max_cached_steps,
72
+ "max_pruned_steps": max_pruned_steps,
68
73
  },
69
74
  **kwargs,
70
75
  )
71
- if not getattr(pipe, "_is_cached", False):
76
+ if not getattr(pipe, "_is_pruned", False):
72
77
  original_call = pipe.__class__.__call__
73
78
 
74
79
  @functools.wraps(original_call)
75
80
  def new_call(self, *args, **kwargs):
76
- with cache_context.cache_context(
77
- cache_context.create_cache_context(
78
- **cache_kwargs,
81
+ with prune_context.prune_context(
82
+ prune_context.create_prune_context(
83
+ **prune_kwargs,
79
84
  )
80
85
  ):
81
86
  return original_call(self, *args, **kwargs)
82
87
 
83
88
  pipe.__class__.__call__ = new_call
84
- pipe.__class__._is_cached = True
89
+ pipe.__class__._is_pruned = True
85
90
 
86
91
  if not shallow_patch:
87
- apply_cache_on_transformer(pipe.transformer, **kwargs)
92
+ apply_db_prune_on_transformer(pipe.transformer)
88
93
 
89
94
  return pipe
@@ -4,7 +4,10 @@ import unittest
4
4
  import torch
5
5
  from diffusers import DiffusionPipeline, WanTransformer3DModel
6
6
 
7
- from cache_dit.cache_factory.dynamic_block_prune import prune_context
7
+ from cache_dit.cache_factory.dynamic_block_prune import (
8
+ prune_context,
9
+ DBPrunedTransformerBlocks,
10
+ )
8
11
 
9
12
 
10
13
  def apply_db_prune_on_transformer(
@@ -15,7 +18,7 @@ def apply_db_prune_on_transformer(
15
18
 
16
19
  blocks = torch.nn.ModuleList(
17
20
  [
18
- prune_context.DBPrunedTransformerBlocks(
21
+ DBPrunedTransformerBlocks(
19
22
  transformer.blocks,
20
23
  transformer=transformer,
21
24
  return_hidden_states_only=True,