blksprs 1.10.1__py3-none-any.whl → 1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  from torch import Tensor
3
3
 
4
- from blksprs.ops.flow import flow_forward
4
+ from blksprs.ops.flow import flow_forward_pull
5
5
  from blksprs.utils.blksprs_tensor import BlksprsTensor
6
6
 
7
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
@@ -9,7 +9,8 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
9
9
 
10
10
 
11
11
  def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
12
- dim: int, sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
12
+ dim: int, sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> (
13
+ BlksprsTensor, Tensor):
13
14
  """Splits a block-sparse tensor in compressed form along the last dimension into partitions.
14
15
 
15
16
  Args:
@@ -19,6 +20,7 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
19
20
  dim (int): The dimension along which to split the tensor. Currently only supports dim=2.
20
21
  sparsity_block_size (int): The size of the sparsity blocks.
21
22
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
23
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
22
24
 
23
25
  Returns:
24
26
  BlksprsTensor: The block-sparse tensor split into partitions in compressed form.
@@ -34,46 +36,66 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
34
36
  validate_sparsity_block_size(sparsity_block_size, x)
35
37
  validate_triton_block_size(triton_block_size, sparsity_block_size)
36
38
 
37
- sparsity_layout_output = (sparsity_layout
38
- .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
39
- sparsity_layout.size(2) // partitions)
40
- .permute(0, 2, 1, 3)
41
- .reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
42
- sparsity_layout.size(2) // partitions).contiguous())
43
-
44
- sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
45
-
46
- sparsity_layout_flat = sparsity_layout.reshape(-1)
47
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
48
- (sparsity_layout_flat == 1) -
49
- (1 * (sparsity_layout_flat == 0)))
50
- .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
51
- sparsity_layout.size(2) // partitions)
52
- .permute(0, 2, 1, 3).reshape(-1).contiguous())
53
-
54
- n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
55
-
56
- validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
57
-
58
39
  adjusted_dim = dim % 3
59
40
  if adjusted_dim != 2:
60
41
  raise NotImplementedError("Currently only supports dim=2")
61
42
 
62
- return BlksprsTensor(_BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
63
- adjusted_dim, sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
43
+ lut = _BlocksparseSplit.build_lut(lut, sparsity_layout, partitions)
44
+
45
+ return BlksprsTensor(
46
+ _BlocksparseSplit.apply(x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
47
+ partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"],
48
+ triton_block_size)), lut["sparsity_layout_output"]
64
49
 
65
50
 
66
51
  class _BlocksparseSplit(torch.autograd.Function):
67
52
 
53
+ @staticmethod
54
+ def build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
55
+ if lut is None:
56
+ lut = dict()
57
+
58
+ if "sparsity_layout_output" not in lut:
59
+ sparsity_layout_output = (sparsity_layout
60
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
61
+ sparsity_layout.size(2) // partitions)
62
+ .permute(0, 2, 1, 3)
63
+ .reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
64
+ sparsity_layout.size(2) // partitions).contiguous())
65
+ lut["sparsity_layout_output"] = sparsity_layout_output
66
+
67
+ if "sparsity_lut" not in lut:
68
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
69
+ lut["sparsity_lut"] = sparsity_lut
70
+
71
+ if "sparsity_reverse_lut" not in lut:
72
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
73
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
74
+ (sparsity_layout_flat == 1) -
75
+ (1 * (sparsity_layout_flat == 0)))
76
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
77
+ sparsity_layout.size(2) // partitions)
78
+ .permute(0, 2, 1, 3).reshape(-1).contiguous())
79
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
80
+
81
+ if "n_sparse_blocks" not in lut:
82
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
83
+ lut["n_sparse_blocks"] = n_sparse_blocks
84
+
85
+ validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
86
+
87
+ return lut
88
+
68
89
  @staticmethod
69
90
  def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
70
- num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
91
+ num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int,
92
+ triton_block_size: int) -> Tensor:
71
93
  ctx.save_for_backward(sparsity_layout_o)
72
94
  ctx.num_partitions = num_partitions
73
95
  ctx.dim = dim
74
96
 
75
- return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
76
- n_sparse_blocks, triton_block_size)
97
+ return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
98
+ n_sparse_blocks, triton_block_size)
77
99
 
78
100
  @staticmethod
79
101
  def backward(ctx, grad_output):
@@ -88,7 +110,8 @@ class _BlocksparseSplit(torch.autograd.Function):
88
110
 
89
111
 
90
112
  def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
91
- dim: int, sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
113
+ dim: int, sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> (
114
+ BlksprsTensor, Tensor):
92
115
  """Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
93
116
 
94
117
  Args:
@@ -98,6 +121,7 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
98
121
  dim (int): The dimension along which to merge the tensor. Currently only supports dim=2.
99
122
  sparsity_block_size (int): The size of the sparsity blocks.
100
123
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
124
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
101
125
 
102
126
  Returns:
103
127
  BlksprsTensor: The merged block-sparse tensor in compressed form.
@@ -113,48 +137,69 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
113
137
  validate_sparsity_block_size(sparsity_block_size, x)
114
138
  validate_triton_block_size(triton_block_size, sparsity_block_size)
115
139
 
116
- sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
117
- sparsity_layout.size(1), sparsity_layout.size(2))
118
- .permute(0, 2, 1, 3)
119
- .reshape(sparsity_layout.size(0) // partitions,
120
- sparsity_layout.size(1), sparsity_layout.size(2) * partitions).contiguous())
121
-
122
- sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
123
-
124
- sparsity_layout_flat = sparsity_layout.reshape(-1)
125
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
126
- (sparsity_layout_flat == 1) -
127
- (1 * (sparsity_layout_flat == 0)))
128
- .reshape(sparsity_layout.size(0) // partitions, partitions,
129
- sparsity_layout.size(1), sparsity_layout.size(2))
130
- .permute(0, 2, 1, 3)
131
- .reshape(sparsity_layout.size(0) // partitions,
132
- sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
133
- .reshape(-1).contiguous())
134
-
135
- n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
136
-
137
- validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
138
-
139
140
  adjusted_dim = dim % 3
140
141
  if adjusted_dim != 2:
141
142
  raise NotImplementedError("Currently only supports dim=2")
142
143
 
143
- return BlksprsTensor(_BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
144
- adjusted_dim, sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
144
+ lut = _BlocksparseMerge.build_lut(lut, sparsity_layout, partitions)
145
+
146
+ return BlksprsTensor(
147
+ _BlocksparseMerge.apply(x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
148
+ partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"],
149
+ triton_block_size)), lut["sparsity_layout_output"]
145
150
 
146
151
 
147
152
  class _BlocksparseMerge(torch.autograd.Function):
148
153
 
154
+ @staticmethod
155
+ def build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
156
+ if lut is None:
157
+ lut = dict()
158
+
159
+ if "sparsity_layout_output" not in lut:
160
+ sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
161
+ sparsity_layout.size(1), sparsity_layout.size(2))
162
+ .permute(0, 2, 1, 3)
163
+ .reshape(sparsity_layout.size(0) // partitions,
164
+ sparsity_layout.size(1),
165
+ sparsity_layout.size(2) * partitions).contiguous())
166
+ lut["sparsity_layout_output"] = sparsity_layout_output
167
+
168
+ if "sparsity_lut" not in lut:
169
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
170
+ lut["sparsity_lut"] = sparsity_lut
171
+
172
+ if "sparsity_reverse_lut" not in lut:
173
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
174
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
175
+ (sparsity_layout_flat == 1) -
176
+ (1 * (sparsity_layout_flat == 0)))
177
+ .reshape(sparsity_layout.size(0) // partitions, partitions,
178
+ sparsity_layout.size(1), sparsity_layout.size(2))
179
+ .permute(0, 2, 1, 3)
180
+ .reshape(sparsity_layout.size(0) // partitions,
181
+ sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
182
+ .reshape(-1).contiguous())
183
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
184
+
185
+ if "n_sparse_blocks" not in lut:
186
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
187
+ lut["n_sparse_blocks"] = n_sparse_blocks
188
+
189
+ validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
190
+
191
+ return lut
192
+
149
193
  @staticmethod
150
194
  def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
151
- num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
195
+ num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int,
196
+ triton_block_size: int) -> Tensor:
152
197
  ctx.save_for_backward(sparsity_layout_o)
153
198
  ctx.num_partitions = num_partitions
154
199
  ctx.dim = dim
155
200
 
156
- return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
157
- n_sparse_blocks, triton_block_size)
201
+ return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
202
+ n_sparse_blocks, triton_block_size)
158
203
 
159
204
  @staticmethod
160
205
  def backward(ctx, grad_output):
@@ -166,5 +211,3 @@ class _BlocksparseMerge(torch.autograd.Function):
166
211
 
167
212
  return split(grad_output, sparsity_layout, num_partitions, dim,
168
213
  sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None, None
169
-
170
-
blksprs/ops/repeat.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
2
  import triton
3
3
  from torch import Tensor
4
4
 
5
- from blksprs.ops.flow import kernel_blocksparse_flow_push, flow_forward
5
+ from blksprs.ops.flow import kernel_blocksparse_flow_push, flow_forward_pull, flow_forward_push
6
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
7
  from blksprs.utils.tools import get_triton_block_size, stride
8
8
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
@@ -10,7 +10,8 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
10
10
 
11
11
 
12
12
  def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
13
- sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None) -> (
13
+ sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None,
14
+ lut: dict = None) -> (
14
15
  BlksprsTensor, Tensor):
15
16
  """Repeats a block-spare tensor in compressed form according to the given repeats.
16
17
 
@@ -30,6 +31,7 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
30
31
  sparsity_block_size (int): The size of the sparsity blocks.
31
32
  sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
32
33
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
34
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
33
35
 
34
36
  Returns:
35
37
  BlksprsTensor: A block-sparse tensor in compressed form containing the repeated values.
@@ -45,33 +47,17 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
45
47
  validate_sparsity_block_size(sparsity_block_size, x)
46
48
  validate_triton_block_size(triton_block_size, sparsity_block_size)
47
49
 
48
- sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
49
-
50
- if sparsity_layout_output is not None:
51
- sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
52
-
53
- sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
54
-
55
- sparsity_layout_flat = sparsity_layout_x.reshape(-1)
56
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
57
- (sparsity_layout_flat == 1) -
58
- (1 * (sparsity_layout_flat == 0)))
59
- .reshape(sparsity_layout_x.size())
60
- .repeat(repeats[0], repeats[1], repeats[2])
61
- .reshape(-1).contiguous())
62
-
63
- n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
64
-
65
- validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
50
+ lut = _BlocksparseRepeat.build_lut_repeat(lut, sparsity_layout_x, repeats, sparsity_layout_output)
66
51
 
67
52
  return BlksprsTensor(
68
- _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
69
- sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
53
+ _BlocksparseRepeat.apply(x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
54
+ lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"],
55
+ triton_block_size)), lut["sparsity_layout_o"]
70
56
 
71
57
 
72
58
  def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
73
59
  sparsity_block_size: int, sparsity_layout_output: Tensor = None,
74
- triton_block_size: int = None) -> (
60
+ triton_block_size: int = None, lut: dict = None) -> (
75
61
  BlksprsTensor, Tensor):
76
62
  """Repeats and interleaves the block-sparse tensor in compressed form.
77
63
 
@@ -89,6 +75,7 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
89
75
  sparsity_block_size (int): The size of the sparsity blocks.
90
76
  sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
91
77
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
78
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
92
79
 
93
80
  Returns:
94
81
  BlksprsTensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
@@ -104,31 +91,87 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
104
91
  validate_sparsity_block_size(sparsity_block_size, x)
105
92
  validate_triton_block_size(triton_block_size, sparsity_block_size)
106
93
 
107
- sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
94
+ lut = _BlocksparseRepeat.build_lut_repeat_interleave(lut, sparsity_layout_x, repeats, sparsity_layout_output)
108
95
 
109
- if sparsity_layout_output is not None:
110
- sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
96
+ return BlksprsTensor(
97
+ _BlocksparseRepeat.apply(x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
98
+ lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"],
99
+ triton_block_size)), lut["sparsity_layout_o"]
111
100
 
112
- sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
113
101
 
114
- sparsity_layout_flat = sparsity_layout_x.reshape(-1)
115
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
116
- (sparsity_layout_flat == 1) -
117
- (1 * (sparsity_layout_flat == 0)))
118
- .reshape(sparsity_layout_x.size())
119
- .repeat_interleave(repeats, dim=0)
120
- .reshape(-1).contiguous())
102
+ class _BlocksparseRepeat(torch.autograd.Function):
121
103
 
122
- n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
104
+ @staticmethod
105
+ def build_lut_repeat(lut: dict, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
106
+ sparsity_layout_output: Tensor):
107
+ if lut is None:
108
+ lut = dict()
123
109
 
124
- validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
110
+ if "sparsity_layout_o" not in lut:
111
+ sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
112
+ lut["sparsity_layout_o"] = sparsity_layout_o
125
113
 
126
- return BlksprsTensor(
127
- _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
128
- sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
114
+ if sparsity_layout_output is not None:
115
+ sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
116
+ lut["sparsity_layout_o"] = sparsity_layout_o
129
117
 
118
+ if "sparsity_lut" not in lut:
119
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
120
+ lut["sparsity_lut"] = sparsity_lut
130
121
 
131
- class _BlocksparseRepeat(torch.autograd.Function):
122
+ if "sparsity_reverse_lut" not in lut:
123
+ sparsity_layout_flat = sparsity_layout_x.reshape(-1)
124
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
125
+ (sparsity_layout_flat == 1) -
126
+ (1 * (sparsity_layout_flat == 0)))
127
+ .reshape(sparsity_layout_x.size())
128
+ .repeat(repeats[0], repeats[1], repeats[2])
129
+ .reshape(-1).contiguous())
130
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
131
+
132
+ if "n_sparse_blocks" not in lut:
133
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
134
+ lut["n_sparse_blocks"] = n_sparse_blocks
135
+
136
+ validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
137
+
138
+ return lut
139
+
140
+ @staticmethod
141
+ def build_lut_repeat_interleave(lut: dict, sparsity_layout_x: Tensor, repeats: int,
142
+ sparsity_layout_output: Tensor):
143
+ if lut is None:
144
+ lut = dict()
145
+
146
+ if "sparsity_layout_o" not in lut:
147
+ sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
148
+ lut["sparsity_layout_o"] = sparsity_layout_o
149
+
150
+ if sparsity_layout_output is not None:
151
+ sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
152
+ lut["sparsity_layout_o"] = sparsity_layout_o
153
+
154
+ if "sparsity_lut" not in lut:
155
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
156
+ lut["sparsity_lut"] = sparsity_lut
157
+
158
+ if "sparsity_reverse_lut" not in lut:
159
+ sparsity_layout_flat = sparsity_layout_x.reshape(-1)
160
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
161
+ (sparsity_layout_flat == 1) -
162
+ (1 * (sparsity_layout_flat == 0)))
163
+ .reshape(sparsity_layout_x.size())
164
+ .repeat_interleave(repeats, dim=0)
165
+ .reshape(-1).contiguous())
166
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
167
+
168
+ if "n_sparse_blocks" not in lut:
169
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
170
+ lut["n_sparse_blocks"] = n_sparse_blocks
171
+
172
+ validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
173
+
174
+ return lut
132
175
 
133
176
  @staticmethod
134
177
  def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
@@ -136,49 +179,18 @@ class _BlocksparseRepeat(torch.autograd.Function):
136
179
  sparsity_block_size: int, n_sparse_blocks: int,
137
180
  triton_block_size: int) -> Tensor:
138
181
  ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
139
- ctx.x_size = x.size()
140
- ctx.x_stride = stride(x)
141
182
 
142
- return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
143
- n_sparse_blocks, triton_block_size)
183
+ return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
184
+ n_sparse_blocks, triton_block_size)
144
185
 
145
186
  @staticmethod
146
187
  def backward(ctx, grad_output):
147
188
  sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
148
- x_size = ctx.x_size
149
- x_stride = ctx.x_stride
150
189
  sparsity_block_size = ctx.sparsity_block_size
151
190
  triton_block_size = ctx.triton_block_size
152
191
 
153
192
  n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
154
193
 
155
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
156
- dtype=grad_output.dtype, device=grad_output.device)
157
-
158
- x_b, x_r, x_c = grad_output.size()
159
- x_b_s, x_r_s, x_c_s = stride(grad_output)
160
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_o.size()
161
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_o)
162
- s_lut_r, s_lut_c = sparsity_lut.size()
163
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
164
- o_b, o_r, o_c = x_size
165
- o_b_s, o_r_s, o_c_s = x_stride
166
-
167
- if triton_block_size is None:
168
- triton_block_size = get_triton_block_size(sparsity_block_size)
169
-
170
- triton_grid = lambda meta: [x_b,
171
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
172
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
173
-
174
- (kernel_blocksparse_flow_push[triton_grid]
175
- (grad_output,
176
- x_b, x_b_s, x_r_s, x_c_s,
177
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
178
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
179
- sparsity_reverse_lut,
180
- output,
181
- o_b, o_b_s, o_r_s, o_c_s,
182
- triton_block_size))
183
-
184
- return output, None, None, None, None, None, None, None
194
+ return flow_forward_push(None, grad_output, sparsity_layout_o, sparsity_lut,
195
+ sparsity_reverse_lut, sparsity_block_size, n_sparse_blocks,
196
+ triton_block_size), None, None, None, None, None, None, None