blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,17 @@
1
1
  import torch
2
2
  from torch import Tensor
3
+ from torch._library import triton_op
3
4
 
4
- from blksprs.ops.flow import flow_forward
5
+ from blksprs.ops.flow import flow_pull_forward
5
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
6
-
7
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
+ validate_sparsity, validate_sparsity_block_size
9
9
 
10
10
 
11
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
11
12
  def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
12
- dim: int, sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
13
+ dim: int, sparsity_block_size: int, lut: dict = None) -> (
14
+ BlksprsTensor, Tensor):
13
15
  """Splits a block-sparse tensor in compressed form along the last dimension into partitions.
14
16
 
15
17
  Args:
@@ -18,7 +20,7 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
18
20
  partitions (int): The number of partitions to split the block-sparse tensor into.
19
21
  dim (int): The dimension along which to split the tensor. Currently only supports dim=2.
20
22
  sparsity_block_size (int): The size of the sparsity blocks.
21
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
23
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
22
24
 
23
25
  Returns:
24
26
  BlksprsTensor: The block-sparse tensor split into partitions in compressed form.
@@ -32,63 +34,89 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
32
34
  validate_device(x)
33
35
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
34
36
  validate_sparsity_block_size(sparsity_block_size, x)
35
- validate_triton_block_size(triton_block_size, sparsity_block_size)
36
37
 
37
- sparsity_layout_output = (sparsity_layout
38
- .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
39
- sparsity_layout.size(2) // partitions)
40
- .permute(0, 2, 1, 3)
41
- .reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
42
- sparsity_layout.size(2) // partitions).contiguous())
38
+ adjusted_dim = dim % 3
39
+ if adjusted_dim != 2:
40
+ raise NotImplementedError("Currently only supports dim=2")
43
41
 
44
- sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
42
+ lut = split_build_lut(lut, sparsity_layout, partitions)
45
43
 
46
- sparsity_layout_flat = sparsity_layout.reshape(-1)
47
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
48
- (sparsity_layout_flat == 1) -
49
- (1 * (sparsity_layout_flat == 0)))
50
- .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
51
- sparsity_layout.size(2) // partitions)
52
- .permute(0, 2, 1, 3).reshape(-1).contiguous())
44
+ return BlksprsTensor(split_forward(
45
+ x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
46
+ partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
53
47
 
54
- n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
55
48
 
56
- validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
49
+ @triton_op("blksprs::split_forward", mutates_args={})
50
+ def split_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
51
+ _: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
52
+ with torch.no_grad():
53
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
54
+ n_sparse_blocks)
57
55
 
58
- adjusted_dim = dim % 3
59
- if adjusted_dim != 2:
60
- raise NotImplementedError("Currently only supports dim=2")
61
56
 
62
- return BlksprsTensor(_BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
63
- adjusted_dim, sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
57
+ def split_wrapper_backward(ctx, grad_output):
58
+ sparsity_layout = ctx.saved_tensors[0]
59
+ num_partitions = ctx.num_partitions
60
+ dim = ctx.dim
61
+ sparsity_block_size = ctx.sparsity_block_size
62
+
63
+ return merge(grad_output, sparsity_layout, num_partitions, dim,
64
+ sparsity_block_size)[0], None, None, None, None, None, None, None
65
+
66
+
67
+ def split_build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
68
+ if lut is None:
69
+ lut = dict()
70
+
71
+ if "sparsity_layout_output" not in lut:
72
+ sparsity_layout_output = (sparsity_layout
73
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
74
+ sparsity_layout.size(2) // partitions)
75
+ .permute(0, 2, 1, 3)
76
+ .reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
77
+ sparsity_layout.size(2) // partitions).contiguous())
78
+ lut["sparsity_layout_output"] = sparsity_layout_output
79
+
80
+ if "sparsity_lut" not in lut:
81
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
82
+ lut["sparsity_lut"] = sparsity_lut
83
+
84
+ if "sparsity_reverse_lut" not in lut:
85
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
86
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
87
+ (sparsity_layout_flat == 1) -
88
+ (1 * (sparsity_layout_flat == 0)))
89
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
90
+ sparsity_layout.size(2) // partitions)
91
+ .permute(0, 2, 1, 3).reshape(-1).contiguous())
92
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
93
+
94
+ if "n_sparse_blocks" not in lut:
95
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
96
+ lut["n_sparse_blocks"] = n_sparse_blocks
64
97
 
98
+ validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
65
99
 
66
- class _BlocksparseSplit(torch.autograd.Function):
100
+ return lut
67
101
 
68
- @staticmethod
69
- def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
70
- num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
71
- ctx.save_for_backward(sparsity_layout_o)
72
- ctx.num_partitions = num_partitions
73
- ctx.dim = dim
74
102
 
75
- return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
76
- n_sparse_blocks, triton_block_size)
103
+ # noinspection PyUnusedLocal
104
+ def split_setup_context(ctx, inputs, output):
105
+ (_, sparsity_layout_o, _, _, num_partitions, dim, sparsity_block_size, _) = inputs
77
106
 
78
- @staticmethod
79
- def backward(ctx, grad_output):
80
- sparsity_layout = ctx.saved_tensors[0]
81
- num_partitions = ctx.num_partitions
82
- dim = ctx.dim
83
- sparsity_block_size = ctx.sparsity_block_size
84
- triton_block_size = ctx.triton_block_size
107
+ ctx.save_for_backward(sparsity_layout_o)
108
+ ctx.num_partitions = num_partitions
109
+ ctx.dim = dim
110
+ ctx.sparsity_block_size = sparsity_block_size
85
111
 
86
- return merge(grad_output, sparsity_layout, num_partitions, dim,
87
- sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None, None
88
112
 
113
+ split_forward.register_autograd(split_wrapper_backward, setup_context=split_setup_context)
89
114
 
115
+
116
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
90
117
  def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
91
- dim: int, sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
118
+ dim: int, sparsity_block_size: int, lut: dict = None) -> (
119
+ BlksprsTensor, Tensor):
92
120
  """Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
93
121
 
94
122
  Args:
@@ -97,7 +125,7 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
97
125
  partitions (int): The number of partitions to be merged.
98
126
  dim (int): The dimension along which to merge the tensor. Currently only supports dim=2.
99
127
  sparsity_block_size (int): The size of the sparsity blocks.
100
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
128
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
101
129
 
102
130
  Returns:
103
131
  BlksprsTensor: The merged block-sparse tensor in compressed form.
@@ -111,60 +139,83 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
111
139
  validate_device(x)
112
140
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
113
141
  validate_sparsity_block_size(sparsity_block_size, x)
114
- validate_triton_block_size(triton_block_size, sparsity_block_size)
115
142
 
116
- sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
117
- sparsity_layout.size(1), sparsity_layout.size(2))
118
- .permute(0, 2, 1, 3)
119
- .reshape(sparsity_layout.size(0) // partitions,
120
- sparsity_layout.size(1), sparsity_layout.size(2) * partitions).contiguous())
143
+ adjusted_dim = dim % 3
144
+ if adjusted_dim != 2:
145
+ raise NotImplementedError("Currently only supports dim=2")
121
146
 
122
- sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
147
+ lut = merge_build_lut(lut, sparsity_layout, partitions)
123
148
 
124
- sparsity_layout_flat = sparsity_layout.reshape(-1)
125
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
126
- (sparsity_layout_flat == 1) -
127
- (1 * (sparsity_layout_flat == 0)))
128
- .reshape(sparsity_layout.size(0) // partitions, partitions,
129
- sparsity_layout.size(1), sparsity_layout.size(2))
130
- .permute(0, 2, 1, 3)
131
- .reshape(sparsity_layout.size(0) // partitions,
132
- sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
133
- .reshape(-1).contiguous())
149
+ return BlksprsTensor(merge_forward(
150
+ x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
151
+ partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
134
152
 
135
- n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
136
153
 
137
- validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
154
+ @triton_op("blksprs::merge_forward", mutates_args={})
155
+ def merge_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
156
+ _: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
157
+ with torch.no_grad():
158
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
159
+ n_sparse_blocks)
138
160
 
139
- adjusted_dim = dim % 3
140
- if adjusted_dim != 2:
141
- raise NotImplementedError("Currently only supports dim=2")
142
161
 
143
- return BlksprsTensor(_BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
144
- adjusted_dim, sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
162
+ def merge_wrapper_backward(ctx, grad_output):
163
+ sparsity_layout = ctx.saved_tensors[0]
164
+ num_partitions = ctx.num_partitions
165
+ dim = ctx.dim
166
+ sparsity_block_size = ctx.sparsity_block_size
167
+
168
+ return split(grad_output, sparsity_layout, num_partitions, dim,
169
+ sparsity_block_size)[0], None, None, None, None, None, None, None
170
+
171
+
172
+ def merge_build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
173
+ if lut is None:
174
+ lut = dict()
175
+
176
+ if "sparsity_layout_output" not in lut:
177
+ sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
178
+ sparsity_layout.size(1), sparsity_layout.size(2))
179
+ .permute(0, 2, 1, 3)
180
+ .reshape(sparsity_layout.size(0) // partitions,
181
+ sparsity_layout.size(1),
182
+ sparsity_layout.size(2) * partitions).contiguous())
183
+ lut["sparsity_layout_output"] = sparsity_layout_output
184
+
185
+ if "sparsity_lut" not in lut:
186
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
187
+ lut["sparsity_lut"] = sparsity_lut
188
+
189
+ if "sparsity_reverse_lut" not in lut:
190
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
191
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
192
+ (sparsity_layout_flat == 1) -
193
+ (1 * (sparsity_layout_flat == 0)))
194
+ .reshape(sparsity_layout.size(0) // partitions, partitions,
195
+ sparsity_layout.size(1), sparsity_layout.size(2))
196
+ .permute(0, 2, 1, 3)
197
+ .reshape(sparsity_layout.size(0) // partitions,
198
+ sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
199
+ .reshape(-1).contiguous())
200
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
145
201
 
202
+ if "n_sparse_blocks" not in lut:
203
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
204
+ lut["n_sparse_blocks"] = n_sparse_blocks
146
205
 
147
- class _BlocksparseMerge(torch.autograd.Function):
206
+ validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
148
207
 
149
- @staticmethod
150
- def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
151
- num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
152
- ctx.save_for_backward(sparsity_layout_o)
153
- ctx.num_partitions = num_partitions
154
- ctx.dim = dim
208
+ return lut
155
209
 
156
- return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
157
- n_sparse_blocks, triton_block_size)
158
210
 
159
- @staticmethod
160
- def backward(ctx, grad_output):
161
- sparsity_layout = ctx.saved_tensors[0]
162
- num_partitions = ctx.num_partitions
163
- dim = ctx.dim
164
- sparsity_block_size = ctx.sparsity_block_size
165
- triton_block_size = ctx.triton_block_size
211
+ # noinspection PyUnusedLocal
212
+ def merge_setup_context(ctx, inputs, output):
213
+ (_, sparsity_layout_o, _, _, num_partitions, dim, sparsity_block_size, _) = inputs
166
214
 
167
- return split(grad_output, sparsity_layout, num_partitions, dim,
168
- sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None, None
215
+ ctx.save_for_backward(sparsity_layout_o)
216
+ ctx.num_partitions = num_partitions
217
+ ctx.dim = dim
218
+ ctx.sparsity_block_size = sparsity_block_size
169
219
 
170
220
 
221
+ merge_forward.register_autograd(merge_wrapper_backward, setup_context=merge_setup_context)
blksprs/ops/repeat.py CHANGED
@@ -1,16 +1,16 @@
1
1
  import torch
2
- import triton
3
2
  from torch import Tensor
3
+ from torch._library import triton_op
4
4
 
5
- from blksprs.ops.flow import kernel_blocksparse_flow_push, flow_forward
5
+ from blksprs.ops.flow import flow_pull_forward, flow_push_forward
6
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
8
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
9
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
+ validate_sparsity, validate_sparsity_block_size
10
9
 
11
10
 
11
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
12
12
  def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
13
- sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None) -> (
13
+ sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
14
14
  BlksprsTensor, Tensor):
15
15
  """Repeats a block-spare tensor in compressed form according to the given repeats.
16
16
 
@@ -29,7 +29,7 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
29
29
  third dimension respectively.
30
30
  sparsity_block_size (int): The size of the sparsity blocks.
31
31
  sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
32
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
32
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
33
33
 
34
34
  Returns:
35
35
  BlksprsTensor: A block-sparse tensor in compressed form containing the repeated values.
@@ -43,35 +43,17 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
43
43
  validate_device(x)
44
44
  validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
45
45
  validate_sparsity_block_size(sparsity_block_size, x)
46
- validate_triton_block_size(triton_block_size, sparsity_block_size)
47
46
 
48
- sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
47
+ lut = repeat_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
49
48
 
50
- if sparsity_layout_output is not None:
51
- sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
52
-
53
- sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
54
-
55
- sparsity_layout_flat = sparsity_layout_x.reshape(-1)
56
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
57
- (sparsity_layout_flat == 1) -
58
- (1 * (sparsity_layout_flat == 0)))
59
- .reshape(sparsity_layout_x.size())
60
- .repeat(repeats[0], repeats[1], repeats[2])
61
- .reshape(-1).contiguous())
62
-
63
- n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
64
-
65
- validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
66
-
67
- return BlksprsTensor(
68
- _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
69
- sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
49
+ return BlksprsTensor(repeat_forward(
50
+ x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
51
+ lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
70
52
 
71
53
 
54
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
72
55
  def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
73
- sparsity_block_size: int, sparsity_layout_output: Tensor = None,
74
- triton_block_size: int = None) -> (
56
+ sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
75
57
  BlksprsTensor, Tensor):
76
58
  """Repeats and interleaves the block-sparse tensor in compressed form.
77
59
 
@@ -88,7 +70,7 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
88
70
  repeats (int): The number of times to repeat the matrices.
89
71
  sparsity_block_size (int): The size of the sparsity blocks.
90
72
  sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
91
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
73
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
92
74
 
93
75
  Returns:
94
76
  BlksprsTensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
@@ -102,83 +84,111 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
102
84
  validate_device(x)
103
85
  validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
104
86
  validate_sparsity_block_size(sparsity_block_size, x)
105
- validate_triton_block_size(triton_block_size, sparsity_block_size)
106
87
 
107
- sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
88
+ lut = repeat_interleave_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
89
+
90
+ return BlksprsTensor(repeat_forward(
91
+ x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
92
+ lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
93
+
94
+
95
+ @triton_op("blksprs::repeat_forward", mutates_args={})
96
+ def repeat_forward(x: Tensor, _: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
97
+ sparsity_reverse_lut: Tensor,
98
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
99
+ with torch.no_grad():
100
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
101
+ n_sparse_blocks)
102
+
103
+
104
+ def repeat_wrapper_backward(ctx, grad_output):
105
+ sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
106
+ sparsity_block_size = ctx.sparsity_block_size
107
+ n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
108
+
109
+ return flow_push_forward(grad_output, sparsity_layout_o, sparsity_lut,
110
+ sparsity_reverse_lut, sparsity_block_size,
111
+ n_sparse_blocks), None, None, None, None, None, None
112
+
113
+
114
+ def repeat_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
115
+ sparsity_layout_output: Tensor):
116
+ if lut is None:
117
+ lut = dict()
118
+
119
+ if "sparsity_layout_o" not in lut:
120
+ sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
121
+ lut["sparsity_layout_o"] = sparsity_layout_o
108
122
 
109
123
  if sparsity_layout_output is not None:
110
- sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
111
-
112
- sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
113
-
114
- sparsity_layout_flat = sparsity_layout_x.reshape(-1)
115
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
116
- (sparsity_layout_flat == 1) -
117
- (1 * (sparsity_layout_flat == 0)))
118
- .reshape(sparsity_layout_x.size())
119
- .repeat_interleave(repeats, dim=0)
120
- .reshape(-1).contiguous())
121
-
122
- n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
123
-
124
- validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
125
-
126
- return BlksprsTensor(
127
- _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
128
- sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
129
-
130
-
131
- class _BlocksparseRepeat(torch.autograd.Function):
132
-
133
- @staticmethod
134
- def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
135
- sparsity_reverse_lut: Tensor,
136
- sparsity_block_size: int, n_sparse_blocks: int,
137
- triton_block_size: int) -> Tensor:
138
- ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
139
- ctx.x_size = x.size()
140
- ctx.x_stride = stride(x)
141
-
142
- return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
143
- n_sparse_blocks, triton_block_size)
144
-
145
- @staticmethod
146
- def backward(ctx, grad_output):
147
- sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
148
- x_size = ctx.x_size
149
- x_stride = ctx.x_stride
150
- sparsity_block_size = ctx.sparsity_block_size
151
- triton_block_size = ctx.triton_block_size
152
-
153
- n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
154
-
155
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
156
- dtype=grad_output.dtype, device=grad_output.device)
157
-
158
- x_b, x_r, x_c = grad_output.size()
159
- x_b_s, x_r_s, x_c_s = stride(grad_output)
160
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_o.size()
161
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_o)
162
- s_lut_r, s_lut_c = sparsity_lut.size()
163
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
164
- o_b, o_r, o_c = x_size
165
- o_b_s, o_r_s, o_c_s = x_stride
166
-
167
- if triton_block_size is None:
168
- triton_block_size = get_triton_block_size(sparsity_block_size)
169
-
170
- triton_grid = lambda meta: [x_b,
171
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
172
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
173
-
174
- (kernel_blocksparse_flow_push[triton_grid]
175
- (grad_output,
176
- x_b, x_b_s, x_r_s, x_c_s,
177
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
178
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
179
- sparsity_reverse_lut,
180
- output,
181
- o_b, o_b_s, o_r_s, o_c_s,
182
- triton_block_size))
183
-
184
- return output, None, None, None, None, None, None, None
124
+ sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
125
+ lut["sparsity_layout_o"] = sparsity_layout_o
126
+
127
+ if "sparsity_lut" not in lut:
128
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
129
+ lut["sparsity_lut"] = sparsity_lut
130
+
131
+ if "sparsity_reverse_lut" not in lut:
132
+ sparsity_layout_flat = sparsity_layout_x.reshape(-1)
133
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
134
+ (sparsity_layout_flat == 1) -
135
+ (1 * (sparsity_layout_flat == 0)))
136
+ .reshape(sparsity_layout_x.size())
137
+ .repeat(repeats[0], repeats[1], repeats[2])
138
+ .reshape(-1).contiguous())
139
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
140
+
141
+ if "n_sparse_blocks" not in lut:
142
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
143
+ lut["n_sparse_blocks"] = n_sparse_blocks
144
+
145
+ validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
146
+
147
+ return lut
148
+
149
+
150
+ def repeat_interleave_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: int,
151
+ sparsity_layout_output: Tensor):
152
+ if lut is None:
153
+ lut = dict()
154
+
155
+ if "sparsity_layout_o" not in lut:
156
+ sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
157
+ lut["sparsity_layout_o"] = sparsity_layout_o
158
+
159
+ if sparsity_layout_output is not None:
160
+ sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
161
+ lut["sparsity_layout_o"] = sparsity_layout_o
162
+
163
+ if "sparsity_lut" not in lut:
164
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
165
+ lut["sparsity_lut"] = sparsity_lut
166
+
167
+ if "sparsity_reverse_lut" not in lut:
168
+ sparsity_layout_flat = sparsity_layout_x.reshape(-1)
169
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
170
+ (sparsity_layout_flat == 1) -
171
+ (1 * (sparsity_layout_flat == 0)))
172
+ .reshape(sparsity_layout_x.size())
173
+ .repeat_interleave(repeats, dim=0)
174
+ .reshape(-1).contiguous())
175
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
176
+
177
+ if "n_sparse_blocks" not in lut:
178
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
179
+ lut["n_sparse_blocks"] = n_sparse_blocks
180
+
181
+ validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
182
+
183
+ return lut
184
+
185
+
186
+ # noinspection PyUnusedLocal
187
+ def repeat_setup_context(ctx, inputs, output):
188
+ (_, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size, _) = inputs
189
+
190
+ ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
191
+ ctx.sparsity_block_size = sparsity_block_size
192
+
193
+
194
+ repeat_forward.register_autograd(repeat_wrapper_backward, setup_context=repeat_setup_context)