cache-dit 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/primitives.py DELETED
@@ -1,152 +0,0 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/primitives.py
2
-
3
- from typing import List, Optional, Tuple, Union
4
-
5
- import torch
6
- import torch.distributed as dist
7
-
8
- if dist.is_available():
9
- import torch.distributed._functional_collectives as ft_c
10
- import torch.distributed.distributed_c10d as c10d
11
- else:
12
- ft_c = None
13
- c10d = None
14
-
15
-
16
- def get_group(group=None):
17
- if group is None:
18
- group = c10d._get_default_group()
19
-
20
- if isinstance(group, dist.ProcessGroup):
21
- pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = group
22
- else:
23
- pg = group.get_group()
24
-
25
- return pg
26
-
27
-
28
- def get_world_size(group=None):
29
- pg = get_group(group)
30
- return dist.get_world_size(pg)
31
-
32
-
33
- def get_rank(group=None):
34
- pg = get_group(group)
35
- return dist.get_rank(pg)
36
-
37
-
38
- def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor:
39
- """
40
- When tracing the code, the result tensor is not an AsyncCollectiveTensor,
41
- so we cannot call ``wait()``.
42
- """
43
- if isinstance(tensor, ft_c.AsyncCollectiveTensor):
44
- return tensor.wait()
45
- return tensor
46
-
47
-
48
- def all_gather_tensor_sync(x, *args, group=None, **kwargs):
49
- group = get_group(group)
50
- x_shape = x.shape
51
- x = x.flatten()
52
- x_numel = x.numel()
53
- x = ft_c.all_gather_tensor(x, *args, group=group, **kwargs)
54
- x = _maybe_wait(x)
55
- x_shape = list(x_shape)
56
- x_shape[0] *= x.numel() // x_numel
57
- x = x.reshape(x_shape)
58
- return x
59
-
60
-
61
- def all_gather_tensor_autograd_sync(x, *args, group=None, **kwargs):
62
- group = get_group(group)
63
- x_shape = x.shape
64
- x = x.flatten()
65
- x_numel = x.numel()
66
- x = ft_c.all_gather_tensor_autograd(x, *args, group=group, **kwargs)
67
- x = _maybe_wait(x)
68
- x_shape = list(x_shape)
69
- x_shape[0] *= x.numel() // x_numel
70
- x = x.reshape(x_shape)
71
- return x
72
-
73
-
74
- def all_to_all_single_sync(x, *args, **kwargs):
75
- x_shape = x.shape
76
- x = x.flatten()
77
- x = ft_c.all_to_all_single(x, *args, **kwargs)
78
- x = _maybe_wait(x)
79
- x = x.reshape(x_shape)
80
- return x
81
-
82
-
83
- def all_to_all_single_autograd_sync(x, *args, **kwargs):
84
- x_shape = x.shape
85
- x = x.flatten()
86
- x = ft_c.all_to_all_single_autograd(x, *args, **kwargs)
87
- x = _maybe_wait(x)
88
- x = x.reshape(x_shape)
89
- return x
90
-
91
-
92
- def all_reduce_sync(x, *args, group=None, **kwargs):
93
- group = get_group(group)
94
- x = ft_c.all_reduce(x, *args, group=group, **kwargs)
95
- x = _maybe_wait(x)
96
- return x
97
-
98
-
99
- def get_buffer(
100
- shape_or_tensor: Union[Tuple[int], torch.Tensor],
101
- *,
102
- repeats: int = 1,
103
- dim: int = 0,
104
- dtype: Optional[torch.dtype] = None,
105
- device: Optional[torch.device] = None,
106
- group=None,
107
- ) -> torch.Tensor:
108
- if repeats is None:
109
- repeats = get_world_size(group)
110
-
111
- if isinstance(shape_or_tensor, torch.Tensor):
112
- shape = shape_or_tensor.shape
113
- dtype = shape_or_tensor.dtype
114
- device = shape_or_tensor.device
115
-
116
- assert dtype is not None
117
- assert device is not None
118
-
119
- shape = list(shape)
120
- if repeats > 1:
121
- shape[dim] *= repeats
122
-
123
- buffer = torch.empty(shape, dtype=dtype, device=device)
124
- return buffer
125
-
126
-
127
- def get_assigned_chunk(
128
- tensor: torch.Tensor,
129
- dim: int = 0,
130
- idx: Optional[int] = None,
131
- group=None,
132
- ) -> torch.Tensor:
133
- if idx is None:
134
- idx = get_rank(group)
135
- world_size = get_world_size(group)
136
- total_size = tensor.shape[dim]
137
- assert (
138
- total_size % world_size == 0
139
- ), f"tensor.shape[{dim}]={total_size} is not divisible by world_size={world_size}"
140
- return tensor.chunk(world_size, dim=dim)[idx]
141
-
142
-
143
- def get_complete_tensor(
144
- tensor: torch.Tensor,
145
- *,
146
- dim: int = 0,
147
- group=None,
148
- ) -> torch.Tensor:
149
- tensor = tensor.transpose(0, dim).contiguous()
150
- output_tensor = all_gather_tensor_sync(tensor, gather_dim=0, group=group)
151
- output_tensor = output_tensor.transpose(0, dim)
152
- return output_tensor