quack-kernels 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/utils.py CHANGED
@@ -6,7 +6,7 @@ from typing import Optional, Tuple, Type, Union
6
6
  import cutlass
7
7
  import cutlass.cute as cute
8
8
 
9
- from cutlass import Float32, Int32
9
+ from cutlass import Float32, Int32, const_expr
10
10
  from cutlass.cutlass_dsl import T, dsl_user_op
11
11
  from cutlass._mlir.dialects import llvm, nvvm, vector
12
12
  from cutlass.cute.runtime import from_dlpack
@@ -22,6 +22,59 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te
22
22
  )
23
23
 
24
24
 
25
+ def transpose_view(a: cute.Tensor) -> cute.Tensor:
26
+ """Transpose the first two dimensions of a tensor on smem."""
27
+ shape = (a.shape[1], a.shape[0], *a.shape[2:])
28
+ order = (1, 0, *range(2, cute.rank(a)))
29
+ return cute.composition(a, cute.make_ordered_layout(shape, order=order))
30
+
31
+
32
+ def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
33
+ return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
34
+
35
+
36
+ @dsl_user_op
37
+ def get_copy_atom(
38
+ dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
39
+ ) -> cute.CopyAtom:
40
+ num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
41
+ copy_op = cute.nvgpu.cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
42
+ return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
43
+
44
+
45
+ @dsl_user_op
46
+ def copy(
47
+ src: cute.Tensor,
48
+ dst: cute.Tensor,
49
+ *,
50
+ pred: Optional[cute.Tensor] = None,
51
+ num_copy_elems: int = 1,
52
+ is_async: bool = False,
53
+ loc=None,
54
+ ip=None,
55
+ **kwargs,
56
+ ) -> None:
57
+ copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
58
+ cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
59
+
60
+
61
+ def tiled_copy_2d(
62
+ dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = True
63
+ ) -> cute.TiledCopy:
64
+ num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
65
+ copy_elems = num_copy_bits // dtype.width
66
+ copy_op = cute.nvgpu.cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
67
+ copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
68
+ gmem_threads_per_row = major_mode_size // copy_elems
69
+ assert num_threads % gmem_threads_per_row == 0
70
+ thr_layout = cute.make_ordered_layout(
71
+ (num_threads // gmem_threads_per_row, gmem_threads_per_row),
72
+ order=(1, 0),
73
+ )
74
+ val_layout = cute.make_layout((1, copy_elems))
75
+ return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
76
+
77
+
25
78
  @dsl_user_op
26
79
  def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
27
80
  return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
@@ -29,7 +82,7 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut
29
82
 
30
83
  @cute.jit
31
84
  def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32:
32
- if cutlass.const_expr(isinstance(x, cute.Pointer)):
85
+ if const_expr(isinstance(x, cute.Pointer)):
33
86
  return Float32(cute.make_tensor(x, cute.make_layout(1))[0])
34
87
  else:
35
88
  assert isinstance(x, Float32)
@@ -71,7 +124,7 @@ def store_shared_remote(
71
124
  remote_mbar_ptr_i32 = set_block_rank(
72
125
  mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
73
126
  ).ir_value()
74
- if cutlass.const_expr(isinstance(val, float)):
127
+ if const_expr(isinstance(val, float)):
75
128
  val = Float32(val)
76
129
  assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
77
130
  suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
@@ -196,7 +249,7 @@ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Nu
196
249
  tXrX_fill.fill(fill_value)
197
250
  for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
198
251
  for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
199
- if cutlass.const_expr(tXpX is not None):
252
+ if const_expr(tXpX is not None):
200
253
  if not tXpX[rest_v, 0, rest_k]:
201
254
  cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
202
255
  else:
@@ -232,9 +285,9 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
232
285
  def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
233
286
  flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
234
287
  flat_stride = cute.flatten_to_tuple(tensor.stride)
235
- assert len(flat_coord_i64) == len(
236
- flat_stride
237
- ), "Coordinate and stride must have the same length"
288
+ assert len(flat_coord_i64) == len(flat_stride), (
289
+ "Coordinate and stride must have the same length"
290
+ )
238
291
  offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
239
292
  assert isinstance(tensor.iterator, cute.Pointer)
240
293
  # HACK: we assume that applying the offset does not change the pointer alignment
@@ -265,7 +318,7 @@ def coord_offset_i64(
265
318
 
266
319
  @cute.jit
267
320
  def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
268
- if cutlass.const_expr(lane is None):
321
+ if const_expr(lane is None):
269
322
  lane = cute.arch.lane_idx()
270
323
  for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
271
324
  offset = 1 << i
quack/varlen_utils.py CHANGED
@@ -14,9 +14,4 @@ class VarlenArguments(ArgumentsBase):
14
14
  mCuSeqlensM: Optional[cute.Tensor] = None
15
15
  mCuSeqlensK: Optional[cute.Tensor] = None
16
16
  mTensormaps: Optional[cute.Tensor] = None
17
-
18
- def __post_init__(self):
19
- if self.mCuSeqlensM is not None or self.mCuSeqlensK is not None:
20
- assert (
21
- self.mTensormaps is not None
22
- ), "mTensormaps must be provided if mCuSeqlensM or mCuSeqlensK is provided"
17
+ mAIdx: Optional[cute.Tensor] = None
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Requires-Python: >=3.10
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.2.0
6
+ Requires-Dist: nvidia-cutlass-dsl==4.2.1
7
7
  Requires-Dist: torch
8
8
  Provides-Extra: dev
9
9
  Requires-Dist: pre-commit; extra == "dev"
@@ -1,16 +1,16 @@
1
- quack/__init__.py,sha256=H1m0CnfPidSSmprZeTGJc8LVh7stdBPmPLEuZwgN_7M,364
1
+ quack/__init__.py,sha256=sJum67V7jEQPUDWz4FKJ5Sk7MqmBtbMXjZPVboQnDdE,364
2
2
  quack/activation.py,sha256=SzQDUCB-kccqsy1aYUrHYJ2cGxKMXxxqpjJaJoqBYaE,10017
3
- quack/autotuner.py,sha256=czO6JrYL0EJpOeJOYDSsVdrJaFuwfL3vTdG8QfL1F34,10792
3
+ quack/autotuner.py,sha256=atw0ntedi22RPwSdjWOoge4S56S8VFvRocJQcYhpAlo,13454
4
4
  quack/cross_entropy.py,sha256=TE8j21c-7E4cInKtFjcKsgKXNhKCRFkNfhCJpgpasj8,28409
5
- quack/cute_dsl_utils.py,sha256=D2Pw7rzX9jY8u8wikIPvPvinmFLCDeZg95HPBLqGej4,4635
6
- quack/dense_gemm_sm100.py,sha256=hKBNC34UxdctrTKVP68nvANZl4Dq2rnUjRcweESEq3g,109965
7
- quack/dense_gemm_sm90.py,sha256=TjnjHnjhAwWH5YQWsFlADq07xSxtsprkw_p2Cy0yw7I,100407
5
+ quack/cute_dsl_utils.py,sha256=d8xLD17a9EsSQgmgWDO8rUWWCTRM8e1kDq1wzilaYC8,4563
6
+ quack/dense_gemm_sm90.py,sha256=LvcR178zzzWClkEerhIx940Sg-AF_BpQdnjqC8s9W1o,113832
8
7
  quack/fast_math.py,sha256=E1XUqfUt0_n9BPZNggF-UDzZ6anso9bYUrwqafemWvQ,2297
9
- quack/gemm_act_sm90.py,sha256=N5UAFWZvw1na22Vh5JSGgcdqZ2zI6kQMBVOLxYbCAUU,14332
8
+ quack/gemm_act_sm90.py,sha256=yJEkwCtKjldxzJYq78CpCV6fxoqoZJSpd7KvnglHqfo,16206
10
9
  quack/gemm_config.py,sha256=gbYjPFeyT5wAhVwFQroRHlHoMKEJqAWX9P8wWy04l8Q,2258
11
- quack/gemm_dact_sm90.py,sha256=KCXgjOzdamSDexwrwf_pX2r-ippPRirbClrlU6BP7b8,4990
12
- quack/gemm_interface.py,sha256=_JTpE7zQw6NUw-v65Wql_XUOZBfW0oSEgiMnharTJU4,20501
13
- quack/gemm_wrapper_utils.py,sha256=aMMtu-Ojhtjay_5xJH4AjP-JRVks1AB8jmtNme_DIqU,5960
10
+ quack/gemm_dact_sm90.py,sha256=QOACq-v9XHfY6p5frKzYCvkCbqGDq69beYcfCfl-5Kc,6458
11
+ quack/gemm_interface.py,sha256=qEbQRsvTrwKdLLlGVCMH76diMCKOsA6GqsC0PaepLow,39636
12
+ quack/gemm_sm100.py,sha256=T-2BUrUBXROxQ9Iz-6pB5T8j9go29Vlw4ZCJQ_oM7yg,110396
13
+ quack/gemm_wrapper_utils.py,sha256=oDCXngJuH-qbDI9DJuXkDHUogXleWZrF1mRpI1DAcI8,12687
14
14
  quack/layernorm.py,sha256=AOe95-YqhFPw96x8pJq7FfBe26ROX9ZTvH025lM1ILs,13579
15
15
  quack/linear.py,sha256=SrhRiAFjC7ONIMVmiNu-kSPLHNUyaCXt59a1f_5nNXo,9383
16
16
  quack/linear_cross_entropy.py,sha256=Zhy_gdMsKHOie-jntBaqIuiDJtkiq6qEBwnyuWwIRw4,10092
@@ -18,20 +18,20 @@ quack/mlp.py,sha256=YjdwQRwEePA9KyidFXp5H1-lxiJc8dZ41vl8Fv8pgss,2259
18
18
  quack/pipeline.py,sha256=DyCwZX8WvoUBFcMBz7CeYm9VUM31haEGgBhAzmxu8cE,5519
19
19
  quack/reduce.py,sha256=0hRFMFfn6xC5QLk32Qmgc17XVkQ1yKC-3TfksccSBaU,10341
20
20
  quack/reduction_base.py,sha256=CT-t_j7z8H1ByD9FkQYDRik_-THMDFv9QoXHmr9Xx9E,3636
21
- quack/rmsnorm.py,sha256=PrW2zuaQs_Gr6g8B6DMsGSJFZdEsWf32if_EwUR_IDQ,49386
21
+ quack/rmsnorm.py,sha256=Ak3EL-qzwgaKGZl7O2upiR3FC93776Cgse_B5PZhTu0,45643
22
22
  quack/softmax.py,sha256=WFWtgc40iLPFBpdStBBTC9803Npnv9rZjOzb_nK-RDs,17110
23
23
  quack/symmetric_dense_gemm_sm90.py,sha256=2UXooIpClT2izdyGis1XaIgYYlLj-7MrcOMg2yR7YCk,88694
24
24
  quack/tensormap_manager.py,sha256=Ts3Mxp0_es2RNA0ffvUjWMXN79lsfWEBZ0DQYhtbcnw,5338
25
- quack/tile_scheduler.py,sha256=BQ-SeW5wxulKuwmpq0CAIjkuirv4KWdUdoIGQB88aGE,42319
25
+ quack/tile_scheduler.py,sha256=5lcprf3VIXWCNusWHBCveHpCWRzQ0nzcIMhaQbXher8,41727
26
26
  quack/topk.py,sha256=RQl-23lIicQ9ry9Njur8i0JGem_WbO_Gchr6jy8EtVM,9185
27
- quack/utils.py,sha256=wOgNw9VL40FCsLwN52juPfk48zVpX-rta3MQhAQe8Wc,12767
28
- quack/varlen_utils.py,sha256=vkduMEpo5bJJvZRNnIcKPb6pp1wD34vaIpMIB0ZGIZA,681
27
+ quack/utils.py,sha256=DVMSbMngPBnIRrHuGDXKqVueiNv9DFCfGv076hxzJms,14747
28
+ quack/varlen_utils.py,sha256=GwXc8tO6BrYoYszhSeJ0u_KmreJAEodP1EAizLS-jaA,464
29
29
  quack/sort/bitonic_sort.py,sha256=8t0SG1a6iEpYIlY8YM_AWvm4aN-4AA4vEzdBuJMJm9g,4768
30
30
  quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
31
31
  quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
32
32
  quack/sort/utils.py,sha256=Mkr-l97RMAV-ZoNrwuzA1U3KO0Wjr38CV9Jm7ScyZoI,1090
33
- quack_kernels-0.2.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
34
- quack_kernels-0.2.1.dist-info/METADATA,sha256=_AFigx6aFt-25GzUP6YWalDBwHvwzgK9EU85WjZXvsI,285
35
- quack_kernels-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
- quack_kernels-0.2.1.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
37
- quack_kernels-0.2.1.dist-info/RECORD,,
33
+ quack_kernels-0.2.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
34
+ quack_kernels-0.2.2.dist-info/METADATA,sha256=ZZofR2edTztufmX_0ExiJ7CpFsT80koJf-pRRUm3ssg,285
35
+ quack_kernels-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
+ quack_kernels-0.2.2.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
37
+ quack_kernels-0.2.2.dist-info/RECORD,,