cache-dit 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -3
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -1
- cache_dit/cache_factory/cache_adapters.py +298 -123
- cache_dit/cache_factory/cache_blocks.py +9 -3
- cache_dit/cache_factory/cache_context.py +85 -15
- cache_dit/cache_factory/cache_interface.py +18 -11
- cache_dit/cache_factory/taylorseer.py +5 -4
- cache_dit/cache_factory/utils.py +1 -1
- cache_dit/utils.py +25 -22
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/METADATA +19 -10
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/RECORD +16 -17
- cache_dit/primitives.py +0 -152
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/top_level.txt +0 -0
cache_dit/primitives.py
DELETED
|
@@ -1,152 +0,0 @@
|
|
|
1
|
-
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/primitives.py
|
|
2
|
-
|
|
3
|
-
from typing import List, Optional, Tuple, Union
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
import torch.distributed as dist
|
|
7
|
-
|
|
8
|
-
if dist.is_available():
|
|
9
|
-
import torch.distributed._functional_collectives as ft_c
|
|
10
|
-
import torch.distributed.distributed_c10d as c10d
|
|
11
|
-
else:
|
|
12
|
-
ft_c = None
|
|
13
|
-
c10d = None
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def get_group(group=None):
|
|
17
|
-
if group is None:
|
|
18
|
-
group = c10d._get_default_group()
|
|
19
|
-
|
|
20
|
-
if isinstance(group, dist.ProcessGroup):
|
|
21
|
-
pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = group
|
|
22
|
-
else:
|
|
23
|
-
pg = group.get_group()
|
|
24
|
-
|
|
25
|
-
return pg
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def get_world_size(group=None):
|
|
29
|
-
pg = get_group(group)
|
|
30
|
-
return dist.get_world_size(pg)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def get_rank(group=None):
|
|
34
|
-
pg = get_group(group)
|
|
35
|
-
return dist.get_rank(pg)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor:
|
|
39
|
-
"""
|
|
40
|
-
When tracing the code, the result tensor is not an AsyncCollectiveTensor,
|
|
41
|
-
so we cannot call ``wait()``.
|
|
42
|
-
"""
|
|
43
|
-
if isinstance(tensor, ft_c.AsyncCollectiveTensor):
|
|
44
|
-
return tensor.wait()
|
|
45
|
-
return tensor
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def all_gather_tensor_sync(x, *args, group=None, **kwargs):
|
|
49
|
-
group = get_group(group)
|
|
50
|
-
x_shape = x.shape
|
|
51
|
-
x = x.flatten()
|
|
52
|
-
x_numel = x.numel()
|
|
53
|
-
x = ft_c.all_gather_tensor(x, *args, group=group, **kwargs)
|
|
54
|
-
x = _maybe_wait(x)
|
|
55
|
-
x_shape = list(x_shape)
|
|
56
|
-
x_shape[0] *= x.numel() // x_numel
|
|
57
|
-
x = x.reshape(x_shape)
|
|
58
|
-
return x
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def all_gather_tensor_autograd_sync(x, *args, group=None, **kwargs):
|
|
62
|
-
group = get_group(group)
|
|
63
|
-
x_shape = x.shape
|
|
64
|
-
x = x.flatten()
|
|
65
|
-
x_numel = x.numel()
|
|
66
|
-
x = ft_c.all_gather_tensor_autograd(x, *args, group=group, **kwargs)
|
|
67
|
-
x = _maybe_wait(x)
|
|
68
|
-
x_shape = list(x_shape)
|
|
69
|
-
x_shape[0] *= x.numel() // x_numel
|
|
70
|
-
x = x.reshape(x_shape)
|
|
71
|
-
return x
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def all_to_all_single_sync(x, *args, **kwargs):
|
|
75
|
-
x_shape = x.shape
|
|
76
|
-
x = x.flatten()
|
|
77
|
-
x = ft_c.all_to_all_single(x, *args, **kwargs)
|
|
78
|
-
x = _maybe_wait(x)
|
|
79
|
-
x = x.reshape(x_shape)
|
|
80
|
-
return x
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def all_to_all_single_autograd_sync(x, *args, **kwargs):
|
|
84
|
-
x_shape = x.shape
|
|
85
|
-
x = x.flatten()
|
|
86
|
-
x = ft_c.all_to_all_single_autograd(x, *args, **kwargs)
|
|
87
|
-
x = _maybe_wait(x)
|
|
88
|
-
x = x.reshape(x_shape)
|
|
89
|
-
return x
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
def all_reduce_sync(x, *args, group=None, **kwargs):
|
|
93
|
-
group = get_group(group)
|
|
94
|
-
x = ft_c.all_reduce(x, *args, group=group, **kwargs)
|
|
95
|
-
x = _maybe_wait(x)
|
|
96
|
-
return x
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def get_buffer(
|
|
100
|
-
shape_or_tensor: Union[Tuple[int], torch.Tensor],
|
|
101
|
-
*,
|
|
102
|
-
repeats: int = 1,
|
|
103
|
-
dim: int = 0,
|
|
104
|
-
dtype: Optional[torch.dtype] = None,
|
|
105
|
-
device: Optional[torch.device] = None,
|
|
106
|
-
group=None,
|
|
107
|
-
) -> torch.Tensor:
|
|
108
|
-
if repeats is None:
|
|
109
|
-
repeats = get_world_size(group)
|
|
110
|
-
|
|
111
|
-
if isinstance(shape_or_tensor, torch.Tensor):
|
|
112
|
-
shape = shape_or_tensor.shape
|
|
113
|
-
dtype = shape_or_tensor.dtype
|
|
114
|
-
device = shape_or_tensor.device
|
|
115
|
-
|
|
116
|
-
assert dtype is not None
|
|
117
|
-
assert device is not None
|
|
118
|
-
|
|
119
|
-
shape = list(shape)
|
|
120
|
-
if repeats > 1:
|
|
121
|
-
shape[dim] *= repeats
|
|
122
|
-
|
|
123
|
-
buffer = torch.empty(shape, dtype=dtype, device=device)
|
|
124
|
-
return buffer
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def get_assigned_chunk(
|
|
128
|
-
tensor: torch.Tensor,
|
|
129
|
-
dim: int = 0,
|
|
130
|
-
idx: Optional[int] = None,
|
|
131
|
-
group=None,
|
|
132
|
-
) -> torch.Tensor:
|
|
133
|
-
if idx is None:
|
|
134
|
-
idx = get_rank(group)
|
|
135
|
-
world_size = get_world_size(group)
|
|
136
|
-
total_size = tensor.shape[dim]
|
|
137
|
-
assert (
|
|
138
|
-
total_size % world_size == 0
|
|
139
|
-
), f"tensor.shape[{dim}]={total_size} is not divisible by world_size={world_size}"
|
|
140
|
-
return tensor.chunk(world_size, dim=dim)[idx]
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
def get_complete_tensor(
|
|
144
|
-
tensor: torch.Tensor,
|
|
145
|
-
*,
|
|
146
|
-
dim: int = 0,
|
|
147
|
-
group=None,
|
|
148
|
-
) -> torch.Tensor:
|
|
149
|
-
tensor = tensor.transpose(0, dim).contiguous()
|
|
150
|
-
output_tensor = all_gather_tensor_sync(tensor, gather_dim=0, group=group)
|
|
151
|
-
output_tensor = output_tensor.transpose(0, dim)
|
|
152
|
-
return output_tensor
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|