quack-kernels 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.8"
1
+ __version__ = "0.1.10"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
quack/cross_entropy.py CHANGED
@@ -446,13 +446,10 @@ class CrossEntropyBackward:
446
446
  log2_e = math.log2(math.e)
447
447
  probs = utils.exp2f((x - lse) * log2_e)
448
448
  prob_shifted = probs - 1.0
449
-
450
449
  mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
451
- for i in cutlass.range_constexpr(cute.size(tXcFull)):
450
+ for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
452
451
  mask[i] = tXcFull[i][1] == label
453
-
454
- mask = mask.load()
455
- grad = cute.where(mask, prob_shifted, probs)
452
+ grad = cute.where(mask.load(), prob_shifted, probs)
456
453
  grad = grad * dloss
457
454
 
458
455
  tXrO.store(grad.to(tXrO.element_type))