blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,16 @@
1
1
  import torch
2
2
  from torch import Tensor
3
+ from torch._library import triton_op
3
4
 
4
- from blksprs.ops.flow import flow_forward_pull
5
+ from blksprs.ops.flow import flow_pull_forward
5
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
6
-
7
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
+ validate_sparsity, validate_sparsity_block_size
9
9
 
10
10
 
11
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
11
12
  def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
12
- dim: int, sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> (
13
+ dim: int, sparsity_block_size: int, lut: dict = None) -> (
13
14
  BlksprsTensor, Tensor):
14
15
  """Splits a block-sparse tensor in compressed form along the last dimension into partitions.
15
16
 
@@ -19,7 +20,6 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
19
20
  partitions (int): The number of partitions to split the block-sparse tensor into.
20
21
  dim (int): The dimension along which to split the tensor. Currently only supports dim=2.
21
22
  sparsity_block_size (int): The size of the sparsity blocks.
22
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
23
23
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
24
24
 
25
25
  Returns:
@@ -34,83 +34,88 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
34
34
  validate_device(x)
35
35
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
36
36
  validate_sparsity_block_size(sparsity_block_size, x)
37
- validate_triton_block_size(triton_block_size, sparsity_block_size)
38
37
 
39
38
  adjusted_dim = dim % 3
40
39
  if adjusted_dim != 2:
41
40
  raise NotImplementedError("Currently only supports dim=2")
42
41
 
43
- lut = _BlocksparseSplit.build_lut(lut, sparsity_layout, partitions)
42
+ lut = split_build_lut(lut, sparsity_layout, partitions)
43
+
44
+ return BlksprsTensor(split_forward(
45
+ x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
46
+ partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
47
+
48
+
49
+ @triton_op("blksprs::split_forward", mutates_args={})
50
+ def split_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
51
+ _: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
52
+ with torch.no_grad():
53
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
54
+ n_sparse_blocks)
44
55
 
45
- return BlksprsTensor(
46
- _BlocksparseSplit.apply(x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
47
- partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"],
48
- triton_block_size)), lut["sparsity_layout_output"]
49
56
 
57
+ def split_wrapper_backward(ctx, grad_output):
58
+ sparsity_layout = ctx.saved_tensors[0]
59
+ num_partitions = ctx.num_partitions
60
+ dim = ctx.dim
61
+ sparsity_block_size = ctx.sparsity_block_size
50
62
 
51
- class _BlocksparseSplit(torch.autograd.Function):
63
+ return merge(grad_output, sparsity_layout, num_partitions, dim,
64
+ sparsity_block_size)[0], None, None, None, None, None, None, None
52
65
 
53
- @staticmethod
54
- def build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
55
- if lut is None:
56
- lut = dict()
57
66
 
58
- if "sparsity_layout_output" not in lut:
59
- sparsity_layout_output = (sparsity_layout
60
- .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
61
- sparsity_layout.size(2) // partitions)
62
- .permute(0, 2, 1, 3)
63
- .reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
64
- sparsity_layout.size(2) // partitions).contiguous())
65
- lut["sparsity_layout_output"] = sparsity_layout_output
67
+ def split_build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
68
+ if lut is None:
69
+ lut = dict()
66
70
 
67
- if "sparsity_lut" not in lut:
68
- sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
69
- lut["sparsity_lut"] = sparsity_lut
71
+ if "sparsity_layout_output" not in lut:
72
+ sparsity_layout_output = (sparsity_layout
73
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
74
+ sparsity_layout.size(2) // partitions)
75
+ .permute(0, 2, 1, 3)
76
+ .reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
77
+ sparsity_layout.size(2) // partitions).contiguous())
78
+ lut["sparsity_layout_output"] = sparsity_layout_output
70
79
 
71
- if "sparsity_reverse_lut" not in lut:
72
- sparsity_layout_flat = sparsity_layout.reshape(-1)
73
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
74
- (sparsity_layout_flat == 1) -
75
- (1 * (sparsity_layout_flat == 0)))
76
- .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
77
- sparsity_layout.size(2) // partitions)
78
- .permute(0, 2, 1, 3).reshape(-1).contiguous())
79
- lut["sparsity_reverse_lut"] = sparsity_reverse_lut
80
+ if "sparsity_lut" not in lut:
81
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
82
+ lut["sparsity_lut"] = sparsity_lut
80
83
 
81
- if "n_sparse_blocks" not in lut:
82
- n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
83
- lut["n_sparse_blocks"] = n_sparse_blocks
84
+ if "sparsity_reverse_lut" not in lut:
85
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
86
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
87
+ (sparsity_layout_flat == 1) -
88
+ (1 * (sparsity_layout_flat == 0)))
89
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
90
+ sparsity_layout.size(2) // partitions)
91
+ .permute(0, 2, 1, 3).reshape(-1).contiguous())
92
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
84
93
 
85
- validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
94
+ if "n_sparse_blocks" not in lut:
95
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
96
+ lut["n_sparse_blocks"] = n_sparse_blocks
86
97
 
87
- return lut
98
+ validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
88
99
 
89
- @staticmethod
90
- def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
91
- num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int,
92
- triton_block_size: int) -> Tensor:
93
- ctx.save_for_backward(sparsity_layout_o)
94
- ctx.num_partitions = num_partitions
95
- ctx.dim = dim
100
+ return lut
96
101
 
97
- return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
98
- n_sparse_blocks, triton_block_size)
99
102
 
100
- @staticmethod
101
- def backward(ctx, grad_output):
102
- sparsity_layout = ctx.saved_tensors[0]
103
- num_partitions = ctx.num_partitions
104
- dim = ctx.dim
105
- sparsity_block_size = ctx.sparsity_block_size
106
- triton_block_size = ctx.triton_block_size
103
+ # noinspection PyUnusedLocal
104
+ def split_setup_context(ctx, inputs, output):
105
+ (_, sparsity_layout_o, _, _, num_partitions, dim, sparsity_block_size, _) = inputs
107
106
 
108
- return merge(grad_output, sparsity_layout, num_partitions, dim,
109
- sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None, None
107
+ ctx.save_for_backward(sparsity_layout_o)
108
+ ctx.num_partitions = num_partitions
109
+ ctx.dim = dim
110
+ ctx.sparsity_block_size = sparsity_block_size
110
111
 
111
112
 
113
+ split_forward.register_autograd(split_wrapper_backward, setup_context=split_setup_context)
114
+
115
+
116
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
112
117
  def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
113
- dim: int, sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> (
118
+ dim: int, sparsity_block_size: int, lut: dict = None) -> (
114
119
  BlksprsTensor, Tensor):
115
120
  """Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
116
121
 
@@ -120,7 +125,6 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
120
125
  partitions (int): The number of partitions to be merged.
121
126
  dim (int): The dimension along which to merge the tensor. Currently only supports dim=2.
122
127
  sparsity_block_size (int): The size of the sparsity blocks.
123
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
124
128
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
125
129
 
126
130
  Returns:
@@ -135,79 +139,83 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
135
139
  validate_device(x)
136
140
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
137
141
  validate_sparsity_block_size(sparsity_block_size, x)
138
- validate_triton_block_size(triton_block_size, sparsity_block_size)
139
142
 
140
143
  adjusted_dim = dim % 3
141
144
  if adjusted_dim != 2:
142
145
  raise NotImplementedError("Currently only supports dim=2")
143
146
 
144
- lut = _BlocksparseMerge.build_lut(lut, sparsity_layout, partitions)
145
-
146
- return BlksprsTensor(
147
- _BlocksparseMerge.apply(x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
148
- partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"],
149
- triton_block_size)), lut["sparsity_layout_output"]
150
-
151
-
152
- class _BlocksparseMerge(torch.autograd.Function):
153
-
154
- @staticmethod
155
- def build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
156
- if lut is None:
157
- lut = dict()
158
-
159
- if "sparsity_layout_output" not in lut:
160
- sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
161
- sparsity_layout.size(1), sparsity_layout.size(2))
162
- .permute(0, 2, 1, 3)
163
- .reshape(sparsity_layout.size(0) // partitions,
164
- sparsity_layout.size(1),
165
- sparsity_layout.size(2) * partitions).contiguous())
166
- lut["sparsity_layout_output"] = sparsity_layout_output
167
-
168
- if "sparsity_lut" not in lut:
169
- sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
170
- lut["sparsity_lut"] = sparsity_lut
171
-
172
- if "sparsity_reverse_lut" not in lut:
173
- sparsity_layout_flat = sparsity_layout.reshape(-1)
174
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
175
- (sparsity_layout_flat == 1) -
176
- (1 * (sparsity_layout_flat == 0)))
177
- .reshape(sparsity_layout.size(0) // partitions, partitions,
178
- sparsity_layout.size(1), sparsity_layout.size(2))
179
- .permute(0, 2, 1, 3)
180
- .reshape(sparsity_layout.size(0) // partitions,
181
- sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
182
- .reshape(-1).contiguous())
183
- lut["sparsity_reverse_lut"] = sparsity_reverse_lut
184
-
185
- if "n_sparse_blocks" not in lut:
186
- n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
187
- lut["n_sparse_blocks"] = n_sparse_blocks
188
-
189
- validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
190
-
191
- return lut
192
-
193
- @staticmethod
194
- def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
195
- num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int,
196
- triton_block_size: int) -> Tensor:
197
- ctx.save_for_backward(sparsity_layout_o)
198
- ctx.num_partitions = num_partitions
199
- ctx.dim = dim
200
-
201
- return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
202
- n_sparse_blocks, triton_block_size)
203
-
204
- @staticmethod
205
- def backward(ctx, grad_output):
206
- sparsity_layout = ctx.saved_tensors[0]
207
- num_partitions = ctx.num_partitions
208
- dim = ctx.dim
209
- sparsity_block_size = ctx.sparsity_block_size
210
- triton_block_size = ctx.triton_block_size
211
-
212
- return split(grad_output, sparsity_layout, num_partitions, dim,
213
- sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None, None
147
+ lut = merge_build_lut(lut, sparsity_layout, partitions)
148
+
149
+ return BlksprsTensor(merge_forward(
150
+ x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
151
+ partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
152
+
153
+
154
+ @triton_op("blksprs::merge_forward", mutates_args={})
155
+ def merge_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
156
+ _: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
157
+ with torch.no_grad():
158
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
159
+ n_sparse_blocks)
160
+
161
+
162
+ def merge_wrapper_backward(ctx, grad_output):
163
+ sparsity_layout = ctx.saved_tensors[0]
164
+ num_partitions = ctx.num_partitions
165
+ dim = ctx.dim
166
+ sparsity_block_size = ctx.sparsity_block_size
167
+
168
+ return split(grad_output, sparsity_layout, num_partitions, dim,
169
+ sparsity_block_size)[0], None, None, None, None, None, None, None
170
+
171
+
172
+ def merge_build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
173
+ if lut is None:
174
+ lut = dict()
175
+
176
+ if "sparsity_layout_output" not in lut:
177
+ sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
178
+ sparsity_layout.size(1), sparsity_layout.size(2))
179
+ .permute(0, 2, 1, 3)
180
+ .reshape(sparsity_layout.size(0) // partitions,
181
+ sparsity_layout.size(1),
182
+ sparsity_layout.size(2) * partitions).contiguous())
183
+ lut["sparsity_layout_output"] = sparsity_layout_output
184
+
185
+ if "sparsity_lut" not in lut:
186
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
187
+ lut["sparsity_lut"] = sparsity_lut
188
+
189
+ if "sparsity_reverse_lut" not in lut:
190
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
191
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
192
+ (sparsity_layout_flat == 1) -
193
+ (1 * (sparsity_layout_flat == 0)))
194
+ .reshape(sparsity_layout.size(0) // partitions, partitions,
195
+ sparsity_layout.size(1), sparsity_layout.size(2))
196
+ .permute(0, 2, 1, 3)
197
+ .reshape(sparsity_layout.size(0) // partitions,
198
+ sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
199
+ .reshape(-1).contiguous())
200
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
201
+
202
+ if "n_sparse_blocks" not in lut:
203
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
204
+ lut["n_sparse_blocks"] = n_sparse_blocks
205
+
206
+ validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
207
+
208
+ return lut
209
+
210
+
211
+ # noinspection PyUnusedLocal
212
+ def merge_setup_context(ctx, inputs, output):
213
+ (_, sparsity_layout_o, _, _, num_partitions, dim, sparsity_block_size, _) = inputs
214
+
215
+ ctx.save_for_backward(sparsity_layout_o)
216
+ ctx.num_partitions = num_partitions
217
+ ctx.dim = dim
218
+ ctx.sparsity_block_size = sparsity_block_size
219
+
220
+
221
+ merge_forward.register_autograd(merge_wrapper_backward, setup_context=merge_setup_context)
blksprs/ops/repeat.py CHANGED
@@ -1,17 +1,16 @@
1
1
  import torch
2
- import triton
3
2
  from torch import Tensor
3
+ from torch._library import triton_op
4
4
 
5
- from blksprs.ops.flow import kernel_blocksparse_flow_push, flow_forward_pull, flow_forward_push
5
+ from blksprs.ops.flow import flow_pull_forward, flow_push_forward
6
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
8
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
9
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
+ validate_sparsity, validate_sparsity_block_size
10
9
 
11
10
 
11
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
12
12
  def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
13
- sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None,
14
- lut: dict = None) -> (
13
+ sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
15
14
  BlksprsTensor, Tensor):
16
15
  """Repeats a block-spare tensor in compressed form according to the given repeats.
17
16
 
@@ -30,7 +29,6 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
30
29
  third dimension respectively.
31
30
  sparsity_block_size (int): The size of the sparsity blocks.
32
31
  sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
33
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
34
32
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
35
33
 
36
34
  Returns:
@@ -45,19 +43,17 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
45
43
  validate_device(x)
46
44
  validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
47
45
  validate_sparsity_block_size(sparsity_block_size, x)
48
- validate_triton_block_size(triton_block_size, sparsity_block_size)
49
46
 
50
- lut = _BlocksparseRepeat.build_lut_repeat(lut, sparsity_layout_x, repeats, sparsity_layout_output)
47
+ lut = repeat_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
51
48
 
52
- return BlksprsTensor(
53
- _BlocksparseRepeat.apply(x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
54
- lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"],
55
- triton_block_size)), lut["sparsity_layout_o"]
49
+ return BlksprsTensor(repeat_forward(
50
+ x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
51
+ lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
56
52
 
57
53
 
54
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
58
55
  def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
59
- sparsity_block_size: int, sparsity_layout_output: Tensor = None,
60
- triton_block_size: int = None, lut: dict = None) -> (
56
+ sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
61
57
  BlksprsTensor, Tensor):
62
58
  """Repeats and interleaves the block-sparse tensor in compressed form.
63
59
 
@@ -74,7 +70,6 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
74
70
  repeats (int): The number of times to repeat the matrices.
75
71
  sparsity_block_size (int): The size of the sparsity blocks.
76
72
  sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
77
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
78
73
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
79
74
 
80
75
  Returns:
@@ -89,108 +84,111 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
89
84
  validate_device(x)
90
85
  validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
91
86
  validate_sparsity_block_size(sparsity_block_size, x)
92
- validate_triton_block_size(triton_block_size, sparsity_block_size)
93
87
 
94
- lut = _BlocksparseRepeat.build_lut_repeat_interleave(lut, sparsity_layout_x, repeats, sparsity_layout_output)
95
-
96
- return BlksprsTensor(
97
- _BlocksparseRepeat.apply(x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
98
- lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"],
99
- triton_block_size)), lut["sparsity_layout_o"]
100
-
101
-
102
- class _BlocksparseRepeat(torch.autograd.Function):
103
-
104
- @staticmethod
105
- def build_lut_repeat(lut: dict, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
106
- sparsity_layout_output: Tensor):
107
- if lut is None:
108
- lut = dict()
109
-
110
- if "sparsity_layout_o" not in lut:
111
- sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
112
- lut["sparsity_layout_o"] = sparsity_layout_o
113
-
114
- if sparsity_layout_output is not None:
115
- sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
116
- lut["sparsity_layout_o"] = sparsity_layout_o
117
-
118
- if "sparsity_lut" not in lut:
119
- sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
120
- lut["sparsity_lut"] = sparsity_lut
121
-
122
- if "sparsity_reverse_lut" not in lut:
123
- sparsity_layout_flat = sparsity_layout_x.reshape(-1)
124
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
125
- (sparsity_layout_flat == 1) -
126
- (1 * (sparsity_layout_flat == 0)))
127
- .reshape(sparsity_layout_x.size())
128
- .repeat(repeats[0], repeats[1], repeats[2])
129
- .reshape(-1).contiguous())
130
- lut["sparsity_reverse_lut"] = sparsity_reverse_lut
131
-
132
- if "n_sparse_blocks" not in lut:
133
- n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
134
- lut["n_sparse_blocks"] = n_sparse_blocks
135
-
136
- validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
137
-
138
- return lut
139
-
140
- @staticmethod
141
- def build_lut_repeat_interleave(lut: dict, sparsity_layout_x: Tensor, repeats: int,
142
- sparsity_layout_output: Tensor):
143
- if lut is None:
144
- lut = dict()
145
-
146
- if "sparsity_layout_o" not in lut:
147
- sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
148
- lut["sparsity_layout_o"] = sparsity_layout_o
149
-
150
- if sparsity_layout_output is not None:
151
- sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
152
- lut["sparsity_layout_o"] = sparsity_layout_o
153
-
154
- if "sparsity_lut" not in lut:
155
- sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
156
- lut["sparsity_lut"] = sparsity_lut
157
-
158
- if "sparsity_reverse_lut" not in lut:
159
- sparsity_layout_flat = sparsity_layout_x.reshape(-1)
160
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
161
- (sparsity_layout_flat == 1) -
162
- (1 * (sparsity_layout_flat == 0)))
163
- .reshape(sparsity_layout_x.size())
164
- .repeat_interleave(repeats, dim=0)
165
- .reshape(-1).contiguous())
166
- lut["sparsity_reverse_lut"] = sparsity_reverse_lut
167
-
168
- if "n_sparse_blocks" not in lut:
169
- n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
170
- lut["n_sparse_blocks"] = n_sparse_blocks
171
-
172
- validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
173
-
174
- return lut
175
-
176
- @staticmethod
177
- def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
178
- sparsity_reverse_lut: Tensor,
179
- sparsity_block_size: int, n_sparse_blocks: int,
180
- triton_block_size: int) -> Tensor:
181
- ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
182
-
183
- return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
184
- n_sparse_blocks, triton_block_size)
185
-
186
- @staticmethod
187
- def backward(ctx, grad_output):
188
- sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
189
- sparsity_block_size = ctx.sparsity_block_size
190
- triton_block_size = ctx.triton_block_size
191
-
192
- n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
193
-
194
- return flow_forward_push(None, grad_output, sparsity_layout_o, sparsity_lut,
195
- sparsity_reverse_lut, sparsity_block_size, n_sparse_blocks,
196
- triton_block_size), None, None, None, None, None, None, None
88
+ lut = repeat_interleave_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
89
+
90
+ return BlksprsTensor(repeat_forward(
91
+ x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
92
+ lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
93
+
94
+
95
+ @triton_op("blksprs::repeat_forward", mutates_args={})
96
+ def repeat_forward(x: Tensor, _: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
97
+ sparsity_reverse_lut: Tensor,
98
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
99
+ with torch.no_grad():
100
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
101
+ n_sparse_blocks)
102
+
103
+
104
+ def repeat_wrapper_backward(ctx, grad_output):
105
+ sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
106
+ sparsity_block_size = ctx.sparsity_block_size
107
+ n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
108
+
109
+ return flow_push_forward(grad_output, sparsity_layout_o, sparsity_lut,
110
+ sparsity_reverse_lut, sparsity_block_size,
111
+ n_sparse_blocks), None, None, None, None, None, None
112
+
113
+
114
+ def repeat_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
115
+ sparsity_layout_output: Tensor):
116
+ if lut is None:
117
+ lut = dict()
118
+
119
+ if "sparsity_layout_o" not in lut:
120
+ sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
121
+ lut["sparsity_layout_o"] = sparsity_layout_o
122
+
123
+ if sparsity_layout_output is not None:
124
+ sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
125
+ lut["sparsity_layout_o"] = sparsity_layout_o
126
+
127
+ if "sparsity_lut" not in lut:
128
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
129
+ lut["sparsity_lut"] = sparsity_lut
130
+
131
+ if "sparsity_reverse_lut" not in lut:
132
+ sparsity_layout_flat = sparsity_layout_x.reshape(-1)
133
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
134
+ (sparsity_layout_flat == 1) -
135
+ (1 * (sparsity_layout_flat == 0)))
136
+ .reshape(sparsity_layout_x.size())
137
+ .repeat(repeats[0], repeats[1], repeats[2])
138
+ .reshape(-1).contiguous())
139
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
140
+
141
+ if "n_sparse_blocks" not in lut:
142
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
143
+ lut["n_sparse_blocks"] = n_sparse_blocks
144
+
145
+ validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
146
+
147
+ return lut
148
+
149
+
150
+ def repeat_interleave_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: int,
151
+ sparsity_layout_output: Tensor):
152
+ if lut is None:
153
+ lut = dict()
154
+
155
+ if "sparsity_layout_o" not in lut:
156
+ sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
157
+ lut["sparsity_layout_o"] = sparsity_layout_o
158
+
159
+ if sparsity_layout_output is not None:
160
+ sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
161
+ lut["sparsity_layout_o"] = sparsity_layout_o
162
+
163
+ if "sparsity_lut" not in lut:
164
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
165
+ lut["sparsity_lut"] = sparsity_lut
166
+
167
+ if "sparsity_reverse_lut" not in lut:
168
+ sparsity_layout_flat = sparsity_layout_x.reshape(-1)
169
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
170
+ (sparsity_layout_flat == 1) -
171
+ (1 * (sparsity_layout_flat == 0)))
172
+ .reshape(sparsity_layout_x.size())
173
+ .repeat_interleave(repeats, dim=0)
174
+ .reshape(-1).contiguous())
175
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
176
+
177
+ if "n_sparse_blocks" not in lut:
178
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
179
+ lut["n_sparse_blocks"] = n_sparse_blocks
180
+
181
+ validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
182
+
183
+ return lut
184
+
185
+
186
+ # noinspection PyUnusedLocal
187
+ def repeat_setup_context(ctx, inputs, output):
188
+ (_, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size, _) = inputs
189
+
190
+ ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
191
+ ctx.sparsity_block_size = sparsity_block_size
192
+
193
+
194
+ repeat_forward.register_autograd(repeat_wrapper_backward, setup_context=repeat_setup_context)