cache-dit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (31) hide show
  1. cache_dit/__init__.py +0 -0
  2. cache_dit/_version.py +21 -0
  3. cache_dit/cache_factory/__init__.py +166 -0
  4. cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
  5. cache_dit/cache_factory/dual_block_cache/cache_context.py +1361 -0
  6. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +45 -0
  7. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +89 -0
  8. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +100 -0
  9. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +88 -0
  10. cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
  11. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +45 -0
  12. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +89 -0
  13. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +100 -0
  14. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +89 -0
  15. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +979 -0
  16. cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
  17. cache_dit/cache_factory/first_block_cache/cache_context.py +727 -0
  18. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +53 -0
  19. cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +89 -0
  20. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +100 -0
  21. cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +89 -0
  22. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +98 -0
  23. cache_dit/cache_factory/taylorseer.py +76 -0
  24. cache_dit/cache_factory/utils.py +0 -0
  25. cache_dit/logger.py +97 -0
  26. cache_dit/primitives.py +152 -0
  27. cache_dit-0.1.0.dist-info/METADATA +350 -0
  28. cache_dit-0.1.0.dist-info/RECORD +31 -0
  29. cache_dit-0.1.0.dist-info/WHEEL +5 -0
  30. cache_dit-0.1.0.dist-info/licenses/LICENSE +53 -0
  31. cache_dit-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,45 @@
1
+ # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/__init__.py
2
+
3
+ import importlib
4
+
5
+ from diffusers import DiffusionPipeline
6
+
7
+
8
+ def apply_db_cache_on_transformer(transformer, *args, **kwargs):
9
+ transformer_cls_name: str = transformer.__class__.__name__
10
+ if transformer_cls_name.startswith("Flux"):
11
+ adapter_name = "flux"
12
+ elif transformer_cls_name.startswith("Mochi"):
13
+ adapter_name = "mochi"
14
+ elif transformer_cls_name.startswith("CogVideoX"):
15
+ adapter_name = "cogvideox"
16
+ else:
17
+ raise ValueError(
18
+ f"Unknown transformer class name: {transformer_cls_name}"
19
+ )
20
+
21
+ adapter_module = importlib.import_module(f".{adapter_name}", __package__)
22
+ apply_db_cache_on_transformer_fn = getattr(
23
+ adapter_module, "apply_db_cache_on_transformer"
24
+ )
25
+ return apply_db_cache_on_transformer_fn(transformer, *args, **kwargs)
26
+
27
+
28
+ def apply_db_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
29
+ assert isinstance(pipe, DiffusionPipeline)
30
+
31
+ pipe_cls_name: str = pipe.__class__.__name__
32
+ if pipe_cls_name.startswith("Flux"):
33
+ adapter_name = "flux"
34
+ elif pipe_cls_name.startswith("Mochi"):
35
+ adapter_name = "mochi"
36
+ elif pipe_cls_name.startswith("CogVideoX"):
37
+ adapter_name = "cogvideox"
38
+ else:
39
+ raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
40
+
41
+ adapter_module = importlib.import_module(f".{adapter_name}", __package__)
42
+ apply_db_cache_on_pipe_fn = getattr(
43
+ adapter_module, "apply_db_cache_on_pipe"
44
+ )
45
+ return apply_db_cache_on_pipe_fn(pipe, *args, **kwargs)
@@ -0,0 +1,89 @@
1
+ # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/cogvideox.py
2
+
3
+ import functools
4
+ import unittest
5
+
6
+ import torch
7
+ from diffusers import CogVideoXTransformer3DModel, DiffusionPipeline
8
+
9
+ from cache_dit.cache_factory.dual_block_cache import cache_context
10
+
11
+
12
+ def apply_db_cache_on_transformer(
13
+ transformer: CogVideoXTransformer3DModel,
14
+ ):
15
+ if getattr(transformer, "_is_cached", False):
16
+ return transformer
17
+
18
+ cached_transformer_blocks = torch.nn.ModuleList(
19
+ [
20
+ cache_context.DBCachedTransformerBlocks(
21
+ transformer.transformer_blocks,
22
+ transformer=transformer,
23
+ )
24
+ ]
25
+ )
26
+
27
+ original_forward = transformer.forward
28
+
29
+ @functools.wraps(transformer.__class__.forward)
30
+ def new_forward(
31
+ self,
32
+ *args,
33
+ **kwargs,
34
+ ):
35
+ with unittest.mock.patch.object(
36
+ self,
37
+ "transformer_blocks",
38
+ cached_transformer_blocks,
39
+ ):
40
+ return original_forward(
41
+ *args,
42
+ **kwargs,
43
+ )
44
+
45
+ transformer.forward = new_forward.__get__(transformer)
46
+
47
+ transformer._is_cached = True
48
+
49
+ return transformer
50
+
51
+
52
+ def apply_db_cache_on_pipe(
53
+ pipe: DiffusionPipeline,
54
+ *,
55
+ shallow_patch: bool = False,
56
+ residual_diff_threshold=0.04,
57
+ downsample_factor=1,
58
+ warmup_steps=0,
59
+ max_cached_steps=-1,
60
+ **kwargs,
61
+ ):
62
+ cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
63
+ default_attrs={
64
+ "residual_diff_threshold": residual_diff_threshold,
65
+ "downsample_factor": downsample_factor,
66
+ "warmup_steps": warmup_steps,
67
+ "max_cached_steps": max_cached_steps,
68
+ },
69
+ **kwargs,
70
+ )
71
+ if not getattr(pipe, "_is_cached", False):
72
+ original_call = pipe.__class__.__call__
73
+
74
+ @functools.wraps(original_call)
75
+ def new_call(self, *args, **kwargs):
76
+ with cache_context.cache_context(
77
+ cache_context.create_cache_context(
78
+ **cache_kwargs,
79
+ )
80
+ ):
81
+ return original_call(self, *args, **kwargs)
82
+
83
+ pipe.__class__.__call__ = new_call
84
+ pipe.__class__._is_cached = True
85
+
86
+ if not shallow_patch:
87
+ apply_db_cache_on_transformer(pipe.transformer, **kwargs)
88
+
89
+ return pipe
@@ -0,0 +1,100 @@
1
+ # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/flux.py
2
+
3
+ import functools
4
+ import unittest
5
+
6
+ import torch
7
+ from diffusers import DiffusionPipeline, FluxTransformer2DModel
8
+
9
+ from cache_dit.cache_factory.dual_block_cache import cache_context
10
+
11
+
12
+ def apply_db_cache_on_transformer(
13
+ transformer: FluxTransformer2DModel,
14
+ ):
15
+ if getattr(transformer, "_is_cached", False):
16
+ return transformer
17
+
18
+ cached_transformer_blocks = torch.nn.ModuleList(
19
+ [
20
+ cache_context.DBCachedTransformerBlocks(
21
+ transformer.transformer_blocks,
22
+ transformer.single_transformer_blocks,
23
+ transformer=transformer,
24
+ return_hidden_states_first=False,
25
+ )
26
+ ]
27
+ )
28
+ dummy_single_transformer_blocks = torch.nn.ModuleList()
29
+
30
+ original_forward = transformer.forward
31
+
32
+ @functools.wraps(original_forward)
33
+ def new_forward(
34
+ self,
35
+ *args,
36
+ **kwargs,
37
+ ):
38
+ with (
39
+ unittest.mock.patch.object(
40
+ self,
41
+ "transformer_blocks",
42
+ cached_transformer_blocks,
43
+ ),
44
+ unittest.mock.patch.object(
45
+ self,
46
+ "single_transformer_blocks",
47
+ dummy_single_transformer_blocks,
48
+ ),
49
+ ):
50
+ return original_forward(
51
+ *args,
52
+ **kwargs,
53
+ )
54
+
55
+ transformer.forward = new_forward.__get__(transformer)
56
+
57
+ transformer._is_cached = True
58
+
59
+ return transformer
60
+
61
+
62
+ def apply_db_cache_on_pipe(
63
+ pipe: DiffusionPipeline,
64
+ *,
65
+ shallow_patch: bool = False,
66
+ residual_diff_threshold=0.05,
67
+ downsample_factor=1,
68
+ warmup_steps=0,
69
+ max_cached_steps=-1,
70
+ **kwargs,
71
+ ):
72
+ cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
73
+ default_attrs={
74
+ "residual_diff_threshold": residual_diff_threshold,
75
+ "downsample_factor": downsample_factor,
76
+ "warmup_steps": warmup_steps,
77
+ "max_cached_steps": max_cached_steps,
78
+ },
79
+ **kwargs,
80
+ )
81
+
82
+ if not getattr(pipe, "_is_cached", False):
83
+ original_call = pipe.__class__.__call__
84
+
85
+ @functools.wraps(original_call)
86
+ def new_call(self, *args, **kwargs):
87
+ with cache_context.cache_context(
88
+ cache_context.create_cache_context(
89
+ **cache_kwargs,
90
+ )
91
+ ):
92
+ return original_call(self, *args, **kwargs)
93
+
94
+ pipe.__class__.__call__ = new_call
95
+ pipe.__class__._is_cached = True
96
+
97
+ if not shallow_patch:
98
+ apply_db_cache_on_transformer(pipe.transformer, **kwargs)
99
+
100
+ return pipe
@@ -0,0 +1,88 @@
1
+ # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/mochi.py
2
+ import functools
3
+ import unittest
4
+
5
+ import torch
6
+ from diffusers import DiffusionPipeline, MochiTransformer3DModel
7
+
8
+ from cache_dit.cache_factory.dual_block_cache import cache_context
9
+
10
+
11
+ def apply_db_cache_on_transformer(
12
+ transformer: MochiTransformer3DModel,
13
+ ):
14
+ if getattr(transformer, "_is_cached", False):
15
+ return transformer
16
+
17
+ cached_transformer_blocks = torch.nn.ModuleList(
18
+ [
19
+ cache_context.DBCachedTransformerBlocks(
20
+ transformer.transformer_blocks,
21
+ transformer=transformer,
22
+ )
23
+ ]
24
+ )
25
+
26
+ original_forward = transformer.forward
27
+
28
+ @functools.wraps(transformer.__class__.forward)
29
+ def new_forward(
30
+ self,
31
+ *args,
32
+ **kwargs,
33
+ ):
34
+ with unittest.mock.patch.object(
35
+ self,
36
+ "transformer_blocks",
37
+ cached_transformer_blocks,
38
+ ):
39
+ return original_forward(
40
+ *args,
41
+ **kwargs,
42
+ )
43
+
44
+ transformer.forward = new_forward.__get__(transformer)
45
+
46
+ transformer._is_cached = True
47
+
48
+ return transformer
49
+
50
+
51
+ def apply_db_cache_on_pipe(
52
+ pipe: DiffusionPipeline,
53
+ *,
54
+ shallow_patch: bool = False,
55
+ residual_diff_threshold=0.06,
56
+ downsample_factor=1,
57
+ warmup_steps=0,
58
+ max_cached_steps=-1,
59
+ **kwargs,
60
+ ):
61
+ cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
62
+ default_attrs={
63
+ "residual_diff_threshold": residual_diff_threshold,
64
+ "downsample_factor": downsample_factor,
65
+ "warmup_steps": warmup_steps,
66
+ "max_cached_steps": max_cached_steps,
67
+ },
68
+ **kwargs,
69
+ )
70
+ if not getattr(pipe, "_is_cached", False):
71
+ original_call = pipe.__class__.__call__
72
+
73
+ @functools.wraps(original_call)
74
+ def new_call(self, *args, **kwargs):
75
+ with cache_context.cache_context(
76
+ cache_context.create_cache_context(
77
+ **cache_kwargs,
78
+ )
79
+ ):
80
+ return original_call(self, *args, **kwargs)
81
+
82
+ pipe.__class__.__call__ = new_call
83
+ pipe.__class__._is_cached = True
84
+
85
+ if not shallow_patch:
86
+ apply_db_cache_on_transformer(pipe.transformer, **kwargs)
87
+
88
+ return pipe
@@ -0,0 +1,45 @@
1
+ # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/__init__.py
2
+
3
+ import importlib
4
+
5
+ from diffusers import DiffusionPipeline
6
+
7
+
8
+ def apply_db_prune_on_transformer(transformer, *args, **kwargs):
9
+ transformer_cls_name: str = transformer.__class__.__name__
10
+ if transformer_cls_name.startswith("Flux"):
11
+ adapter_name = "flux"
12
+ elif transformer_cls_name.startswith("Mochi"):
13
+ adapter_name = "mochi"
14
+ elif transformer_cls_name.startswith("CogVideoX"):
15
+ adapter_name = "cogvideox"
16
+ else:
17
+ raise ValueError(
18
+ f"Unknown transformer class name: {transformer_cls_name}"
19
+ )
20
+
21
+ adapter_module = importlib.import_module(f".{adapter_name}", __package__)
22
+ apply_db_cache_on_transformer_fn = getattr(
23
+ adapter_module, "apply_db_prune_on_transformer"
24
+ )
25
+ return apply_db_cache_on_transformer_fn(transformer, *args, **kwargs)
26
+
27
+
28
+ def apply_db_prune_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
29
+ assert isinstance(pipe, DiffusionPipeline)
30
+
31
+ pipe_cls_name: str = pipe.__class__.__name__
32
+ if pipe_cls_name.startswith("Flux"):
33
+ adapter_name = "flux"
34
+ elif pipe_cls_name.startswith("Mochi"):
35
+ adapter_name = "mochi"
36
+ elif pipe_cls_name.startswith("CogVideoX"):
37
+ adapter_name = "cogvideox"
38
+ else:
39
+ raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
40
+
41
+ adapter_module = importlib.import_module(f".{adapter_name}", __package__)
42
+ apply_db_cache_on_pipe_fn = getattr(
43
+ adapter_module, "apply_db_prune_on_pipe"
44
+ )
45
+ return apply_db_cache_on_pipe_fn(pipe, *args, **kwargs)
@@ -0,0 +1,89 @@
1
+ # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/cogvideox.py
2
+
3
+ import functools
4
+ import unittest
5
+
6
+ import torch
7
+ from diffusers import CogVideoXTransformer3DModel, DiffusionPipeline
8
+
9
+ from cache_dit.cache_factory.dynamic_block_prune import prune_context
10
+
11
+
12
+ def apply_db_prune_on_transformer(
13
+ transformer: CogVideoXTransformer3DModel,
14
+ ):
15
+ if getattr(transformer, "_is_pruned", False):
16
+ return transformer
17
+
18
+ cached_transformer_blocks = torch.nn.ModuleList(
19
+ [
20
+ prune_context.DBPrunedTransformerBlocks(
21
+ transformer.transformer_blocks,
22
+ transformer=transformer,
23
+ )
24
+ ]
25
+ )
26
+
27
+ original_forward = transformer.forward
28
+
29
+ @functools.wraps(transformer.__class__.forward)
30
+ def new_forward(
31
+ self,
32
+ *args,
33
+ **kwargs,
34
+ ):
35
+ with unittest.mock.patch.object(
36
+ self,
37
+ "transformer_blocks",
38
+ cached_transformer_blocks,
39
+ ):
40
+ return original_forward(
41
+ *args,
42
+ **kwargs,
43
+ )
44
+
45
+ transformer.forward = new_forward.__get__(transformer)
46
+
47
+ transformer._is_pruned = True
48
+
49
+ return transformer
50
+
51
+
52
+ def apply_db_prune_on_pipe(
53
+ pipe: DiffusionPipeline,
54
+ *,
55
+ shallow_patch: bool = False,
56
+ residual_diff_threshold=0.04,
57
+ downsample_factor=1,
58
+ warmup_steps=0,
59
+ max_pruned_steps=-1,
60
+ **kwargs,
61
+ ):
62
+ cache_kwargs, kwargs = prune_context.collect_prune_kwargs(
63
+ default_attrs={
64
+ "residual_diff_threshold": residual_diff_threshold,
65
+ "downsample_factor": downsample_factor,
66
+ "warmup_steps": warmup_steps,
67
+ "max_pruned_steps": max_pruned_steps,
68
+ },
69
+ **kwargs,
70
+ )
71
+ if not getattr(pipe, "_is_pruned", False):
72
+ original_call = pipe.__class__.__call__
73
+
74
+ @functools.wraps(original_call)
75
+ def new_call(self, *args, **kwargs):
76
+ with prune_context.prune_context(
77
+ prune_context.create_prune_context(
78
+ **cache_kwargs,
79
+ )
80
+ ):
81
+ return original_call(self, *args, **kwargs)
82
+
83
+ pipe.__class__.__call__ = new_call
84
+ pipe.__class__._is_pruned = True
85
+
86
+ if not shallow_patch:
87
+ apply_db_prune_on_transformer(pipe.transformer, **kwargs)
88
+
89
+ return pipe
@@ -0,0 +1,100 @@
1
+ # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/flux.py
2
+
3
+ import functools
4
+ import unittest
5
+
6
+ import torch
7
+ from diffusers import DiffusionPipeline, FluxTransformer2DModel
8
+
9
+ from cache_dit.cache_factory.dynamic_block_prune import prune_context
10
+
11
+
12
+ def apply_db_prune_on_transformer(
13
+ transformer: FluxTransformer2DModel,
14
+ ):
15
+ if getattr(transformer, "_is_pruned", False):
16
+ return transformer
17
+
18
+ cached_transformer_blocks = torch.nn.ModuleList(
19
+ [
20
+ prune_context.DBPrunedTransformerBlocks(
21
+ transformer.transformer_blocks,
22
+ transformer.single_transformer_blocks,
23
+ transformer=transformer,
24
+ return_hidden_states_first=False,
25
+ )
26
+ ]
27
+ )
28
+ dummy_single_transformer_blocks = torch.nn.ModuleList()
29
+
30
+ original_forward = transformer.forward
31
+
32
+ @functools.wraps(original_forward)
33
+ def new_forward(
34
+ self,
35
+ *args,
36
+ **kwargs,
37
+ ):
38
+ with (
39
+ unittest.mock.patch.object(
40
+ self,
41
+ "transformer_blocks",
42
+ cached_transformer_blocks,
43
+ ),
44
+ unittest.mock.patch.object(
45
+ self,
46
+ "single_transformer_blocks",
47
+ dummy_single_transformer_blocks,
48
+ ),
49
+ ):
50
+ return original_forward(
51
+ *args,
52
+ **kwargs,
53
+ )
54
+
55
+ transformer.forward = new_forward.__get__(transformer)
56
+
57
+ transformer._is_pruned = True
58
+
59
+ return transformer
60
+
61
+
62
+ def apply_db_prune_on_pipe(
63
+ pipe: DiffusionPipeline,
64
+ *,
65
+ shallow_patch: bool = False,
66
+ residual_diff_threshold=0.05,
67
+ downsample_factor=1,
68
+ warmup_steps=0,
69
+ max_pruned_steps=-1,
70
+ **kwargs,
71
+ ):
72
+ cache_kwargs, kwargs = prune_context.collect_prune_kwargs(
73
+ default_attrs={
74
+ "residual_diff_threshold": residual_diff_threshold,
75
+ "downsample_factor": downsample_factor,
76
+ "warmup_steps": warmup_steps,
77
+ "max_pruned_steps": max_pruned_steps,
78
+ },
79
+ **kwargs,
80
+ )
81
+
82
+ if not getattr(pipe, "_is_pruned", False):
83
+ original_call = pipe.__class__.__call__
84
+
85
+ @functools.wraps(original_call)
86
+ def new_call(self, *args, **kwargs):
87
+ with prune_context.prune_context(
88
+ prune_context.create_prune_context(
89
+ **cache_kwargs,
90
+ )
91
+ ):
92
+ return original_call(self, *args, **kwargs)
93
+
94
+ pipe.__class__.__call__ = new_call
95
+ pipe.__class__._is_pruned = True
96
+
97
+ if not shallow_patch:
98
+ apply_db_prune_on_transformer(pipe.transformer, **kwargs)
99
+
100
+ return pipe
@@ -0,0 +1,89 @@
1
+ # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/mochi.py
2
+
3
+ import functools
4
+ import unittest
5
+
6
+ import torch
7
+ from diffusers import DiffusionPipeline, MochiTransformer3DModel
8
+
9
+ from cache_dit.cache_factory.dynamic_block_prune import prune_context
10
+
11
+
12
+ def apply_db_prune_on_transformer(
13
+ transformer: MochiTransformer3DModel,
14
+ ):
15
+ if getattr(transformer, "_is_pruned", False):
16
+ return transformer
17
+
18
+ cached_transformer_blocks = torch.nn.ModuleList(
19
+ [
20
+ prune_context.DBPrunedTransformerBlocks(
21
+ transformer.transformer_blocks,
22
+ transformer=transformer,
23
+ )
24
+ ]
25
+ )
26
+
27
+ original_forward = transformer.forward
28
+
29
+ @functools.wraps(transformer.__class__.forward)
30
+ def new_forward(
31
+ self,
32
+ *args,
33
+ **kwargs,
34
+ ):
35
+ with unittest.mock.patch.object(
36
+ self,
37
+ "transformer_blocks",
38
+ cached_transformer_blocks,
39
+ ):
40
+ return original_forward(
41
+ *args,
42
+ **kwargs,
43
+ )
44
+
45
+ transformer.forward = new_forward.__get__(transformer)
46
+
47
+ transformer._is_pruned = True
48
+
49
+ return transformer
50
+
51
+
52
+ def apply_db_prune_on_pipe(
53
+ pipe: DiffusionPipeline,
54
+ *,
55
+ shallow_patch: bool = False,
56
+ residual_diff_threshold=0.06,
57
+ downsample_factor=1,
58
+ warmup_steps=0,
59
+ max_pruned_steps=-1,
60
+ **kwargs,
61
+ ):
62
+ cache_kwargs, kwargs = prune_context.collect_prune_kwargs(
63
+ default_attrs={
64
+ "residual_diff_threshold": residual_diff_threshold,
65
+ "downsample_factor": downsample_factor,
66
+ "warmup_steps": warmup_steps,
67
+ "max_pruned_steps": max_pruned_steps,
68
+ },
69
+ **kwargs,
70
+ )
71
+ if not getattr(pipe, "_is_pruned", False):
72
+ original_call = pipe.__class__.__call__
73
+
74
+ @functools.wraps(original_call)
75
+ def new_call(self, *args, **kwargs):
76
+ with prune_context.prune_context(
77
+ prune_context.create_prune_context(
78
+ **cache_kwargs,
79
+ )
80
+ ):
81
+ return original_call(self, *args, **kwargs)
82
+
83
+ pipe.__class__.__call__ = new_call
84
+ pipe.__class__._is_pruned = True
85
+
86
+ if not shallow_patch:
87
+ apply_db_prune_on_transformer(pipe.transformer, **kwargs)
88
+
89
+ return pipe