blksprs 1.9.3__tar.gz → 1.10.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.9.3 → blksprs-1.10.1}/PKG-INFO +18 -14
- {blksprs-1.9.3 → blksprs-1.10.1}/README.md +17 -13
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/__init__.py +0 -6
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/layouting/distribution_layout.py +6 -6
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/layouting/sparsity_layout.py +7 -7
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/ops/conversion.py +19 -21
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/ops/distribution.py +14 -14
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/ops/flow.py +12 -12
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/ops/matmul.py +8 -8
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/ops/misc/broadcast_ops.py +6 -6
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/ops/misc/exp.py +2 -2
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/ops/misc/row_wise.py +16 -19
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/ops/partitioning.py +24 -10
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/ops/softmax.py +17 -16
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/ops/transpose.py +9 -8
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs.egg-info/PKG-INFO +18 -14
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs.egg-info/SOURCES.txt +0 -1
- {blksprs-1.9.3 → blksprs-1.10.1}/pyproject.toml +1 -1
- blksprs-1.9.3/blksprs/ops/experimental/distribution_mdi.py +0 -447
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/ops/repeat.py +0 -0
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/utils/blksprs_tensor.py +0 -0
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/utils/layout_utils.py +0 -0
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/utils/processing.py +0 -0
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/utils/tools.py +0 -0
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs/utils/validation.py +0 -0
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-1.9.3 → blksprs-1.10.1}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.9.3 → blksprs-1.10.1}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.10.1
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -23,14 +23,6 @@ Requires-Dist: build; extra == "build"
|
|
|
23
23
|
[](https://github.com/FelixSchoen/blksprs/releases)
|
|
24
24
|
[](https://www.python.org/downloads/release/python-3119/)
|
|
25
25
|
|
|
26
|
-
## Important Notice
|
|
27
|
-
|
|
28
|
-
🚨 **Non-Final API** 🚨
|
|
29
|
-
|
|
30
|
-
Although it already supports a wide variety of functions, this library is still under active development and the API is
|
|
31
|
-
subject to change. For feature requests or bug reports, please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
32
|
-
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
33
|
-
|
|
34
26
|
## Overview
|
|
35
27
|
|
|
36
28
|
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
|
|
@@ -44,7 +36,7 @@ Currently supported operations (includes gradient calculation):
|
|
|
44
36
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
45
37
|
- Repeat (_supports target sparsity layout_)
|
|
46
38
|
- Repeat Interleave (_supports target sparsity layout_)
|
|
47
|
-
- Splitting and merging of matrices along the last
|
|
39
|
+
- Splitting and merging of matrices (_currently* only supports splitting and merging along the last dimension_)
|
|
48
40
|
- Conversion to and from sparse form
|
|
49
41
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
50
42
|
|
|
@@ -70,13 +62,15 @@ Furthermore, the library provides a set of utility functions
|
|
|
70
62
|
- for the creation of sparsity layouts based on existing
|
|
71
63
|
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
72
64
|
- for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
|
|
73
|
-
- as well as utility functions to
|
|
74
|
-
|
|
65
|
+
- as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
|
|
66
|
+
|
|
67
|
+
_* see the [Roadmap](#roadmap) section for more information_
|
|
75
68
|
|
|
76
69
|
## Installation
|
|
77
70
|
|
|
78
|
-
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
|
|
79
|
-
the Linux platform
|
|
71
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with
|
|
72
|
+
the Linux platform**.
|
|
73
|
+
Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
|
|
80
74
|
|
|
81
75
|
We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
|
|
82
76
|
|
|
@@ -92,6 +86,16 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
92
86
|
|
|
93
87
|
See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
|
|
94
88
|
|
|
89
|
+
## Roadmap
|
|
90
|
+
|
|
91
|
+
Note that since this library covers all our current needs it is in a **bugfix-only** state.
|
|
92
|
+
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and ``merge`` operations.
|
|
93
|
+
We will continue to maintain the library and fix any issues that arise.
|
|
94
|
+
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
95
|
+
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
96
|
+
|
|
97
|
+
It might be that this changes with future projects, but as of December 2024, we are content with the current state of the library.
|
|
98
|
+
|
|
95
99
|
## Usage
|
|
96
100
|
|
|
97
101
|
We provide an example below to demonstrate the usage of the library.
|
|
@@ -3,14 +3,6 @@
|
|
|
3
3
|
[](https://github.com/FelixSchoen/blksprs/releases)
|
|
4
4
|
[](https://www.python.org/downloads/release/python-3119/)
|
|
5
5
|
|
|
6
|
-
## Important Notice
|
|
7
|
-
|
|
8
|
-
🚨 **Non-Final API** 🚨
|
|
9
|
-
|
|
10
|
-
Although it already supports a wide variety of functions, this library is still under active development and the API is
|
|
11
|
-
subject to change. For feature requests or bug reports, please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
12
|
-
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
13
|
-
|
|
14
6
|
## Overview
|
|
15
7
|
|
|
16
8
|
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
|
|
@@ -24,7 +16,7 @@ Currently supported operations (includes gradient calculation):
|
|
|
24
16
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
25
17
|
- Repeat (_supports target sparsity layout_)
|
|
26
18
|
- Repeat Interleave (_supports target sparsity layout_)
|
|
27
|
-
- Splitting and merging of matrices along the last
|
|
19
|
+
- Splitting and merging of matrices (_currently* only supports splitting and merging along the last dimension_)
|
|
28
20
|
- Conversion to and from sparse form
|
|
29
21
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
30
22
|
|
|
@@ -50,13 +42,15 @@ Furthermore, the library provides a set of utility functions
|
|
|
50
42
|
- for the creation of sparsity layouts based on existing
|
|
51
43
|
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
52
44
|
- for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
|
|
53
|
-
- as well as utility functions to
|
|
54
|
-
|
|
45
|
+
- as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
|
|
46
|
+
|
|
47
|
+
_* see the [Roadmap](#roadmap) section for more information_
|
|
55
48
|
|
|
56
49
|
## Installation
|
|
57
50
|
|
|
58
|
-
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
|
|
59
|
-
the Linux platform
|
|
51
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with
|
|
52
|
+
the Linux platform**.
|
|
53
|
+
Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
|
|
60
54
|
|
|
61
55
|
We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
|
|
62
56
|
|
|
@@ -72,6 +66,16 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
72
66
|
|
|
73
67
|
See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
|
|
74
68
|
|
|
69
|
+
## Roadmap
|
|
70
|
+
|
|
71
|
+
Note that since this library covers all our current needs it is in a **bugfix-only** state.
|
|
72
|
+
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and ``merge`` operations.
|
|
73
|
+
We will continue to maintain the library and fix any issues that arise.
|
|
74
|
+
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
75
|
+
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
76
|
+
|
|
77
|
+
It might be that this changes with future projects, but as of December 2024, we are content with the current state of the library.
|
|
78
|
+
|
|
75
79
|
## Usage
|
|
76
80
|
|
|
77
81
|
We provide an example below to demonstrate the usage of the library.
|
|
@@ -15,9 +15,6 @@ class ops:
|
|
|
15
15
|
from blksprs.ops.misc.broadcast_ops import broadcast_add, broadcast_sub
|
|
16
16
|
from blksprs.ops.misc.exp import exp
|
|
17
17
|
|
|
18
|
-
class experimental:
|
|
19
|
-
from blksprs.ops.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
|
|
20
|
-
|
|
21
18
|
|
|
22
19
|
class layouting:
|
|
23
20
|
from blksprs.layouting.distribution_layout import build_distribution_layout
|
|
@@ -25,9 +22,6 @@ class layouting:
|
|
|
25
22
|
build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
|
|
26
23
|
from blksprs.utils.layout_utils import build_full_sparsity_layout
|
|
27
24
|
|
|
28
|
-
class experimental:
|
|
29
|
-
from blksprs.ops.experimental.distribution_mdi import build_distribution_layout_mdi
|
|
30
|
-
|
|
31
25
|
|
|
32
26
|
class utils:
|
|
33
27
|
from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
|
|
@@ -84,21 +84,21 @@ def kernel_distribution_layout(i,
|
|
|
84
84
|
|
|
85
85
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
86
86
|
spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
|
|
87
|
-
spa_bat_i_msk = (spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
87
|
+
spa_bat_i_msk = (spa_bat_i_idx >= 0 and spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
88
88
|
spa_bat_i = tl.load(s_lut_i + spa_bat_i_idx, mask=spa_bat_i_msk)
|
|
89
89
|
|
|
90
90
|
spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
|
|
91
|
-
spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
91
|
+
spa_row_i_msk = (spa_row_i_idx >= 0 and spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
92
92
|
spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
|
|
93
93
|
|
|
94
94
|
spa_col_i_idx = (pid_blk * s_lut_i_r_s + 2 * s_lut_i_c_s)
|
|
95
|
-
spa_col_i_msk = (spa_col_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
95
|
+
spa_col_i_msk = (spa_col_i_idx >= 0 and spa_col_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
96
96
|
spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
|
|
97
97
|
|
|
98
98
|
blk_i_idx = (pid_blk * i_b_s +
|
|
99
99
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
100
100
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
101
|
-
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
101
|
+
blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
|
|
102
102
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
|
|
103
103
|
|
|
104
104
|
dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
|
|
@@ -111,10 +111,10 @@ def kernel_distribution_layout(i,
|
|
|
111
111
|
elif dim == 2:
|
|
112
112
|
dst_col_idx = blk_i // sparsity_block_size
|
|
113
113
|
|
|
114
|
-
blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.
|
|
114
|
+
blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int1)
|
|
115
115
|
|
|
116
116
|
blk_o_idx = ((dst_bat_idx * o_b_s) +
|
|
117
117
|
(dst_row_idx * o_r_s) +
|
|
118
118
|
(dst_col_idx * o_c_s))
|
|
119
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
119
|
+
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
120
120
|
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
@@ -71,7 +71,7 @@ def kernel_sparsity_layout(x,
|
|
|
71
71
|
blk_x_idx = (pid_bat * x_b_s +
|
|
72
72
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
73
73
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
74
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
74
|
+
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
75
75
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
76
76
|
|
|
77
77
|
# Store sparsity layout value
|
|
@@ -79,7 +79,7 @@ def kernel_sparsity_layout(x,
|
|
|
79
79
|
blk_o_idx = (pid_bat * o_b_s +
|
|
80
80
|
(((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
|
|
81
81
|
((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
|
|
82
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
82
|
+
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
83
83
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
84
84
|
|
|
85
85
|
|
|
@@ -162,22 +162,22 @@ def kernel_sparsity_layout_adaption(x,
|
|
|
162
162
|
|
|
163
163
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
164
164
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
165
|
-
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
165
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
166
166
|
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
167
167
|
|
|
168
168
|
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
169
|
-
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
169
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
170
170
|
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
171
171
|
|
|
172
172
|
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
173
|
-
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
173
|
+
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
174
174
|
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
175
175
|
|
|
176
176
|
# Load x values
|
|
177
177
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
178
178
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
179
179
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
180
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
180
|
+
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
181
181
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
182
182
|
|
|
183
183
|
# Store sparsity layout value
|
|
@@ -187,7 +187,7 @@ def kernel_sparsity_layout_adaption(x,
|
|
|
187
187
|
// sparsity_block_size_to) * o_r_s) +
|
|
188
188
|
(((spa_col * sparsity_block_size_from + pid_col * TRITON_BLOCK_SIZE)
|
|
189
189
|
// sparsity_block_size_to) * o_c_s))
|
|
190
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
190
|
+
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
191
191
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
192
192
|
|
|
193
193
|
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Any
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
import triton
|
|
5
3
|
from torch import Tensor
|
|
@@ -54,12 +52,12 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int
|
|
|
54
52
|
validate_contiguous(sparsity_reverse_lut)
|
|
55
53
|
|
|
56
54
|
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
57
|
-
return
|
|
55
|
+
return x
|
|
58
56
|
|
|
59
|
-
return
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
57
|
+
return _BlocksparseToDense.apply(x,
|
|
58
|
+
sparsity_layout, sparsity_reverse_lut,
|
|
59
|
+
sparsity_block_size, fill_value,
|
|
60
|
+
triton_block_size)
|
|
63
61
|
|
|
64
62
|
|
|
65
63
|
class _BlocksparseToDense(torch.autograd.Function):
|
|
@@ -133,7 +131,7 @@ class _BlocksparseToDense(torch.autograd.Function):
|
|
|
133
131
|
|
|
134
132
|
# Get reverse sparsity index for current block
|
|
135
133
|
rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
|
|
136
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
134
|
+
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
137
135
|
rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
138
136
|
|
|
139
137
|
# If block is present commence operations
|
|
@@ -143,13 +141,13 @@ class _BlocksparseToDense(torch.autograd.Function):
|
|
|
143
141
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
144
142
|
(((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
145
143
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
146
|
-
blk_msk = (blk_idx < x_b * x_b_s)
|
|
144
|
+
blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
|
|
147
145
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
148
146
|
|
|
149
147
|
o_idx = (pid_blk * o_b_s +
|
|
150
148
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
151
149
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
152
|
-
o_msk = (o_idx < o_b * o_b_s)
|
|
150
|
+
o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
|
|
153
151
|
tl.store(o + o_idx, blk, o_msk)
|
|
154
152
|
|
|
155
153
|
|
|
@@ -260,15 +258,15 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
260
258
|
|
|
261
259
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
262
260
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
263
|
-
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
261
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
264
262
|
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
265
263
|
|
|
266
264
|
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
267
|
-
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
265
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
268
266
|
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
269
267
|
|
|
270
268
|
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
271
|
-
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
269
|
+
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
272
270
|
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
273
271
|
|
|
274
272
|
# Load block from dense tensor
|
|
@@ -277,14 +275,14 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
277
275
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
278
276
|
((spa_col * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
|
|
279
277
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
280
|
-
blk_d_msk = (blk_d_idx < x_b * x_b_s)
|
|
278
|
+
blk_d_msk = (blk_d_idx >= 0 and blk_d_idx < x_b * x_b_s)
|
|
281
279
|
blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
|
|
282
280
|
|
|
283
281
|
# Store block in sparse tensor
|
|
284
282
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
285
283
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
286
284
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
|
|
287
|
-
blk_o_msk = (blk_o_idx < (pid_blk + 1) * o_b_s)
|
|
285
|
+
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < (pid_blk + 1) * o_b_s)
|
|
288
286
|
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
289
287
|
|
|
290
288
|
|
|
@@ -424,15 +422,15 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
|
424
422
|
|
|
425
423
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
426
424
|
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
427
|
-
spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
425
|
+
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
428
426
|
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
429
427
|
|
|
430
428
|
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
431
|
-
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
429
|
+
spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
432
430
|
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
433
431
|
|
|
434
432
|
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
435
|
-
spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
433
|
+
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
436
434
|
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
437
435
|
|
|
438
436
|
# Get equivalent sparsity block in from layout
|
|
@@ -444,7 +442,7 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
|
444
442
|
rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
|
|
445
443
|
spa_row_x * s_l_x_r_s +
|
|
446
444
|
spa_col_x * s_l_x_c_s)
|
|
447
|
-
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
445
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
448
446
|
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
449
447
|
|
|
450
448
|
# If block is present commence operations
|
|
@@ -459,12 +457,12 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
|
459
457
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
460
458
|
((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
461
459
|
((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
462
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
460
|
+
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
463
461
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
464
462
|
|
|
465
463
|
# Store output
|
|
466
464
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
467
465
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
468
466
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
469
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
467
|
+
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
470
468
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
@@ -138,22 +138,22 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
138
138
|
|
|
139
139
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
140
140
|
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
141
|
-
spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
141
|
+
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
142
142
|
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
143
143
|
|
|
144
144
|
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
145
|
-
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
145
|
+
spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
146
146
|
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
147
147
|
|
|
148
148
|
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
149
|
-
spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
149
|
+
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
150
150
|
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
151
151
|
|
|
152
152
|
# Load index values
|
|
153
153
|
blk_i_idx = ((pid_blk * i_b_s) +
|
|
154
154
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
155
155
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
156
|
-
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
156
|
+
blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
|
|
157
157
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
158
158
|
|
|
159
159
|
# Get indices of sparsity blocks and positions within the blocks
|
|
@@ -180,21 +180,21 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
180
180
|
rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
|
|
181
181
|
(rev_dst_row_x * s_l_x_r_s) +
|
|
182
182
|
(rev_dst_col_x * s_l_x_c_s))
|
|
183
|
-
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
183
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
184
184
|
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
185
185
|
|
|
186
186
|
# Load x values
|
|
187
187
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
188
188
|
dst_row_x +
|
|
189
189
|
dst_col_x)
|
|
190
|
-
blk_x_msk = ((blk_x_idx < x_b * x_b_s)
|
|
190
|
+
blk_x_msk = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s) and rev_idx_spa_x_msk != -1)
|
|
191
191
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
192
192
|
|
|
193
193
|
# Store output
|
|
194
194
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
195
195
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
196
196
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
197
|
-
blk_o_msk = ((blk_o_idx < o_b * o_b_s)
|
|
197
|
+
blk_o_msk = ((blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s) and rev_idx_spa_x_msk != -1)
|
|
198
198
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
199
199
|
|
|
200
200
|
|
|
@@ -364,29 +364,29 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
364
364
|
|
|
365
365
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
366
366
|
spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
367
|
-
spa_bat_x_msk = (spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
367
|
+
spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
368
368
|
spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
|
|
369
369
|
|
|
370
370
|
spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
371
|
-
spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
371
|
+
spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
372
372
|
spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
|
|
373
373
|
|
|
374
374
|
spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
|
|
375
|
-
spa_col_x_msk = (spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
375
|
+
spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
376
376
|
spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
|
|
377
377
|
|
|
378
378
|
# Load x values
|
|
379
379
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
380
380
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
381
381
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
382
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
382
|
+
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
383
383
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
384
384
|
|
|
385
385
|
# Load index values
|
|
386
386
|
blk_i_idx = ((pid_blk * i_b_s) +
|
|
387
387
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
388
388
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
389
|
-
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
389
|
+
blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
|
|
390
390
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
391
391
|
|
|
392
392
|
# Get indices of sparsity blocks and positions within the blocks
|
|
@@ -413,14 +413,14 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
413
413
|
rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
|
|
414
414
|
(rev_dst_row_o * s_l_o_r_s) +
|
|
415
415
|
(rev_dst_col_o * s_l_o_c_s))
|
|
416
|
-
rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
416
|
+
rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
417
417
|
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
418
418
|
|
|
419
419
|
# Store output
|
|
420
420
|
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
421
421
|
dst_row_o +
|
|
422
422
|
dst_col_o)
|
|
423
|
-
blk_o_msk = ((blk_o_idx < o_b * o_b_s)
|
|
423
|
+
blk_o_msk = ((blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s) and rev_idx_spa_o_msk != -1)
|
|
424
424
|
|
|
425
425
|
if reduce_op_ind == 0:
|
|
426
426
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
@@ -22,22 +22,22 @@ def kernel_blocksparse_flow_pull(x,
|
|
|
22
22
|
|
|
23
23
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
24
24
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
25
|
-
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
25
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
26
26
|
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
27
27
|
|
|
28
28
|
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
29
|
-
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
29
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
30
30
|
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
31
31
|
|
|
32
32
|
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
33
|
-
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
33
|
+
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
34
34
|
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
35
35
|
|
|
36
36
|
# Get reverse sparsity index
|
|
37
37
|
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
38
38
|
spa_row * s_l_o_r_s +
|
|
39
39
|
spa_col * s_l_o_c_s)
|
|
40
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
40
|
+
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
41
41
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
42
42
|
|
|
43
43
|
if rev_idx_spa == -1:
|
|
@@ -47,13 +47,13 @@ def kernel_blocksparse_flow_pull(x,
|
|
|
47
47
|
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
48
48
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
49
49
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
50
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
50
|
+
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
51
51
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
52
52
|
|
|
53
53
|
blk_o_idx = (pid_blk * o_b_s +
|
|
54
54
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
55
55
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
56
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
56
|
+
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
57
57
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
58
58
|
|
|
59
59
|
|
|
@@ -73,22 +73,22 @@ def kernel_blocksparse_flow_push(x,
|
|
|
73
73
|
|
|
74
74
|
# Get sparsity index of current input block consisting of its batch, row, and column index
|
|
75
75
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
76
|
-
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
76
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
77
77
|
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
78
78
|
|
|
79
79
|
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
80
|
-
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
80
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
81
81
|
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
82
82
|
|
|
83
83
|
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
84
|
-
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
84
|
+
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
85
85
|
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
86
86
|
|
|
87
87
|
# Get reverse sparsity index
|
|
88
88
|
rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
|
|
89
89
|
spa_row * s_l_x_r_s +
|
|
90
90
|
spa_col * s_l_x_c_s)
|
|
91
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
|
|
91
|
+
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
|
|
92
92
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
93
93
|
|
|
94
94
|
if rev_idx_spa == -1:
|
|
@@ -98,13 +98,13 @@ def kernel_blocksparse_flow_push(x,
|
|
|
98
98
|
blk_x_idx = (pid_blk * x_b_s +
|
|
99
99
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
100
100
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
101
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
101
|
+
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
102
102
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
103
103
|
|
|
104
104
|
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
105
105
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
106
106
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
107
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
107
|
+
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
108
108
|
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
109
109
|
|
|
110
110
|
|
|
@@ -164,15 +164,15 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
164
164
|
|
|
165
165
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
166
166
|
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
167
|
-
spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
167
|
+
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
168
168
|
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
169
169
|
|
|
170
170
|
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
171
|
-
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
171
|
+
spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
172
172
|
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
173
173
|
|
|
174
174
|
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
175
|
-
spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
175
|
+
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
176
176
|
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
177
177
|
|
|
178
178
|
# Setup buffer
|
|
@@ -192,12 +192,12 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
192
192
|
rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
|
|
193
193
|
spa_row_o * s_l_x_r_s +
|
|
194
194
|
i_seg_spa * s_l_x_c_s)
|
|
195
|
-
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
195
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
196
196
|
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
197
197
|
|
|
198
198
|
# Get reverse sparsity indices for y
|
|
199
199
|
rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
|
|
200
|
-
rev_idx_spa_y_msk = (rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
|
|
200
|
+
rev_idx_spa_y_msk = (rev_idx_spa_y_idx >= 0 and rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
|
|
201
201
|
rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
|
|
202
202
|
|
|
203
203
|
# If both blocks are present commence calculation
|
|
@@ -206,14 +206,14 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
206
206
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
207
207
|
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
208
208
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
209
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
209
|
+
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
210
210
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
211
211
|
|
|
212
212
|
blk_y_idx = ((rev_idx_spa_y * y_b_s) +
|
|
213
213
|
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
214
214
|
tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
215
215
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
216
|
-
blk_y_msk = (blk_y_idx < y_b * y_b_s)
|
|
216
|
+
blk_y_msk = (blk_y_idx >= 0 and blk_y_idx < y_b * y_b_s)
|
|
217
217
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
218
218
|
|
|
219
219
|
# Perform matrix multiplication
|
|
@@ -223,5 +223,5 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
223
223
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
224
224
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
225
225
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
226
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
226
|
+
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
227
227
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|