blksprs 1.10.2__py3-none-any.whl → 2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -13,25 +13,21 @@ class ops:
13
13
  class misc:
14
14
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
15
15
  from blksprs.ops.misc.broadcast_ops import broadcast_add, broadcast_sub
16
- from blksprs.ops.misc.exp import exp
17
16
 
18
17
 
19
18
  class layouting:
20
19
  from blksprs.layouting.distribution_layout import build_distribution_layout
21
20
  from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
22
- build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
23
- from blksprs.utils.layout_utils import build_full_sparsity_layout
21
+ build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast, build_sparsity_layout_full
24
22
 
25
23
 
26
24
  class utils:
27
25
  from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
28
26
  apply_function_applicable_row_wise
29
27
  from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
30
- from blksprs.utils.validation import disable_validation
31
28
 
32
29
  class validation:
33
30
  from blksprs.utils.validation import disable_validation
34
31
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, \
35
32
  validate_dtype_int, validate_device, validate_sparsity, validate_sparsity_dense, \
36
- validate_sparsity_block_size, \
37
- validate_triton_block_size
33
+ validate_sparsity_block_size
@@ -4,14 +4,14 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
8
- from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
7
+ from blksprs.utils.tools import stride, get_autotune_configs
8
+ from blksprs.utils.validation import validate_dimensions, validate_device, \
9
9
  validate_contiguous
10
10
 
11
11
 
12
12
  def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
13
13
  dim: int, size_target: torch.Size,
14
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
14
+ sparsity_block_size: int) -> Tensor:
15
15
  """Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
16
16
 
17
17
  Args:
@@ -20,7 +20,6 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
20
20
  dim (int): The dimension along which the operation is conducted.
21
21
  size_target (torch.Size): The size of the block-sparse target tensor in regular form.
22
22
  sparsity_block_size (int): The size of the sparsity blocks.
23
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
24
23
 
25
24
  Returns:
26
25
  Tensor: The sparsity layout of the source or target tensor.
@@ -44,16 +43,11 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
44
43
  o_b, o_r, o_c = output.size()
45
44
  o_b_s, o_r_s, o_c_s = stride(output)
46
45
 
47
- if triton_block_size is None:
48
- triton_block_size = get_triton_block_size(sparsity_block_size)
49
-
50
- validate_triton_block_size(triton_block_size, sparsity_block_size)
51
-
52
46
  triton_grid = lambda meta: [i_b,
53
47
  triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
54
48
  triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
55
49
 
56
- (kernel_distribution_layout[triton_grid]
50
+ (build_distribution_layout_kernel[triton_grid]
57
51
  (indices,
58
52
  i_b, i_b_s, i_r_s, i_c_s,
59
53
  sparsity_lut_i,
@@ -61,27 +55,34 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
61
55
  adjusted_dim,
62
56
  output,
63
57
  o_b, o_b_s, o_r_s, o_c_s,
64
- sparsity_block_size,
65
- triton_block_size))
58
+ sparsity_block_size))
66
59
 
67
60
  return output
68
61
 
69
62
 
63
+ @triton.autotune(
64
+ configs=get_autotune_configs(),
65
+ key=[],
66
+ reset_to_zero=["o"]
67
+ )
70
68
  @triton.jit
71
- def kernel_distribution_layout(i,
72
- i_b, i_b_s, i_r_s, i_c_s,
73
- s_lut_i,
74
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
75
- dim,
76
- o,
77
- o_b, o_b_s, o_r_s, o_c_s,
78
- sparsity_block_size,
79
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
69
+ def build_distribution_layout_kernel(i,
70
+ i_b, i_b_s, i_r_s, i_c_s,
71
+ s_lut_i,
72
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
73
+ dim,
74
+ o,
75
+ o_b, o_b_s, o_r_s, o_c_s,
76
+ sparsity_block_size,
77
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
80
78
  # Get triton block indices
81
79
  pid_blk = tl.program_id(axis=0)
82
80
  pid_row = tl.program_id(axis=1)
83
81
  pid_col = tl.program_id(axis=2)
84
82
 
83
+ # Get valid triton block size
84
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
85
+
85
86
  # Get position of current sparsity block consisting of its batch, row, and column index
86
87
  spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
87
88
  spa_bat_i_msk = (spa_bat_i_idx >= 0 and spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
@@ -96,9 +97,12 @@ def kernel_distribution_layout(i,
96
97
  spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
97
98
 
98
99
  blk_i_idx = (pid_blk * i_b_s +
99
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
100
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
101
- blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
100
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
101
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
102
+ blk_i_msk = ((blk_i_idx >= 0 and
103
+ blk_i_idx < i_b * i_b_s) and
104
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
105
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
102
106
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
103
107
 
104
108
  dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
@@ -116,5 +120,8 @@ def kernel_distribution_layout(i,
116
120
  blk_o_idx = ((dst_bat_idx * o_b_s) +
117
121
  (dst_row_idx * o_r_s) +
118
122
  (dst_col_idx * o_c_s))
119
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
123
+ blk_o_msk = ((blk_o_idx >= 0 and
124
+ blk_o_idx < o_b * o_b_s) and
125
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
126
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
120
127
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
@@ -3,21 +3,21 @@ import math
3
3
  import torch
4
4
  import triton
5
5
  from torch import Tensor
6
+ from torch._library.triton import wrap_triton
6
7
  from triton import language as tl
7
8
 
8
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import get_triton_block_size, stride
10
- from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
10
+ from blksprs.utils.tools import stride, get_autotune_configs
11
+ from blksprs.utils.validation import validate_dimensions, validate_device, \
11
12
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
12
13
 
13
14
 
14
- def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
15
+ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
15
16
  """Builds the sparsity layout of a dense tensor in regular form covering its sparse blocks.
16
17
 
17
18
  Args:
18
19
  x (Tensor): A block-sparse (or dense) tensor in regular form.
19
20
  sparsity_block_size (int): The size of the sparsity blocks.
20
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
21
21
 
22
22
  Returns:
23
23
  Tensor: The sparsity layout of the input block-sparse (or dense) tensor.
@@ -35,57 +35,61 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size
35
35
  o_b, o_r, o_c = output.size()
36
36
  o_b_s, o_r_s, o_c_s = stride(output)
37
37
 
38
- if triton_block_size is None:
39
- triton_block_size = get_triton_block_size(sparsity_block_size)
40
-
41
- validate_triton_block_size(triton_block_size, sparsity_block_size)
42
-
43
38
  triton_grid = lambda meta: [x_b,
44
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
45
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
39
+ triton.cdiv(x_r, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"])),
40
+ triton.cdiv(x_c, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"]))]
46
41
 
47
- (kernel_sparsity_layout[triton_grid]
42
+ (wrap_triton(build_sparsity_layout_kernel)[triton_grid]
48
43
  (x,
49
44
  x_b, x_b_s, x_r_s, x_c_s,
50
45
  output,
51
46
  o_b, o_b_s, o_r_s, o_c_s,
52
- sparsity_block_size,
53
- triton_block_size))
47
+ sparsity_block_size))
54
48
 
55
49
  return output
56
50
 
57
51
 
52
+ @triton.autotune(
53
+ configs=get_autotune_configs(),
54
+ key=[],
55
+ reset_to_zero=["o"]
56
+ )
58
57
  @triton.jit
59
- def kernel_sparsity_layout(x,
60
- x_b, x_b_s, x_r_s, x_c_s,
61
- o,
62
- o_b, o_b_s, o_r_s, o_c_s,
63
- sparsity_block_size,
64
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
58
+ def build_sparsity_layout_kernel(x,
59
+ x_b, x_b_s, x_r_s, x_c_s,
60
+ o,
61
+ o_b, o_b_s, o_r_s, o_c_s,
62
+ sparsity_block_size,
63
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
65
64
  # Get triton block indices
66
65
  pid_bat = tl.program_id(axis=0)
67
66
  pid_row = tl.program_id(axis=1)
68
67
  pid_col = tl.program_id(axis=2)
69
68
 
69
+ # Get valid triton block size
70
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
71
+
70
72
  # Load x values
71
73
  blk_x_idx = (pid_bat * x_b_s +
72
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
73
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
74
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
74
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
75
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
76
+ blk_x_msk = ((blk_x_idx >= 0 and
77
+ blk_x_idx < x_b * x_b_s) and
78
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
79
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
75
80
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
76
81
 
77
82
  # Store sparsity layout value
78
83
  if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
79
84
  blk_o_idx = (pid_bat * o_b_s +
80
- (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
81
- ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
85
+ (((pid_row * val_tbs) // sparsity_block_size) * o_r_s +
86
+ ((pid_col * val_tbs) // sparsity_block_size) * o_c_s))
82
87
  blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
83
88
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
84
89
 
85
90
 
86
91
  def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tensor,
87
- sparsity_block_size_from: int, sparsity_block_size_to: int,
88
- triton_block_size: int = None) -> Tensor:
92
+ sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
89
93
  """Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
90
94
  used.
91
95
 
@@ -94,7 +98,6 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
94
98
  sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
95
99
  sparsity_block_size_from (int): The size of the sparsity blocks of the input tensor.
96
100
  sparsity_block_size_to (int): The desired size of the sparsity blocks for the resulting layout.
97
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
98
101
 
99
102
  Returns:
100
103
  Tensor: The sparsity layout in regular form using the new sparsity block size of the input block-sparse tensor
@@ -107,8 +110,6 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
107
110
  validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
108
111
  validate_sparsity_block_size(sparsity_block_size_from, x)
109
112
  validate_sparsity_block_size(sparsity_block_size_to)
110
- min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
111
- validate_triton_block_size(triton_block_size, min_sparsity_block_size)
112
113
 
113
114
  sparsity_lut = torch.nonzero(sparsity_layout_from).contiguous()
114
115
 
@@ -126,40 +127,44 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
126
127
  s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
127
128
  o_b_s, o_r_s, o_c_s = stride(output)
128
129
 
129
- if triton_block_size is None:
130
- triton_block_size = get_triton_block_size(sparsity_block_size_from)
131
-
132
130
  triton_grid = lambda meta: [x_b,
133
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
134
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
131
+ triton.cdiv(x_r, min(meta["sparsity_block_size_to"], meta["TRITON_BLOCK_SIZE"])),
132
+ triton.cdiv(x_c, min(meta["sparsity_block_size_to"], meta["TRITON_BLOCK_SIZE"]))]
135
133
 
136
- (kernel_sparsity_layout_adaption[triton_grid]
134
+ (wrap_triton(build_sparsity_layout_adaption_kernel)[triton_grid]
137
135
  (x,
138
136
  x_b, x_b_s, x_r_s, x_c_s,
139
137
  sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
140
138
  output,
141
139
  o_b, o_b_s, o_r_s, o_c_s,
142
140
  sparsity_block_size_from,
143
- sparsity_block_size_to,
144
- triton_block_size))
141
+ sparsity_block_size_to))
145
142
 
146
143
  return output
147
144
 
148
145
 
146
+ @triton.autotune(
147
+ configs=get_autotune_configs(),
148
+ key=[],
149
+ reset_to_zero=["o"]
150
+ )
149
151
  @triton.jit
150
- def kernel_sparsity_layout_adaption(x,
151
- x_b, x_b_s, x_r_s, x_c_s,
152
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
153
- o,
154
- o_b, o_b_s, o_r_s, o_c_s,
155
- sparsity_block_size_from,
156
- sparsity_block_size_to,
157
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
152
+ def build_sparsity_layout_adaption_kernel(x,
153
+ x_b, x_b_s, x_r_s, x_c_s,
154
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
155
+ o,
156
+ o_b, o_b_s, o_r_s, o_c_s,
157
+ sparsity_block_size_from,
158
+ sparsity_block_size_to,
159
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
158
160
  # Get triton block indices
159
161
  pid_blk = tl.program_id(axis=0)
160
162
  pid_row = tl.program_id(axis=1)
161
163
  pid_col = tl.program_id(axis=2)
162
164
 
165
+ # Get valid triton block size
166
+ val_tbs = min(sparsity_block_size_to, TRITON_BLOCK_SIZE)
167
+
163
168
  # Get sparsity index of current output block consisting of its batch, row, and column index
164
169
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
165
170
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -175,23 +180,26 @@ def kernel_sparsity_layout_adaption(x,
175
180
 
176
181
  # Load x values
177
182
  blk_x_idx = ((pid_blk * x_b_s) +
178
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
179
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
180
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
183
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
184
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
185
+ blk_x_msk = ((blk_x_idx >= 0 and
186
+ blk_x_idx < x_b * x_b_s) and
187
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < sparsity_block_size_from and
188
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < sparsity_block_size_from))
181
189
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
182
190
 
183
191
  # Store sparsity layout value
184
192
  if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
185
193
  blk_o_idx = ((spa_bat * o_b_s) +
186
- (((spa_row * sparsity_block_size_from + pid_row * TRITON_BLOCK_SIZE)
194
+ (((pid_row * val_tbs + spa_row * sparsity_block_size_from)
187
195
  // sparsity_block_size_to) * o_r_s) +
188
- (((spa_col * sparsity_block_size_from + pid_col * TRITON_BLOCK_SIZE)
196
+ (((pid_col * val_tbs + spa_col * sparsity_block_size_from)
189
197
  // sparsity_block_size_to) * o_c_s))
190
198
  blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
191
199
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
192
200
 
193
201
 
194
- def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
202
+ def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor) -> Tensor:
195
203
  """Builds the precise sparsity layout of the result of a matrix multiplication between the two input tensors.
196
204
 
197
205
  Args:
@@ -225,3 +233,8 @@ def build_sparsity_layout_matmul_fast(sparsity_layout_x: Tensor, sparsity_layout
225
233
  sparsity_layout_y_slice = torch.max(sparsity_layout_y, dim=-2).values.unsqueeze(1)
226
234
 
227
235
  return torch.logical_or(sparsity_layout_x_slice, sparsity_layout_y_slice)
236
+
237
+
238
+ def build_sparsity_layout_full(x: Tensor, sparsity_block_size: int) -> Tensor:
239
+ return torch.ones(size=(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size),
240
+ dtype=torch.bool, device=x.device)