cache-dit 0.1.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (30) hide show
  1. cache_dit/__init__.py +0 -0
  2. cache_dit/_version.py +21 -0
  3. cache_dit/cache_factory/__init__.py +166 -0
  4. cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
  5. cache_dit/cache_factory/dual_block_cache/cache_context.py +1361 -0
  6. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +45 -0
  7. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +89 -0
  8. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +100 -0
  9. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +88 -0
  10. cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
  11. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +45 -0
  12. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +89 -0
  13. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +100 -0
  14. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +89 -0
  15. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +979 -0
  16. cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
  17. cache_dit/cache_factory/first_block_cache/cache_context.py +727 -0
  18. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +53 -0
  19. cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +89 -0
  20. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +100 -0
  21. cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +89 -0
  22. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +98 -0
  23. cache_dit/cache_factory/taylorseer.py +76 -0
  24. cache_dit/cache_factory/utils.py +0 -0
  25. cache_dit/logger.py +97 -0
  26. cache_dit/primitives.py +152 -0
  27. cache_dit-0.1.1.dev2.dist-info/METADATA +31 -0
  28. cache_dit-0.1.1.dev2.dist-info/RECORD +30 -0
  29. cache_dit-0.1.1.dev2.dist-info/WHEEL +5 -0
  30. cache_dit-0.1.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,152 @@
1
+ # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/primitives.py
2
+
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ if dist.is_available():
9
+ import torch.distributed._functional_collectives as ft_c
10
+ import torch.distributed.distributed_c10d as c10d
11
+ else:
12
+ ft_c = None
13
+ c10d = None
14
+
15
+
16
+ def get_group(group=None):
17
+ if group is None:
18
+ group = c10d._get_default_group()
19
+
20
+ if isinstance(group, dist.ProcessGroup):
21
+ pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = group
22
+ else:
23
+ pg = group.get_group()
24
+
25
+ return pg
26
+
27
+
28
+ def get_world_size(group=None):
29
+ pg = get_group(group)
30
+ return dist.get_world_size(pg)
31
+
32
+
33
+ def get_rank(group=None):
34
+ pg = get_group(group)
35
+ return dist.get_rank(pg)
36
+
37
+
38
+ def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor:
39
+ """
40
+ When tracing the code, the result tensor is not an AsyncCollectiveTensor,
41
+ so we cannot call ``wait()``.
42
+ """
43
+ if isinstance(tensor, ft_c.AsyncCollectiveTensor):
44
+ return tensor.wait()
45
+ return tensor
46
+
47
+
48
+ def all_gather_tensor_sync(x, *args, group=None, **kwargs):
49
+ group = get_group(group)
50
+ x_shape = x.shape
51
+ x = x.flatten()
52
+ x_numel = x.numel()
53
+ x = ft_c.all_gather_tensor(x, *args, group=group, **kwargs)
54
+ x = _maybe_wait(x)
55
+ x_shape = list(x_shape)
56
+ x_shape[0] *= x.numel() // x_numel
57
+ x = x.reshape(x_shape)
58
+ return x
59
+
60
+
61
+ def all_gather_tensor_autograd_sync(x, *args, group=None, **kwargs):
62
+ group = get_group(group)
63
+ x_shape = x.shape
64
+ x = x.flatten()
65
+ x_numel = x.numel()
66
+ x = ft_c.all_gather_tensor_autograd(x, *args, group=group, **kwargs)
67
+ x = _maybe_wait(x)
68
+ x_shape = list(x_shape)
69
+ x_shape[0] *= x.numel() // x_numel
70
+ x = x.reshape(x_shape)
71
+ return x
72
+
73
+
74
+ def all_to_all_single_sync(x, *args, **kwargs):
75
+ x_shape = x.shape
76
+ x = x.flatten()
77
+ x = ft_c.all_to_all_single(x, *args, **kwargs)
78
+ x = _maybe_wait(x)
79
+ x = x.reshape(x_shape)
80
+ return x
81
+
82
+
83
+ def all_to_all_single_autograd_sync(x, *args, **kwargs):
84
+ x_shape = x.shape
85
+ x = x.flatten()
86
+ x = ft_c.all_to_all_single_autograd(x, *args, **kwargs)
87
+ x = _maybe_wait(x)
88
+ x = x.reshape(x_shape)
89
+ return x
90
+
91
+
92
+ def all_reduce_sync(x, *args, group=None, **kwargs):
93
+ group = get_group(group)
94
+ x = ft_c.all_reduce(x, *args, group=group, **kwargs)
95
+ x = _maybe_wait(x)
96
+ return x
97
+
98
+
99
+ def get_buffer(
100
+ shape_or_tensor: Union[Tuple[int], torch.Tensor],
101
+ *,
102
+ repeats: int = 1,
103
+ dim: int = 0,
104
+ dtype: Optional[torch.dtype] = None,
105
+ device: Optional[torch.device] = None,
106
+ group=None,
107
+ ) -> torch.Tensor:
108
+ if repeats is None:
109
+ repeats = get_world_size(group)
110
+
111
+ if isinstance(shape_or_tensor, torch.Tensor):
112
+ shape = shape_or_tensor.shape
113
+ dtype = shape_or_tensor.dtype
114
+ device = shape_or_tensor.device
115
+
116
+ assert dtype is not None
117
+ assert device is not None
118
+
119
+ shape = list(shape)
120
+ if repeats > 1:
121
+ shape[dim] *= repeats
122
+
123
+ buffer = torch.empty(shape, dtype=dtype, device=device)
124
+ return buffer
125
+
126
+
127
+ def get_assigned_chunk(
128
+ tensor: torch.Tensor,
129
+ dim: int = 0,
130
+ idx: Optional[int] = None,
131
+ group=None,
132
+ ) -> torch.Tensor:
133
+ if idx is None:
134
+ idx = get_rank(group)
135
+ world_size = get_world_size(group)
136
+ total_size = tensor.shape[dim]
137
+ assert (
138
+ total_size % world_size == 0
139
+ ), f"tensor.shape[{dim}]={total_size} is not divisible by world_size={world_size}"
140
+ return tensor.chunk(world_size, dim=dim)[idx]
141
+
142
+
143
+ def get_complete_tensor(
144
+ tensor: torch.Tensor,
145
+ *,
146
+ dim: int = 0,
147
+ group=None,
148
+ ) -> torch.Tensor:
149
+ tensor = tensor.transpose(0, dim).contiguous()
150
+ output_tensor = all_gather_tensor_sync(tensor, gather_dim=0, group=group)
151
+ output_tensor = output_tensor.transpose(0, dim)
152
+ return output_tensor
@@ -0,0 +1,31 @@
1
+ Metadata-Version: 2.4
2
+ Name: cache_dit
3
+ Version: 0.1.1.dev2
4
+ Summary: ⚡️DBCache: A Training-free UNet-style Cache Acceleration for Diffusion Transformers
5
+ Author: DefTruth, vipshop.com, etc.
6
+ Maintainer: DefTruth, vipshop.com, etc
7
+ Project-URL: Repository, https://github.com/vipshop/DBCache.git
8
+ Project-URL: Homepage, https://github.com/vipshop/DBCache.git
9
+ Requires-Python: >=3.10
10
+ Requires-Dist: packaging
11
+ Requires-Dist: torch
12
+ Requires-Dist: transformers
13
+ Requires-Dist: diffusers
14
+ Provides-Extra: all
15
+ Provides-Extra: dev
16
+ Requires-Dist: pre-commit; extra == "dev"
17
+ Requires-Dist: pytest<8.0.0,>=7.0.0; extra == "dev"
18
+ Requires-Dist: pytest-html; extra == "dev"
19
+ Requires-Dist: expecttest; extra == "dev"
20
+ Requires-Dist: hypothesis; extra == "dev"
21
+ Requires-Dist: transformers; extra == "dev"
22
+ Requires-Dist: diffusers; extra == "dev"
23
+ Requires-Dist: accelerate; extra == "dev"
24
+ Requires-Dist: peft; extra == "dev"
25
+ Requires-Dist: protobuf; extra == "dev"
26
+ Requires-Dist: sentencepiece; extra == "dev"
27
+ Requires-Dist: opencv-python-headless; extra == "dev"
28
+ Requires-Dist: ftfy; extra == "dev"
29
+ Dynamic: provides-extra
30
+ Dynamic: requires-dist
31
+ Dynamic: requires-python
@@ -0,0 +1,30 @@
1
+ cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ cache_dit/_version.py,sha256=lIErBp1sZ_uGq2rboUGas8Ch-hnYN8OtqGh3G0mtds0,524
3
+ cache_dit/logger.py,sha256=dKfNe_RRk9HJwfgHGeRR1f0LbskJpKdGmISCbL9roQs,3443
4
+ cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
+ cache_dit/cache_factory/__init__.py,sha256=plAOUMsne-dTYA-cq1RLbE7dlH-kFA_Hst9MzbWPqiI,5224
6
+ cache_dit/cache_factory/taylorseer.py,sha256=0W29ykJg3MnyLAB2KFicsl11Xe41cDYPgI60bquG_NY,2495
7
+ cache_dit/cache_factory/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=EJ-uhA2-sWMW1jNDhcBtjHDqSn8lUzfKbYoPfZDQhZU,49665
10
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=C6tfXHpdY8YFV3gk74dr_IpYH4bO4ItbPCQYud3NgAM,1667
11
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=g3ua-hmTpVeTOQNVYjUX2gsHuG2NV0B81iKHGa51wwk,2401
12
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=UbE6nIF-EtA92QxIZVMzIssdZKQSPAVX1hchF9R8drU,2754
13
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=qxMu1L3ycT8F-uxpGsmFQBY_BH1vDiGIOXgS_Qbb7dM,2391
14
+ cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=cE27f5NPgQ_COmTnF__e85Uz5pWyXID0ut-tmtSQfVQ,34597
16
+ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=8IjJjZOs5XRzsj7Ni2MXpR2Z1PUyRSONIhmfAn1G0eM,1667
17
+ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=ORJpdkXkgziDUo-rpebC6pUemgYaDCoeu0cwwLz175U,2407
18
+ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=KbEkLSsHtS6xwLWNh3jlOlXRyGRdrI2pWV1zyQxMTj4,2757
19
+ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=rgeXfww-7WX6URSDg7mF1HuxSmYmoJVjMVoNGuxjwxc,2395
20
+ cache_dit/cache_factory/first_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
+ cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=DpDhtK095PlrvACf7sbjOt2-QpVkV1arr1qGEKJqgaQ,23502
22
+ cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py,sha256=h80pgoZ3l8qC4rbm9KY0jSN8hOsmGgyvvFxD-xznHdw,1959
23
+ cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py,sha256=qO5CWyurtwW30mvOe6cxeQPTSXLDlPJcezm72zEjDq8,2375
24
+ cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py,sha256=Dcd4OzABCtyQCZNX2KNnUTdVoO1E1ApM7P8gcVYzcK0,2733
25
+ cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py,sha256=lQTClo52OwPbNEE4jiBZQhfC7hbtYqnYIABp_vbm_dk,2363
26
+ cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=IVH-lroOzvYb4XKLk9MOw54EtijBtuzVaKcVGz0KlBA,2656
27
+ cache_dit-0.1.1.dev2.dist-info/METADATA,sha256=pQr1yJwVuMWqB9b-IfPUq5x9UrpzvUauKqAzJWmQIZ0,1150
28
+ cache_dit-0.1.1.dev2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
29
+ cache_dit-0.1.1.dev2.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
30
+ cache_dit-0.1.1.dev2.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ cache_dit