liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +307 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +63 -0
  18. liger_kernel/ops/__init__.py +141 -0
  19. liger_kernel/ops/backends/README.md +151 -0
  20. liger_kernel/ops/backends/__init__.py +13 -0
  21. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  22. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  23. liger_kernel/ops/backends/registry.py +61 -0
  24. liger_kernel/ops/cross_entropy.py +383 -114
  25. liger_kernel/ops/dyt.py +160 -0
  26. liger_kernel/ops/experimental/embedding.py +141 -0
  27. liger_kernel/ops/experimental/mm_int8int2.py +349 -0
  28. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  29. liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
  30. liger_kernel/ops/fused_linear_jsd.py +228 -0
  31. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  32. liger_kernel/ops/geglu.py +66 -64
  33. liger_kernel/ops/group_norm.py +306 -0
  34. liger_kernel/ops/grpo_loss.py +312 -0
  35. liger_kernel/ops/jsd.py +201 -0
  36. liger_kernel/ops/kl_div.py +262 -0
  37. liger_kernel/ops/layer_norm.py +320 -0
  38. liger_kernel/ops/llama4_rope.py +225 -0
  39. liger_kernel/ops/multi_token_attention.py +207 -0
  40. liger_kernel/ops/poly_norm.py +390 -0
  41. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  42. liger_kernel/ops/rms_norm.py +484 -88
  43. liger_kernel/ops/rope.py +122 -117
  44. liger_kernel/ops/softmax.py +201 -0
  45. liger_kernel/ops/sparsemax.py +179 -0
  46. liger_kernel/ops/swiglu.py +68 -65
  47. liger_kernel/ops/tiled_mlp.py +136 -0
  48. liger_kernel/ops/tvd.py +207 -0
  49. liger_kernel/ops/utils.py +82 -3
  50. liger_kernel/transformers/__init__.py +218 -6
  51. liger_kernel/transformers/auto_model.py +38 -0
  52. liger_kernel/transformers/cross_entropy.py +52 -7
  53. liger_kernel/transformers/dyt.py +22 -0
  54. liger_kernel/transformers/experimental/__init__.py +5 -0
  55. liger_kernel/transformers/experimental/embedding.py +26 -0
  56. liger_kernel/transformers/fsdp.py +55 -0
  57. liger_kernel/transformers/functional.py +301 -0
  58. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  59. liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
  60. liger_kernel/transformers/fused_linear_jsd.py +95 -0
  61. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  62. liger_kernel/transformers/geglu.py +6 -7
  63. liger_kernel/transformers/group_norm.py +50 -0
  64. liger_kernel/transformers/grpo_loss.py +153 -0
  65. liger_kernel/transformers/jsd.py +70 -0
  66. liger_kernel/transformers/kl_div.py +12 -0
  67. liger_kernel/transformers/layer_norm.py +24 -0
  68. liger_kernel/transformers/llama4_rope.py +93 -0
  69. liger_kernel/transformers/model/falcon_h1.py +122 -0
  70. liger_kernel/transformers/model/gemma.py +261 -0
  71. liger_kernel/transformers/model/gemma2.py +283 -0
  72. liger_kernel/transformers/model/gemma3.py +332 -0
  73. liger_kernel/transformers/model/glm4.py +141 -0
  74. liger_kernel/transformers/model/glm4v.py +163 -0
  75. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  76. liger_kernel/transformers/model/gpt_oss.py +211 -0
  77. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  78. liger_kernel/transformers/model/internvl.py +157 -0
  79. liger_kernel/transformers/model/llama.py +221 -41
  80. liger_kernel/transformers/model/llama4.py +121 -0
  81. liger_kernel/transformers/model/llava.py +344 -0
  82. liger_kernel/transformers/model/loss_utils.py +95 -0
  83. liger_kernel/transformers/model/mistral.py +145 -0
  84. liger_kernel/transformers/model/mixtral.py +293 -0
  85. liger_kernel/transformers/model/mllama.py +269 -0
  86. liger_kernel/transformers/model/olmo2.py +141 -0
  87. liger_kernel/transformers/model/olmo3.py +142 -0
  88. liger_kernel/transformers/model/output_classes.py +147 -0
  89. liger_kernel/transformers/model/paligemma.py +433 -0
  90. liger_kernel/transformers/model/phi3.py +120 -0
  91. liger_kernel/transformers/model/qwen2.py +259 -0
  92. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  93. liger_kernel/transformers/model/qwen2_vl.py +159 -0
  94. liger_kernel/transformers/model/qwen3.py +136 -0
  95. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  96. liger_kernel/transformers/model/qwen3_next.py +146 -0
  97. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  98. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  99. liger_kernel/transformers/model/smollm3.py +199 -0
  100. liger_kernel/transformers/model/smolvlm.py +158 -0
  101. liger_kernel/transformers/monkey_patch.py +2816 -21
  102. liger_kernel/transformers/multi_token_attention.py +64 -0
  103. liger_kernel/transformers/poly_norm.py +42 -0
  104. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  105. liger_kernel/transformers/rms_norm.py +75 -5
  106. liger_kernel/transformers/rope.py +47 -3
  107. liger_kernel/transformers/softmax.py +12 -0
  108. liger_kernel/transformers/sparsemax.py +16 -0
  109. liger_kernel/transformers/swiglu.py +62 -6
  110. liger_kernel/transformers/tiled_mlp.py +133 -0
  111. liger_kernel/transformers/trainer/__init__.py +4 -0
  112. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  113. liger_kernel/transformers/trainer_integration.py +2 -45
  114. liger_kernel/transformers/tvd.py +13 -0
  115. liger_kernel/triton/__init__.py +1 -3
  116. liger_kernel/triton/monkey_patch.py +1 -5
  117. liger_kernel/utils.py +96 -0
  118. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
  119. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
  120. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  121. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
  122. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
  123. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
  124. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
  125. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  126. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,160 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import compare_version
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+ from liger_kernel.ops.utils import infer_device
10
+ from liger_kernel.utils import get_npu_multi_processor_count
11
+ from liger_kernel.utils import is_npu_available
12
+
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
14
+ try:
15
+ # typical import path with dispatch available
16
+ from triton.language.extra.libdevice import tanh
17
+ except ModuleNotFoundError:
18
+ # for working with NGC containers
19
+ from triton.language.extra.cuda.libdevice import tanh
20
+ else:
21
+ from triton.language.math import tanh
22
+
23
+
24
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
25
+ # for bn in [1024, 2048, 4096]
26
+ # for ns in [1,2,4]
27
+ # for nw in [4, 8, 16, 32]
28
+ # ],
29
+ # key=['N'])
30
+ @triton.jit
31
+ def _dyt_fwd_kernel(X, Y, Alpha, Gamma, Beta, HAVE_BETA: tl.constexpr, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024):
32
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
33
+ mask = col < N
34
+ row_id = tl.cast(tl.program_id(1), tl.int64)
35
+
36
+ X += row_id * N
37
+ Y += row_id * N
38
+ alpha = tl.load(Alpha).to(tl.float32)
39
+
40
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
41
+
42
+ x = tl.load(X + col, mask=mask, other=0.0).to(tl.float32)
43
+
44
+ tanh_x = tanh(alpha * x)
45
+ y = tanh_x * gamma
46
+ if HAVE_BETA:
47
+ beta = tl.load(Beta + col, mask=mask, other=0.0).to(tl.float32)
48
+ y += beta
49
+ tl.store(Y + col, y, mask=mask)
50
+
51
+
52
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
53
+ # for bn in [1024, 2048, 4096]
54
+ # for ns in [1,2,4]
55
+ # for nw in [4, 8, 16]
56
+ # ],
57
+ # key=['N'])
58
+ @triton.jit
59
+ def _dyt_bwd_kernel(
60
+ DY, DX, DA, DG, DB, X, Alpha, Gamma, HAVE_BETA: tl.constexpr, M, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024
61
+ ):
62
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
63
+ mask = col < N
64
+ start_row_id = tl.cast(tl.program_id(1), tl.int64)
65
+
66
+ alpha = tl.load(Alpha).to(tl.float32)
67
+ da = 0.0
68
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
69
+ dg = tl.zeros((BLOCK_N,), dtype=tl.float32)
70
+ if HAVE_BETA:
71
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
72
+ for row_id in range(start_row_id, M, tl.num_programs(1)):
73
+ x = tl.load(X + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
74
+ dy = tl.load(DY + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
75
+ tanh_x = tanh(alpha * x)
76
+ if HAVE_BETA:
77
+ db += dy
78
+ dg += dy * tanh_x
79
+ tmp = (1 - tanh_x * tanh_x) * dy * gamma
80
+ da += tl.sum(x * tmp, 0)
81
+ dx = alpha * tmp
82
+ tl.store(DX + row_id * N + col, dx, mask=mask)
83
+
84
+ tl.store(DG + start_row_id * N + col, dg, mask=mask)
85
+ if HAVE_BETA:
86
+ tl.store(DB + start_row_id * N + col, db, mask=mask)
87
+ tl.store(DA + start_row_id * tl.cdiv(N, 512) + tl.program_id(0), da)
88
+
89
+
90
+ def liger_dyt_fwd(x, alpha, gamma, beta):
91
+ assert x.is_contiguous()
92
+ HAVE_BETA = True if beta is not None else False
93
+ input_shape = x.shape
94
+ x = x.view(-1, input_shape[-1])
95
+ M, N = x.shape
96
+
97
+ y = torch.empty_like(x)
98
+
99
+ if N >= 4096:
100
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 2048), "num_warps": 4, "num_stages": 1}
101
+ else:
102
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 4, "num_stages": 1}
103
+
104
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), M)
105
+ _dyt_fwd_kernel[(grid)](
106
+ x,
107
+ y,
108
+ alpha,
109
+ gamma,
110
+ beta,
111
+ HAVE_BETA,
112
+ N,
113
+ **kwargs,
114
+ )
115
+ return y.view(input_shape)
116
+
117
+
118
+ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
119
+ assert dy.is_contiguous()
120
+ input_shape = x.shape
121
+ x = x.view(-1, input_shape[-1])
122
+ M, N = x.shape
123
+ HAVE_BETA = True if beta is not None else False
124
+
125
+ device = infer_device()
126
+ if device == "cuda":
127
+ NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
128
+ elif device == "xpu":
129
+ NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
130
+ elif device == "npu":
131
+ NUM_SMS = get_npu_multi_processor_count()
132
+ da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
133
+ dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
134
+ db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
135
+ dx = torch.empty_like(dy)
136
+
137
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 8, "num_stages": 2}
138
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), NUM_SMS)
139
+ _dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, **kwargs)
140
+ if HAVE_BETA:
141
+ db = db.sum(0).to(x.dtype)
142
+ dg = dg.sum(0).to(gamma.dtype)
143
+ da = da.sum().to(x.dtype).unsqueeze(0)
144
+ return dx.view(input_shape), da, dg, db
145
+
146
+
147
+ class LigerDyTFunction(torch.autograd.Function):
148
+ @staticmethod
149
+ @ensure_contiguous
150
+ def forward(ctx, x, alpha, gamma, beta):
151
+ y = liger_dyt_fwd(x, alpha, gamma, beta)
152
+ ctx.save_for_backward(x, alpha, gamma, beta)
153
+ return y
154
+
155
+ @staticmethod
156
+ @ensure_contiguous
157
+ def backward(ctx, dy):
158
+ x, alpha, gamma, beta = ctx.saved_tensors
159
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta)
160
+ return dx, dalpha, dgamma, dbeta
@@ -0,0 +1,141 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.utils import ensure_contiguous
6
+
7
+
8
+ @triton.jit
9
+ def embedding_forward_kernel(
10
+ embeddings_ptr,
11
+ indices_ptr,
12
+ output_ptr,
13
+ n_elements,
14
+ embedding_dim: tl.constexpr,
15
+ BLOCK_SIZE_M: tl.constexpr,
16
+ BLOCK_SIZE_N: tl.constexpr,
17
+ ):
18
+ pid_m = tl.program_id(0)
19
+ pid_n = tl.program_id(1)
20
+
21
+ start_m = pid_m * BLOCK_SIZE_M
22
+ start_n = pid_n * BLOCK_SIZE_N
23
+ offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
24
+ mask_m = offsets_m < n_elements
25
+ indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
26
+ offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
27
+ mask_n = offsets_n < embedding_dim
28
+
29
+ embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
30
+ embeddings = tl.load(
31
+ embeddings_ptr + embedding_offsets,
32
+ mask=mask_m[:, None] & mask_n[None, :],
33
+ other=0.0,
34
+ )
35
+
36
+ output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
37
+ tl.store(output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :])
38
+
39
+
40
+ @triton.jit
41
+ def embedding_backward_kernel(
42
+ grad_output_ptr,
43
+ grad_weight_ptr,
44
+ indices_ptr,
45
+ n_elements,
46
+ embedding_dim: tl.constexpr,
47
+ BLOCK_SIZE_M: tl.constexpr,
48
+ BLOCK_SIZE_N: tl.constexpr,
49
+ ):
50
+ pid_m = tl.program_id(0)
51
+ pid_n = tl.program_id(1)
52
+
53
+ start_m = pid_m * BLOCK_SIZE_M
54
+ start_n = pid_n * BLOCK_SIZE_N
55
+ offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
56
+ mask_m = offsets_m < n_elements
57
+ indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
58
+ offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
59
+ mask_n = offsets_n < embedding_dim
60
+
61
+ grad_output = tl.load(
62
+ grad_output_ptr + offsets_m[:, None] * embedding_dim + offsets_n[None, :],
63
+ mask=mask_m[:, None] & mask_n[None, :],
64
+ other=0.0,
65
+ )
66
+
67
+ grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
68
+
69
+ tl.atomic_add(
70
+ grad_weight_ptr + grad_weight_offsets,
71
+ grad_output,
72
+ mask=mask_m[:, None] & mask_n[None, :],
73
+ )
74
+
75
+
76
+ class LigerEmbeddingFunction(torch.autograd.Function):
77
+ @staticmethod
78
+ @ensure_contiguous
79
+ def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor):
80
+ ori_shape = indices.shape
81
+ indices = indices.view(-1)
82
+ output = torch.empty(
83
+ indices.shape[0],
84
+ embeddings.shape[1],
85
+ device=indices.device,
86
+ dtype=embeddings.dtype,
87
+ )
88
+
89
+ n_elements = indices.numel()
90
+ embedding_dim = embeddings.shape[1]
91
+
92
+ BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim))
93
+ BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
94
+ grid = (
95
+ triton.cdiv(n_elements, BLOCK_SIZE_M),
96
+ triton.cdiv(embedding_dim, BLOCK_SIZE_N),
97
+ )
98
+
99
+ embedding_forward_kernel[grid](
100
+ embeddings,
101
+ indices,
102
+ output,
103
+ n_elements,
104
+ embedding_dim=embedding_dim,
105
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
106
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
107
+ )
108
+
109
+ ctx.save_for_backward(indices, embeddings)
110
+
111
+ return output.view(*ori_shape, -1)
112
+
113
+ @staticmethod
114
+ @ensure_contiguous
115
+ def backward(ctx, grad_output: torch.Tensor):
116
+ indices, embedding_table = ctx.saved_tensors
117
+ grad_output = grad_output.contiguous().view(-1, embedding_table.shape[1])
118
+
119
+ grad_weight = torch.zeros_like(embedding_table)
120
+
121
+ n_elements = indices.numel()
122
+ embedding_dim = embedding_table.shape[1]
123
+
124
+ BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim))
125
+ BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
126
+ grid = (
127
+ triton.cdiv(n_elements, BLOCK_SIZE_M),
128
+ triton.cdiv(embedding_dim, BLOCK_SIZE_N),
129
+ )
130
+
131
+ embedding_backward_kernel[grid](
132
+ grad_output,
133
+ grad_weight,
134
+ indices,
135
+ n_elements,
136
+ embedding_dim=embedding_dim,
137
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
138
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
139
+ )
140
+
141
+ return grad_weight, None
@@ -0,0 +1,349 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ def unpack_weights(packed: torch.Tensor, bits: int = 2) -> torch.Tensor:
7
+ values_per_item = 8 // bits
8
+ packed_shape = packed.shape
9
+
10
+ if len(packed_shape) == 1:
11
+ original_row_dim = packed_shape[0] * values_per_item
12
+ unpacked_shape = (original_row_dim,)
13
+ else:
14
+ original_row_dim = packed_shape[0] * values_per_item
15
+ unpacked_shape = (original_row_dim, *packed_shape[1:])
16
+
17
+ unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8)
18
+
19
+ for i in range(values_per_item):
20
+ start = i * packed_shape[0]
21
+ end = start + packed_shape[0]
22
+ mask = 3 << (2 * i)
23
+ unpacked[start:end] = (packed & mask) >> (2 * i)
24
+
25
+ unpacked = unpacked.to(torch.int32) - 1
26
+ return unpacked
27
+
28
+
29
+ def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor:
30
+ intweights += 1
31
+ original_shape = intweights.shape
32
+ values_per_item = 8 // bits
33
+ row_dim = (original_shape[0] + values_per_item - 1) // values_per_item
34
+
35
+ if len(original_shape) == 1:
36
+ packed_tensor_shape = (row_dim,)
37
+ else:
38
+ packed_tensor_shape = (row_dim, *original_shape[1:])
39
+
40
+ packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8)
41
+ unpacked = intweights.to(torch.uint8)
42
+
43
+ def lshift(t: torch.Tensor, bits: int):
44
+ return t << bits
45
+
46
+ it = min(values_per_item, (original_shape[0] // row_dim) + 1)
47
+ for i in range(it):
48
+ start = i * row_dim
49
+ end = min(start + row_dim, original_shape[0])
50
+ packed[: (end - start)] |= lshift(unpacked[start:end], bits * i)
51
+
52
+ return packed
53
+
54
+
55
+ def get_autotune_config():
56
+ return [
57
+ triton.Config(
58
+ {
59
+ "BLOCK_SIZE_M": 128,
60
+ "BLOCK_SIZE_N": 256,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 8,
63
+ },
64
+ num_stages=3,
65
+ num_warps=8,
66
+ ),
67
+ triton.Config(
68
+ {
69
+ "BLOCK_SIZE_M": 64,
70
+ "BLOCK_SIZE_N": 256,
71
+ "BLOCK_SIZE_K": 32,
72
+ "GROUP_SIZE_M": 8,
73
+ },
74
+ num_stages=4,
75
+ num_warps=4,
76
+ ),
77
+ triton.Config(
78
+ {
79
+ "BLOCK_SIZE_M": 128,
80
+ "BLOCK_SIZE_N": 128,
81
+ "BLOCK_SIZE_K": 32,
82
+ "GROUP_SIZE_M": 8,
83
+ },
84
+ num_stages=4,
85
+ num_warps=4,
86
+ ),
87
+ triton.Config(
88
+ {
89
+ "BLOCK_SIZE_M": 128,
90
+ "BLOCK_SIZE_N": 64,
91
+ "BLOCK_SIZE_K": 32,
92
+ "GROUP_SIZE_M": 8,
93
+ },
94
+ num_stages=4,
95
+ num_warps=4,
96
+ ),
97
+ triton.Config(
98
+ {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 32,
102
+ "GROUP_SIZE_M": 8,
103
+ },
104
+ num_stages=4,
105
+ num_warps=4,
106
+ ),
107
+ triton.Config(
108
+ {
109
+ "BLOCK_SIZE_M": 128,
110
+ "BLOCK_SIZE_N": 32,
111
+ "BLOCK_SIZE_K": 32,
112
+ "GROUP_SIZE_M": 8,
113
+ },
114
+ num_stages=4,
115
+ num_warps=4,
116
+ ),
117
+ triton.Config(
118
+ {
119
+ "BLOCK_SIZE_M": 128,
120
+ "BLOCK_SIZE_N": 256,
121
+ "BLOCK_SIZE_K": 128,
122
+ "GROUP_SIZE_M": 8,
123
+ },
124
+ num_stages=3,
125
+ num_warps=8,
126
+ ),
127
+ triton.Config(
128
+ {
129
+ "BLOCK_SIZE_M": 256,
130
+ "BLOCK_SIZE_N": 128,
131
+ "BLOCK_SIZE_K": 128,
132
+ "GROUP_SIZE_M": 8,
133
+ },
134
+ num_stages=3,
135
+ num_warps=8,
136
+ ),
137
+ triton.Config(
138
+ {
139
+ "BLOCK_SIZE_M": 256,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 8,
143
+ },
144
+ num_stages=4,
145
+ num_warps=4,
146
+ ),
147
+ triton.Config(
148
+ {
149
+ "BLOCK_SIZE_M": 64,
150
+ "BLOCK_SIZE_N": 256,
151
+ "BLOCK_SIZE_K": 128,
152
+ "GROUP_SIZE_M": 8,
153
+ },
154
+ num_stages=4,
155
+ num_warps=4,
156
+ ),
157
+ triton.Config(
158
+ {
159
+ "BLOCK_SIZE_M": 128,
160
+ "BLOCK_SIZE_N": 128,
161
+ "BLOCK_SIZE_K": 128,
162
+ "GROUP_SIZE_M": 8,
163
+ },
164
+ num_stages=4,
165
+ num_warps=4,
166
+ ),
167
+ triton.Config(
168
+ {
169
+ "BLOCK_SIZE_M": 128,
170
+ "BLOCK_SIZE_N": 64,
171
+ "BLOCK_SIZE_K": 64,
172
+ "GROUP_SIZE_M": 8,
173
+ },
174
+ num_stages=4,
175
+ num_warps=4,
176
+ ),
177
+ triton.Config(
178
+ {
179
+ "BLOCK_SIZE_M": 64,
180
+ "BLOCK_SIZE_N": 128,
181
+ "BLOCK_SIZE_K": 64,
182
+ "GROUP_SIZE_M": 8,
183
+ },
184
+ num_stages=4,
185
+ num_warps=4,
186
+ ),
187
+ triton.Config(
188
+ {
189
+ "BLOCK_SIZE_M": 128,
190
+ "BLOCK_SIZE_N": 32,
191
+ "BLOCK_SIZE_K": 64,
192
+ "GROUP_SIZE_M": 8,
193
+ },
194
+ num_stages=4,
195
+ num_warps=4,
196
+ ),
197
+ triton.Config(
198
+ {
199
+ "BLOCK_SIZE_M": 32,
200
+ "BLOCK_SIZE_N": 32,
201
+ "BLOCK_SIZE_K": 32,
202
+ "GROUP_SIZE_M": 4,
203
+ },
204
+ num_stages=4,
205
+ num_warps=4,
206
+ ),
207
+ ]
208
+
209
+
210
+ @triton.autotune(
211
+ configs=get_autotune_config(),
212
+ key=["M", "N", "K"],
213
+ )
214
+ @triton.jit
215
+ def matmul_kernel(
216
+ a_ptr,
217
+ b_ptr,
218
+ c_ptr,
219
+ M,
220
+ N,
221
+ K: tl.constexpr,
222
+ stride_am,
223
+ stride_ak,
224
+ stride_bk,
225
+ stride_bn,
226
+ stride_cm,
227
+ stride_cn,
228
+ BLOCK_SIZE_M: tl.constexpr,
229
+ BLOCK_SIZE_N: tl.constexpr,
230
+ BLOCK_SIZE_K: tl.constexpr,
231
+ GROUP_SIZE_M: tl.constexpr,
232
+ ):
233
+ # We want K / 4 to be divisible by BLOCK_SIZE_K so that the multiplication can be aligned
234
+ tl.static_assert(
235
+ K % (4 * BLOCK_SIZE_K) == 0,
236
+ "K / 4 must be divisible by BLOCK_SIZE_K => K divisible by 4*BLOCK_SIZE_K",
237
+ )
238
+ # determine the block id in the 1D grid, pid <=> blockId in cuda
239
+ pid = tl.program_id(axis=0)
240
+ # number of blocks we would need in the M dimension
241
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
242
+ # number of blocks we would need in the N dimension
243
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
244
+ # blocks are grouped along the M dimension. num_pid_in_group computes how many blocks are grouped together,
245
+ # and group_id calculates the group to which the current block (pid) belongs.
246
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
247
+ group_id = pid // num_pid_in_group
248
+
249
+ # pid of the first block in the group that the current block belongs too
250
+ first_pid_m = group_id * GROUP_SIZE_M
251
+
252
+ # pid_m : pid of the block along the M dimension of the output matrix, and pid_n : pid of the block along the N dimension of the output matrix
253
+ # remember that the grid of blocks is 1D, but we calculate pid_m and pid_n to locate the block pid place in the output matrix
254
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
255
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
256
+ pid_n = (pid % num_pid_in_group) // group_size_m
257
+
258
+ # offs_am represent the indices of elements within the block for matrices A with respect to the M dimension
259
+ # offs_bn represent the indices of elements within the block for matrices B with respect to the N dimension
260
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
261
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
262
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
263
+
264
+ """
265
+ This part of the code generates pointers to the specific blocks of matrices A and B that the current thread block will process.
266
+
267
+ As described in the PyTorch documentation, a stride refers to the step size needed to move from one element to the next along a given dimension:
268
+
269
+ For matrix A: stride_am = A.stride(0) = K (stride along the rows), and stride_ak = A.stride(1) = 1 (stride along the columns).
270
+ For matrix B: stride_bk = B.stride(0) = N (stride along the rows), and stride_bn = B.stride(1) = 1 (stride along the columns).
271
+ Now, let's break down the pointer generation:
272
+
273
+ offs_am[:, None] creates a column of shape [BLOCK_SIZE_M, 1], which represents the row indices of matrix A that this block is processing. It is multiplied by K (the number of columns in matrix A) since A is stored in row-major order. So, the element at position (i, j) in A is located at index i*K + j in memory.
274
+ offs_k[None, BLOCK_SIZE_K] creates a row vector representing the column indices of the block, i.e., a range from 0 to BLOCK_SIZE_K. This is used to compute the positions of the columns within the block.
275
+ When combined, the result has the shape [BLOCK_SIZE_M, BLOCK_SIZE_K], where each entry (i, j) points to the element in matrix A at position (i, j) for the current block.
276
+
277
+ The same logic is applied to matrix B, but the resulting shape is [BLOCK_SIZE_K, BLOCK_SIZE_N], representing the block of matrix B that the thread block will work on.
278
+ """
279
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
280
+ b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
281
+
282
+ # An accumulator matrix is initialized with zeros. It stores the intermediate results of the block matrix multiplication.
283
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
284
+ """
285
+ We split the loop into two layers. The outer loop runs 4 times, and each iteration focuses on a specific portion of matrix A.
286
+
287
+ For example, when i = 0, we’re only concerned with the blocks of matrix A that cover the range from 0 to K // (4 * BLOCK_SIZE_K).
288
+ Since matrix B is packed, its first dimension is effectively divided by 4. So, while we process the first segment of matrix A,
289
+ we still iterate over the entire first dimension of matrix B.
290
+
291
+ In each of the 4 iterations of the outer loop, we go through the full blocks of matrix B, but what changes is the data we extract.
292
+ Matrix B elements contain 4 weights, all packed into an int8 format, and during each iteration of the outer loop,
293
+ we extract a different weight by using bitwise shifting operations. This way, we access a unique weight on each pass.
294
+ """
295
+ for i in range(4):
296
+ b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
297
+ for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K)):
298
+ k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j
299
+ # load the block of matrix A
300
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0)
301
+ # load the block of matrix B
302
+ b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K, other=0)
303
+ # when i = 0 for example, we only care about the first 2 bits of the elements of the matrix B, so we use the mask 00000011 to mask the other bits
304
+ mask = 3 << (2 * i)
305
+ # we shift the results after the mask
306
+ b = (b_uint8 & mask) >> (2 * i)
307
+ # During the packing of the weights, it's easier to pack 0, 1, 2, then -1, 0, 1, so we add 1 to the weight tensor, and we substract it here
308
+ tensor_full = tl.full((1,), 1, dtype=tl.int8)
309
+ # We accumulate the result of multiplication of the blocks along the K dimension on int32 to avoid any overflows or underflows.
310
+ accumulator += tl.dot(a, (b.to(tl.int8) - tensor_full), out_dtype=tl.int32)
311
+ # we move the pointers, for the a_ptrs we more in a horizontal way along the second dimension -> we use strid_ak=1
312
+ # for b_ptrs we move in a vertical way, along the rows -> we use stride_bk=N
313
+ a_ptrs += BLOCK_SIZE_K * stride_ak
314
+ b_ptrs += BLOCK_SIZE_K * stride_bk
315
+
316
+ c = accumulator
317
+ # These lines compute the offsets into matrix C where the result of this block’s computation should be stored.
318
+ # stride_cm = N & stride_cn = 1
319
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
320
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
321
+ c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
322
+ # we do a boundary check to ensure only elements within matrix bounds are stored
323
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
324
+ tl.store(c_ptrs, c, mask=c_mask)
325
+
326
+
327
+ def matmul(a, b):
328
+ assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed"
329
+ assert a.is_contiguous(), "Matrix A must be contiguous"
330
+ M, K = a.shape
331
+ _, N = b.shape
332
+ # c is in int32 to avoid any overflows or underflows
333
+ c = torch.empty((M, N), device=a.device, dtype=torch.int32)
334
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
335
+ matmul_kernel[grid](
336
+ a,
337
+ b,
338
+ c,
339
+ M,
340
+ N,
341
+ K,
342
+ a.stride(0),
343
+ a.stride(1),
344
+ b.stride(0),
345
+ b.stride(1),
346
+ c.stride(0),
347
+ c.stride(1),
348
+ )
349
+ return c