blksprs 1.11__py3-none-any.whl → 2.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,15 @@
1
1
  import torch
2
2
  from torch import Tensor
3
+ from torch._library import triton_op
3
4
 
4
- from blksprs.ops.flow import flow_forward_pull
5
+ from blksprs.ops.flow import flow_pull_forward
5
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
6
-
7
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
+ validate_sparsity, validate_sparsity_block_size
9
9
 
10
10
 
11
11
  def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
12
- dim: int, sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> (
12
+ dim: int, sparsity_block_size: int, lut: dict = None) -> (
13
13
  BlksprsTensor, Tensor):
14
14
  """Splits a block-sparse tensor in compressed form along the last dimension into partitions.
15
15
 
@@ -19,7 +19,6 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
19
19
  partitions (int): The number of partitions to split the block-sparse tensor into.
20
20
  dim (int): The dimension along which to split the tensor. Currently only supports dim=2.
21
21
  sparsity_block_size (int): The size of the sparsity blocks.
22
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
23
22
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
24
23
 
25
24
  Returns:
@@ -34,83 +33,86 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
34
33
  validate_device(x)
35
34
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
36
35
  validate_sparsity_block_size(sparsity_block_size, x)
37
- validate_triton_block_size(triton_block_size, sparsity_block_size)
38
36
 
39
37
  adjusted_dim = dim % 3
40
38
  if adjusted_dim != 2:
41
39
  raise NotImplementedError("Currently only supports dim=2")
42
40
 
43
- lut = _BlocksparseSplit.build_lut(lut, sparsity_layout, partitions)
41
+ lut = split_build_lut(lut, sparsity_layout, partitions)
42
+
43
+ return BlksprsTensor(split_forward(
44
+ x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
45
+ partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
46
+
47
+
48
+ @triton_op("blksprs::split", mutates_args={})
49
+ def split_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
50
+ _: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
51
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
52
+ n_sparse_blocks)
44
53
 
45
- return BlksprsTensor(
46
- _BlocksparseSplit.apply(x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
47
- partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"],
48
- triton_block_size)), lut["sparsity_layout_output"]
49
54
 
55
+ def split_backward(ctx, grad_output):
56
+ sparsity_layout = ctx.saved_tensors[0]
57
+ num_partitions = ctx.num_partitions
58
+ dim = ctx.dim
59
+ sparsity_block_size = ctx.sparsity_block_size
50
60
 
51
- class _BlocksparseSplit(torch.autograd.Function):
61
+ return merge(grad_output, sparsity_layout, num_partitions, dim,
62
+ sparsity_block_size)[0], None, None, None, None, None, None, None
52
63
 
53
- @staticmethod
54
- def build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
55
- if lut is None:
56
- lut = dict()
57
64
 
58
- if "sparsity_layout_output" not in lut:
59
- sparsity_layout_output = (sparsity_layout
60
- .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
61
- sparsity_layout.size(2) // partitions)
62
- .permute(0, 2, 1, 3)
63
- .reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
64
- sparsity_layout.size(2) // partitions).contiguous())
65
- lut["sparsity_layout_output"] = sparsity_layout_output
65
+ def split_build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
66
+ if lut is None:
67
+ lut = dict()
66
68
 
67
- if "sparsity_lut" not in lut:
68
- sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
69
- lut["sparsity_lut"] = sparsity_lut
69
+ if "sparsity_layout_output" not in lut:
70
+ sparsity_layout_output = (sparsity_layout
71
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
72
+ sparsity_layout.size(2) // partitions)
73
+ .permute(0, 2, 1, 3)
74
+ .reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
75
+ sparsity_layout.size(2) // partitions).contiguous())
76
+ lut["sparsity_layout_output"] = sparsity_layout_output
70
77
 
71
- if "sparsity_reverse_lut" not in lut:
72
- sparsity_layout_flat = sparsity_layout.reshape(-1)
73
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
74
- (sparsity_layout_flat == 1) -
75
- (1 * (sparsity_layout_flat == 0)))
76
- .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
77
- sparsity_layout.size(2) // partitions)
78
- .permute(0, 2, 1, 3).reshape(-1).contiguous())
79
- lut["sparsity_reverse_lut"] = sparsity_reverse_lut
78
+ if "sparsity_lut" not in lut:
79
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
80
+ lut["sparsity_lut"] = sparsity_lut
80
81
 
81
- if "n_sparse_blocks" not in lut:
82
- n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
83
- lut["n_sparse_blocks"] = n_sparse_blocks
82
+ if "sparsity_reverse_lut" not in lut:
83
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
84
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
85
+ (sparsity_layout_flat == 1) -
86
+ (1 * (sparsity_layout_flat == 0)))
87
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
88
+ sparsity_layout.size(2) // partitions)
89
+ .permute(0, 2, 1, 3).reshape(-1).contiguous())
90
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
84
91
 
85
- validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
92
+ if "n_sparse_blocks" not in lut:
93
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
94
+ lut["n_sparse_blocks"] = n_sparse_blocks
86
95
 
87
- return lut
96
+ validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
88
97
 
89
- @staticmethod
90
- def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
91
- num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int,
92
- triton_block_size: int) -> Tensor:
93
- ctx.save_for_backward(sparsity_layout_o)
94
- ctx.num_partitions = num_partitions
95
- ctx.dim = dim
98
+ return lut
96
99
 
97
- return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
98
- n_sparse_blocks, triton_block_size)
99
100
 
100
- @staticmethod
101
- def backward(ctx, grad_output):
102
- sparsity_layout = ctx.saved_tensors[0]
103
- num_partitions = ctx.num_partitions
104
- dim = ctx.dim
105
- sparsity_block_size = ctx.sparsity_block_size
106
- triton_block_size = ctx.triton_block_size
101
+ # noinspection PyUnusedLocal
102
+ def split_setup_context(ctx, inputs, output):
103
+ (_, sparsity_layout_o, _, _, num_partitions, dim, sparsity_block_size, _) = inputs
107
104
 
108
- return merge(grad_output, sparsity_layout, num_partitions, dim,
109
- sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None, None
105
+ ctx.save_for_backward(sparsity_layout_o)
106
+ ctx.num_partitions = num_partitions
107
+ ctx.dim = dim
108
+ ctx.sparsity_block_size = sparsity_block_size
109
+
110
+
111
+ split_forward.register_autograd(split_backward, setup_context=split_setup_context)
110
112
 
111
113
 
112
114
  def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
113
- dim: int, sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> (
115
+ dim: int, sparsity_block_size: int, lut: dict = None) -> (
114
116
  BlksprsTensor, Tensor):
115
117
  """Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
116
118
 
@@ -120,7 +122,6 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
120
122
  partitions (int): The number of partitions to be merged.
121
123
  dim (int): The dimension along which to merge the tensor. Currently only supports dim=2.
122
124
  sparsity_block_size (int): The size of the sparsity blocks.
123
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
124
125
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
125
126
 
126
127
  Returns:
@@ -135,79 +136,82 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
135
136
  validate_device(x)
136
137
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
137
138
  validate_sparsity_block_size(sparsity_block_size, x)
138
- validate_triton_block_size(triton_block_size, sparsity_block_size)
139
139
 
140
140
  adjusted_dim = dim % 3
141
141
  if adjusted_dim != 2:
142
142
  raise NotImplementedError("Currently only supports dim=2")
143
143
 
144
- lut = _BlocksparseMerge.build_lut(lut, sparsity_layout, partitions)
145
-
146
- return BlksprsTensor(
147
- _BlocksparseMerge.apply(x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
148
- partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"],
149
- triton_block_size)), lut["sparsity_layout_output"]
150
-
151
-
152
- class _BlocksparseMerge(torch.autograd.Function):
153
-
154
- @staticmethod
155
- def build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
156
- if lut is None:
157
- lut = dict()
158
-
159
- if "sparsity_layout_output" not in lut:
160
- sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
161
- sparsity_layout.size(1), sparsity_layout.size(2))
162
- .permute(0, 2, 1, 3)
163
- .reshape(sparsity_layout.size(0) // partitions,
164
- sparsity_layout.size(1),
165
- sparsity_layout.size(2) * partitions).contiguous())
166
- lut["sparsity_layout_output"] = sparsity_layout_output
167
-
168
- if "sparsity_lut" not in lut:
169
- sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
170
- lut["sparsity_lut"] = sparsity_lut
171
-
172
- if "sparsity_reverse_lut" not in lut:
173
- sparsity_layout_flat = sparsity_layout.reshape(-1)
174
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
175
- (sparsity_layout_flat == 1) -
176
- (1 * (sparsity_layout_flat == 0)))
177
- .reshape(sparsity_layout.size(0) // partitions, partitions,
178
- sparsity_layout.size(1), sparsity_layout.size(2))
179
- .permute(0, 2, 1, 3)
180
- .reshape(sparsity_layout.size(0) // partitions,
181
- sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
182
- .reshape(-1).contiguous())
183
- lut["sparsity_reverse_lut"] = sparsity_reverse_lut
184
-
185
- if "n_sparse_blocks" not in lut:
186
- n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
187
- lut["n_sparse_blocks"] = n_sparse_blocks
188
-
189
- validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
190
-
191
- return lut
192
-
193
- @staticmethod
194
- def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
195
- num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int,
196
- triton_block_size: int) -> Tensor:
197
- ctx.save_for_backward(sparsity_layout_o)
198
- ctx.num_partitions = num_partitions
199
- ctx.dim = dim
200
-
201
- return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
202
- n_sparse_blocks, triton_block_size)
203
-
204
- @staticmethod
205
- def backward(ctx, grad_output):
206
- sparsity_layout = ctx.saved_tensors[0]
207
- num_partitions = ctx.num_partitions
208
- dim = ctx.dim
209
- sparsity_block_size = ctx.sparsity_block_size
210
- triton_block_size = ctx.triton_block_size
211
-
212
- return split(grad_output, sparsity_layout, num_partitions, dim,
213
- sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None, None
144
+ lut = merge_build_lut(lut, sparsity_layout, partitions)
145
+
146
+ return BlksprsTensor(merge_forward(
147
+ x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
148
+ partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
149
+
150
+
151
+ @triton_op("blksprs::merge", mutates_args={})
152
+ def merge_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
153
+ _: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
154
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
155
+ n_sparse_blocks)
156
+
157
+
158
+ def merge_backward(ctx, grad_output):
159
+ sparsity_layout = ctx.saved_tensors[0]
160
+ num_partitions = ctx.num_partitions
161
+ dim = ctx.dim
162
+ sparsity_block_size = ctx.sparsity_block_size
163
+
164
+ return split(grad_output, sparsity_layout, num_partitions, dim,
165
+ sparsity_block_size)[0], None, None, None, None, None, None, None
166
+
167
+
168
+ def merge_build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
169
+ if lut is None:
170
+ lut = dict()
171
+
172
+ if "sparsity_layout_output" not in lut:
173
+ sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
174
+ sparsity_layout.size(1), sparsity_layout.size(2))
175
+ .permute(0, 2, 1, 3)
176
+ .reshape(sparsity_layout.size(0) // partitions,
177
+ sparsity_layout.size(1),
178
+ sparsity_layout.size(2) * partitions).contiguous())
179
+ lut["sparsity_layout_output"] = sparsity_layout_output
180
+
181
+ if "sparsity_lut" not in lut:
182
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
183
+ lut["sparsity_lut"] = sparsity_lut
184
+
185
+ if "sparsity_reverse_lut" not in lut:
186
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
187
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
188
+ (sparsity_layout_flat == 1) -
189
+ (1 * (sparsity_layout_flat == 0)))
190
+ .reshape(sparsity_layout.size(0) // partitions, partitions,
191
+ sparsity_layout.size(1), sparsity_layout.size(2))
192
+ .permute(0, 2, 1, 3)
193
+ .reshape(sparsity_layout.size(0) // partitions,
194
+ sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
195
+ .reshape(-1).contiguous())
196
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
197
+
198
+ if "n_sparse_blocks" not in lut:
199
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
200
+ lut["n_sparse_blocks"] = n_sparse_blocks
201
+
202
+ validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
203
+
204
+ return lut
205
+
206
+
207
+ # noinspection PyUnusedLocal
208
+ def merge_setup_context(ctx, inputs, output):
209
+ (_, sparsity_layout_o, _, _, num_partitions, dim, sparsity_block_size, _) = inputs
210
+
211
+ ctx.save_for_backward(sparsity_layout_o)
212
+ ctx.num_partitions = num_partitions
213
+ ctx.dim = dim
214
+ ctx.sparsity_block_size = sparsity_block_size
215
+
216
+
217
+ merge_forward.register_autograd(merge_backward, setup_context=merge_setup_context)
blksprs/ops/repeat.py CHANGED
@@ -1,17 +1,15 @@
1
1
  import torch
2
- import triton
3
2
  from torch import Tensor
3
+ from torch._library import triton_op
4
4
 
5
- from blksprs.ops.flow import kernel_blocksparse_flow_push, flow_forward_pull, flow_forward_push
5
+ from blksprs.ops.flow import flow_pull_forward, flow_push_forward
6
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
8
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
9
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
+ validate_sparsity, validate_sparsity_block_size
10
9
 
11
10
 
12
11
  def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
13
- sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None,
14
- lut: dict = None) -> (
12
+ sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
15
13
  BlksprsTensor, Tensor):
16
14
  """Repeats a block-spare tensor in compressed form according to the given repeats.
17
15
 
@@ -30,7 +28,6 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
30
28
  third dimension respectively.
31
29
  sparsity_block_size (int): The size of the sparsity blocks.
32
30
  sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
33
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
34
31
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
35
32
 
36
33
  Returns:
@@ -45,19 +42,16 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
45
42
  validate_device(x)
46
43
  validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
47
44
  validate_sparsity_block_size(sparsity_block_size, x)
48
- validate_triton_block_size(triton_block_size, sparsity_block_size)
49
45
 
50
- lut = _BlocksparseRepeat.build_lut_repeat(lut, sparsity_layout_x, repeats, sparsity_layout_output)
46
+ lut = repeat_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
51
47
 
52
- return BlksprsTensor(
53
- _BlocksparseRepeat.apply(x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
54
- lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"],
55
- triton_block_size)), lut["sparsity_layout_o"]
48
+ return BlksprsTensor(repeat_forward(
49
+ x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
50
+ lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
56
51
 
57
52
 
58
53
  def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
59
- sparsity_block_size: int, sparsity_layout_output: Tensor = None,
60
- triton_block_size: int = None, lut: dict = None) -> (
54
+ sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
61
55
  BlksprsTensor, Tensor):
62
56
  """Repeats and interleaves the block-sparse tensor in compressed form.
63
57
 
@@ -74,7 +68,6 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
74
68
  repeats (int): The number of times to repeat the matrices.
75
69
  sparsity_block_size (int): The size of the sparsity blocks.
76
70
  sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
77
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
78
71
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
79
72
 
80
73
  Returns:
@@ -89,108 +82,110 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
89
82
  validate_device(x)
90
83
  validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
91
84
  validate_sparsity_block_size(sparsity_block_size, x)
92
- validate_triton_block_size(triton_block_size, sparsity_block_size)
93
85
 
94
- lut = _BlocksparseRepeat.build_lut_repeat_interleave(lut, sparsity_layout_x, repeats, sparsity_layout_output)
95
-
96
- return BlksprsTensor(
97
- _BlocksparseRepeat.apply(x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
98
- lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"],
99
- triton_block_size)), lut["sparsity_layout_o"]
100
-
101
-
102
- class _BlocksparseRepeat(torch.autograd.Function):
103
-
104
- @staticmethod
105
- def build_lut_repeat(lut: dict, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
106
- sparsity_layout_output: Tensor):
107
- if lut is None:
108
- lut = dict()
109
-
110
- if "sparsity_layout_o" not in lut:
111
- sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
112
- lut["sparsity_layout_o"] = sparsity_layout_o
113
-
114
- if sparsity_layout_output is not None:
115
- sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
116
- lut["sparsity_layout_o"] = sparsity_layout_o
117
-
118
- if "sparsity_lut" not in lut:
119
- sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
120
- lut["sparsity_lut"] = sparsity_lut
121
-
122
- if "sparsity_reverse_lut" not in lut:
123
- sparsity_layout_flat = sparsity_layout_x.reshape(-1)
124
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
125
- (sparsity_layout_flat == 1) -
126
- (1 * (sparsity_layout_flat == 0)))
127
- .reshape(sparsity_layout_x.size())
128
- .repeat(repeats[0], repeats[1], repeats[2])
129
- .reshape(-1).contiguous())
130
- lut["sparsity_reverse_lut"] = sparsity_reverse_lut
131
-
132
- if "n_sparse_blocks" not in lut:
133
- n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
134
- lut["n_sparse_blocks"] = n_sparse_blocks
135
-
136
- validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
137
-
138
- return lut
139
-
140
- @staticmethod
141
- def build_lut_repeat_interleave(lut: dict, sparsity_layout_x: Tensor, repeats: int,
142
- sparsity_layout_output: Tensor):
143
- if lut is None:
144
- lut = dict()
145
-
146
- if "sparsity_layout_o" not in lut:
147
- sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
148
- lut["sparsity_layout_o"] = sparsity_layout_o
149
-
150
- if sparsity_layout_output is not None:
151
- sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
152
- lut["sparsity_layout_o"] = sparsity_layout_o
153
-
154
- if "sparsity_lut" not in lut:
155
- sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
156
- lut["sparsity_lut"] = sparsity_lut
157
-
158
- if "sparsity_reverse_lut" not in lut:
159
- sparsity_layout_flat = sparsity_layout_x.reshape(-1)
160
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
161
- (sparsity_layout_flat == 1) -
162
- (1 * (sparsity_layout_flat == 0)))
163
- .reshape(sparsity_layout_x.size())
164
- .repeat_interleave(repeats, dim=0)
165
- .reshape(-1).contiguous())
166
- lut["sparsity_reverse_lut"] = sparsity_reverse_lut
167
-
168
- if "n_sparse_blocks" not in lut:
169
- n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
170
- lut["n_sparse_blocks"] = n_sparse_blocks
171
-
172
- validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
173
-
174
- return lut
175
-
176
- @staticmethod
177
- def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
178
- sparsity_reverse_lut: Tensor,
179
- sparsity_block_size: int, n_sparse_blocks: int,
180
- triton_block_size: int) -> Tensor:
181
- ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
182
-
183
- return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
184
- n_sparse_blocks, triton_block_size)
185
-
186
- @staticmethod
187
- def backward(ctx, grad_output):
188
- sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
189
- sparsity_block_size = ctx.sparsity_block_size
190
- triton_block_size = ctx.triton_block_size
191
-
192
- n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
193
-
194
- return flow_forward_push(None, grad_output, sparsity_layout_o, sparsity_lut,
195
- sparsity_reverse_lut, sparsity_block_size, n_sparse_blocks,
196
- triton_block_size), None, None, None, None, None, None, None
86
+ lut = repeat_interleave_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
87
+
88
+ return BlksprsTensor(repeat_forward(
89
+ x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
90
+ lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
91
+
92
+
93
+ @triton_op("blksprs::repeat", mutates_args={})
94
+ def repeat_forward(x: Tensor, _: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
95
+ sparsity_reverse_lut: Tensor,
96
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
97
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
98
+ n_sparse_blocks)
99
+
100
+
101
+ def repeat_backward(ctx, grad_output):
102
+ sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
103
+ sparsity_block_size = ctx.sparsity_block_size
104
+ n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
105
+
106
+ return flow_push_forward(grad_output, sparsity_layout_o, sparsity_lut,
107
+ sparsity_reverse_lut, sparsity_block_size,
108
+ n_sparse_blocks), None, None, None, None, None, None
109
+
110
+
111
+ def repeat_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
112
+ sparsity_layout_output: Tensor):
113
+ if lut is None:
114
+ lut = dict()
115
+
116
+ if "sparsity_layout_o" not in lut:
117
+ sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
118
+ lut["sparsity_layout_o"] = sparsity_layout_o
119
+
120
+ if sparsity_layout_output is not None:
121
+ sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
122
+ lut["sparsity_layout_o"] = sparsity_layout_o
123
+
124
+ if "sparsity_lut" not in lut:
125
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
126
+ lut["sparsity_lut"] = sparsity_lut
127
+
128
+ if "sparsity_reverse_lut" not in lut:
129
+ sparsity_layout_flat = sparsity_layout_x.reshape(-1)
130
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
131
+ (sparsity_layout_flat == 1) -
132
+ (1 * (sparsity_layout_flat == 0)))
133
+ .reshape(sparsity_layout_x.size())
134
+ .repeat(repeats[0], repeats[1], repeats[2])
135
+ .reshape(-1).contiguous())
136
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
137
+
138
+ if "n_sparse_blocks" not in lut:
139
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
140
+ lut["n_sparse_blocks"] = n_sparse_blocks
141
+
142
+ validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
143
+
144
+ return lut
145
+
146
+
147
+ def repeat_interleave_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: int,
148
+ sparsity_layout_output: Tensor):
149
+ if lut is None:
150
+ lut = dict()
151
+
152
+ if "sparsity_layout_o" not in lut:
153
+ sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
154
+ lut["sparsity_layout_o"] = sparsity_layout_o
155
+
156
+ if sparsity_layout_output is not None:
157
+ sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
158
+ lut["sparsity_layout_o"] = sparsity_layout_o
159
+
160
+ if "sparsity_lut" not in lut:
161
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
162
+ lut["sparsity_lut"] = sparsity_lut
163
+
164
+ if "sparsity_reverse_lut" not in lut:
165
+ sparsity_layout_flat = sparsity_layout_x.reshape(-1)
166
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
167
+ (sparsity_layout_flat == 1) -
168
+ (1 * (sparsity_layout_flat == 0)))
169
+ .reshape(sparsity_layout_x.size())
170
+ .repeat_interleave(repeats, dim=0)
171
+ .reshape(-1).contiguous())
172
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
173
+
174
+ if "n_sparse_blocks" not in lut:
175
+ n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
176
+ lut["n_sparse_blocks"] = n_sparse_blocks
177
+
178
+ validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
179
+
180
+ return lut
181
+
182
+
183
+ # noinspection PyUnusedLocal
184
+ def repeat_setup_context(ctx, inputs, output):
185
+ (_, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size, _) = inputs
186
+
187
+ ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
188
+ ctx.sparsity_block_size = sparsity_block_size
189
+
190
+
191
+ repeat_forward.register_autograd(repeat_backward, setup_context=repeat_setup_context)