blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,20 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library import triton_op
5
+ from torch._library.triton import wrap_triton
4
6
  from triton import language as tl
5
7
 
6
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
8
11
  from blksprs.utils.validation import validate_contiguous, validate_device, \
9
- validate_sparsity_block_size, validate_triton_block_size
12
+ validate_sparsity_block_size
10
13
 
11
14
 
15
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
12
16
  def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
13
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
17
+ sparsity_block_size: int) -> BlksprsTensor:
14
18
  """Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
15
19
  compressed form.
16
20
 
@@ -19,7 +23,6 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
19
23
  y (Tensor): A dense input tensor.
20
24
  sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
21
25
  sparsity_block_size (int): The size of the sparsity blocks.
22
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
23
26
 
24
27
  Returns:
25
28
  BlksprsTensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
@@ -34,7 +37,6 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
34
37
  if x.size(-1) != y.size(-1):
35
38
  raise ValueError("Dimensions of tensors must match")
36
39
  validate_sparsity_block_size(sparsity_block_size)
37
- validate_triton_block_size(triton_block_size, sparsity_block_size)
38
40
 
39
41
  sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
40
42
 
@@ -42,56 +44,66 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
42
44
 
43
45
  validate_contiguous(sparsity_layout_output, sparsity_lut_o)
44
46
 
45
- output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
46
-
47
- x_b, x_c = x.size()
48
- x_b_s, x_c_s = stride(x)
49
- y_b, y_c = y.size()
50
- y_b_s, y_c_s = stride(y)
51
- o_b, o_r, o_c = output.size()
52
- o_b_s, o_r_s, o_c_s = stride(output)
53
- s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
54
- s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
55
-
56
- if triton_block_size is None:
57
- triton_block_size = get_triton_block_size(sparsity_block_size)
58
-
59
- triton_grid = lambda meta: [o_b,
60
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
61
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
62
-
63
- (kernel_broadcast_addition[triton_grid]
64
- (x,
65
- x_b, x_b_s, x_c_s,
66
- y,
67
- y_b, y_b_s, y_c_s,
68
- output,
69
- o_b, o_b_s, o_r_s, o_c_s,
70
- sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
71
- sparsity_block_size,
72
- triton_block_size))
73
-
74
- return BlksprsTensor(output)
47
+ return BlksprsTensor(broadcast_add_forward(x, y, sparsity_lut_o, sparsity_block_size, n_sparse_blocks))
75
48
 
76
49
 
77
50
  def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
78
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
51
+ sparsity_block_size: int) -> BlksprsTensor:
79
52
  """Wrapper for ``broadcast_add`` with negated y.
80
53
 
81
54
  """
82
- return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
83
-
84
-
55
+ return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size)
56
+
57
+
58
+ @triton_op("blksprs::broadcast_add_forward", mutates_args={})
59
+ def broadcast_add_forward(x: Tensor, y: Tensor,
60
+ sparsity_lut_o: Tensor,
61
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
62
+ with torch.no_grad():
63
+ output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
64
+
65
+ x_b, x_c = x.size()
66
+ x_b_s, x_c_s = stride(x)
67
+ y_b, y_c = y.size()
68
+ y_b_s, y_c_s = stride(y)
69
+ o_b, o_r, o_c = output.size()
70
+ o_b_s, o_r_s, o_c_s = stride(output)
71
+ s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
72
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
73
+
74
+ triton_grid = lambda meta: [o_b,
75
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
76
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
77
+
78
+ (wrap_triton(broadcast_add_kernel)[triton_grid]
79
+ (x,
80
+ x_b, x_b_s, x_c_s,
81
+ y,
82
+ y_b, y_b_s, y_c_s,
83
+ output,
84
+ o_b, o_b_s, o_r_s, o_c_s,
85
+ sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
86
+ sparsity_block_size))
87
+
88
+ return output
89
+
90
+
91
+ @triton.autotune(
92
+ configs=get_autotune_configs(),
93
+ key=["sparsity_block_size"],
94
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
95
+ reset_to_zero=["o"]
96
+ )
85
97
  @triton.jit
86
- def kernel_broadcast_addition(x,
87
- x_b, x_b_s, x_c_s,
88
- y,
89
- y_b, y_b_s, y_c_s,
90
- o,
91
- o_b, o_b_s, o_r_s, o_c_s,
92
- s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
93
- sparsity_block_size,
94
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
98
+ def broadcast_add_kernel(x,
99
+ x_b, x_b_s, x_c_s,
100
+ y,
101
+ y_b, y_b_s, y_c_s,
102
+ o,
103
+ o_b, o_b_s, o_r_s, o_c_s,
104
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
105
+ sparsity_block_size,
106
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
95
107
  # Get triton block indices
96
108
  pid_blk = tl.program_id(axis=0)
97
109
  pid_row = tl.program_id(axis=1)
@@ -112,16 +124,18 @@ def kernel_broadcast_addition(x,
112
124
 
113
125
  # Load x block
114
126
  blk_x_idx = (spa_bat_o * x_b_s +
115
- ((spa_row_o * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
127
+ ((pid_row * TRITON_BLOCK_SIZE + spa_row_o * sparsity_block_size +
116
128
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
117
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
129
+ blk_x_msk = (blk_x_idx >= 0 and
130
+ blk_x_idx < x_b * x_b_s)
118
131
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
119
132
 
120
133
  # Load y block
121
134
  blk_y_idx = (spa_bat_o * y_b_s +
122
- ((spa_col_o * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
135
+ ((pid_col * TRITON_BLOCK_SIZE + spa_col_o * sparsity_block_size +
123
136
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
124
- blk_y_msk = (blk_y_idx >= 0 and blk_y_idx < y_b * y_b_s)
137
+ blk_y_msk = (blk_y_idx >= 0 and
138
+ blk_y_idx < y_b * y_b_s)
125
139
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
126
140
 
127
141
  # Compute sum
@@ -132,5 +146,6 @@ def kernel_broadcast_addition(x,
132
146
  blk_o_idx = ((pid_blk * o_b_s) +
133
147
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
134
148
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
135
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
149
+ blk_o_msk = (blk_o_idx >= 0 and
150
+ blk_o_idx < o_b * o_b_s)
136
151
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)