cache-dit 0.1.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +0 -0
- cache_dit/_version.py +21 -0
- cache_dit/cache_factory/__init__.py +166 -0
- cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
- cache_dit/cache_factory/dual_block_cache/cache_context.py +1361 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +45 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +88 -0
- cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +45 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +89 -0
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +979 -0
- cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
- cache_dit/cache_factory/first_block_cache/cache_context.py +727 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +53 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +89 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +98 -0
- cache_dit/cache_factory/taylorseer.py +76 -0
- cache_dit/cache_factory/utils.py +0 -0
- cache_dit/logger.py +97 -0
- cache_dit/primitives.py +152 -0
- cache_dit-0.1.1.dev2.dist-info/METADATA +31 -0
- cache_dit-0.1.1.dev2.dist-info/RECORD +30 -0
- cache_dit-0.1.1.dev2.dist-info/WHEEL +5 -0
- cache_dit-0.1.1.dev2.dist-info/top_level.txt +1 -0
cache_dit/primitives.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/primitives.py
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.distributed as dist
|
|
7
|
+
|
|
8
|
+
if dist.is_available():
|
|
9
|
+
import torch.distributed._functional_collectives as ft_c
|
|
10
|
+
import torch.distributed.distributed_c10d as c10d
|
|
11
|
+
else:
|
|
12
|
+
ft_c = None
|
|
13
|
+
c10d = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_group(group=None):
|
|
17
|
+
if group is None:
|
|
18
|
+
group = c10d._get_default_group()
|
|
19
|
+
|
|
20
|
+
if isinstance(group, dist.ProcessGroup):
|
|
21
|
+
pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = group
|
|
22
|
+
else:
|
|
23
|
+
pg = group.get_group()
|
|
24
|
+
|
|
25
|
+
return pg
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_world_size(group=None):
|
|
29
|
+
pg = get_group(group)
|
|
30
|
+
return dist.get_world_size(pg)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_rank(group=None):
|
|
34
|
+
pg = get_group(group)
|
|
35
|
+
return dist.get_rank(pg)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor:
|
|
39
|
+
"""
|
|
40
|
+
When tracing the code, the result tensor is not an AsyncCollectiveTensor,
|
|
41
|
+
so we cannot call ``wait()``.
|
|
42
|
+
"""
|
|
43
|
+
if isinstance(tensor, ft_c.AsyncCollectiveTensor):
|
|
44
|
+
return tensor.wait()
|
|
45
|
+
return tensor
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def all_gather_tensor_sync(x, *args, group=None, **kwargs):
|
|
49
|
+
group = get_group(group)
|
|
50
|
+
x_shape = x.shape
|
|
51
|
+
x = x.flatten()
|
|
52
|
+
x_numel = x.numel()
|
|
53
|
+
x = ft_c.all_gather_tensor(x, *args, group=group, **kwargs)
|
|
54
|
+
x = _maybe_wait(x)
|
|
55
|
+
x_shape = list(x_shape)
|
|
56
|
+
x_shape[0] *= x.numel() // x_numel
|
|
57
|
+
x = x.reshape(x_shape)
|
|
58
|
+
return x
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def all_gather_tensor_autograd_sync(x, *args, group=None, **kwargs):
|
|
62
|
+
group = get_group(group)
|
|
63
|
+
x_shape = x.shape
|
|
64
|
+
x = x.flatten()
|
|
65
|
+
x_numel = x.numel()
|
|
66
|
+
x = ft_c.all_gather_tensor_autograd(x, *args, group=group, **kwargs)
|
|
67
|
+
x = _maybe_wait(x)
|
|
68
|
+
x_shape = list(x_shape)
|
|
69
|
+
x_shape[0] *= x.numel() // x_numel
|
|
70
|
+
x = x.reshape(x_shape)
|
|
71
|
+
return x
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def all_to_all_single_sync(x, *args, **kwargs):
|
|
75
|
+
x_shape = x.shape
|
|
76
|
+
x = x.flatten()
|
|
77
|
+
x = ft_c.all_to_all_single(x, *args, **kwargs)
|
|
78
|
+
x = _maybe_wait(x)
|
|
79
|
+
x = x.reshape(x_shape)
|
|
80
|
+
return x
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def all_to_all_single_autograd_sync(x, *args, **kwargs):
|
|
84
|
+
x_shape = x.shape
|
|
85
|
+
x = x.flatten()
|
|
86
|
+
x = ft_c.all_to_all_single_autograd(x, *args, **kwargs)
|
|
87
|
+
x = _maybe_wait(x)
|
|
88
|
+
x = x.reshape(x_shape)
|
|
89
|
+
return x
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def all_reduce_sync(x, *args, group=None, **kwargs):
|
|
93
|
+
group = get_group(group)
|
|
94
|
+
x = ft_c.all_reduce(x, *args, group=group, **kwargs)
|
|
95
|
+
x = _maybe_wait(x)
|
|
96
|
+
return x
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def get_buffer(
|
|
100
|
+
shape_or_tensor: Union[Tuple[int], torch.Tensor],
|
|
101
|
+
*,
|
|
102
|
+
repeats: int = 1,
|
|
103
|
+
dim: int = 0,
|
|
104
|
+
dtype: Optional[torch.dtype] = None,
|
|
105
|
+
device: Optional[torch.device] = None,
|
|
106
|
+
group=None,
|
|
107
|
+
) -> torch.Tensor:
|
|
108
|
+
if repeats is None:
|
|
109
|
+
repeats = get_world_size(group)
|
|
110
|
+
|
|
111
|
+
if isinstance(shape_or_tensor, torch.Tensor):
|
|
112
|
+
shape = shape_or_tensor.shape
|
|
113
|
+
dtype = shape_or_tensor.dtype
|
|
114
|
+
device = shape_or_tensor.device
|
|
115
|
+
|
|
116
|
+
assert dtype is not None
|
|
117
|
+
assert device is not None
|
|
118
|
+
|
|
119
|
+
shape = list(shape)
|
|
120
|
+
if repeats > 1:
|
|
121
|
+
shape[dim] *= repeats
|
|
122
|
+
|
|
123
|
+
buffer = torch.empty(shape, dtype=dtype, device=device)
|
|
124
|
+
return buffer
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def get_assigned_chunk(
|
|
128
|
+
tensor: torch.Tensor,
|
|
129
|
+
dim: int = 0,
|
|
130
|
+
idx: Optional[int] = None,
|
|
131
|
+
group=None,
|
|
132
|
+
) -> torch.Tensor:
|
|
133
|
+
if idx is None:
|
|
134
|
+
idx = get_rank(group)
|
|
135
|
+
world_size = get_world_size(group)
|
|
136
|
+
total_size = tensor.shape[dim]
|
|
137
|
+
assert (
|
|
138
|
+
total_size % world_size == 0
|
|
139
|
+
), f"tensor.shape[{dim}]={total_size} is not divisible by world_size={world_size}"
|
|
140
|
+
return tensor.chunk(world_size, dim=dim)[idx]
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def get_complete_tensor(
|
|
144
|
+
tensor: torch.Tensor,
|
|
145
|
+
*,
|
|
146
|
+
dim: int = 0,
|
|
147
|
+
group=None,
|
|
148
|
+
) -> torch.Tensor:
|
|
149
|
+
tensor = tensor.transpose(0, dim).contiguous()
|
|
150
|
+
output_tensor = all_gather_tensor_sync(tensor, gather_dim=0, group=group)
|
|
151
|
+
output_tensor = output_tensor.transpose(0, dim)
|
|
152
|
+
return output_tensor
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: cache_dit
|
|
3
|
+
Version: 0.1.1.dev2
|
|
4
|
+
Summary: ⚡️DBCache: A Training-free UNet-style Cache Acceleration for Diffusion Transformers
|
|
5
|
+
Author: DefTruth, vipshop.com, etc.
|
|
6
|
+
Maintainer: DefTruth, vipshop.com, etc
|
|
7
|
+
Project-URL: Repository, https://github.com/vipshop/DBCache.git
|
|
8
|
+
Project-URL: Homepage, https://github.com/vipshop/DBCache.git
|
|
9
|
+
Requires-Python: >=3.10
|
|
10
|
+
Requires-Dist: packaging
|
|
11
|
+
Requires-Dist: torch
|
|
12
|
+
Requires-Dist: transformers
|
|
13
|
+
Requires-Dist: diffusers
|
|
14
|
+
Provides-Extra: all
|
|
15
|
+
Provides-Extra: dev
|
|
16
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
17
|
+
Requires-Dist: pytest<8.0.0,>=7.0.0; extra == "dev"
|
|
18
|
+
Requires-Dist: pytest-html; extra == "dev"
|
|
19
|
+
Requires-Dist: expecttest; extra == "dev"
|
|
20
|
+
Requires-Dist: hypothesis; extra == "dev"
|
|
21
|
+
Requires-Dist: transformers; extra == "dev"
|
|
22
|
+
Requires-Dist: diffusers; extra == "dev"
|
|
23
|
+
Requires-Dist: accelerate; extra == "dev"
|
|
24
|
+
Requires-Dist: peft; extra == "dev"
|
|
25
|
+
Requires-Dist: protobuf; extra == "dev"
|
|
26
|
+
Requires-Dist: sentencepiece; extra == "dev"
|
|
27
|
+
Requires-Dist: opencv-python-headless; extra == "dev"
|
|
28
|
+
Requires-Dist: ftfy; extra == "dev"
|
|
29
|
+
Dynamic: provides-extra
|
|
30
|
+
Dynamic: requires-dist
|
|
31
|
+
Dynamic: requires-python
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
cache_dit/_version.py,sha256=lIErBp1sZ_uGq2rboUGas8Ch-hnYN8OtqGh3G0mtds0,524
|
|
3
|
+
cache_dit/logger.py,sha256=dKfNe_RRk9HJwfgHGeRR1f0LbskJpKdGmISCbL9roQs,3443
|
|
4
|
+
cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
|
|
5
|
+
cache_dit/cache_factory/__init__.py,sha256=plAOUMsne-dTYA-cq1RLbE7dlH-kFA_Hst9MzbWPqiI,5224
|
|
6
|
+
cache_dit/cache_factory/taylorseer.py,sha256=0W29ykJg3MnyLAB2KFicsl11Xe41cDYPgI60bquG_NY,2495
|
|
7
|
+
cache_dit/cache_factory/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
+
cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
|
+
cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=EJ-uhA2-sWMW1jNDhcBtjHDqSn8lUzfKbYoPfZDQhZU,49665
|
|
10
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=C6tfXHpdY8YFV3gk74dr_IpYH4bO4ItbPCQYud3NgAM,1667
|
|
11
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=g3ua-hmTpVeTOQNVYjUX2gsHuG2NV0B81iKHGa51wwk,2401
|
|
12
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=UbE6nIF-EtA92QxIZVMzIssdZKQSPAVX1hchF9R8drU,2754
|
|
13
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=qxMu1L3ycT8F-uxpGsmFQBY_BH1vDiGIOXgS_Qbb7dM,2391
|
|
14
|
+
cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
+
cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=cE27f5NPgQ_COmTnF__e85Uz5pWyXID0ut-tmtSQfVQ,34597
|
|
16
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=8IjJjZOs5XRzsj7Ni2MXpR2Z1PUyRSONIhmfAn1G0eM,1667
|
|
17
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=ORJpdkXkgziDUo-rpebC6pUemgYaDCoeu0cwwLz175U,2407
|
|
18
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=KbEkLSsHtS6xwLWNh3jlOlXRyGRdrI2pWV1zyQxMTj4,2757
|
|
19
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=rgeXfww-7WX6URSDg7mF1HuxSmYmoJVjMVoNGuxjwxc,2395
|
|
20
|
+
cache_dit/cache_factory/first_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
21
|
+
cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=DpDhtK095PlrvACf7sbjOt2-QpVkV1arr1qGEKJqgaQ,23502
|
|
22
|
+
cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py,sha256=h80pgoZ3l8qC4rbm9KY0jSN8hOsmGgyvvFxD-xznHdw,1959
|
|
23
|
+
cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py,sha256=qO5CWyurtwW30mvOe6cxeQPTSXLDlPJcezm72zEjDq8,2375
|
|
24
|
+
cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py,sha256=Dcd4OzABCtyQCZNX2KNnUTdVoO1E1ApM7P8gcVYzcK0,2733
|
|
25
|
+
cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py,sha256=lQTClo52OwPbNEE4jiBZQhfC7hbtYqnYIABp_vbm_dk,2363
|
|
26
|
+
cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=IVH-lroOzvYb4XKLk9MOw54EtijBtuzVaKcVGz0KlBA,2656
|
|
27
|
+
cache_dit-0.1.1.dev2.dist-info/METADATA,sha256=pQr1yJwVuMWqB9b-IfPUq5x9UrpzvUauKqAzJWmQIZ0,1150
|
|
28
|
+
cache_dit-0.1.1.dev2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
29
|
+
cache_dit-0.1.1.dev2.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
30
|
+
cache_dit-0.1.1.dev2.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
cache_dit
|