blksprs 1.9.2__tar.gz → 1.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {blksprs-1.9.2 → blksprs-1.10}/PKG-INFO +18 -14
  2. {blksprs-1.9.2 → blksprs-1.10}/README.md +17 -13
  3. {blksprs-1.9.2 → blksprs-1.10}/blksprs/__init__.py +0 -6
  4. {blksprs-1.9.2 → blksprs-1.10}/blksprs/layouting/distribution_layout.py +6 -6
  5. {blksprs-1.9.2 → blksprs-1.10}/blksprs/layouting/sparsity_layout.py +7 -7
  6. {blksprs-1.9.2 → blksprs-1.10}/blksprs/ops/conversion.py +14 -16
  7. {blksprs-1.9.2 → blksprs-1.10}/blksprs/ops/distribution.py +14 -14
  8. {blksprs-1.9.2 → blksprs-1.10}/blksprs/ops/flow.py +12 -12
  9. {blksprs-1.9.2 → blksprs-1.10}/blksprs/ops/matmul.py +8 -8
  10. {blksprs-1.9.2 → blksprs-1.10}/blksprs/ops/misc/broadcast_ops.py +6 -6
  11. {blksprs-1.9.2 → blksprs-1.10}/blksprs/ops/misc/exp.py +2 -2
  12. {blksprs-1.9.2 → blksprs-1.10}/blksprs/ops/misc/row_wise.py +16 -19
  13. {blksprs-1.9.2 → blksprs-1.10}/blksprs/ops/partitioning.py +24 -10
  14. {blksprs-1.9.2 → blksprs-1.10}/blksprs/ops/softmax.py +17 -16
  15. {blksprs-1.9.2 → blksprs-1.10}/blksprs/ops/transpose.py +9 -8
  16. {blksprs-1.9.2 → blksprs-1.10}/blksprs/utils/blksprs_tensor.py +3 -1
  17. {blksprs-1.9.2 → blksprs-1.10}/blksprs.egg-info/PKG-INFO +18 -14
  18. {blksprs-1.9.2 → blksprs-1.10}/blksprs.egg-info/SOURCES.txt +0 -1
  19. {blksprs-1.9.2 → blksprs-1.10}/pyproject.toml +1 -1
  20. blksprs-1.9.2/blksprs/ops/experimental/distribution_mdi.py +0 -447
  21. {blksprs-1.9.2 → blksprs-1.10}/blksprs/ops/repeat.py +0 -0
  22. {blksprs-1.9.2 → blksprs-1.10}/blksprs/utils/benchmarking.py +0 -0
  23. {blksprs-1.9.2 → blksprs-1.10}/blksprs/utils/layout_utils.py +0 -0
  24. {blksprs-1.9.2 → blksprs-1.10}/blksprs/utils/processing.py +0 -0
  25. {blksprs-1.9.2 → blksprs-1.10}/blksprs/utils/tools.py +0 -0
  26. {blksprs-1.9.2 → blksprs-1.10}/blksprs/utils/validation.py +0 -0
  27. {blksprs-1.9.2 → blksprs-1.10}/blksprs.egg-info/dependency_links.txt +0 -0
  28. {blksprs-1.9.2 → blksprs-1.10}/blksprs.egg-info/requires.txt +0 -0
  29. {blksprs-1.9.2 → blksprs-1.10}/blksprs.egg-info/top_level.txt +0 -0
  30. {blksprs-1.9.2 → blksprs-1.10}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.9.2
3
+ Version: 1.10
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -23,14 +23,6 @@ Requires-Dist: build; extra == "build"
23
23
  [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
24
24
  [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
25
25
 
26
- ## Important Notice
27
-
28
- 🚨 **Non-Final API** 🚨
29
-
30
- Although it already supports a wide variety of functions, this library is still under active development and the API is
31
- subject to change. For feature requests or bug reports, please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
32
- We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
33
-
34
26
  ## Overview
35
27
 
36
28
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
@@ -44,7 +36,7 @@ Currently supported operations (includes gradient calculation):
44
36
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
45
37
  - Repeat (_supports target sparsity layout_)
46
38
  - Repeat Interleave (_supports target sparsity layout_)
47
- - Splitting and merging of matrices along the last dimension
39
+ - Splitting and merging of matrices (_currently* only supports splitting and merging along the last dimension_)
48
40
  - Conversion to and from sparse form
49
41
  - Conversion to different sparsity layouts and different sparsity block sizes
50
42
 
@@ -70,13 +62,15 @@ Furthermore, the library provides a set of utility functions
70
62
  - for the creation of sparsity layouts based on existing
71
63
  dense tensors and for the scatter operation (module ``bs.layouting``),
72
64
  - for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
73
- - as well as utility functions to apply linear layers,
74
- ensure correct input dimensionality, and validate input (module ``bs.utils``).
65
+ - as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
66
+
67
+ _* see the [Roadmap](#roadmap) section for more information_
75
68
 
76
69
  ## Installation
77
70
 
78
- Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
79
- the Linux platform.
71
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with
72
+ the Linux platform**.
73
+ Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
80
74
 
81
75
  We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
82
76
 
@@ -92,6 +86,16 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
92
86
 
93
87
  See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
94
88
 
89
+ ## Roadmap
90
+
91
+ Note that since this library covers all our current needs it is in a **bugfix-only** state.
92
+ This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and ``merge`` operations.
93
+ We will continue to maintain the library and fix any issues that arise.
94
+ Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
95
+ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
96
+
97
+ It might be that this changes with future projects, but as of December 2024, we are content with the current state of the library.
98
+
95
99
  ## Usage
96
100
 
97
101
  We provide an example below to demonstrate the usage of the library.
@@ -3,14 +3,6 @@
3
3
  [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
4
4
  [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
5
5
 
6
- ## Important Notice
7
-
8
- 🚨 **Non-Final API** 🚨
9
-
10
- Although it already supports a wide variety of functions, this library is still under active development and the API is
11
- subject to change. For feature requests or bug reports, please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
12
- We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
13
-
14
6
  ## Overview
15
7
 
16
8
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
@@ -24,7 +16,7 @@ Currently supported operations (includes gradient calculation):
24
16
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
25
17
  - Repeat (_supports target sparsity layout_)
26
18
  - Repeat Interleave (_supports target sparsity layout_)
27
- - Splitting and merging of matrices along the last dimension
19
+ - Splitting and merging of matrices (_currently* only supports splitting and merging along the last dimension_)
28
20
  - Conversion to and from sparse form
29
21
  - Conversion to different sparsity layouts and different sparsity block sizes
30
22
 
@@ -50,13 +42,15 @@ Furthermore, the library provides a set of utility functions
50
42
  - for the creation of sparsity layouts based on existing
51
43
  dense tensors and for the scatter operation (module ``bs.layouting``),
52
44
  - for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
53
- - as well as utility functions to apply linear layers,
54
- ensure correct input dimensionality, and validate input (module ``bs.utils``).
45
+ - as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
46
+
47
+ _* see the [Roadmap](#roadmap) section for more information_
55
48
 
56
49
  ## Installation
57
50
 
58
- Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
59
- the Linux platform.
51
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with
52
+ the Linux platform**.
53
+ Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
60
54
 
61
55
  We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
62
56
 
@@ -72,6 +66,16 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
72
66
 
73
67
  See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
74
68
 
69
+ ## Roadmap
70
+
71
+ Note that since this library covers all our current needs it is in a **bugfix-only** state.
72
+ This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and ``merge`` operations.
73
+ We will continue to maintain the library and fix any issues that arise.
74
+ Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
75
+ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
76
+
77
+ It might be that this changes with future projects, but as of December 2024, we are content with the current state of the library.
78
+
75
79
  ## Usage
76
80
 
77
81
  We provide an example below to demonstrate the usage of the library.
@@ -15,9 +15,6 @@ class ops:
15
15
  from blksprs.ops.misc.broadcast_ops import broadcast_add, broadcast_sub
16
16
  from blksprs.ops.misc.exp import exp
17
17
 
18
- class experimental:
19
- from blksprs.ops.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
20
-
21
18
 
22
19
  class layouting:
23
20
  from blksprs.layouting.distribution_layout import build_distribution_layout
@@ -25,9 +22,6 @@ class layouting:
25
22
  build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
26
23
  from blksprs.utils.layout_utils import build_full_sparsity_layout
27
24
 
28
- class experimental:
29
- from blksprs.ops.experimental.distribution_mdi import build_distribution_layout_mdi
30
-
31
25
 
32
26
  class utils:
33
27
  from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
@@ -84,21 +84,21 @@ def kernel_distribution_layout(i,
84
84
 
85
85
  # Get position of current sparsity block consisting of its batch, row, and column index
86
86
  spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
87
- spa_bat_i_msk = (spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
87
+ spa_bat_i_msk = (spa_bat_i_idx >= 0 and spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
88
88
  spa_bat_i = tl.load(s_lut_i + spa_bat_i_idx, mask=spa_bat_i_msk)
89
89
 
90
90
  spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
91
- spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
91
+ spa_row_i_msk = (spa_row_i_idx >= 0 and spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
92
92
  spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
93
93
 
94
94
  spa_col_i_idx = (pid_blk * s_lut_i_r_s + 2 * s_lut_i_c_s)
95
- spa_col_i_msk = (spa_col_i_idx < s_lut_i_r * s_lut_i_r_s)
95
+ spa_col_i_msk = (spa_col_i_idx >= 0 and spa_col_i_idx < s_lut_i_r * s_lut_i_r_s)
96
96
  spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
97
97
 
98
98
  blk_i_idx = (pid_blk * i_b_s +
99
99
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
100
100
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
101
- blk_i_msk = (blk_i_idx < i_b * i_b_s)
101
+ blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
102
102
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
103
103
 
104
104
  dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
@@ -111,10 +111,10 @@ def kernel_distribution_layout(i,
111
111
  elif dim == 2:
112
112
  dst_col_idx = blk_i // sparsity_block_size
113
113
 
114
- blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
114
+ blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int1)
115
115
 
116
116
  blk_o_idx = ((dst_bat_idx * o_b_s) +
117
117
  (dst_row_idx * o_r_s) +
118
118
  (dst_col_idx * o_c_s))
119
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
119
+ blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
120
120
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
@@ -71,7 +71,7 @@ def kernel_sparsity_layout(x,
71
71
  blk_x_idx = (pid_bat * x_b_s +
72
72
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
73
73
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
74
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
74
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
75
75
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
76
76
 
77
77
  # Store sparsity layout value
@@ -79,7 +79,7 @@ def kernel_sparsity_layout(x,
79
79
  blk_o_idx = (pid_bat * o_b_s +
80
80
  (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
81
81
  ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
82
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
82
+ blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
83
83
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
84
84
 
85
85
 
@@ -162,22 +162,22 @@ def kernel_sparsity_layout_adaption(x,
162
162
 
163
163
  # Get sparsity index of current output block consisting of its batch, row, and column index
164
164
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
165
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
165
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
166
166
  spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
167
167
 
168
168
  spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
169
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
169
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
170
170
  spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
171
171
 
172
172
  spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
173
- spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
173
+ spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
174
174
  spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
175
175
 
176
176
  # Load x values
177
177
  blk_x_idx = ((pid_blk * x_b_s) +
178
178
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
179
179
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
180
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
180
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
181
181
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
182
182
 
183
183
  # Store sparsity layout value
@@ -187,7 +187,7 @@ def kernel_sparsity_layout_adaption(x,
187
187
  // sparsity_block_size_to) * o_r_s) +
188
188
  (((spa_col * sparsity_block_size_from + pid_col * TRITON_BLOCK_SIZE)
189
189
  // sparsity_block_size_to) * o_c_s))
190
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
190
+ blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
191
191
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
192
192
 
193
193
 
@@ -1,5 +1,3 @@
1
- from typing import Any
2
-
3
1
  import torch
4
2
  import triton
5
3
  from torch import Tensor
@@ -133,7 +131,7 @@ class _BlocksparseToDense(torch.autograd.Function):
133
131
 
134
132
  # Get reverse sparsity index for current block
135
133
  rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
136
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
134
+ rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
137
135
  rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
138
136
 
139
137
  # If block is present commence operations
@@ -143,13 +141,13 @@ class _BlocksparseToDense(torch.autograd.Function):
143
141
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
144
142
  (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
145
143
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
146
- blk_msk = (blk_idx < x_b * x_b_s)
144
+ blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
147
145
  blk = tl.load(x + blk_idx, mask=blk_msk)
148
146
 
149
147
  o_idx = (pid_blk * o_b_s +
150
148
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
151
149
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
152
- o_msk = (o_idx < o_b * o_b_s)
150
+ o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
153
151
  tl.store(o + o_idx, blk, o_msk)
154
152
 
155
153
 
@@ -260,15 +258,15 @@ class _BlocksparseToSparse(torch.autograd.Function):
260
258
 
261
259
  # Get sparsity index of current output block consisting of its batch, row, and column index
262
260
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
263
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
261
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
264
262
  spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
265
263
 
266
264
  spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
267
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
265
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
268
266
  spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
269
267
 
270
268
  spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
271
- spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
269
+ spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
272
270
  spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
273
271
 
274
272
  # Load block from dense tensor
@@ -277,14 +275,14 @@ class _BlocksparseToSparse(torch.autograd.Function):
277
275
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
278
276
  ((spa_col * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
279
277
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
280
- blk_d_msk = (blk_d_idx < x_b * x_b_s)
278
+ blk_d_msk = (blk_d_idx >= 0 and blk_d_idx < x_b * x_b_s)
281
279
  blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
282
280
 
283
281
  # Store block in sparse tensor
284
282
  blk_o_idx = ((pid_blk * o_b_s) +
285
283
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
286
284
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
287
- blk_o_msk = (blk_o_idx < (pid_blk + 1) * o_b_s)
285
+ blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < (pid_blk + 1) * o_b_s)
288
286
  tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
289
287
 
290
288
 
@@ -424,15 +422,15 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
424
422
 
425
423
  # Get position of current sparsity block consisting of its batch, row, and column index
426
424
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
427
- spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
425
+ spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
428
426
  spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
429
427
 
430
428
  spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
431
- spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
429
+ spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
432
430
  spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
433
431
 
434
432
  spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
435
- spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
433
+ spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
436
434
  spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
437
435
 
438
436
  # Get equivalent sparsity block in from layout
@@ -444,7 +442,7 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
444
442
  rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
445
443
  spa_row_x * s_l_x_r_s +
446
444
  spa_col_x * s_l_x_c_s)
447
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
445
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
448
446
  rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
449
447
 
450
448
  # If block is present commence operations
@@ -459,12 +457,12 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
459
457
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
460
458
  ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
461
459
  ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
462
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
460
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
463
461
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
464
462
 
465
463
  # Store output
466
464
  blk_o_idx = ((pid_blk * o_b_s) +
467
465
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
468
466
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
469
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
467
+ blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
470
468
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -138,22 +138,22 @@ class _BlocksparseGather(torch.autograd.Function):
138
138
 
139
139
  # Get position of current sparsity block consisting of its batch, row, and column index
140
140
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
141
- spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
141
+ spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
142
142
  spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
143
143
 
144
144
  spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
145
- spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
145
+ spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
146
146
  spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
147
147
 
148
148
  spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
149
- spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
149
+ spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
150
150
  spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
151
151
 
152
152
  # Load index values
153
153
  blk_i_idx = ((pid_blk * i_b_s) +
154
154
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
155
155
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
156
- blk_i_msk = (blk_i_idx < i_b * i_b_s)
156
+ blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
157
157
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
158
158
 
159
159
  # Get indices of sparsity blocks and positions within the blocks
@@ -180,21 +180,21 @@ class _BlocksparseGather(torch.autograd.Function):
180
180
  rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
181
181
  (rev_dst_row_x * s_l_x_r_s) +
182
182
  (rev_dst_col_x * s_l_x_c_s))
183
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
183
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
184
184
  rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
185
185
 
186
186
  # Load x values
187
187
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
188
188
  dst_row_x +
189
189
  dst_col_x)
190
- blk_x_msk = ((blk_x_idx < x_b * x_b_s) & rev_idx_spa_x_msk != -1)
190
+ blk_x_msk = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s) and rev_idx_spa_x_msk != -1)
191
191
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
192
192
 
193
193
  # Store output
194
194
  blk_o_idx = ((pid_blk * o_b_s) +
195
195
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
196
196
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
197
- blk_o_msk = ((blk_o_idx < o_b * o_b_s) & rev_idx_spa_x_msk != -1)
197
+ blk_o_msk = ((blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s) and rev_idx_spa_x_msk != -1)
198
198
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
199
199
 
200
200
 
@@ -364,29 +364,29 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
364
364
 
365
365
  # Get position of current sparsity block consisting of its batch, row, and column index
366
366
  spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
367
- spa_bat_x_msk = (spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
367
+ spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
368
368
  spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
369
369
 
370
370
  spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
371
- spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
371
+ spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
372
372
  spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
373
373
 
374
374
  spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
375
- spa_col_x_msk = (spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
375
+ spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
376
376
  spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
377
377
 
378
378
  # Load x values
379
379
  blk_x_idx = ((pid_blk * x_b_s) +
380
380
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
381
381
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
382
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
382
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
383
383
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
384
384
 
385
385
  # Load index values
386
386
  blk_i_idx = ((pid_blk * i_b_s) +
387
387
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
388
388
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
389
- blk_i_msk = (blk_i_idx < i_b * i_b_s)
389
+ blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
390
390
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
391
391
 
392
392
  # Get indices of sparsity blocks and positions within the blocks
@@ -413,14 +413,14 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
413
413
  rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
414
414
  (rev_dst_row_o * s_l_o_r_s) +
415
415
  (rev_dst_col_o * s_l_o_c_s))
416
- rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
416
+ rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
417
417
  rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
418
418
 
419
419
  # Store output
420
420
  blk_o_idx = ((rev_idx_spa_o * o_b_s) +
421
421
  dst_row_o +
422
422
  dst_col_o)
423
- blk_o_msk = ((blk_o_idx < o_b * o_b_s) & rev_idx_spa_o_msk != -1)
423
+ blk_o_msk = ((blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s) and rev_idx_spa_o_msk != -1)
424
424
 
425
425
  if reduce_op_ind == 0:
426
426
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -22,22 +22,22 @@ def kernel_blocksparse_flow_pull(x,
22
22
 
23
23
  # Get sparsity index of current output block consisting of its batch, row, and column index
24
24
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
25
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
25
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
26
26
  spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
27
27
 
28
28
  spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
29
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
29
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
30
30
  spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
31
31
 
32
32
  spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
33
- spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
33
+ spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
34
34
  spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
35
35
 
36
36
  # Get reverse sparsity index
37
37
  rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
38
38
  spa_row * s_l_o_r_s +
39
39
  spa_col * s_l_o_c_s)
40
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
40
+ rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
41
41
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
42
42
 
43
43
  if rev_idx_spa == -1:
@@ -47,13 +47,13 @@ def kernel_blocksparse_flow_pull(x,
47
47
  blk_x_idx = (rev_idx_spa * x_b_s +
48
48
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
49
49
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
50
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
50
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
51
51
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
52
52
 
53
53
  blk_o_idx = (pid_blk * o_b_s +
54
54
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
55
55
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
56
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
56
+ blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
57
57
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
58
58
 
59
59
 
@@ -73,22 +73,22 @@ def kernel_blocksparse_flow_push(x,
73
73
 
74
74
  # Get sparsity index of current input block consisting of its batch, row, and column index
75
75
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
76
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
76
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
77
77
  spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
78
78
 
79
79
  spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
80
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
80
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
81
81
  spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
82
82
 
83
83
  spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
84
- spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
84
+ spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
85
85
  spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
86
86
 
87
87
  # Get reverse sparsity index
88
88
  rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
89
89
  spa_row * s_l_x_r_s +
90
90
  spa_col * s_l_x_c_s)
91
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
91
+ rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
92
92
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
93
93
 
94
94
  if rev_idx_spa == -1:
@@ -98,13 +98,13 @@ def kernel_blocksparse_flow_push(x,
98
98
  blk_x_idx = (pid_blk * x_b_s +
99
99
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
100
100
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
101
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
101
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
102
102
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
103
103
 
104
104
  blk_o_idx = (rev_idx_spa * o_b_s +
105
105
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
106
106
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
107
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
107
+ blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
108
108
  tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
109
109
 
110
110
 
@@ -164,15 +164,15 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
164
164
 
165
165
  # Get position of current sparsity block consisting of its batch, row, and column index
166
166
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
167
- spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
167
+ spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
168
168
  spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
169
169
 
170
170
  spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
171
- spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
171
+ spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
172
172
  spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
173
173
 
174
174
  spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
175
- spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
175
+ spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
176
176
  spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
177
177
 
178
178
  # Setup buffer
@@ -192,12 +192,12 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
192
192
  rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
193
193
  spa_row_o * s_l_x_r_s +
194
194
  i_seg_spa * s_l_x_c_s)
195
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
195
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
196
196
  rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
197
197
 
198
198
  # Get reverse sparsity indices for y
199
199
  rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
200
- rev_idx_spa_y_msk = (rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
200
+ rev_idx_spa_y_msk = (rev_idx_spa_y_idx >= 0 and rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
201
201
  rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
202
202
 
203
203
  # If both blocks are present commence calculation
@@ -206,14 +206,14 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
206
206
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
207
207
  ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
208
208
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
209
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
209
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
210
210
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
211
211
 
212
212
  blk_y_idx = ((rev_idx_spa_y * y_b_s) +
213
213
  ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
214
214
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
215
215
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
216
- blk_y_msk = (blk_y_idx < y_b * y_b_s)
216
+ blk_y_msk = (blk_y_idx >= 0 and blk_y_idx < y_b * y_b_s)
217
217
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
218
218
 
219
219
  # Perform matrix multiplication
@@ -223,5 +223,5 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
223
223
  blk_o_idx = ((pid_blk * o_b_s) +
224
224
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
225
225
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
226
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
226
+ blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
227
227
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -99,29 +99,29 @@ def kernel_broadcast_addition(x,
99
99
 
100
100
  # Get position of current sparsity block consisting of its batch, row, and column index
101
101
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
102
- spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
102
+ spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
103
103
  spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
104
104
 
105
105
  spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
106
- spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
106
+ spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
107
107
  spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
108
108
 
109
109
  spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
110
- spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
110
+ spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
111
111
  spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
112
112
 
113
113
  # Load x block
114
114
  blk_x_idx = (spa_bat_o * x_b_s +
115
115
  ((spa_row_o * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
116
116
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
117
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
117
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
118
118
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
119
119
 
120
120
  # Load y block
121
121
  blk_y_idx = (spa_bat_o * y_b_s +
122
122
  ((spa_col_o * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
123
123
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
124
- blk_y_msk = (blk_y_idx < y_b * y_b_s)
124
+ blk_y_msk = (blk_y_idx >= 0 and blk_y_idx < y_b * y_b_s)
125
125
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
126
126
 
127
127
  # Compute sum
@@ -132,5 +132,5 @@ def kernel_broadcast_addition(x,
132
132
  blk_o_idx = ((pid_blk * o_b_s) +
133
133
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
134
134
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
135
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
135
+ blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
136
136
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)