blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -1,5 +1,7 @@
1
1
  from blksprs.utils.blksprs_tensor import BlksprsTensor
2
2
 
3
+ __version__ = "2.0"
4
+
3
5
 
4
6
  class ops:
5
7
  from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
@@ -18,19 +20,16 @@ class ops:
18
20
  class layouting:
19
21
  from blksprs.layouting.distribution_layout import build_distribution_layout
20
22
  from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
21
- build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
22
- from blksprs.utils.layout_utils import build_full_sparsity_layout
23
+ build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast, build_sparsity_layout_full
23
24
 
24
25
 
25
26
  class utils:
26
27
  from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
27
28
  apply_function_applicable_row_wise
28
29
  from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
29
- from blksprs.utils.validation import disable_validation
30
30
 
31
31
  class validation:
32
32
  from blksprs.utils.validation import disable_validation
33
33
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, \
34
34
  validate_dtype_int, validate_device, validate_sparsity, validate_sparsity_dense, \
35
- validate_sparsity_block_size, \
36
- validate_triton_block_size
35
+ validate_sparsity_block_size
@@ -1,17 +1,23 @@
1
+ import typing
2
+
1
3
  import torch
2
4
  import triton
3
5
  from torch import Tensor
6
+ from torch._library import triton_op
7
+ from torch._library.triton import wrap_triton
4
8
  from triton import language as tl
5
9
 
6
10
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
8
- from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
11
+ from blksprs.utils.tools import stride
12
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
13
+ from blksprs.utils.validation import validate_dimensions, validate_device, \
9
14
  validate_contiguous
10
15
 
11
16
 
17
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
12
18
  def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
13
19
  dim: int, size_target: torch.Size,
14
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
20
+ sparsity_block_size: int) -> Tensor:
15
21
  """Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
16
22
 
17
23
  Args:
@@ -20,7 +26,6 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
20
26
  dim (int): The dimension along which the operation is conducted.
21
27
  size_target (torch.Size): The size of the block-sparse target tensor in regular form.
22
28
  sparsity_block_size (int): The size of the sparsity blocks.
23
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
24
29
 
25
30
  Returns:
26
31
  Tensor: The sparsity layout of the source or target tensor.
@@ -34,49 +39,58 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
34
39
 
35
40
  adjusted_dim = dim % 3
36
41
 
37
- output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
38
- dtype=torch.bool, device=indices.device)
39
-
40
- i_b, i_r, i_c = indices.size()
41
- i_b_s, i_r_s, i_c_s = stride(indices)
42
- s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
43
- s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
44
- o_b, o_r, o_c = output.size()
45
- o_b_s, o_r_s, o_c_s = stride(output)
46
-
47
- if triton_block_size is None:
48
- triton_block_size = get_triton_block_size(sparsity_block_size)
49
-
50
- validate_triton_block_size(triton_block_size, sparsity_block_size)
51
-
52
- triton_grid = lambda meta: [i_b,
53
- triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
54
- triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
55
-
56
- (kernel_distribution_layout[triton_grid]
57
- (indices,
58
- i_b, i_b_s, i_r_s, i_c_s,
59
- sparsity_lut_i,
60
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
61
- adjusted_dim,
62
- output,
63
- o_b, o_b_s, o_r_s, o_c_s,
64
- sparsity_block_size,
65
- triton_block_size))
66
-
67
- return output
68
-
69
-
42
+ return build_distribution_layout_operation(indices, sparsity_lut_i, adjusted_dim, size_target, sparsity_block_size)
43
+
44
+
45
+ @triton_op("blksprs::build_distribution_layout", mutates_args={})
46
+ def build_distribution_layout_operation(indices: Tensor, sparsity_lut_i: Tensor,
47
+ adjusted_dim: int, size_target: typing.List[int],
48
+ sparsity_block_size: int) -> Tensor:
49
+ with torch.no_grad():
50
+ output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size,
51
+ size_target[2] // sparsity_block_size,
52
+ dtype=torch.bool, device=indices.device)
53
+
54
+ i_b, i_r, i_c = indices.size()
55
+ i_b_s, i_r_s, i_c_s = stride(indices)
56
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
57
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
58
+ o_b, o_r, o_c = output.size()
59
+ o_b_s, o_r_s, o_c_s = stride(output)
60
+
61
+ triton_grid = lambda meta: [i_b,
62
+ triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
63
+ triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
64
+
65
+ (wrap_triton(build_distribution_layout_kernel)[triton_grid]
66
+ (indices,
67
+ i_b, i_b_s, i_r_s, i_c_s,
68
+ sparsity_lut_i,
69
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
70
+ adjusted_dim,
71
+ output,
72
+ o_b, o_b_s, o_r_s, o_c_s,
73
+ sparsity_block_size))
74
+
75
+ return output
76
+
77
+
78
+ @triton.autotune(
79
+ configs=get_autotune_configs(),
80
+ key=["sparsity_block_size"],
81
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
82
+ reset_to_zero=["o"]
83
+ )
70
84
  @triton.jit
71
- def kernel_distribution_layout(i,
72
- i_b, i_b_s, i_r_s, i_c_s,
73
- s_lut_i,
74
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
75
- dim,
76
- o,
77
- o_b, o_b_s, o_r_s, o_c_s,
78
- sparsity_block_size,
79
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
85
+ def build_distribution_layout_kernel(i,
86
+ i_b, i_b_s, i_r_s, i_c_s,
87
+ s_lut_i,
88
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
89
+ dim,
90
+ o,
91
+ o_b, o_b_s, o_r_s, o_c_s,
92
+ sparsity_block_size,
93
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
80
94
  # Get triton block indices
81
95
  pid_blk = tl.program_id(axis=0)
82
96
  pid_row = tl.program_id(axis=1)
@@ -98,7 +112,8 @@ def kernel_distribution_layout(i,
98
112
  blk_i_idx = (pid_blk * i_b_s +
99
113
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
100
114
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
101
- blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
115
+ blk_i_msk = (blk_i_idx >= 0 and
116
+ blk_i_idx < i_b * i_b_s)
102
117
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
103
118
 
104
119
  dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
@@ -116,5 +131,6 @@ def kernel_distribution_layout(i,
116
131
  blk_o_idx = ((dst_bat_idx * o_b_s) +
117
132
  (dst_row_idx * o_r_s) +
118
133
  (dst_col_idx * o_c_s))
119
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
134
+ blk_o_msk = (blk_o_idx >= 0 and
135
+ blk_o_idx < o_b * o_b_s)
120
136
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
@@ -3,21 +3,23 @@ import math
3
3
  import torch
4
4
  import triton
5
5
  from torch import Tensor
6
+ from torch._library.triton import wrap_triton, triton_op
6
7
  from triton import language as tl
7
8
 
8
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import get_triton_block_size, stride
10
- from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
10
+ from blksprs.utils.tools import stride
11
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
12
+ from blksprs.utils.validation import validate_dimensions, validate_device, \
11
13
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
12
14
 
13
15
 
14
- def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
16
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
17
+ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
15
18
  """Builds the sparsity layout of a dense tensor in regular form covering its sparse blocks.
16
19
 
17
20
  Args:
18
21
  x (Tensor): A block-sparse (or dense) tensor in regular form.
19
22
  sparsity_block_size (int): The size of the sparsity blocks.
20
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
21
23
 
22
24
  Returns:
23
25
  Tensor: The sparsity layout of the input block-sparse (or dense) tensor.
@@ -27,41 +29,47 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size
27
29
  validate_contiguous(x)
28
30
  validate_device(x)
29
31
 
30
- output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
31
- dtype=torch.bool, device=x.device)
32
+ return build_sparsity_layout_operation(x, sparsity_block_size)
32
33
 
33
- x_b, x_r, x_c = x.size()
34
- x_b_s, x_r_s, x_c_s = stride(x)
35
- o_b, o_r, o_c = output.size()
36
- o_b_s, o_r_s, o_c_s = stride(output)
37
34
 
38
- if triton_block_size is None:
39
- triton_block_size = get_triton_block_size(sparsity_block_size)
35
+ @triton_op("blksprs::build_sparsity_layout", mutates_args={})
36
+ def build_sparsity_layout_operation(x: Tensor, sparsity_block_size: int) -> Tensor:
37
+ with torch.no_grad():
38
+ output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
39
+ dtype=torch.bool, device=x.device)
40
40
 
41
- validate_triton_block_size(triton_block_size, sparsity_block_size)
41
+ x_b, x_r, x_c = x.size()
42
+ x_b_s, x_r_s, x_c_s = stride(x)
43
+ o_b, o_r, o_c = output.size()
44
+ o_b_s, o_r_s, o_c_s = stride(output)
42
45
 
43
- triton_grid = lambda meta: [x_b,
44
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
45
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
46
+ triton_grid = lambda meta: [x_b,
47
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
48
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
46
49
 
47
- (kernel_sparsity_layout[triton_grid]
48
- (x,
49
- x_b, x_b_s, x_r_s, x_c_s,
50
- output,
51
- o_b, o_b_s, o_r_s, o_c_s,
52
- sparsity_block_size,
53
- triton_block_size))
50
+ (wrap_triton(build_sparsity_layout_kernel)[triton_grid]
51
+ (x,
52
+ x_b, x_b_s, x_r_s, x_c_s,
53
+ output,
54
+ o_b, o_b_s, o_r_s, o_c_s,
55
+ sparsity_block_size))
54
56
 
55
- return output
57
+ return output
56
58
 
57
59
 
60
+ @triton.autotune(
61
+ configs=get_autotune_configs(),
62
+ key=["sparsity_block_size"],
63
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
64
+ reset_to_zero=["o"]
65
+ )
58
66
  @triton.jit
59
- def kernel_sparsity_layout(x,
60
- x_b, x_b_s, x_r_s, x_c_s,
61
- o,
62
- o_b, o_b_s, o_r_s, o_c_s,
63
- sparsity_block_size,
64
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
67
+ def build_sparsity_layout_kernel(x,
68
+ x_b, x_b_s, x_r_s, x_c_s,
69
+ o,
70
+ o_b, o_b_s, o_r_s, o_c_s,
71
+ sparsity_block_size,
72
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
65
73
  # Get triton block indices
66
74
  pid_bat = tl.program_id(axis=0)
67
75
  pid_row = tl.program_id(axis=1)
@@ -71,7 +79,8 @@ def kernel_sparsity_layout(x,
71
79
  blk_x_idx = (pid_bat * x_b_s +
72
80
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
73
81
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
74
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
82
+ blk_x_msk = (blk_x_idx >= 0 and
83
+ blk_x_idx < x_b * x_b_s)
75
84
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
76
85
 
77
86
  # Store sparsity layout value
@@ -83,9 +92,9 @@ def kernel_sparsity_layout(x,
83
92
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
84
93
 
85
94
 
95
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
86
96
  def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tensor,
87
- sparsity_block_size_from: int, sparsity_block_size_to: int,
88
- triton_block_size: int = None) -> Tensor:
97
+ sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
89
98
  """Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
90
99
  used.
91
100
 
@@ -94,7 +103,6 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
94
103
  sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
95
104
  sparsity_block_size_from (int): The size of the sparsity blocks of the input tensor.
96
105
  sparsity_block_size_to (int): The desired size of the sparsity blocks for the resulting layout.
97
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
98
106
 
99
107
  Returns:
100
108
  Tensor: The sparsity layout in regular form using the new sparsity block size of the input block-sparse tensor
@@ -107,54 +115,62 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
107
115
  validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
108
116
  validate_sparsity_block_size(sparsity_block_size_from, x)
109
117
  validate_sparsity_block_size(sparsity_block_size_to)
110
- min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
111
- validate_triton_block_size(triton_block_size, min_sparsity_block_size)
112
118
 
113
119
  sparsity_lut = torch.nonzero(sparsity_layout_from).contiguous()
114
120
 
115
121
  validate_contiguous(sparsity_layout_from, sparsity_lut)
116
122
 
117
- o_b = sparsity_layout_from.size(0)
118
- o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
119
- o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
123
+ return build_sparsity_layout_adaption_operation(x, sparsity_layout_from, sparsity_lut,
124
+ sparsity_block_size_from, sparsity_block_size_to)
120
125
 
121
- output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
122
126
 
123
- x_b, x_r, x_c = x.size()
124
- x_b_s, x_r_s, x_c_s = stride(x)
125
- s_lut_r, s_lut_c = sparsity_lut.size()
126
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
127
- o_b_s, o_r_s, o_c_s = stride(output)
127
+ @triton_op("blksprs::build_sparsity_layout_adaption", mutates_args={})
128
+ def build_sparsity_layout_adaption_operation(x: Tensor, sparsity_layout_from: Tensor, sparsity_lut: Tensor,
129
+ sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
130
+ with torch.no_grad():
131
+ o_b = sparsity_layout_from.size(0)
132
+ o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
133
+ o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
128
134
 
129
- if triton_block_size is None:
130
- triton_block_size = get_triton_block_size(sparsity_block_size_from)
135
+ output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
131
136
 
132
- triton_grid = lambda meta: [x_b,
133
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
134
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
137
+ x_b, x_r, x_c = x.size()
138
+ x_b_s, x_r_s, x_c_s = stride(x)
139
+ s_lut_r, s_lut_c = sparsity_lut.size()
140
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
141
+ o_b_s, o_r_s, o_c_s = stride(output)
135
142
 
136
- (kernel_sparsity_layout_adaption[triton_grid]
137
- (x,
138
- x_b, x_b_s, x_r_s, x_c_s,
139
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
140
- output,
141
- o_b, o_b_s, o_r_s, o_c_s,
142
- sparsity_block_size_from,
143
- sparsity_block_size_to,
144
- triton_block_size))
143
+ triton_grid = lambda meta: [x_b,
144
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
145
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
145
146
 
146
- return output
147
+ (wrap_triton(build_sparsity_layout_adaption_kernel)[triton_grid]
148
+ (x,
149
+ x_b, x_b_s, x_r_s, x_c_s,
150
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
151
+ output,
152
+ o_b, o_b_s, o_r_s, o_c_s,
153
+ sparsity_block_size_from,
154
+ sparsity_block_size_to))
147
155
 
156
+ return output
148
157
 
158
+
159
+ @triton.autotune(
160
+ configs=get_autotune_configs(),
161
+ key=["sparsity_block_size_from", "sparsity_block_size_to"],
162
+ prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
163
+ reset_to_zero=["o"]
164
+ )
149
165
  @triton.jit
150
- def kernel_sparsity_layout_adaption(x,
151
- x_b, x_b_s, x_r_s, x_c_s,
152
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
153
- o,
154
- o_b, o_b_s, o_r_s, o_c_s,
155
- sparsity_block_size_from,
156
- sparsity_block_size_to,
157
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
166
+ def build_sparsity_layout_adaption_kernel(x,
167
+ x_b, x_b_s, x_r_s, x_c_s,
168
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
169
+ o,
170
+ o_b, o_b_s, o_r_s, o_c_s,
171
+ sparsity_block_size_from,
172
+ sparsity_block_size_to,
173
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
158
174
  # Get triton block indices
159
175
  pid_blk = tl.program_id(axis=0)
160
176
  pid_row = tl.program_id(axis=1)
@@ -177,21 +193,23 @@ def kernel_sparsity_layout_adaption(x,
177
193
  blk_x_idx = ((pid_blk * x_b_s) +
178
194
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
179
195
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
180
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
196
+ blk_x_msk = (blk_x_idx >= 0 and
197
+ blk_x_idx < x_b * x_b_s)
181
198
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
182
199
 
183
200
  # Store sparsity layout value
184
201
  if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
185
202
  blk_o_idx = ((spa_bat * o_b_s) +
186
- (((spa_row * sparsity_block_size_from + pid_row * TRITON_BLOCK_SIZE)
203
+ (((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size_from)
187
204
  // sparsity_block_size_to) * o_r_s) +
188
- (((spa_col * sparsity_block_size_from + pid_col * TRITON_BLOCK_SIZE)
205
+ (((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size_from)
189
206
  // sparsity_block_size_to) * o_c_s))
190
207
  blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
191
208
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
192
209
 
193
210
 
194
- def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
211
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
212
+ def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor) -> Tensor:
195
213
  """Builds the precise sparsity layout of the result of a matrix multiplication between the two input tensors.
196
214
 
197
215
  Args:
@@ -205,6 +223,7 @@ def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: T
205
223
  return torch.matmul(sparsity_layout_x.to(torch.float), sparsity_layout_y.to(torch.float)).to(torch.bool)
206
224
 
207
225
 
226
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
208
227
  def build_sparsity_layout_matmul_fast(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
209
228
  """Builds the approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
210
229
 
@@ -225,3 +244,8 @@ def build_sparsity_layout_matmul_fast(sparsity_layout_x: Tensor, sparsity_layout
225
244
  sparsity_layout_y_slice = torch.max(sparsity_layout_y, dim=-2).values.unsqueeze(1)
226
245
 
227
246
  return torch.logical_or(sparsity_layout_x_slice, sparsity_layout_y_slice)
247
+
248
+
249
+ def build_sparsity_layout_full(x: Tensor, sparsity_block_size: int) -> Tensor:
250
+ return torch.ones(size=(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size),
251
+ dtype=torch.bool, device=x.device)