tilelang-rocm 0.1.4.post9__cp310-cp310-manylinux1_x86_64.whl → 0.1.4.post10__cp310-cp310-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tilelang/VERSION CHANGED
@@ -1 +1 @@
1
- 0.1.4.post9
1
+ 0.1.4.post10
@@ -64,6 +64,12 @@ def shared_16x16_to_local_64x4_layout_B(i, j):
64
64
  return thread_id, local
65
65
 
66
66
 
67
+ shared_16x16_to_local_64x4_layout_m_n = shared_16x16_to_local_64x4_layout_A
68
+ shared_16x16_to_local_64x4_layout_n_k = shared_16x16_to_local_64x4_layout_A
69
+ shared_16x16_to_local_64x4_layout_n_m = shared_16x16_to_local_64x4_layout_B
70
+ shared_16x16_to_local_64x4_layout_k_n = shared_16x16_to_local_64x4_layout_B
71
+
72
+
67
73
  def thread_id_shared_access_64x4_to_16x16_layout_C_m_n(thread_id, local_id):
68
74
  i = local_id + (thread_id // 16) * 4
69
75
  j = thread_id % 16
@@ -100,6 +106,30 @@ def shared_16x32_to_local_64x8_layout_B(i, j):
100
106
  return thread_id, local
101
107
 
102
108
 
109
+ def thread_id_shared_access_64x16_to_16x64_layout_A(thread_id, local_id):
110
+ i = thread_id % 16
111
+ j = local_id + (thread_id // 16) * 16
112
+ return i, j
113
+
114
+
115
+ def shared_16x64_to_local_64x16_layout_A(i, j):
116
+ thread_id = i + 16 * (j // 16)
117
+ local = (j % 16)
118
+ return thread_id, local
119
+
120
+
121
+ def thread_id_shared_access_64x16_to_16x64_layout_B(thread_id, local_id):
122
+ i = local_id + (thread_id // 16) * 16
123
+ j = thread_id % 16
124
+ return i, j
125
+
126
+
127
+ def shared_16x64_to_local_64x16_layout_B(i, j):
128
+ thread_id = i + 16 * (j // 16)
129
+ local = (j % 16)
130
+ return thread_id, local
131
+
132
+
103
133
  def make_mfma_swizzle_layout(shared_buf, vecSize=8):
104
134
  dtype = shared_buf.dtype
105
135
  shape = shared_buf.shape
@@ -150,12 +150,16 @@ class MatrixCoreIntrinEmitter(object):
150
150
  shared_16x16_to_local_64x4_layout_B,
151
151
  shared_16x32_to_local_64x8_layout_A,
152
152
  shared_16x32_to_local_64x8_layout_B,
153
+ shared_16x64_to_local_64x16_layout_A,
154
+ shared_16x64_to_local_64x16_layout_B,
153
155
  thread_id_shared_access_64x1_to_16x4_layout_A,
154
156
  thread_id_shared_access_64x1_to_4x16_layout_B,
155
157
  thread_id_shared_access_64x4_to_16x16_layout_A,
156
158
  thread_id_shared_access_64x4_to_16x16_layout_B,
157
159
  thread_id_shared_access_64x8_to_16x32_layout_A,
158
160
  thread_id_shared_access_64x8_to_16x32_layout_B,
161
+ thread_id_shared_access_64x16_to_16x64_layout_A,
162
+ thread_id_shared_access_64x16_to_16x64_layout_B,
159
163
  )
160
164
 
161
165
  k_dim = self.k_dim * self.k_pack
@@ -180,8 +184,15 @@ class MatrixCoreIntrinEmitter(object):
180
184
  if is_b:
181
185
  index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B
182
186
  reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B
187
+ elif k_dim == 64:
188
+ index_map = shared_16x64_to_local_64x16_layout_B if transposed else shared_16x64_to_local_64x16_layout_A
189
+ reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_B if transposed else thread_id_shared_access_64x16_to_16x64_layout_A
190
+
191
+ if is_b:
192
+ index_map = shared_16x64_to_local_64x16_layout_A if transposed else shared_16x64_to_local_64x16_layout_B
193
+ reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_A if transposed else thread_id_shared_access_64x16_to_16x64_layout_B
183
194
  else:
184
- raise ValueError("k_dim must be 4 or 16 currently")
195
+ raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently")
185
196
 
186
197
  return index_map, reverse_index_map
187
198
 
@@ -240,7 +251,7 @@ class MatrixCoreIntrinEmitter(object):
240
251
  for i in T.serial(warp_rows):
241
252
  for local_id in T.vectorized(k_pack * local_size_a):
242
253
  row, col = T.meta_var(reverse_index_map(tx, local_id))
243
- l, r = (rk * chunk + ki * micro_size_k,
254
+ l, r = (rk * chunk + ki * (k_pack * micro_size_k),
244
255
  warp_m * warp_row_tiles + i * micro_size_x)
245
256
  A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
246
257
  r + col]
@@ -249,7 +260,7 @@ class MatrixCoreIntrinEmitter(object):
249
260
  for local_id in T.vectorized(k_pack * local_size_a):
250
261
  row, col = T.meta_var(reverse_index_map(tx, local_id))
251
262
  l, r = (warp_m * warp_row_tiles + i * micro_size_x,
252
- rk * chunk + ki * micro_size_k)
263
+ rk * chunk + ki * (k_pack * micro_size_k))
253
264
  A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
254
265
  r + col]
255
266
 
@@ -284,7 +295,7 @@ class MatrixCoreIntrinEmitter(object):
284
295
  row, col = T.meta_var(reverse_index_map(tx, local_id))
285
296
  l, r = (
286
297
  warp_n * warp_col_tiles + j * micro_size_y,
287
- rk * chunk + ki * micro_size_k,
298
+ rk * chunk + ki * (k_pack * micro_size_k),
288
299
  )
289
300
  B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
290
301
  r + col]
@@ -293,7 +304,7 @@ class MatrixCoreIntrinEmitter(object):
293
304
  for local_id in T.vectorized(k_pack * local_size_b):
294
305
  row, col = T.meta_var(reverse_index_map(tx, local_id))
295
306
  l, r = (
296
- rk * chunk + ki * micro_size_k,
307
+ rk * chunk + ki * (k_pack * micro_size_k),
297
308
  warp_n * warp_col_tiles + j * micro_size_y,
298
309
  )
299
310
  B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
@@ -356,7 +367,7 @@ class MatrixCoreIntrinEmitter(object):
356
367
  def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding):
357
368
  tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
358
369
  for i, j in T.grid(warp_rows, warp_cols):
359
- for local_id in T.serial(local_size_out):
370
+ for local_id in T.vectorized(local_size_out):
360
371
  row, col = T.meta_var(mfma_store_index_map(tx, local_id))
361
372
  C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row,
362
373
  col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
@@ -366,7 +377,7 @@ class MatrixCoreIntrinEmitter(object):
366
377
  def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
367
378
  tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
368
379
  for i, j in T.grid(warp_rows, warp_cols):
369
- for local_id in T.serial(local_size_out):
380
+ for local_id in T.vectorized(local_size_out):
370
381
  row, col = T.meta_var(mfma_store_index_map(tx, local_id))
371
382
  C_buf[(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row,
372
383
  (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM +
tilelang/jit/__init__.py CHANGED
@@ -181,6 +181,9 @@ class _JitImplementation:
181
181
  else:
182
182
  raise ValueError(f"Invalid function type: {type(program_result_source)}")
183
183
 
184
+ if self.verbose:
185
+ logger.info(f"Verbose: Compiling for program \n {program_result.script()}")
186
+
184
187
  kernel_result = compile(
185
188
  program_result,
186
189
  out_idx=self.out_idx,
Binary file
Binary file
tilelang/lib/libtvm.so CHANGED
Binary file
Binary file
@@ -66,7 +66,7 @@ using float16x16 =
66
66
 
67
67
  using half_t = float16_t;
68
68
 
69
- using bfloat16_t = __hip_bfloat16;
69
+ using bfloat16_t = hip_bfloat16;
70
70
 
71
71
  struct bfloat16x2 {
72
72
  bfloat16_t data[2];
@@ -19,16 +19,16 @@ template <> struct MfmaTraits<half> {
19
19
  }
20
20
  };
21
21
 
22
- // Specialization for __hip_bfloat16
23
- template <> struct MfmaTraits<__hip_bfloat16> {
22
+ // Specialization for bfloat16_t
23
+ template <> struct MfmaTraits<bfloat16_t> {
24
24
  template <typename AccType>
25
- static TL_DEVICE void mfma_op(const __hip_bfloat16 *b,
26
- const __hip_bfloat16 *a, AccType *c) {
25
+ static TL_DEVICE void mfma_op(const bfloat16_t *b, const bfloat16_t *a,
26
+ AccType *c) {
27
27
  bfloat16x4_vec b_vec, a_vec;
28
28
 
29
29
  // Reinterpret the pointers
30
- short *b_short = reinterpret_cast<short *>(const_cast<__hip_bfloat16 *>(b));
31
- short *a_short = reinterpret_cast<short *>(const_cast<__hip_bfloat16 *>(a));
30
+ short *b_short = reinterpret_cast<short *>(const_cast<bfloat16_t *>(b));
31
+ short *a_short = reinterpret_cast<short *>(const_cast<bfloat16_t *>(a));
32
32
 
33
33
  // Copy the data
34
34
  for (int i = 0; i < 4; ++i) {
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tilelang-rocm
3
- Version: 0.1.4.post9
3
+ Version: 0.1.4.post10
4
4
  Summary: A tile level programming language to generate high performance code.
5
5
  Home-page: https://github.com/tile-ai/tilelang
6
6
  Author: Microsoft Research
@@ -1,7 +1,7 @@
1
1
  tilelang/CMakeLists.txt,sha256=xJhnusYZI4UhD_fzseGH3Tn2BeovUzz3aWUwPq-WU0Y,7010
2
2
  tilelang/LICENSE,sha256=v9fVeAgRKQXc5ySwTns767gj0-dHN9XYPpGURkAVAXs,1127
3
3
  tilelang/README.md,sha256=1RC_2IUBY-p0BR-d2xkNXC8zrva8-U3AVkmCozkssbY,11924
4
- tilelang/VERSION,sha256=jHyORphLiyy1Y7K0gWps7qDNJc-NPepcM6_C-_a_2bI,12
4
+ tilelang/VERSION,sha256=Il_uPi-DubG_kcKlTV5qP75Dq4zOcrugFwCAygQ4SQU,13
5
5
  tilelang/__init__.py,sha256=yH0BknCRnFQN-E7d6p1HPNbeY4o3COqG7XzR_EJpbTo,3215
6
6
  tilelang/_ffi_api.py,sha256=D-HfDxx8EZq6qItftg-ejOhpC_smIZLN-pWPVCNX_UM,243
7
7
  tilelang/config.cmake,sha256=370i6N3wwi7-LPGZDBtiiu54UWp39ndD-9lCurLhHwI,14330
@@ -6331,12 +6331,12 @@ tilelang/engine/lower.py,sha256=OJX7d_qk6WXA9VDF-fTTKw1q-ZnBkn4tuD9lA4TPdxk,8923
6331
6331
  tilelang/engine/param.py,sha256=5eWc48aao84WIrbtaLuAYGrb3RE1SyiG6hIy8DwlryI,3943
6332
6332
  tilelang/engine/phase.py,sha256=SPruc1tsC_yFD9q93pilkFmzY3UZ2YJTCRH06YPDI1Q,7154
6333
6333
  tilelang/intrinsics/__init__.py,sha256=ymvtsKjVY0f_9k-QIMtO4CEh6hEnG7H4NiW3buNLVQg,501
6334
- tilelang/intrinsics/mfma_layout.py,sha256=KSA2TFy9XB9lC-mxFgqJX7ZuHNrsJLLJriVYbf8-wSc,3168
6335
- tilelang/intrinsics/mfma_macro_generator.py,sha256=ws3Tyd0Q7laSbIcv-0Ff1YF80P_G4mooM4ir_aqiPTs,16783
6334
+ tilelang/intrinsics/mfma_layout.py,sha256=O5jh8gfILH-ASDss0pvTZhJq2jgmcyHwRPzigxGoGW8,4056
6335
+ tilelang/intrinsics/mfma_macro_generator.py,sha256=axkPrTN-Lplt0xFMk-ckQu8Qynve-lNl6thQCIwY1Q0,17653
6336
6336
  tilelang/intrinsics/mma_layout.py,sha256=eHFiNKd3zKzNFuRrpZdEQx0apbHWj8Ak6Q3e9_CeDiM,5090
6337
6337
  tilelang/intrinsics/mma_macro_generator.py,sha256=BVx3Bt3K67XSTp-Op425OrPxmtD43jyLre0wY8AeW-w,44478
6338
6338
  tilelang/intrinsics/utils.py,sha256=dbQpWOy0F4rg3WotzHQToPtJgY2BLtRy1CKkSnrs--k,4243
6339
- tilelang/jit/__init__.py,sha256=KYrkfRFdQOyyF55iJKSo_IGOwXZNegLMV3Bi7WXOhLY,11901
6339
+ tilelang/jit/__init__.py,sha256=kUDaEwa4Xmr2lvaEI6SvFH4R8wWZWkTLbfgb0iPaEoE,12031
6340
6340
  tilelang/jit/env.py,sha256=Ih2qp1fTLcA0zuAJfkq4-yAQSXu3VCil_XjL-ndKYtc,2604
6341
6341
  tilelang/jit/kernel.py,sha256=B-j3H8kFYalZ3SENKQ3SKbN16Grlr0CKFB2RBZlawA4,15284
6342
6342
  tilelang/jit/param.py,sha256=GIIvPUyXSwi8mY_TMrZGOB9pcg7-hycz8PQNjNQ-JpU,1218
@@ -6383,10 +6383,10 @@ tilelang/layout/__init__.py,sha256=F1wr9yBG9GW84h8KWXz-hRJFfqyZuY0EKSrG08KyrWQ,2
6383
6383
  tilelang/layout/fragment.py,sha256=zTv9P96lsYi9BWc5pxR4PA2Z5RSDGP7D5uJCiNw7_oc,8445
6384
6384
  tilelang/layout/layout.py,sha256=20CWxz_S8k_WNvWiR4gdIrEsQ36e5bsnOEqmu4zGk_c,4311
6385
6385
  tilelang/layout/swizzle.py,sha256=PMqu_s1sNCh9uo8eDs5qmLKXnDqZwv34GT3H9D4YDO0,438
6386
- tilelang/lib/libtilelang.so,sha256=nz9QoBmoqWOZAn68IDvMhMBk_c_etYJf9IathdtijJE,5040880
6387
- tilelang/lib/libtilelang_module.so,sha256=FzKrdpHtsQGwLGD4pQKQp9-jhROAkrNh3JtslXs1EnY,5040880
6388
- tilelang/lib/libtvm.so,sha256=OVNSaJ79WYV-AyY4NNqW4PAnQjS_4uCh6ZibrlzTN_c,83980800
6389
- tilelang/lib/libtvm_runtime.so,sha256=77cdY-xJBotDchdDrqJEIT0Bhs5qw4XkxqGhORhpa3A,4794616
6386
+ tilelang/lib/libtilelang.so,sha256=gWKpADoERYHrDu-eOXjYtG8ak-dvv_2cs8WDiFYnc8o,5040880
6387
+ tilelang/lib/libtilelang_module.so,sha256=eStBjzb0e9y3lZVc21EefE1_nryt0Cs0wlYyz32QifI,5040880
6388
+ tilelang/lib/libtvm.so,sha256=7qy1XeUSGm4YXJC22LUmLZMfaGEPdTCXVxq3pkqS_ZY,83982112
6389
+ tilelang/lib/libtvm_runtime.so,sha256=BgUg1x-f2q3urmCbmBrbUtVFmS3SUyI6dBSAVP_yi8k,4794680
6390
6390
  tilelang/math/__init__.py,sha256=JC4fqrU_LV_wDErti-wHNr4j6_mqP1PsK0qqkhaSzRU,209
6391
6391
  tilelang/primitives/__init__.py,sha256=10gQN3QWUFM1nkGXY46QFcWUXxwsKMsVn23JdyFHil4,167
6392
6392
  tilelang/primitives/gemm/__init__.py,sha256=j62ObmbL5Q6m3lSouNBQDk1hZZRnSp4UNNCCaSlKYXU,1658
@@ -6416,10 +6416,10 @@ tilelang/src/tl_templates/cuda/gemm_sm90.h,sha256=S3v63snxR_3TEg8LLCoR6cqGVgii8k
6416
6416
  tilelang/src/tl_templates/cuda/ldsm.h,sha256=TxCxYVzUK4tvUNVqULCL5HEaAuW9vOv0_-QYmoRFUkM,5053
6417
6417
  tilelang/src/tl_templates/cuda/reduce.h,sha256=U9mKcHSttin1FQ0BohpaP0bHvgPvb3t-czwTuDeK5-8,4394
6418
6418
  tilelang/src/tl_templates/cuda/threadblock_swizzle.h,sha256=GIXQwC1gzwUhnq4CzORHh5hA_QHVfMrOcUeGTy1Fon8,1945
6419
- tilelang/src/tl_templates/hip/common.h,sha256=_JyiapMcqlUuwYaqNbgO52_mZgQxWz6cqW9au5aKzN0,3593
6419
+ tilelang/src/tl_templates/hip/common.h,sha256=KzMntW2OlABMwYjlxNqKiBDZQXQNrYApeVf8F9548s4,3591
6420
6420
  tilelang/src/tl_templates/hip/copy.h,sha256=fGHkbe4ReXoEtIWrgQ-mlCycaIL65SvNGWK1OJZdUQo,3324
6421
6421
  tilelang/src/tl_templates/hip/debug.h,sha256=9xGr4ka5x_nvY55XwbgTJFFwEnd09ta9jAZwjHyQau0,8231
6422
- tilelang/src/tl_templates/hip/gemm.h,sha256=opIRbjhIXRYmzdi4vw8vqoVCZYEQXPBM0ckakgiezNg,11634
6422
+ tilelang/src/tl_templates/hip/gemm.h,sha256=lYeOjV8OG2oZbcS7ByzOudE7i0FQJ71mrUcImkfhTrg,11610
6423
6423
  tilelang/src/tl_templates/hip/hip_fp8.h,sha256=JYGiuuroLQH7CXT7IdKcpNUECmGOTe8DjIjcS9eLc0U,377
6424
6424
  tilelang/src/tl_templates/hip/ldsm.h,sha256=gRx_bSdsCsgcVumwUJwOnv4HuHruU2kC9TE9x_jo8k0,106
6425
6425
  tilelang/src/tl_templates/hip/reduce.h,sha256=-VKpG-TNbzPHIqsSReYpqZoM-oXFzIx6fMeBieV26Kc,1372
@@ -6437,8 +6437,8 @@ tilelang/utils/deprecated.py,sha256=CiZ9y_76_dZ24SFDdasDiLmibwi6xO2Gdj6WzTWU0Qg,
6437
6437
  tilelang/utils/language.py,sha256=KUzUZ8Z2x1np0Hu_MrjWOIcRrVAZHX90li1Xw9fYZXY,3291
6438
6438
  tilelang/utils/target.py,sha256=P-74pdCLWcp2MZMQUoPIFwKF1NZ1QT-L0VroIL8m2to,2486
6439
6439
  tilelang/utils/tensor.py,sha256=SZ4ewoJ-Mq3zg8zIHS7-XLUmYDdlNwh841yUkjnQtNU,12573
6440
- tilelang_rocm-0.1.4.post9.dist-info/licenses/LICENSE,sha256=v9fVeAgRKQXc5ySwTns767gj0-dHN9XYPpGURkAVAXs,1127
6441
- tilelang_rocm-0.1.4.post9.dist-info/METADATA,sha256=0JNlrV6aGKHvKKiSYB1qgu8R4cnfceXp3lxj70wX5Ng,13075
6442
- tilelang_rocm-0.1.4.post9.dist-info/WHEEL,sha256=0-G7woG4LgutcYzUGJCOYFgoh749-FtfhSMeIPLVGS0,104
6443
- tilelang_rocm-0.1.4.post9.dist-info/top_level.txt,sha256=qvMq-AYkDVggI-9VIAzCe5CXHl66IEWj7J29-JbuFsI,21
6444
- tilelang_rocm-0.1.4.post9.dist-info/RECORD,,
6440
+ tilelang_rocm-0.1.4.post10.dist-info/licenses/LICENSE,sha256=v9fVeAgRKQXc5ySwTns767gj0-dHN9XYPpGURkAVAXs,1127
6441
+ tilelang_rocm-0.1.4.post10.dist-info/METADATA,sha256=aBSwulXbZ0_bU2cIYKNCuRsUM2qlAsJ6O14B7fPJUoQ,13076
6442
+ tilelang_rocm-0.1.4.post10.dist-info/WHEEL,sha256=0-G7woG4LgutcYzUGJCOYFgoh749-FtfhSMeIPLVGS0,104
6443
+ tilelang_rocm-0.1.4.post10.dist-info/top_level.txt,sha256=qvMq-AYkDVggI-9VIAzCe5CXHl66IEWj7J29-JbuFsI,21
6444
+ tilelang_rocm-0.1.4.post10.dist-info/RECORD,,