quack-kernels 0.1.9__tar.gz → 0.1.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. {quack_kernels-0.1.9/quack_kernels.egg-info → quack_kernels-0.1.10}/PKG-INFO +3 -3
  2. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/pyproject.toml +3 -2
  3. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/quack/__init__.py +1 -1
  4. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/quack/cross_entropy.py +2 -5
  5. quack_kernels-0.1.10/quack/dense_gemm_sm90.py +1430 -0
  6. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/quack/utils.py +1 -1
  7. {quack_kernels-0.1.9 → quack_kernels-0.1.10/quack_kernels.egg-info}/PKG-INFO +3 -3
  8. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/quack_kernels.egg-info/SOURCES.txt +1 -1
  9. quack_kernels-0.1.10/quack_kernels.egg-info/requires.txt +6 -0
  10. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/tests/test_rmsnorm.py +7 -184
  11. quack_kernels-0.1.9/quack_kernels.egg-info/requires.txt +0 -6
  12. quack_kernels-0.1.9/setup.py +0 -3
  13. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/LICENSE +0 -0
  14. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/README.md +0 -0
  15. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/quack/layernorm.py +0 -0
  16. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/quack/reduction_base.py +0 -0
  17. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/quack/rmsnorm.py +0 -0
  18. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/quack/softmax.py +0 -0
  19. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/quack_kernels.egg-info/dependency_links.txt +0 -0
  20. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/quack_kernels.egg-info/top_level.txt +0 -0
  21. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/setup.cfg +0 -0
  22. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/tests/test_cross_entropy.py +0 -0
  23. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/tests/test_layernorm.py +0 -0
  24. {quack_kernels-0.1.9 → quack_kernels-0.1.10}/tests/test_softmax.py +0 -0
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.9
4
- Requires-Python: >=3.9
3
+ Version: 0.1.10
4
+ Requires-Python: >=3.12
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
6
+ Requires-Dist: nvidia-cutlass-dsl==4.1.0
7
7
  Requires-Dist: torch
8
8
  Provides-Extra: dev
9
9
  Requires-Dist: pre-commit; extra == "dev"
@@ -5,9 +5,9 @@ build-backend = "setuptools.build_meta"
5
5
  [project]
6
6
  name = "quack-kernels"
7
7
  dynamic = ["version"]
8
- requires-python = ">=3.9"
8
+ requires-python = ">=3.12"
9
9
  dependencies = [
10
- "nvidia-cutlass-dsl==4.1.0.dev0",
10
+ "nvidia-cutlass-dsl==4.1.0",
11
11
  "torch",
12
12
  ]
13
13
 
@@ -29,5 +29,6 @@ line-length = 100
29
29
  [tool.ruff.lint]
30
30
  ignore = [
31
31
  "E731", # do not assign a lambda expression, use a def
32
+ "E741", # Do not use variables named 'I', 'O', or 'l'
32
33
  "F841", # local variable is assigned to but never used
33
34
  ]
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.9"
1
+ __version__ = "0.1.10"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
@@ -446,13 +446,10 @@ class CrossEntropyBackward:
446
446
  log2_e = math.log2(math.e)
447
447
  probs = utils.exp2f((x - lse) * log2_e)
448
448
  prob_shifted = probs - 1.0
449
-
450
449
  mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
451
- for i in cutlass.range_constexpr(cute.size(tXcFull)):
450
+ for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
452
451
  mask[i] = tXcFull[i][1] == label
453
-
454
- mask = mask.load()
455
- grad = cute.where(mask, prob_shifted, probs)
452
+ grad = cute.where(mask.load(), prob_shifted, probs)
456
453
  grad = grad * dloss
457
454
 
458
455
  tXrO.store(grad.to(tXrO.element_type))