flash-hog 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,42 @@
1
+ Metadata-Version: 2.3
2
+ Name: flash-hog
3
+ Version: 0.1.0
4
+ Summary: Add your description here
5
+ Requires-Dist: absl-py>=2.4.0
6
+ Requires-Dist: chex>=0.1.91
7
+ Requires-Dist: einops>=0.8.1
8
+ Requires-Dist: equinox>=0.13.2
9
+ Requires-Dist: jax[cuda13]>=0.8.0
10
+ Requires-Dist: nvidia-cutlass-dsl>=4.2.1
11
+ Requires-Dist: pytest>=8.4.2
12
+ Requires-Dist: ruff>=0.14.2
13
+ Requires-Dist: torch>=2.9.0
14
+ Requires-Dist: ty>=0.0.1a24
15
+ Requires-Python: >=3.12, <3.14
16
+ Description-Content-Type: text/markdown
17
+
18
+ # Flash Hog
19
+ <p align="center">
20
+ <img src="assets/logo.png" alt="Flash Hog Logo" width="256" />
21
+ </p>
22
+
23
+ This repo contains the code for Flash Higher-Order-Gradients, aka. Flash Hog.
24
+ This kernel achieves around a 3.7x speedup over an XLA optimized kernel, with linear memory scaling instead of quadratic scaling.
25
+
26
+ <p align="center">
27
+ <img src="assets/speedup.png" alt="Hog Speedup" width="512"/>
28
+ </p>
29
+
30
+ ## Installation
31
+ TODO
32
+
33
+ ## Method
34
+ Flash Hog does 4 recomputation passes to avoid any atomics or saving any intermediary tensors of shape `(N_Q, N_K)`.
35
+ This shakes out to be thread-wise tiling across Q in 3 passes first, once to compute `dd`, then once for `b`, then once for both `dQ'` and `ddO`.
36
+ Finally we do another pass tiled over K, producing `dK'` and `dV'`.
37
+ The equations we implement are the following:
38
+
39
+
40
+ <p align="center">
41
+ <img src="assets/handwritten_equations.png" alt="Equations" width="512"/>
42
+ </p>
@@ -0,0 +1,25 @@
1
+ # Flash Hog
2
+ <p align="center">
3
+ <img src="assets/logo.png" alt="Flash Hog Logo" width="256" />
4
+ </p>
5
+
6
+ This repo contains the code for Flash Higher-Order-Gradients, aka. Flash Hog.
7
+ This kernel achieves around a 3.7x speedup over an XLA optimized kernel, with linear memory scaling instead of quadratic scaling.
8
+
9
+ <p align="center">
10
+ <img src="assets/speedup.png" alt="Hog Speedup" width="512"/>
11
+ </p>
12
+
13
+ ## Installation
14
+ TODO
15
+
16
+ ## Method
17
+ Flash Hog does 4 recomputation passes to avoid any atomics or saving any intermediary tensors of shape `(N_Q, N_K)`.
18
+ This shakes out to be thread-wise tiling across Q in 3 passes first, once to compute `dd`, then once for `b`, then once for both `dQ'` and `ddO`.
19
+ Finally we do another pass tiled over K, producing `dK'` and `dV'`.
20
+ The equations we implement are the following:
21
+
22
+
23
+ <p align="center">
24
+ <img src="assets/handwritten_equations.png" alt="Equations" width="512"/>
25
+ </p>
@@ -0,0 +1,3 @@
1
+ from fhog.triton_bwdbwd import flash_bwdbwd
2
+
3
+ __all__ = ["flash_bwdbwd"]
@@ -0,0 +1,110 @@
1
+ from pathlib import Path
2
+
3
+ import jax.numpy as jnp
4
+ import jax.random as jrandom
5
+ import torch
6
+ import triton
7
+
8
+ from fhog.jax_refs.jax_impl import attn_bwd_bwd
9
+ from fhog.triton_bwdbwd import flash_bwdbwd
10
+
11
+ N_REPS = 5000
12
+ N_WARMUP = 300
13
+
14
+
15
+ def full_benchmark(nq, nkv):
16
+ d_in = 64
17
+ d_out = 64
18
+
19
+ q = torch.randn((nq, d_in), device="cuda", dtype=torch.float16)
20
+ k = torch.randn((nkv, d_in), device="cuda", dtype=torch.float16)
21
+ v = torch.randn((nkv, d_out), device="cuda", dtype=torch.float16)
22
+ o = torch.randn((nq, d_out), device="cuda", dtype=torch.float16)
23
+ l = torch.randn((nq,), device="cuda", dtype=torch.float16)
24
+ do = torch.randn((nq, d_out), device="cuda", dtype=torch.float16)
25
+ ddq = torch.randn((nq, d_in), device="cuda", dtype=torch.float16)
26
+ ddk = torch.randn((nkv, d_in), device="cuda", dtype=torch.float16)
27
+ ddv = torch.randn((nkv, d_out), device="cuda", dtype=torch.float16)
28
+
29
+ def jax_benchmark():
30
+ # nq, d_in = q.shape
31
+ # nkv, _ = k.shape
32
+ # _, d_out = v.shape
33
+
34
+ scale = jnp.array(1.0 / jnp.sqrt(d_in), dtype=jnp.float32)
35
+
36
+ # assert nq == 128
37
+ # assert nkv == 256
38
+ # assert d_in == 64
39
+ # assert d_out == 64
40
+
41
+ def do_run_jax():
42
+ # attn_bwd_bwd_stats(
43
+ attn_bwd_bwd(
44
+ # stats=jnp.asarray(l),
45
+ q=jnp.asarray(q),
46
+ k=jnp.asarray(k),
47
+ v=jnp.asarray(v),
48
+ # o=jnp.asarray(o),
49
+ do=jnp.asarray(do),
50
+ ddq=jnp.asarray(ddq),
51
+ ddk=jnp.asarray(ddk),
52
+ ddv=jnp.asarray(ddv),
53
+ scale=jnp.asarray(scale),
54
+ is_causal=False,
55
+ sliding_window_length=jnp.inf,
56
+ )
57
+
58
+ results = triton.testing.do_bench(do_run_jax, rep=N_REPS, warmup=N_WARMUP)
59
+ return results
60
+
61
+ def trition_benchmark():
62
+ def do_run_triton():
63
+ triton_dq2, triton_dk2, triton_dv2, triton_ddo = flash_bwdbwd(
64
+ q.unsqueeze(0),
65
+ k.unsqueeze(0),
66
+ v.unsqueeze(0),
67
+ o.unsqueeze(0),
68
+ do.unsqueeze(0),
69
+ ddq.unsqueeze(0),
70
+ ddk.unsqueeze(0),
71
+ ddv.unsqueeze(0),
72
+ l.unsqueeze(0),
73
+ 1 / d_in**0.5,
74
+ )
75
+
76
+ results = triton.testing.do_bench(do_run_triton, rep=N_REPS, warmup=N_WARMUP)
77
+ return results
78
+
79
+ try:
80
+ jax_time = jax_benchmark()
81
+ except Exception as e:
82
+ print(f"JAX benchmark failed for nq={nq}, nkv={nkv} with error: {e}")
83
+ jax_time = float("inf")
84
+ triton_time = trition_benchmark()
85
+ print(f"{nq}: {jax_time=} {triton_time=} => speedup: {jax_time / triton_time:.2f}x")
86
+ return jax_time, triton_time
87
+
88
+
89
+ if __name__ == "__main__":
90
+ # nq = 512
91
+ # nkv = 1024
92
+ # for siz in
93
+ for size in [
94
+ # 128,
95
+ # 512,
96
+ # 1024,
97
+ # 2048,
98
+ # 4096,
99
+ 8192,
100
+ 16384,
101
+ 32768,
102
+ 65536,
103
+ 131072,
104
+ 262144,
105
+ 524288, # 2^19
106
+ 1048576, # 2^20
107
+ 2097152,
108
+ ]:
109
+ # for size in [128, 512]:
110
+ full_benchmark(size, size)
@@ -0,0 +1,71 @@
1
+ """
2
+ Implementations for the functions used in fhog.jax.attention.
3
+ """
4
+
5
+ from functools import partial
6
+
7
+ import jax
8
+ from jax._src.cudnn.fused_attention_stablehlo import MaskType
9
+ from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention as cuda_dot_product_attention
10
+ from jax.tree_util import Partial
11
+
12
+ from fhog.pallas_bwdbwd import TuningConfig, flash_bwdbwd0
13
+
14
+
15
+ def dot_product_attention_fwd(query, key, value, mask_type: MaskType, scale: float):
16
+ """
17
+ Forward pass, no saving.
18
+ Only needs to return the output.
19
+ """
20
+ return cuda_dot_product_attention(query, key, value, mask_type=mask_type, scale=scale)
21
+
22
+
23
+ def dot_product_attention_fwd_rule(query, key, value, mask_type: MaskType, scale: float):
24
+ """
25
+ Forward pass, saving stats, Q, K, V and O.
26
+ """
27
+ out, vjp_fun = jax.vjp(partial(cuda_dot_product_attention, mask_type=mask_type, scale=scale), query, key, value)
28
+ residual = vjp_fun
29
+ return out, residual
30
+
31
+
32
+ def dot_product_attention_bwd_rule(mask_type: MaskType, scale: float, res, g):
33
+ """
34
+ Backward pass, no saving
35
+ """
36
+ vjp_fun = res
37
+ # breakpoint()
38
+ dQ, dK, dV = vjp_fun(g)
39
+ return dQ, dK, dV
40
+
41
+
42
+ def dot_product_attention_bwd_rule_fwd_rule(mask_type: MaskType, scale: float, res, g):
43
+ """
44
+ Backward pass, saving for higher order backward
45
+ """
46
+ vjp_fun = res
47
+ dO = g
48
+
49
+ # query, key, value = vjp_fun.args_res
50
+ # *_, stats, out = vjp_fun.opaque_residuals
51
+ dQ, dK, dV = vjp_fun(dO)
52
+ residual = (vjp_fun, dO)
53
+
54
+ return (dQ, dK, dV), residual
55
+
56
+
57
+ def dot_product_attention_bwd_rule_bwd_rule(mask_type: MaskType, scale: float, res, g):
58
+ """
59
+ Backward pass through the backward pass
60
+ """
61
+ vjp_fun, dO = res
62
+ query, key, value = vjp_fun.args_res
63
+ *_, stats, out = vjp_fun.opaque_residuals
64
+ vjp_fun_structure = jax.tree.structure(vjp_fun)
65
+ # breakpoint()
66
+
67
+ ddQ, ddK, ddV = g
68
+
69
+ dQ2, dK2, dV2, ddO = flash_bwdbwd0(query, key, value, out, dO, ddQ, ddK, ddV, stats, scale, config=TuningConfig(tile_q=128, tile_k=32, max_concurrent_steps=4))
70
+ vjp_fun_grad = jax.tree.unflatten(vjp_fun_structure, [dQ2, dK2, dV2, None, None, None]) # TODO: Don't I need new dO in the last argument here?
71
+ return vjp_fun_grad, ddO
@@ -0,0 +1,107 @@
1
+ """
2
+ Main interface for Jax attention with support for higher order memory-efficient backward on GPU.
3
+ """
4
+
5
+ from functools import partial
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from jax import Array
10
+ from jax._src.cudnn.fused_attention_stablehlo import MaskType
11
+ from jaxtyping import Bool, Float, Int
12
+
13
+ import fhog.jax._attention_impl as attn_impl
14
+
15
+ # Reimplementation of much of jax._src.cudnn.fused_attention_stablehlo as used in jax.nn.dot_product_attention
16
+
17
+
18
+ @partial(jax.custom_vjp, nondiff_argnames=["mask_type", "scale"])
19
+ def dot_product_attention(
20
+ query: Float[Array, "T N H"],
21
+ key: Float[Array, "S N H"],
22
+ value: Float[Array, "S N H"],
23
+ mask_type: MaskType = MaskType.NO_MASK,
24
+ scale: float | None = None,
25
+ ):
26
+ """
27
+ Dimensions:
28
+ T: Query length
29
+ S: Key/value length
30
+ N: Number of attention heads
31
+ H: Head dimension
32
+
33
+ """
34
+ if scale is None:
35
+ scale = query.shape[-1] ** -0.5
36
+ dtype = query.dtype
37
+ assert dtype == key.dtype == value.dtype
38
+ return attn_impl.dot_product_attention_fwd(query, key, value, mask_type=mask_type, scale=scale)
39
+
40
+
41
+ @partial(jax.custom_vjp, nondiff_argnames=["mask_type", "scale"])
42
+ def _dot_product_attention_fwd(query, key, value, mask_type: bool, scale: float):
43
+ """
44
+ Forward pass, saving for regular backward.
45
+ """
46
+ print("Running _dot_product_attention_fwd")
47
+ out, res = attn_impl.dot_product_attention_fwd_rule(query, key, value, mask_type=mask_type, scale=scale)
48
+ return out, res
49
+
50
+
51
+ def _dot_product_attention_fwd_fwd(query, key, value, mask_type: bool, scale: float):
52
+ """
53
+ Run the forward pass, saving in expectation of a regular backward pass and a higher order backward pass.
54
+ """
55
+ print("Running _dot_product_attention_fwd_fwd")
56
+ out, res = attn_impl.dot_product_attention_fwd_rule(query, key, value, mask_type=mask_type, scale=scale)
57
+ return (out, res), res
58
+
59
+
60
+ def _dot_product_attention_fwd_bwd(mask_type: bool, scale: float, res, g):
61
+ """
62
+ Backward through the saving forward pass.
63
+ """
64
+ print("Running _dot_product_attention_fwd_bwd")
65
+ # breakpoint()
66
+ dO, dvjp_fun = g
67
+ dQ2, dK2, dV2 = dvjp_fun.args_res
68
+ # *_, stats, out = dvjp_fun.opaque_residuals # TODO: Do I need dO from here?
69
+
70
+ dQ, dK, dV = attn_impl.dot_product_attention_bwd_rule(mask_type=mask_type, scale=scale, res=res, g=dO)
71
+ return dQ + dQ2, dK + dK2, dV + dV2
72
+
73
+
74
+ _dot_product_attention_fwd.defvjp(_dot_product_attention_fwd_fwd, _dot_product_attention_fwd_bwd)
75
+
76
+
77
+ @partial(jax.custom_vjp, nondiff_argnames=["mask_type", "scale"])
78
+ def _dot_product_attention_bwd(mask_type: MaskType, scale: float, res, g):
79
+ """
80
+ Regular backward pass.
81
+ """
82
+ print("Running _dot_product_attention_bwd")
83
+ grads = attn_impl.dot_product_attention_bwd_rule(mask_type=mask_type, scale=scale, res=res, g=g)
84
+ return grads
85
+
86
+
87
+ def _dot_product_attention_bwd_fwd(mask_type: MaskType, scale: float, res, g):
88
+ """
89
+ Backward pass, saving for higher order backward.
90
+ """
91
+ print("Running _dot_product_attention_bwd_fwd")
92
+ out, res = attn_impl.dot_product_attention_bwd_rule_fwd_rule(mask_type=mask_type, scale=scale, res=res, g=g)
93
+ return out, res
94
+
95
+
96
+ def _dot_product_attention_bwd_bwd(mask_type: MaskType, scale: float, res, g):
97
+ """
98
+ Backward pass through the backward pass.
99
+ """
100
+ print("Running _dot_product_attention_bwd_bwd")
101
+ grads = attn_impl.dot_product_attention_bwd_rule_bwd_rule(mask_type=mask_type, scale=scale, res=res, g=g)
102
+ return grads
103
+
104
+
105
+ _dot_product_attention_bwd.defvjp(_dot_product_attention_bwd_fwd, _dot_product_attention_bwd_bwd)
106
+
107
+ dot_product_attention.defvjp(_dot_product_attention_fwd, _dot_product_attention_bwd)
@@ -0,0 +1,89 @@
1
+ import numpy as np
2
+
3
+ import fhog.jax_refs.jax_impl as jax_impl
4
+ import fhog.jax_refs.torch_impl as torch_impl
5
+
6
+ print("Loading libraries...")
7
+
8
+ import jax.numpy as jnp
9
+ import jax.random as jrandom
10
+ import torch
11
+
12
+ print("Imported libraries.")
13
+
14
+
15
+ def compare_jax_and_torch(
16
+ *, stats, q, k, v, o, do, ddq, ddk, ddv, scale=1.0, is_causal=False
17
+ ):
18
+ jax_out = jax_impl.attn_bwd_bwd_stats(
19
+ stats=stats,
20
+ q=q,
21
+ k=k,
22
+ v=v,
23
+ o=o,
24
+ do=do,
25
+ ddq=ddq,
26
+ ddk=ddk,
27
+ ddv=ddv,
28
+ scale=scale,
29
+ is_causal=is_causal,
30
+ )
31
+ torch_out = torch_impl.attn_bwd_bwd(
32
+ stats=torch.tensor(np.asarray(stats), dtype=torch.float32),
33
+ q=torch.tensor(np.asarray(q), dtype=torch.float32),
34
+ k=torch.tensor(np.asarray(k), dtype=torch.float32),
35
+ v=torch.tensor(np.asarray(v), dtype=torch.float32),
36
+ o=torch.tensor(np.asarray(o), dtype=torch.float32),
37
+ do=torch.tensor(np.asarray(do), dtype=torch.float32),
38
+ ddq=torch.tensor(np.asarray(ddq), dtype=torch.float32),
39
+ ddk=torch.tensor(np.asarray(ddk), dtype=torch.float32),
40
+ ddv=torch.tensor(np.asarray(ddv), dtype=torch.float32),
41
+ scale=torch.tensor(np.asarray(scale), dtype=torch.float32),
42
+ is_causal=is_causal,
43
+ )
44
+ return jax_out, torch_out
45
+
46
+
47
+ if __name__ == "__main__":
48
+ key1, key2, key3, key4, key5, key6, key7, key8, key9 = jrandom.split(
49
+ jrandom.PRNGKey(42), 9
50
+ )
51
+ nq = 128
52
+ nkv = 256
53
+ d_in = 64
54
+ d_out = 64
55
+ dtype = jnp.float32
56
+ scale = jnp.array(1.0 / jnp.sqrt(d_in), dtype=jnp.float32)
57
+ q = jrandom.normal(key1, (nq, d_in), dtype=dtype)
58
+ k = jrandom.normal(key2, (nkv, d_in), dtype=dtype)
59
+ v = jrandom.normal(key3, (nkv, d_out), dtype=dtype)
60
+ o = jrandom.normal(key8, (nq, d_out), dtype=dtype)
61
+ do = jrandom.normal(key4, (nq, d_out), dtype=dtype)
62
+ ddq = jrandom.normal(key5, (nq, d_in), dtype=dtype)
63
+ ddk = jrandom.normal(key6, (nkv, d_in), dtype=dtype)
64
+ ddv = jrandom.normal(key7, (nkv, d_out), dtype=dtype)
65
+ stats = jrandom.normal(key9, (nq,), dtype=dtype)
66
+ is_causal = False
67
+
68
+ jax_out, torch_out = compare_jax_and_torch(
69
+ stats=stats,
70
+ q=q,
71
+ k=k,
72
+ v=v,
73
+ o=o,
74
+ do=do,
75
+ ddq=ddq,
76
+ ddk=ddk,
77
+ ddv=ddv,
78
+ scale=scale,
79
+ is_causal=is_causal,
80
+ )
81
+ print(jax_out)
82
+ print(torch_out)
83
+
84
+ for tensor1, tensor2 in zip(jax_out, torch_out):
85
+ if jnp.allclose(tensor1, jnp.asarray(tensor2.numpy())):
86
+ print("success")
87
+ else:
88
+ print(torch.std(torch.tensor(np.asarray(tensor1)) - tensor2))
89
+ print("failure")