quack-kernels 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/gemm_sm100.py CHANGED
@@ -13,6 +13,7 @@ import cutlass.cute as cute
13
13
  from cutlass.cute.nvgpu import cpasync, tcgen05
14
14
  import cutlass.torch as cutlass_torch
15
15
  import cutlass.pipeline as pipeline
16
+ from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
16
17
  import cutlass.utils.blackwell_helpers as sm100_utils
17
18
  import cutlass.utils.blockscaled_layout as blockscaled_utils
18
19
  from cutlass.cute.nvgpu.warp import (
@@ -479,10 +480,12 @@ class GemmSm100(GemmSm90):
479
480
  assert (varlen_args.mAIdx is not None) == self.gather_A
480
481
 
481
482
  # Assume all strides are divisible by 128 bits except the last stride
482
- new_stride = lambda t: tuple(
483
- cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
484
- for s in t.stride
485
- )
483
+ def new_stride(t: cute.Tensor):
484
+ return tuple(
485
+ cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
486
+ for s in t.stride
487
+ )
488
+
486
489
  mA, mD = [
487
490
  cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
488
491
  if t is not None
@@ -631,7 +634,7 @@ class GemmSm100(GemmSm90):
631
634
  a_prefetch_pipeline_array_ptr: cute.struct.MemRange[
632
635
  cutlass.Int64, self.a_prefetch_stage * 2
633
636
  ]
634
- tile_count: cute.struct.MemRange[Int32, self.sched_stage]
637
+ scheduler_data: cute.struct.MemRange[Int32, self.sched_stage * 4]
635
638
  tmem_dealloc_mbar_ptr: cutlass.Int64
636
639
  tmem_holding_buf: Int32
637
640
  sAIdx: cute.struct.Align[cute.struct.MemRange[Int32, a_idx_smem_size], 16]
@@ -758,9 +761,7 @@ class GemmSm100(GemmSm90):
758
761
 
759
762
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
760
763
 
761
- # /////////////////////////////////////////////////////////////////////////////
762
- # Prefetch Tma desc
763
- # /////////////////////////////////////////////////////////////////////////////
764
+ # Prefetch Tma desc
764
765
  if warp_idx == self.ab_load_warp_id:
765
766
  for tma_atom in (
766
767
  tma_atom_a,
@@ -775,9 +776,7 @@ class GemmSm100(GemmSm90):
775
776
 
776
777
  use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
777
778
 
778
- #
779
779
  # Setup cta/thread coordinates
780
- #
781
780
  # Coords inside cluster
782
781
  bidx, _, _ = cute.arch.block_idx()
783
782
  mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
@@ -786,21 +785,10 @@ class GemmSm100(GemmSm90):
786
785
  # Coord inside cta
787
786
  tidx, _, _ = cute.arch.thread_idx()
788
787
 
789
- #
790
788
  # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier
791
- #
792
789
  smem = cutlass.utils.SmemAllocator()
793
790
  storage = smem.allocate(self.shared_storage)
794
791
 
795
- tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr
796
- tmem_holding_buf = storage.tmem_holding_buf
797
-
798
- # Tensor memory dealloc barrier init
799
- if use_2cta_instrs:
800
- if warp_idx == self.ab_load_warp_id:
801
- num_tmem_dealloc_threads = 32
802
- cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads)
803
-
804
792
  # Initialize pipelines and states
805
793
  ab_pipeline = self.make_ab_pipeline(
806
794
  tiled_mma=tiled_mma,
@@ -819,21 +807,36 @@ class GemmSm100(GemmSm90):
819
807
  acc_pipeline_mbar_ptr=storage.acc_pipeline_array_ptr.data_ptr(),
820
808
  )
821
809
  sched_pipeline = None
822
- tile_count = None
823
- if const_expr(tile_sched_params.tile_count_semaphore is not None):
824
- # Dynamic persistent scheduler
810
+ scheduler_data = None
811
+ if const_expr(self.is_persistent):
825
812
  sched_pipeline = self.make_sched_pipeline(
826
813
  self.cluster_shape_mnk,
827
814
  sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
828
815
  has_C=has_C,
829
816
  )
830
- tile_count = storage.tile_count.get_tensor((self.sched_stage,))
817
+ scheduler_data = storage.scheduler_data.get_tensor((4, self.sched_stage))
831
818
  a_prefetch_pipeline = None
832
819
  if const_expr(self.gather_A):
833
820
  a_prefetch_pipeline = self.make_a_prefetch_pipeline(
834
821
  storage.a_prefetch_pipeline_array_ptr.data_ptr(),
835
822
  )
836
823
 
824
+ tmem_alloc_barrier = pipeline.NamedBarrier(
825
+ barrier_id=int(NamedBarrierGemm.TmemPtr),
826
+ num_threads=cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id)),
827
+ )
828
+ # Tensor memory dealloc barrier init
829
+ tmem = cutlass.utils.TmemAllocator(
830
+ storage.tmem_holding_buf,
831
+ barrier_for_retrieve=tmem_alloc_barrier,
832
+ allocator_warp_id=self.epilog_warp_id[0],
833
+ is_two_cta=use_2cta_instrs,
834
+ two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
835
+ )
836
+
837
+ # Cluster arrive after barrier init
838
+ pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True)
839
+
837
840
  # Setup smem tensor A/B/D
838
841
  # (MMA, MMA_M, MMA_K, STAGE)
839
842
  sA_mma = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
@@ -884,22 +887,19 @@ class GemmSm100(GemmSm90):
884
887
  )
885
888
 
886
889
  TileSchedulerCls = partial(
887
- TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
890
+ TileSchedulerCls.create, tile_sched_params, scheduler_data, sched_pipeline
888
891
  )
889
892
 
890
- tmem_alloc_barrier = pipeline.NamedBarrier(
891
- barrier_id=int(NamedBarrierGemm.TmemPtr),
892
- num_threads=cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id)),
893
- )
894
893
  epi_load_barrier = None
895
894
  if const_expr(has_C):
896
895
  epi_load_barrier = pipeline.NamedBarrier(
897
896
  barrier_id=int(NamedBarrierGemm.EpilogueLoad), num_threads=2 * cute.arch.WARP_SIZE
898
897
  )
899
898
 
900
- #
899
+ # Cluster wait before tensor memory alloc
900
+ pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk)
901
+
901
902
  # Specialized AB load warps
902
- #
903
903
  if warp_idx == self.ab_load_warp_id:
904
904
  is_tma_warp = True
905
905
  # initialize tensormap for A & B
@@ -949,9 +949,7 @@ class GemmSm100(GemmSm90):
949
949
  self.b_layout,
950
950
  is_tma_warp,
951
951
  )
952
- # ///////////////////////////////////////////////////////////////////////////
953
- # Local_tile partition global tensors
954
- # ///////////////////////////////////////////////////////////////////////////
952
+ # Local_tile partition global tensors
955
953
  mma_tile_coord_mnl = (
956
954
  tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
957
955
  tile_coord_mnkl[1],
@@ -1093,9 +1091,7 @@ class GemmSm100(GemmSm90):
1093
1091
  while work_tile.is_valid_tile:
1094
1092
  tile_coord_mnkl = work_tile.tile_idx
1095
1093
  batch_idx = tile_coord_mnkl[3]
1096
- # ///////////////////////////////////////////////////////////////////////////
1097
- # Local_tile partition global tensors
1098
- # ///////////////////////////////////////////////////////////////////////////
1094
+ # Local_tile partition global tensors
1099
1095
  mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
1100
1096
  if const_expr(varlen_m):
1101
1097
  # (M, K)
@@ -1153,10 +1149,8 @@ class GemmSm100(GemmSm90):
1153
1149
  tile_scheduler.advance_to_next_work()
1154
1150
  work_tile = tile_scheduler.get_current_work()
1155
1151
 
1156
- #
1157
1152
  # Specialized scheduler warp. Will also prefetch A indices if gatherA
1158
- #
1159
- if const_expr(tile_sched_params.tile_count_semaphore is not None or self.gather_A):
1153
+ if const_expr(self.is_persistent or self.gather_A):
1160
1154
  if warp_idx == self.scheduler_warp_id:
1161
1155
  is_scheduler_warp = True
1162
1156
  if const_expr(cute.size(cluster_layout_vmnk) > 1):
@@ -1172,7 +1166,7 @@ class GemmSm100(GemmSm90):
1172
1166
  cute.make_identity_tensor(tile_M if varlen_m else tile_K)
1173
1167
  )
1174
1168
  # Persistent tile scheduling loop
1175
- tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
1169
+ tile_scheduler = TileSchedulerCls()
1176
1170
  work_tile = tile_scheduler.initial_work_tile_info()
1177
1171
  a_prefetch_producer_state = None
1178
1172
  if const_expr(self.gather_A):
@@ -1190,7 +1184,7 @@ class GemmSm100(GemmSm90):
1190
1184
  tAgAIdx = thr_copy_AIdx.partition_S(gAIdx)
1191
1185
  len_m = varlen_manager.len_m(batch_idx)
1192
1186
  m_limit = len_m - tile_coord_mnkl[0] * tile_M
1193
- tApAIdx_m = cute.make_fragment((1, tAsAIdx.shape[1]), Boolean)
1187
+ tApAIdx_m = cute.make_rmem_tensor((1, tAsAIdx.shape[1]), Boolean)
1194
1188
  for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True):
1195
1189
  tApAIdx_m[0, m] = tAcAIdx[0, m] < m_limit
1196
1190
  a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
@@ -1220,7 +1214,7 @@ class GemmSm100(GemmSm90):
1220
1214
  if 0 < k_tile_cnt:
1221
1215
  k_tile = k_tile_cnt - 1
1222
1216
  k_limit = len_k - k_tile * tile_K
1223
- tApAIdx_k = cute.make_fragment((1, tAsAIdx.shape[1]), Boolean)
1217
+ tApAIdx_k = cute.make_rmem_tensor((1, tAsAIdx.shape[1]), Boolean)
1224
1218
  for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True):
1225
1219
  tApAIdx_k[0, m] = tAcAIdx[0, m] < k_limit
1226
1220
  a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
@@ -1233,16 +1227,13 @@ class GemmSm100(GemmSm90):
1233
1227
  a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
1234
1228
  a_prefetch_producer_state.advance()
1235
1229
  # Advance to next tile
1236
- tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
1237
1230
  tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
1238
1231
  work_tile = tile_scheduler.get_current_work()
1239
1232
  # End of persistent scheduler loop
1240
1233
  if is_scheduler_warp:
1241
1234
  tile_scheduler.producer_tail()
1242
1235
 
1243
- #
1244
1236
  # Specialized TMA epi load warp
1245
- #
1246
1237
  if const_expr(mC_mnl is not None):
1247
1238
  if warp_idx == self.epi_load_warp_id:
1248
1239
  epi_producer_state = pipeline.make_pipeline_state(
@@ -1281,15 +1272,11 @@ class GemmSm100(GemmSm90):
1281
1272
  # End of persistent scheduler loop
1282
1273
  epi_pipeline.producer_tail(epi_producer_state)
1283
1274
 
1284
- #
1285
1275
  # Specialized MMA warp
1286
- #
1287
1276
  if warp_idx == self.mma_warp_id:
1288
- tmem_alloc_barrier.arrive_and_wait()
1289
1277
  # Retrieving tensor memory ptr and make accumulator tensor
1290
- acc_tmem_ptr = cute.arch.retrieve_tmem_ptr(
1291
- self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf
1292
- )
1278
+ tmem.wait_for_alloc()
1279
+ acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
1293
1280
  # Partition shared/tensor memory tensor for TiledMMA_A/B/D
1294
1281
  # (MMA, MMA_M, MMA_K, STAGE)
1295
1282
  tCrA = tiled_mma.make_fragment_A(sA_mma)
@@ -1394,12 +1381,8 @@ class GemmSm100(GemmSm90):
1394
1381
  #
1395
1382
  if warp_idx < self.mma_warp_id:
1396
1383
  # Alloc tensor memory buffer
1397
- if warp_idx == self.epilog_warp_id[0]:
1398
- cute.arch.alloc_tmem(
1399
- self.num_tmem_alloc_cols, tmem_holding_buf, is_two_cta=use_2cta_instrs
1400
- )
1401
- # Bar sync for retrieve tensor memory ptr from shared memory
1402
- tmem_alloc_barrier.arrive_and_wait()
1384
+ tmem.allocate(self.num_tmem_alloc_cols)
1385
+ tmem.wait_for_alloc()
1403
1386
 
1404
1387
  is_tma_warp = Boolean(warp_idx == self.epilog_warp_id[0])
1405
1388
  varlen_manager.init_tensormap_epi(
@@ -1409,9 +1392,7 @@ class GemmSm100(GemmSm90):
1409
1392
  tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
1410
1393
 
1411
1394
  # Retrieving tensor memory ptr and make accumulator tensor
1412
- acc_tmem_ptr = cute.arch.retrieve_tmem_ptr(
1413
- self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf
1414
- )
1395
+ acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
1415
1396
  # (MMA, MMA_M, MMA_N, STAGE)
1416
1397
  tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
1417
1398
 
@@ -1426,7 +1407,7 @@ class GemmSm100(GemmSm90):
1426
1407
  epi_tidx, tCtAcc_base, epi_tile, use_2cta_instrs
1427
1408
  )
1428
1409
 
1429
- tTR_rD = cute.make_fragment(tTR_rAcc.shape, self.acc_dtype)
1410
+ tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.acc_dtype)
1430
1411
  tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
1431
1412
  tiled_copy_t2r, self.d_layout, self.d_dtype, tTR_rD, sD, epi_tidx
1432
1413
  )
@@ -1535,22 +1516,14 @@ class GemmSm100(GemmSm90):
1535
1516
  tile_scheduler.advance_to_next_work()
1536
1517
  work_tile = tile_scheduler.get_current_work()
1537
1518
 
1538
- # Dealloc the tensor memory buffer
1539
- if warp_idx == self.epilog_warp_id[0]:
1540
- cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs)
1541
- epilogue_barrier.arrive_and_wait()
1542
- if warp_idx == self.epilog_warp_id[0]:
1543
- if const_expr(use_2cta_instrs):
1544
- cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1)
1545
- cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0)
1546
- cute.arch.dealloc_tmem(
1547
- acc_tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs
1548
- )
1549
-
1550
1519
  # Wait for D store complete
1551
1520
  if is_tma_warp:
1552
1521
  epi_store_pipeline.producer_tail()
1553
1522
 
1523
+ # Dealloc the tensor memory buffer
1524
+ tmem.relinquish_alloc_permit()
1525
+ tmem.free(acc_tmem_ptr)
1526
+
1554
1527
  @cute.jit
1555
1528
  def load_A_gather_A(
1556
1529
  self,
@@ -1565,9 +1538,7 @@ class GemmSm100(GemmSm90):
1565
1538
  peek_a_empty_status = Boolean(True)
1566
1539
  if 0 < k_tile_cnt:
1567
1540
  peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state)
1568
- # /////////////////////////////////////////////////////////////////////////
1569
1541
  # cp.async on A
1570
- # /////////////////////////////////////////////////////////////////////////
1571
1542
  is_tma_warp = False
1572
1543
  for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
1573
1544
  smem_idx = a_producer_state.index
@@ -1787,7 +1758,7 @@ class GemmSm100(GemmSm90):
1787
1758
  # (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
1788
1759
  tTR_cAcc = thr_copy_t2r.partition_D(cAcc_epi)
1789
1760
  # (T2R, T2R_M, T2R_N)
1790
- tTR_rAcc = cute.make_fragment(tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype)
1761
+ tTR_rAcc = cute.make_rmem_tensor(tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype)
1791
1762
  return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
1792
1763
 
1793
1764
  def epilog_smem_store_and_partition(
@@ -1860,7 +1831,7 @@ class GemmSm100(GemmSm90):
1860
1831
  thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1861
1832
  # (R2S, R2S_M, R2S_N, PIPE_D)
1862
1833
  tSR_sC = thr_copy_s2r.partition_S(sC)
1863
- tRS_rC = cute.make_fragment(tRS_rD_layout, dtype)
1834
+ tRS_rC = cute.make_rmem_tensor(tRS_rD_layout, dtype)
1864
1835
  # (R2S, R2S_M, R2S_N)
1865
1836
  tSR_rC = tiled_copy_s2r.retile(tRS_rC)
1866
1837
  return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
@@ -1901,6 +1872,7 @@ class GemmSm100(GemmSm90):
1901
1872
  consumer_group=ab_pipeline_consumer_group,
1902
1873
  tx_count=self.num_tma_load_bytes,
1903
1874
  cta_layout_vmnk=cluster_layout_vmnk,
1875
+ defer_sync=True,
1904
1876
  )
1905
1877
  else:
1906
1878
  pipeline_ab = PipelineTmaCpAsyncUmma.create(
@@ -1913,6 +1885,7 @@ class GemmSm100(GemmSm90):
1913
1885
  producer_drop_count=None
1914
1886
  if not self.use_2cta_instrs
1915
1887
  else (2 if not is_leader_cta else 0),
1888
+ defer_sync=True,
1916
1889
  )
1917
1890
  return pipeline_ab
1918
1891
 
@@ -1930,6 +1903,7 @@ class GemmSm100(GemmSm90):
1930
1903
  producer_group=acc_pipeline_producer_group,
1931
1904
  consumer_group=acc_pipeline_consumer_group,
1932
1905
  cta_layout_vmnk=cluster_layout_vmnk,
1906
+ defer_sync=True,
1933
1907
  )
1934
1908
 
1935
1909
  def make_sched_pipeline(
@@ -1941,13 +1915,13 @@ class GemmSm100(GemmSm90):
1941
1915
  # Threads/warps participating in this pipeline
1942
1916
  sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1943
1917
  cluster_size = cute.size(cluster_layout_mnk)
1944
- # Each warp that are not the scheduler warp will contribute 1 to the arrive count
1918
+ # Each warp will contribute 1 to the arrive count
1945
1919
  warps_per_cta = self.num_ab_load_warps + len(
1946
1920
  (self.mma_warp_id, *self.epilog_warp_id, self.scheduler_warp_id)
1947
1921
  )
1948
1922
  if has_C:
1949
1923
  warps_per_cta += 1
1950
- consumer_arrive_cnt = warps_per_cta * cluster_size - 1
1924
+ consumer_arrive_cnt = warps_per_cta * cluster_size
1951
1925
  sched_pipeline_consumer_group = pipeline.CooperativeGroup(
1952
1926
  pipeline.Agent.Thread, consumer_arrive_cnt
1953
1927
  )
@@ -1958,6 +1932,7 @@ class GemmSm100(GemmSm90):
1958
1932
  consumer_group=sched_pipeline_consumer_group,
1959
1933
  # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
1960
1934
  consumer_mask=None if const_expr(cluster_size == 1) else 0,
1935
+ defer_sync=True,
1961
1936
  )
1962
1937
 
1963
1938
  @cute.jit
@@ -1965,9 +1940,7 @@ class GemmSm100(GemmSm90):
1965
1940
  self, a_prefetch_pipeline_mbar_ptr: cute.Pointer
1966
1941
  ) -> pipeline.PipelineAsync:
1967
1942
  producer_cnt = 32
1968
- a_prefetch_producer_group = pipeline.CooperativeGroup(
1969
- pipeline.Agent.Thread, producer_cnt, alignment=producer_cnt
1970
- )
1943
+ a_prefetch_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
1971
1944
  consumer_arrive_cnt = self.num_ab_load_warps - 1
1972
1945
  a_prefetch_consumer_group = pipeline.CooperativeGroup(
1973
1946
  pipeline.Agent.Thread, consumer_arrive_cnt
@@ -1977,6 +1950,7 @@ class GemmSm100(GemmSm90):
1977
1950
  num_stages=self.a_prefetch_stage,
1978
1951
  producer_group=a_prefetch_producer_group,
1979
1952
  consumer_group=a_prefetch_consumer_group,
1953
+ defer_sync=True,
1980
1954
  )
1981
1955
 
1982
1956
  @classmethod
@@ -2721,10 +2695,10 @@ def run(
2721
2695
  tflops = flops / (timing * 1e9) # Convert to TFlops
2722
2696
  print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}")
2723
2697
 
2724
- # time.sleep(0.5)
2725
- # timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
2726
- # tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
2727
- # print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
2698
+ time.sleep(0.5)
2699
+ timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats)
2700
+ tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops
2701
+ print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}")
2728
2702
 
2729
2703
 
2730
2704
  if __name__ == "__main__":