cache-dit 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (43) hide show
  1. cache_dit/__init__.py +12 -0
  2. cache_dit/_version.py +16 -3
  3. cache_dit/cache_factory/.gitignore +2 -0
  4. cache_dit/cache_factory/__init__.py +52 -2
  5. cache_dit/cache_factory/cache_adapters.py +654 -0
  6. cache_dit/cache_factory/cache_blocks.py +487 -0
  7. cache_dit/cache_factory/{dual_block_cache/cache_context.py → cache_context.py} +11 -862
  8. cache_dit/cache_factory/patch/flux.py +249 -0
  9. cache_dit/cache_factory/utils.py +1 -1
  10. cache_dit/compile/__init__.py +1 -1
  11. cache_dit/compile/utils.py +1 -1
  12. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/METADATA +87 -204
  13. cache_dit-0.2.17.dist-info/RECORD +30 -0
  14. cache_dit/cache_factory/adapters.py +0 -169
  15. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -55
  16. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -87
  17. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -98
  18. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +0 -294
  19. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -87
  20. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/qwen_image.py +0 -88
  21. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +0 -97
  22. cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
  23. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -51
  24. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -87
  25. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -98
  26. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +0 -294
  27. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -87
  28. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +0 -97
  29. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +0 -1005
  30. cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
  31. cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
  32. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
  33. cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +0 -89
  34. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
  35. cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
  36. cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +0 -89
  37. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
  38. cache_dit-0.2.15.dist-info/RECORD +0 -50
  39. /cache_dit/cache_factory/{dual_block_cache → patch}/__init__.py +0 -0
  40. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/WHEEL +0 -0
  41. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/entry_points.txt +0 -0
  42. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/licenses/LICENSE +0 -0
  43. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/top_level.txt +0 -0
@@ -1,98 +0,0 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/wan.py
2
-
3
- import functools
4
- import unittest
5
-
6
- import torch
7
- from diffusers import DiffusionPipeline, WanTransformer3DModel
8
-
9
- from cache_dit.cache_factory.first_block_cache import cache_context
10
-
11
-
12
- def apply_cache_on_transformer(
13
- transformer: WanTransformer3DModel,
14
- ):
15
- if getattr(transformer, "_is_cached", False):
16
- return transformer
17
-
18
- blocks = torch.nn.ModuleList(
19
- [
20
- cache_context.CachedTransformerBlocks(
21
- transformer.blocks,
22
- transformer=transformer,
23
- return_hidden_states_only=True,
24
- )
25
- ]
26
- )
27
-
28
- original_forward = transformer.forward
29
-
30
- @functools.wraps(transformer.__class__.forward)
31
- def new_forward(
32
- self,
33
- *args,
34
- **kwargs,
35
- ):
36
- with unittest.mock.patch.object(
37
- self,
38
- "blocks",
39
- blocks,
40
- ):
41
- return original_forward(
42
- *args,
43
- **kwargs,
44
- )
45
-
46
- transformer.forward = new_forward.__get__(transformer)
47
-
48
- transformer._is_cached = True
49
-
50
- return transformer
51
-
52
-
53
- def apply_cache_on_pipe(
54
- pipe: DiffusionPipeline,
55
- *,
56
- shallow_patch: bool = False,
57
- residual_diff_threshold=0.03,
58
- downsample_factor=1,
59
- slg_layers=None,
60
- slg_start: float = 0.0,
61
- slg_end: float = 0.1,
62
- warmup_steps=0,
63
- max_cached_steps=-1,
64
- **kwargs,
65
- ):
66
- cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
67
- default_attrs={
68
- "residual_diff_threshold": residual_diff_threshold,
69
- "downsample_factor": downsample_factor,
70
- "enable_alter_cache": True,
71
- "slg_layers": slg_layers,
72
- "slg_start": slg_start,
73
- "slg_end": slg_end,
74
- "num_inference_steps": kwargs.get("num_inference_steps", 50),
75
- "warmup_steps": warmup_steps,
76
- "max_cached_steps": max_cached_steps,
77
- },
78
- **kwargs,
79
- )
80
- if not getattr(pipe, "_is_cached", False):
81
- original_call = pipe.__class__.__call__
82
-
83
- @functools.wraps(original_call)
84
- def new_call(self, *args, **kwargs):
85
- with cache_context.cache_context(
86
- cache_context.create_cache_context(
87
- **cache_kwargs,
88
- )
89
- ):
90
- return original_call(self, *args, **kwargs)
91
-
92
- pipe.__class__.__call__ = new_call
93
- pipe.__class__._is_cached = True
94
-
95
- if not shallow_patch:
96
- apply_cache_on_transformer(pipe.transformer, **kwargs)
97
-
98
- return pipe
@@ -1,50 +0,0 @@
1
- cache_dit/__init__.py,sha256=0-B173-fLi3IA8nJXoS71zK0zD33Xplysd9skmLfEOY,171
2
- cache_dit/_version.py,sha256=TspyaLI34cMH7HNr-lrUKFnY4CF39yGfG_421vC3fhg,513
3
- cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
- cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
- cache_dit/utils.py,sha256=4cFNh0asch6Zgsixq0bS1ElfwBu_6BG5ZSmaa1khjyg,144
6
- cache_dit/cache_factory/__init__.py,sha256=iYQwLwB_XLoYl0OB9unZGDbBtrYvZaLkOAmhGRwdW2E,191
7
- cache_dit/cache_factory/adapters.py,sha256=QMCaXnmqM7NT7sx4bCF1mMLn-QcXX9h1RmgLAypDedg,5256
8
- cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
9
- cache_dit/cache_factory/utils.py,sha256=V-Mb5Jn07geEUUWo4QAfh6pmSzkL-2OGDn0VAXbG6hQ,1799
10
- cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=sJ9yxQlcrX4qkPln94FrL0WDe2WIn3_UD2-Mk8YtjSw,73301
12
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=uSqF5aD2-feHB25vEbx1STBQVjWVAOn_wYTdAEmS4NU,2045
13
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=3xUjvDzor9AkBkDUc0N7kZqM86MIdajuigesnicNzXE,2260
14
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=cIsov6Pf0dRyddqkzTA2CU-jSDotof8LQr-HIoY9T9M,2615
15
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=SO4q39PQuQ5QVHy5Z-ubiKdstzvQPedONN2J5oiGUh0,9955
16
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=8W9m-WeEVE2ytYi9udKEA8Wtb0EnvP3eT2A1Tu-d29k,2252
17
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/qwen_image.py,sha256=ZIjB4GFIL0_xY8FkXdOJJ9-Xcft54rnBCrz43VWZLi0,2296
18
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=EH2PpYFJ76KQLb35za6T-nM8Q10owLNatv6cd480ydE,2584
19
- cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=1qarKAsEFiaaN2_ghko2dqGz_R7BTQSOyGtb_eQq38Y,35716
21
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=hVBTXj9MMGFGVezT3j8MntFRBiphSaUL4YhSOd8JtuY,1870
22
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=KP8NxtHAKzzBOoX0lhvlMgY_5dmP4Z3T5TOfwl4SSyg,2273
23
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=kCB7lL4OIq8TZn-baMIF8D_PVPTFW60omCMVQCb8ebs,2628
24
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py,sha256=xAkd40BGsfuCKdW3Abrx35VwgZQg4CZFz13P4VY71eY,9968
25
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=zXgoRDDjus3a2WSjtNh4ERtQp20ceb6nzohHMDlo2zY,2265
26
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py,sha256=PA7nuLgfAelnaI8usQx0Kxi8XATzMapyR1WndEdFoZA,2604
27
- cache_dit/cache_factory/first_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
- cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=qn4zWJ_eEMIPYzrxXoslunxbzK0WueuNtC54Pp5Q57k,23241
29
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py,sha256=-FFgA2MoudEo7uDacg4aWgm1KwfLZFsEDTVxatgbq9M,2146
30
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py,sha256=qO5CWyurtwW30mvOe6cxeQPTSXLDlPJcezm72zEjDq8,2375
31
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py,sha256=Dcd4OzABCtyQCZNX2KNnUTdVoO1E1ApM7P8gcVYzcK0,2733
32
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py,sha256=OL7W4ukYlZz0IDmBR1zVV6XT3Mgciglj9Hqzv1wUAkQ,10092
33
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py,sha256=lQTClo52OwPbNEE4jiBZQhfC7hbtYqnYIABp_vbm_dk,2363
34
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=dBNzHBECAuTTA1a7kLdvZL20YzaKTAS3iciVLzKKEWA,2638
35
- cache_dit/compile/__init__.py,sha256=DfMdPleFFGADXLsr7zXui8BTz_y9futY6rNmNdh9y7k,63
36
- cache_dit/compile/utils.py,sha256=N4A55_8uIbEd-S4xyJPcrdKceI2MGM9BTIhJE63jyL4,3786
37
- cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
- cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
- cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
40
- cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
41
- cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
42
- cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
43
- cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
44
- cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
45
- cache_dit-0.2.15.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
46
- cache_dit-0.2.15.dist-info/METADATA,sha256=7cxw9ZaYpCFbtA1iDt-CWhk07mDXuRPECxfuO-wB0IE,25153
47
- cache_dit-0.2.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
48
- cache_dit-0.2.15.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
49
- cache_dit-0.2.15.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
50
- cache_dit-0.2.15.dist-info/RECORD,,