blksprs 1.10.2__py3-none-any.whl → 2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,16 @@
1
1
  import torch
2
2
  from torch import Tensor
3
+ from torch._library import triton_op
3
4
 
4
- from blksprs.ops.flow import flow_forward
5
+ from blksprs.ops.flow import flow_pull_forward
5
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
6
-
7
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
+ validate_sparsity, validate_sparsity_block_size
9
9
 
10
10
 
11
11
  def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
12
- dim: int, sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
12
+ dim: int, sparsity_block_size: int, lut: dict = None) -> (
13
+ BlksprsTensor, Tensor):
13
14
  """Splits a block-sparse tensor in compressed form along the last dimension into partitions.
14
15
 
15
16
  Args:
@@ -18,7 +19,7 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
18
19
  partitions (int): The number of partitions to split the block-sparse tensor into.
19
20
  dim (int): The dimension along which to split the tensor. Currently only supports dim=2.
20
21
  sparsity_block_size (int): The size of the sparsity blocks.
21
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
22
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
22
23
 
23
24
  Returns:
24
25
  BlksprsTensor: The block-sparse tensor split into partitions in compressed form.
@@ -32,63 +33,87 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
32
33
  validate_device(x)
33
34
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
34
35
  validate_sparsity_block_size(sparsity_block_size, x)
35
- validate_triton_block_size(triton_block_size, sparsity_block_size)
36
36
 
37
- sparsity_layout_output = (sparsity_layout
38
- .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
39
- sparsity_layout.size(2) // partitions)
40
- .permute(0, 2, 1, 3)
41
- .reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
42
- sparsity_layout.size(2) // partitions).contiguous())
37
+ adjusted_dim = dim % 3
38
+ if adjusted_dim != 2:
39
+ raise NotImplementedError("Currently only supports dim=2")
43
40
 
44
- sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
41
+ lut = split_build_lut(lut, sparsity_layout, partitions)
45
42
 
46
- sparsity_layout_flat = sparsity_layout.reshape(-1)
47
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
48
- (sparsity_layout_flat == 1) -
49
- (1 * (sparsity_layout_flat == 0)))
50
- .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
51
- sparsity_layout.size(2) // partitions)
52
- .permute(0, 2, 1, 3).reshape(-1).contiguous())
43
+ return BlksprsTensor(split_forward(
44
+ x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
45
+ partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
53
46
 
54
- n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
55
47
 
56
- validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
48
+ @triton_op("blksprs::split", mutates_args={})
49
+ def split_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
50
+ _: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
51
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
52
+ n_sparse_blocks)
57
53
 
58
- adjusted_dim = dim % 3
59
- if adjusted_dim != 2:
60
- raise NotImplementedError("Currently only supports dim=2")
61
54
 
62
- return BlksprsTensor(_BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
63
- adjusted_dim, sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
55
+ def split_backward(ctx, grad_output):
56
+ sparsity_layout = ctx.saved_tensors[0]
57
+ num_partitions = ctx.num_partitions
58
+ dim = ctx.dim
59
+ sparsity_block_size = ctx.sparsity_block_size
60
+
61
+ return merge(grad_output, sparsity_layout, num_partitions, dim,
62
+ sparsity_block_size)[0], None, None, None, None, None, None, None
63
+
64
+
65
+ def split_build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
66
+ if lut is None:
67
+ lut = dict()
68
+
69
+ if "sparsity_layout_output" not in lut:
70
+ sparsity_layout_output = (sparsity_layout
71
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
72
+ sparsity_layout.size(2) // partitions)
73
+ .permute(0, 2, 1, 3)
74
+ .reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
75
+ sparsity_layout.size(2) // partitions).contiguous())
76
+ lut["sparsity_layout_output"] = sparsity_layout_output
77
+
78
+ if "sparsity_lut" not in lut:
79
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
80
+ lut["sparsity_lut"] = sparsity_lut
81
+
82
+ if "sparsity_reverse_lut" not in lut:
83
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
84
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
85
+ (sparsity_layout_flat == 1) -
86
+ (1 * (sparsity_layout_flat == 0)))
87
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
88
+ sparsity_layout.size(2) // partitions)
89
+ .permute(0, 2, 1, 3).reshape(-1).contiguous())
90
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
91
+
92
+ if "n_sparse_blocks" not in lut:
93
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
94
+ lut["n_sparse_blocks"] = n_sparse_blocks
64
95
 
96
+ validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
65
97
 
66
- class _BlocksparseSplit(torch.autograd.Function):
98
+ return lut
67
99
 
68
- @staticmethod
69
- def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
70
- num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
71
- ctx.save_for_backward(sparsity_layout_o)
72
- ctx.num_partitions = num_partitions
73
- ctx.dim = dim
74
100
 
75
- return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
76
- n_sparse_blocks, triton_block_size)
101
+ # noinspection PyUnusedLocal
102
+ def split_setup_context(ctx, inputs, output):
103
+ (_, sparsity_layout_o, _, _, num_partitions, dim, sparsity_block_size, _) = inputs
77
104
 
78
- @staticmethod
79
- def backward(ctx, grad_output):
80
- sparsity_layout = ctx.saved_tensors[0]
81
- num_partitions = ctx.num_partitions
82
- dim = ctx.dim
83
- sparsity_block_size = ctx.sparsity_block_size
84
- triton_block_size = ctx.triton_block_size
105
+ ctx.save_for_backward(sparsity_layout_o)
106
+ ctx.num_partitions = num_partitions
107
+ ctx.dim = dim
108
+ ctx.sparsity_block_size = sparsity_block_size
85
109
 
86
- return merge(grad_output, sparsity_layout, num_partitions, dim,
87
- sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None, None
110
+
111
+ split_forward.register_autograd(split_backward, setup_context=split_setup_context)
88
112
 
89
113
 
90
114
  def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
91
- dim: int, sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
115
+ dim: int, sparsity_block_size: int, lut: dict = None) -> (
116
+ BlksprsTensor, Tensor):
92
117
  """Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
93
118
 
94
119
  Args:
@@ -97,7 +122,7 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
97
122
  partitions (int): The number of partitions to be merged.
98
123
  dim (int): The dimension along which to merge the tensor. Currently only supports dim=2.
99
124
  sparsity_block_size (int): The size of the sparsity blocks.
100
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
125
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
101
126
 
102
127
  Returns:
103
128
  BlksprsTensor: The merged block-sparse tensor in compressed form.
@@ -111,60 +136,82 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
111
136
  validate_device(x)
112
137
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
113
138
  validate_sparsity_block_size(sparsity_block_size, x)
114
- validate_triton_block_size(triton_block_size, sparsity_block_size)
115
139
 
116
- sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
117
- sparsity_layout.size(1), sparsity_layout.size(2))
118
- .permute(0, 2, 1, 3)
119
- .reshape(sparsity_layout.size(0) // partitions,
120
- sparsity_layout.size(1), sparsity_layout.size(2) * partitions).contiguous())
140
+ adjusted_dim = dim % 3
141
+ if adjusted_dim != 2:
142
+ raise NotImplementedError("Currently only supports dim=2")
121
143
 
122
- sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
144
+ lut = merge_build_lut(lut, sparsity_layout, partitions)
123
145
 
124
- sparsity_layout_flat = sparsity_layout.reshape(-1)
125
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
126
- (sparsity_layout_flat == 1) -
127
- (1 * (sparsity_layout_flat == 0)))
128
- .reshape(sparsity_layout.size(0) // partitions, partitions,
129
- sparsity_layout.size(1), sparsity_layout.size(2))
130
- .permute(0, 2, 1, 3)
131
- .reshape(sparsity_layout.size(0) // partitions,
132
- sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
133
- .reshape(-1).contiguous())
146
+ return BlksprsTensor(merge_forward(
147
+ x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
148
+ partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
134
149
 
135
- n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
136
150
 
137
- validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
151
+ @triton_op("blksprs::merge", mutates_args={})
152
+ def merge_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
153
+ _: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
154
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
155
+ n_sparse_blocks)
138
156
 
139
- adjusted_dim = dim % 3
140
- if adjusted_dim != 2:
141
- raise NotImplementedError("Currently only supports dim=2")
142
157
 
143
- return BlksprsTensor(_BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
144
- adjusted_dim, sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
158
+ def merge_backward(ctx, grad_output):
159
+ sparsity_layout = ctx.saved_tensors[0]
160
+ num_partitions = ctx.num_partitions
161
+ dim = ctx.dim
162
+ sparsity_block_size = ctx.sparsity_block_size
163
+
164
+ return split(grad_output, sparsity_layout, num_partitions, dim,
165
+ sparsity_block_size)[0], None, None, None, None, None, None, None
166
+
167
+
168
+ def merge_build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
169
+ if lut is None:
170
+ lut = dict()
171
+
172
+ if "sparsity_layout_output" not in lut:
173
+ sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
174
+ sparsity_layout.size(1), sparsity_layout.size(2))
175
+ .permute(0, 2, 1, 3)
176
+ .reshape(sparsity_layout.size(0) // partitions,
177
+ sparsity_layout.size(1),
178
+ sparsity_layout.size(2) * partitions).contiguous())
179
+ lut["sparsity_layout_output"] = sparsity_layout_output
180
+
181
+ if "sparsity_lut" not in lut:
182
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
183
+ lut["sparsity_lut"] = sparsity_lut
184
+
185
+ if "sparsity_reverse_lut" not in lut:
186
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
187
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
188
+ (sparsity_layout_flat == 1) -
189
+ (1 * (sparsity_layout_flat == 0)))
190
+ .reshape(sparsity_layout.size(0) // partitions, partitions,
191
+ sparsity_layout.size(1), sparsity_layout.size(2))
192
+ .permute(0, 2, 1, 3)
193
+ .reshape(sparsity_layout.size(0) // partitions,
194
+ sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
195
+ .reshape(-1).contiguous())
196
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
145
197
 
198
+ if "n_sparse_blocks" not in lut:
199
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
200
+ lut["n_sparse_blocks"] = n_sparse_blocks
146
201
 
147
- class _BlocksparseMerge(torch.autograd.Function):
202
+ validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
148
203
 
149
- @staticmethod
150
- def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
151
- num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
152
- ctx.save_for_backward(sparsity_layout_o)
153
- ctx.num_partitions = num_partitions
154
- ctx.dim = dim
204
+ return lut
155
205
 
156
- return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
157
- n_sparse_blocks, triton_block_size)
158
206
 
159
- @staticmethod
160
- def backward(ctx, grad_output):
161
- sparsity_layout = ctx.saved_tensors[0]
162
- num_partitions = ctx.num_partitions
163
- dim = ctx.dim
164
- sparsity_block_size = ctx.sparsity_block_size
165
- triton_block_size = ctx.triton_block_size
207
+ # noinspection PyUnusedLocal
208
+ def merge_setup_context(ctx, inputs, output):
209
+ (_, sparsity_layout_o, _, _, num_partitions, dim, sparsity_block_size, _) = inputs
166
210
 
167
- return split(grad_output, sparsity_layout, num_partitions, dim,
168
- sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None, None
211
+ ctx.save_for_backward(sparsity_layout_o)
212
+ ctx.num_partitions = num_partitions
213
+ ctx.dim = dim
214
+ ctx.sparsity_block_size = sparsity_block_size
169
215
 
170
216
 
217
+ merge_forward.register_autograd(merge_backward, setup_context=merge_setup_context)
blksprs/ops/repeat.py CHANGED
@@ -1,16 +1,15 @@
1
1
  import torch
2
- import triton
3
2
  from torch import Tensor
3
+ from torch._library import triton_op
4
4
 
5
- from blksprs.ops.flow import kernel_blocksparse_flow_push, flow_forward
5
+ from blksprs.ops.flow import flow_pull_forward, flow_push_forward
6
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
8
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
9
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
+ validate_sparsity, validate_sparsity_block_size
10
9
 
11
10
 
12
11
  def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
13
- sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None) -> (
12
+ sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
14
13
  BlksprsTensor, Tensor):
15
14
  """Repeats a block-spare tensor in compressed form according to the given repeats.
16
15
 
@@ -29,7 +28,7 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
29
28
  third dimension respectively.
30
29
  sparsity_block_size (int): The size of the sparsity blocks.
31
30
  sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
32
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
31
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
33
32
 
34
33
  Returns:
35
34
  BlksprsTensor: A block-sparse tensor in compressed form containing the repeated values.
@@ -43,35 +42,16 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
43
42
  validate_device(x)
44
43
  validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
45
44
  validate_sparsity_block_size(sparsity_block_size, x)
46
- validate_triton_block_size(triton_block_size, sparsity_block_size)
47
45
 
48
- sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
46
+ lut = repeat_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
49
47
 
50
- if sparsity_layout_output is not None:
51
- sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
52
-
53
- sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
54
-
55
- sparsity_layout_flat = sparsity_layout_x.reshape(-1)
56
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
57
- (sparsity_layout_flat == 1) -
58
- (1 * (sparsity_layout_flat == 0)))
59
- .reshape(sparsity_layout_x.size())
60
- .repeat(repeats[0], repeats[1], repeats[2])
61
- .reshape(-1).contiguous())
62
-
63
- n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
64
-
65
- validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
66
-
67
- return BlksprsTensor(
68
- _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
69
- sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
48
+ return BlksprsTensor(repeat_forward(
49
+ x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
50
+ lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
70
51
 
71
52
 
72
53
  def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
73
- sparsity_block_size: int, sparsity_layout_output: Tensor = None,
74
- triton_block_size: int = None) -> (
54
+ sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
75
55
  BlksprsTensor, Tensor):
76
56
  """Repeats and interleaves the block-sparse tensor in compressed form.
77
57
 
@@ -88,7 +68,7 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
88
68
  repeats (int): The number of times to repeat the matrices.
89
69
  sparsity_block_size (int): The size of the sparsity blocks.
90
70
  sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
91
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
71
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
92
72
 
93
73
  Returns:
94
74
  BlksprsTensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
@@ -102,83 +82,110 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
102
82
  validate_device(x)
103
83
  validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
104
84
  validate_sparsity_block_size(sparsity_block_size, x)
105
- validate_triton_block_size(triton_block_size, sparsity_block_size)
106
85
 
107
- sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
86
+ lut = repeat_interleave_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
87
+
88
+ return BlksprsTensor(repeat_forward(
89
+ x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
90
+ lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
91
+
92
+
93
+ @triton_op("blksprs::repeat", mutates_args={})
94
+ def repeat_forward(x: Tensor, _: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
95
+ sparsity_reverse_lut: Tensor,
96
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
97
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
98
+ n_sparse_blocks)
99
+
100
+
101
+ def repeat_backward(ctx, grad_output):
102
+ sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
103
+ sparsity_block_size = ctx.sparsity_block_size
104
+ n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
105
+
106
+ return flow_push_forward(grad_output, sparsity_layout_o, sparsity_lut,
107
+ sparsity_reverse_lut, sparsity_block_size,
108
+ n_sparse_blocks), None, None, None, None, None, None
109
+
110
+
111
+ def repeat_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
112
+ sparsity_layout_output: Tensor):
113
+ if lut is None:
114
+ lut = dict()
115
+
116
+ if "sparsity_layout_o" not in lut:
117
+ sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
118
+ lut["sparsity_layout_o"] = sparsity_layout_o
108
119
 
109
120
  if sparsity_layout_output is not None:
110
- sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
111
-
112
- sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
113
-
114
- sparsity_layout_flat = sparsity_layout_x.reshape(-1)
115
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
116
- (sparsity_layout_flat == 1) -
117
- (1 * (sparsity_layout_flat == 0)))
118
- .reshape(sparsity_layout_x.size())
119
- .repeat_interleave(repeats, dim=0)
120
- .reshape(-1).contiguous())
121
-
122
- n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
123
-
124
- validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
125
-
126
- return BlksprsTensor(
127
- _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
128
- sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
129
-
130
-
131
- class _BlocksparseRepeat(torch.autograd.Function):
132
-
133
- @staticmethod
134
- def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
135
- sparsity_reverse_lut: Tensor,
136
- sparsity_block_size: int, n_sparse_blocks: int,
137
- triton_block_size: int) -> Tensor:
138
- ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
139
- ctx.x_size = x.size()
140
- ctx.x_stride = stride(x)
141
-
142
- return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
143
- n_sparse_blocks, triton_block_size)
144
-
145
- @staticmethod
146
- def backward(ctx, grad_output):
147
- sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
148
- x_size = ctx.x_size
149
- x_stride = ctx.x_stride
150
- sparsity_block_size = ctx.sparsity_block_size
151
- triton_block_size = ctx.triton_block_size
152
-
153
- n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
154
-
155
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
156
- dtype=grad_output.dtype, device=grad_output.device)
157
-
158
- x_b, x_r, x_c = grad_output.size()
159
- x_b_s, x_r_s, x_c_s = stride(grad_output)
160
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_o.size()
161
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_o)
162
- s_lut_r, s_lut_c = sparsity_lut.size()
163
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
164
- o_b, o_r, o_c = x_size
165
- o_b_s, o_r_s, o_c_s = x_stride
166
-
167
- if triton_block_size is None:
168
- triton_block_size = get_triton_block_size(sparsity_block_size)
169
-
170
- triton_grid = lambda meta: [x_b,
171
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
172
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
173
-
174
- (kernel_blocksparse_flow_push[triton_grid]
175
- (grad_output,
176
- x_b, x_b_s, x_r_s, x_c_s,
177
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
178
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
179
- sparsity_reverse_lut,
180
- output,
181
- o_b, o_b_s, o_r_s, o_c_s,
182
- triton_block_size))
183
-
184
- return output, None, None, None, None, None, None, None
121
+ sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
122
+ lut["sparsity_layout_o"] = sparsity_layout_o
123
+
124
+ if "sparsity_lut" not in lut:
125
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
126
+ lut["sparsity_lut"] = sparsity_lut
127
+
128
+ if "sparsity_reverse_lut" not in lut:
129
+ sparsity_layout_flat = sparsity_layout_x.reshape(-1)
130
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
131
+ (sparsity_layout_flat == 1) -
132
+ (1 * (sparsity_layout_flat == 0)))
133
+ .reshape(sparsity_layout_x.size())
134
+ .repeat(repeats[0], repeats[1], repeats[2])
135
+ .reshape(-1).contiguous())
136
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
137
+
138
+ if "n_sparse_blocks" not in lut:
139
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
140
+ lut["n_sparse_blocks"] = n_sparse_blocks
141
+
142
+ validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
143
+
144
+ return lut
145
+
146
+
147
+ def repeat_interleave_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: int,
148
+ sparsity_layout_output: Tensor):
149
+ if lut is None:
150
+ lut = dict()
151
+
152
+ if "sparsity_layout_o" not in lut:
153
+ sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
154
+ lut["sparsity_layout_o"] = sparsity_layout_o
155
+
156
+ if sparsity_layout_output is not None:
157
+ sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
158
+ lut["sparsity_layout_o"] = sparsity_layout_o
159
+
160
+ if "sparsity_lut" not in lut:
161
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
162
+ lut["sparsity_lut"] = sparsity_lut
163
+
164
+ if "sparsity_reverse_lut" not in lut:
165
+ sparsity_layout_flat = sparsity_layout_x.reshape(-1)
166
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
167
+ (sparsity_layout_flat == 1) -
168
+ (1 * (sparsity_layout_flat == 0)))
169
+ .reshape(sparsity_layout_x.size())
170
+ .repeat_interleave(repeats, dim=0)
171
+ .reshape(-1).contiguous())
172
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
173
+
174
+ if "n_sparse_blocks" not in lut:
175
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
176
+ lut["n_sparse_blocks"] = n_sparse_blocks
177
+
178
+ validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
179
+
180
+ return lut
181
+
182
+
183
+ # noinspection PyUnusedLocal
184
+ def repeat_setup_context(ctx, inputs, output):
185
+ (_, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size, _) = inputs
186
+
187
+ ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
188
+ ctx.sparsity_block_size = sparsity_block_size
189
+
190
+
191
+ repeat_forward.register_autograd(repeat_backward, setup_context=repeat_setup_context)