flash-hog 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash_hog-0.1.0/PKG-INFO +42 -0
- flash_hog-0.1.0/README.md +25 -0
- flash_hog-0.1.0/fhog/__init__.py +3 -0
- flash_hog-0.1.0/fhog/benchhhh.py +110 -0
- flash_hog-0.1.0/fhog/jax/_attention_impl.py +71 -0
- flash_hog-0.1.0/fhog/jax/attention.py +107 -0
- flash_hog-0.1.0/fhog/jax_refs/compare.py +89 -0
- flash_hog-0.1.0/fhog/jax_refs/flash_impl.py +442 -0
- flash_hog-0.1.0/fhog/jax_refs/flash_impl_herman.py +591 -0
- flash_hog-0.1.0/fhog/jax_refs/jax_impl.py +727 -0
- flash_hog-0.1.0/fhog/jax_refs/torch_impl.py +250 -0
- flash_hog-0.1.0/fhog/mgpu_tests.py +1240 -0
- flash_hog-0.1.0/fhog/pallas_attention.py +665 -0
- flash_hog-0.1.0/fhog/pallas_bwdbwd.py +343 -0
- flash_hog-0.1.0/fhog/test_bwdbwd.py +106 -0
- flash_hog-0.1.0/fhog/triton_bwdbwd.py +654 -0
- flash_hog-0.1.0/fhog/triton_flash.py +257 -0
- flash_hog-0.1.0/pyproject.toml +65 -0
flash_hog-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: flash-hog
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Requires-Dist: absl-py>=2.4.0
|
|
6
|
+
Requires-Dist: chex>=0.1.91
|
|
7
|
+
Requires-Dist: einops>=0.8.1
|
|
8
|
+
Requires-Dist: equinox>=0.13.2
|
|
9
|
+
Requires-Dist: jax[cuda13]>=0.8.0
|
|
10
|
+
Requires-Dist: nvidia-cutlass-dsl>=4.2.1
|
|
11
|
+
Requires-Dist: pytest>=8.4.2
|
|
12
|
+
Requires-Dist: ruff>=0.14.2
|
|
13
|
+
Requires-Dist: torch>=2.9.0
|
|
14
|
+
Requires-Dist: ty>=0.0.1a24
|
|
15
|
+
Requires-Python: >=3.12, <3.14
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
|
|
18
|
+
# Flash Hog
|
|
19
|
+
<p align="center">
|
|
20
|
+
<img src="assets/logo.png" alt="Flash Hog Logo" width="256" />
|
|
21
|
+
</p>
|
|
22
|
+
|
|
23
|
+
This repo contains the code for Flash Higher-Order-Gradients, aka. Flash Hog.
|
|
24
|
+
This kernel achieves around a 3.7x speedup over an XLA optimized kernel, with linear memory scaling instead of quadratic scaling.
|
|
25
|
+
|
|
26
|
+
<p align="center">
|
|
27
|
+
<img src="assets/speedup.png" alt="Hog Speedup" width="512"/>
|
|
28
|
+
</p>
|
|
29
|
+
|
|
30
|
+
## Installation
|
|
31
|
+
TODO
|
|
32
|
+
|
|
33
|
+
## Method
|
|
34
|
+
Flash Hog does 4 recomputation passes to avoid any atomics or saving any intermediary tensors of shape `(N_Q, N_K)`.
|
|
35
|
+
This shakes out to be thread-wise tiling across Q in 3 passes first, once to compute `dd`, then once for `b`, then once for both `dQ'` and `ddO`.
|
|
36
|
+
Finally we do another pass tiled over K, producing `dK'` and `dV'`.
|
|
37
|
+
The equations we implement are the following:
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
<p align="center">
|
|
41
|
+
<img src="assets/handwritten_equations.png" alt="Equations" width="512"/>
|
|
42
|
+
</p>
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Flash Hog
|
|
2
|
+
<p align="center">
|
|
3
|
+
<img src="assets/logo.png" alt="Flash Hog Logo" width="256" />
|
|
4
|
+
</p>
|
|
5
|
+
|
|
6
|
+
This repo contains the code for Flash Higher-Order-Gradients, aka. Flash Hog.
|
|
7
|
+
This kernel achieves around a 3.7x speedup over an XLA optimized kernel, with linear memory scaling instead of quadratic scaling.
|
|
8
|
+
|
|
9
|
+
<p align="center">
|
|
10
|
+
<img src="assets/speedup.png" alt="Hog Speedup" width="512"/>
|
|
11
|
+
</p>
|
|
12
|
+
|
|
13
|
+
## Installation
|
|
14
|
+
TODO
|
|
15
|
+
|
|
16
|
+
## Method
|
|
17
|
+
Flash Hog does 4 recomputation passes to avoid any atomics or saving any intermediary tensors of shape `(N_Q, N_K)`.
|
|
18
|
+
This shakes out to be thread-wise tiling across Q in 3 passes first, once to compute `dd`, then once for `b`, then once for both `dQ'` and `ddO`.
|
|
19
|
+
Finally we do another pass tiled over K, producing `dK'` and `dV'`.
|
|
20
|
+
The equations we implement are the following:
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
<p align="center">
|
|
24
|
+
<img src="assets/handwritten_equations.png" alt="Equations" width="512"/>
|
|
25
|
+
</p>
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import jax.random as jrandom
|
|
5
|
+
import torch
|
|
6
|
+
import triton
|
|
7
|
+
|
|
8
|
+
from fhog.jax_refs.jax_impl import attn_bwd_bwd
|
|
9
|
+
from fhog.triton_bwdbwd import flash_bwdbwd
|
|
10
|
+
|
|
11
|
+
N_REPS = 5000
|
|
12
|
+
N_WARMUP = 300
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def full_benchmark(nq, nkv):
|
|
16
|
+
d_in = 64
|
|
17
|
+
d_out = 64
|
|
18
|
+
|
|
19
|
+
q = torch.randn((nq, d_in), device="cuda", dtype=torch.float16)
|
|
20
|
+
k = torch.randn((nkv, d_in), device="cuda", dtype=torch.float16)
|
|
21
|
+
v = torch.randn((nkv, d_out), device="cuda", dtype=torch.float16)
|
|
22
|
+
o = torch.randn((nq, d_out), device="cuda", dtype=torch.float16)
|
|
23
|
+
l = torch.randn((nq,), device="cuda", dtype=torch.float16)
|
|
24
|
+
do = torch.randn((nq, d_out), device="cuda", dtype=torch.float16)
|
|
25
|
+
ddq = torch.randn((nq, d_in), device="cuda", dtype=torch.float16)
|
|
26
|
+
ddk = torch.randn((nkv, d_in), device="cuda", dtype=torch.float16)
|
|
27
|
+
ddv = torch.randn((nkv, d_out), device="cuda", dtype=torch.float16)
|
|
28
|
+
|
|
29
|
+
def jax_benchmark():
|
|
30
|
+
# nq, d_in = q.shape
|
|
31
|
+
# nkv, _ = k.shape
|
|
32
|
+
# _, d_out = v.shape
|
|
33
|
+
|
|
34
|
+
scale = jnp.array(1.0 / jnp.sqrt(d_in), dtype=jnp.float32)
|
|
35
|
+
|
|
36
|
+
# assert nq == 128
|
|
37
|
+
# assert nkv == 256
|
|
38
|
+
# assert d_in == 64
|
|
39
|
+
# assert d_out == 64
|
|
40
|
+
|
|
41
|
+
def do_run_jax():
|
|
42
|
+
# attn_bwd_bwd_stats(
|
|
43
|
+
attn_bwd_bwd(
|
|
44
|
+
# stats=jnp.asarray(l),
|
|
45
|
+
q=jnp.asarray(q),
|
|
46
|
+
k=jnp.asarray(k),
|
|
47
|
+
v=jnp.asarray(v),
|
|
48
|
+
# o=jnp.asarray(o),
|
|
49
|
+
do=jnp.asarray(do),
|
|
50
|
+
ddq=jnp.asarray(ddq),
|
|
51
|
+
ddk=jnp.asarray(ddk),
|
|
52
|
+
ddv=jnp.asarray(ddv),
|
|
53
|
+
scale=jnp.asarray(scale),
|
|
54
|
+
is_causal=False,
|
|
55
|
+
sliding_window_length=jnp.inf,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
results = triton.testing.do_bench(do_run_jax, rep=N_REPS, warmup=N_WARMUP)
|
|
59
|
+
return results
|
|
60
|
+
|
|
61
|
+
def trition_benchmark():
|
|
62
|
+
def do_run_triton():
|
|
63
|
+
triton_dq2, triton_dk2, triton_dv2, triton_ddo = flash_bwdbwd(
|
|
64
|
+
q.unsqueeze(0),
|
|
65
|
+
k.unsqueeze(0),
|
|
66
|
+
v.unsqueeze(0),
|
|
67
|
+
o.unsqueeze(0),
|
|
68
|
+
do.unsqueeze(0),
|
|
69
|
+
ddq.unsqueeze(0),
|
|
70
|
+
ddk.unsqueeze(0),
|
|
71
|
+
ddv.unsqueeze(0),
|
|
72
|
+
l.unsqueeze(0),
|
|
73
|
+
1 / d_in**0.5,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
results = triton.testing.do_bench(do_run_triton, rep=N_REPS, warmup=N_WARMUP)
|
|
77
|
+
return results
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
jax_time = jax_benchmark()
|
|
81
|
+
except Exception as e:
|
|
82
|
+
print(f"JAX benchmark failed for nq={nq}, nkv={nkv} with error: {e}")
|
|
83
|
+
jax_time = float("inf")
|
|
84
|
+
triton_time = trition_benchmark()
|
|
85
|
+
print(f"{nq}: {jax_time=} {triton_time=} => speedup: {jax_time / triton_time:.2f}x")
|
|
86
|
+
return jax_time, triton_time
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
if __name__ == "__main__":
|
|
90
|
+
# nq = 512
|
|
91
|
+
# nkv = 1024
|
|
92
|
+
# for siz in
|
|
93
|
+
for size in [
|
|
94
|
+
# 128,
|
|
95
|
+
# 512,
|
|
96
|
+
# 1024,
|
|
97
|
+
# 2048,
|
|
98
|
+
# 4096,
|
|
99
|
+
8192,
|
|
100
|
+
16384,
|
|
101
|
+
32768,
|
|
102
|
+
65536,
|
|
103
|
+
131072,
|
|
104
|
+
262144,
|
|
105
|
+
524288, # 2^19
|
|
106
|
+
1048576, # 2^20
|
|
107
|
+
2097152,
|
|
108
|
+
]:
|
|
109
|
+
# for size in [128, 512]:
|
|
110
|
+
full_benchmark(size, size)
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implementations for the functions used in fhog.jax.attention.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from functools import partial
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
from jax._src.cudnn.fused_attention_stablehlo import MaskType
|
|
9
|
+
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention as cuda_dot_product_attention
|
|
10
|
+
from jax.tree_util import Partial
|
|
11
|
+
|
|
12
|
+
from fhog.pallas_bwdbwd import TuningConfig, flash_bwdbwd0
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def dot_product_attention_fwd(query, key, value, mask_type: MaskType, scale: float):
|
|
16
|
+
"""
|
|
17
|
+
Forward pass, no saving.
|
|
18
|
+
Only needs to return the output.
|
|
19
|
+
"""
|
|
20
|
+
return cuda_dot_product_attention(query, key, value, mask_type=mask_type, scale=scale)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def dot_product_attention_fwd_rule(query, key, value, mask_type: MaskType, scale: float):
|
|
24
|
+
"""
|
|
25
|
+
Forward pass, saving stats, Q, K, V and O.
|
|
26
|
+
"""
|
|
27
|
+
out, vjp_fun = jax.vjp(partial(cuda_dot_product_attention, mask_type=mask_type, scale=scale), query, key, value)
|
|
28
|
+
residual = vjp_fun
|
|
29
|
+
return out, residual
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def dot_product_attention_bwd_rule(mask_type: MaskType, scale: float, res, g):
|
|
33
|
+
"""
|
|
34
|
+
Backward pass, no saving
|
|
35
|
+
"""
|
|
36
|
+
vjp_fun = res
|
|
37
|
+
# breakpoint()
|
|
38
|
+
dQ, dK, dV = vjp_fun(g)
|
|
39
|
+
return dQ, dK, dV
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def dot_product_attention_bwd_rule_fwd_rule(mask_type: MaskType, scale: float, res, g):
|
|
43
|
+
"""
|
|
44
|
+
Backward pass, saving for higher order backward
|
|
45
|
+
"""
|
|
46
|
+
vjp_fun = res
|
|
47
|
+
dO = g
|
|
48
|
+
|
|
49
|
+
# query, key, value = vjp_fun.args_res
|
|
50
|
+
# *_, stats, out = vjp_fun.opaque_residuals
|
|
51
|
+
dQ, dK, dV = vjp_fun(dO)
|
|
52
|
+
residual = (vjp_fun, dO)
|
|
53
|
+
|
|
54
|
+
return (dQ, dK, dV), residual
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def dot_product_attention_bwd_rule_bwd_rule(mask_type: MaskType, scale: float, res, g):
|
|
58
|
+
"""
|
|
59
|
+
Backward pass through the backward pass
|
|
60
|
+
"""
|
|
61
|
+
vjp_fun, dO = res
|
|
62
|
+
query, key, value = vjp_fun.args_res
|
|
63
|
+
*_, stats, out = vjp_fun.opaque_residuals
|
|
64
|
+
vjp_fun_structure = jax.tree.structure(vjp_fun)
|
|
65
|
+
# breakpoint()
|
|
66
|
+
|
|
67
|
+
ddQ, ddK, ddV = g
|
|
68
|
+
|
|
69
|
+
dQ2, dK2, dV2, ddO = flash_bwdbwd0(query, key, value, out, dO, ddQ, ddK, ddV, stats, scale, config=TuningConfig(tile_q=128, tile_k=32, max_concurrent_steps=4))
|
|
70
|
+
vjp_fun_grad = jax.tree.unflatten(vjp_fun_structure, [dQ2, dK2, dV2, None, None, None]) # TODO: Don't I need new dO in the last argument here?
|
|
71
|
+
return vjp_fun_grad, ddO
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Main interface for Jax attention with support for higher order memory-efficient backward on GPU.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from functools import partial
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
from jax import Array
|
|
10
|
+
from jax._src.cudnn.fused_attention_stablehlo import MaskType
|
|
11
|
+
from jaxtyping import Bool, Float, Int
|
|
12
|
+
|
|
13
|
+
import fhog.jax._attention_impl as attn_impl
|
|
14
|
+
|
|
15
|
+
# Reimplementation of much of jax._src.cudnn.fused_attention_stablehlo as used in jax.nn.dot_product_attention
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@partial(jax.custom_vjp, nondiff_argnames=["mask_type", "scale"])
|
|
19
|
+
def dot_product_attention(
|
|
20
|
+
query: Float[Array, "T N H"],
|
|
21
|
+
key: Float[Array, "S N H"],
|
|
22
|
+
value: Float[Array, "S N H"],
|
|
23
|
+
mask_type: MaskType = MaskType.NO_MASK,
|
|
24
|
+
scale: float | None = None,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Dimensions:
|
|
28
|
+
T: Query length
|
|
29
|
+
S: Key/value length
|
|
30
|
+
N: Number of attention heads
|
|
31
|
+
H: Head dimension
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
if scale is None:
|
|
35
|
+
scale = query.shape[-1] ** -0.5
|
|
36
|
+
dtype = query.dtype
|
|
37
|
+
assert dtype == key.dtype == value.dtype
|
|
38
|
+
return attn_impl.dot_product_attention_fwd(query, key, value, mask_type=mask_type, scale=scale)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@partial(jax.custom_vjp, nondiff_argnames=["mask_type", "scale"])
|
|
42
|
+
def _dot_product_attention_fwd(query, key, value, mask_type: bool, scale: float):
|
|
43
|
+
"""
|
|
44
|
+
Forward pass, saving for regular backward.
|
|
45
|
+
"""
|
|
46
|
+
print("Running _dot_product_attention_fwd")
|
|
47
|
+
out, res = attn_impl.dot_product_attention_fwd_rule(query, key, value, mask_type=mask_type, scale=scale)
|
|
48
|
+
return out, res
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _dot_product_attention_fwd_fwd(query, key, value, mask_type: bool, scale: float):
|
|
52
|
+
"""
|
|
53
|
+
Run the forward pass, saving in expectation of a regular backward pass and a higher order backward pass.
|
|
54
|
+
"""
|
|
55
|
+
print("Running _dot_product_attention_fwd_fwd")
|
|
56
|
+
out, res = attn_impl.dot_product_attention_fwd_rule(query, key, value, mask_type=mask_type, scale=scale)
|
|
57
|
+
return (out, res), res
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _dot_product_attention_fwd_bwd(mask_type: bool, scale: float, res, g):
|
|
61
|
+
"""
|
|
62
|
+
Backward through the saving forward pass.
|
|
63
|
+
"""
|
|
64
|
+
print("Running _dot_product_attention_fwd_bwd")
|
|
65
|
+
# breakpoint()
|
|
66
|
+
dO, dvjp_fun = g
|
|
67
|
+
dQ2, dK2, dV2 = dvjp_fun.args_res
|
|
68
|
+
# *_, stats, out = dvjp_fun.opaque_residuals # TODO: Do I need dO from here?
|
|
69
|
+
|
|
70
|
+
dQ, dK, dV = attn_impl.dot_product_attention_bwd_rule(mask_type=mask_type, scale=scale, res=res, g=dO)
|
|
71
|
+
return dQ + dQ2, dK + dK2, dV + dV2
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
_dot_product_attention_fwd.defvjp(_dot_product_attention_fwd_fwd, _dot_product_attention_fwd_bwd)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@partial(jax.custom_vjp, nondiff_argnames=["mask_type", "scale"])
|
|
78
|
+
def _dot_product_attention_bwd(mask_type: MaskType, scale: float, res, g):
|
|
79
|
+
"""
|
|
80
|
+
Regular backward pass.
|
|
81
|
+
"""
|
|
82
|
+
print("Running _dot_product_attention_bwd")
|
|
83
|
+
grads = attn_impl.dot_product_attention_bwd_rule(mask_type=mask_type, scale=scale, res=res, g=g)
|
|
84
|
+
return grads
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _dot_product_attention_bwd_fwd(mask_type: MaskType, scale: float, res, g):
|
|
88
|
+
"""
|
|
89
|
+
Backward pass, saving for higher order backward.
|
|
90
|
+
"""
|
|
91
|
+
print("Running _dot_product_attention_bwd_fwd")
|
|
92
|
+
out, res = attn_impl.dot_product_attention_bwd_rule_fwd_rule(mask_type=mask_type, scale=scale, res=res, g=g)
|
|
93
|
+
return out, res
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _dot_product_attention_bwd_bwd(mask_type: MaskType, scale: float, res, g):
|
|
97
|
+
"""
|
|
98
|
+
Backward pass through the backward pass.
|
|
99
|
+
"""
|
|
100
|
+
print("Running _dot_product_attention_bwd_bwd")
|
|
101
|
+
grads = attn_impl.dot_product_attention_bwd_rule_bwd_rule(mask_type=mask_type, scale=scale, res=res, g=g)
|
|
102
|
+
return grads
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
_dot_product_attention_bwd.defvjp(_dot_product_attention_bwd_fwd, _dot_product_attention_bwd_bwd)
|
|
106
|
+
|
|
107
|
+
dot_product_attention.defvjp(_dot_product_attention_fwd, _dot_product_attention_bwd)
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
import fhog.jax_refs.jax_impl as jax_impl
|
|
4
|
+
import fhog.jax_refs.torch_impl as torch_impl
|
|
5
|
+
|
|
6
|
+
print("Loading libraries...")
|
|
7
|
+
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import jax.random as jrandom
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
print("Imported libraries.")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def compare_jax_and_torch(
|
|
16
|
+
*, stats, q, k, v, o, do, ddq, ddk, ddv, scale=1.0, is_causal=False
|
|
17
|
+
):
|
|
18
|
+
jax_out = jax_impl.attn_bwd_bwd_stats(
|
|
19
|
+
stats=stats,
|
|
20
|
+
q=q,
|
|
21
|
+
k=k,
|
|
22
|
+
v=v,
|
|
23
|
+
o=o,
|
|
24
|
+
do=do,
|
|
25
|
+
ddq=ddq,
|
|
26
|
+
ddk=ddk,
|
|
27
|
+
ddv=ddv,
|
|
28
|
+
scale=scale,
|
|
29
|
+
is_causal=is_causal,
|
|
30
|
+
)
|
|
31
|
+
torch_out = torch_impl.attn_bwd_bwd(
|
|
32
|
+
stats=torch.tensor(np.asarray(stats), dtype=torch.float32),
|
|
33
|
+
q=torch.tensor(np.asarray(q), dtype=torch.float32),
|
|
34
|
+
k=torch.tensor(np.asarray(k), dtype=torch.float32),
|
|
35
|
+
v=torch.tensor(np.asarray(v), dtype=torch.float32),
|
|
36
|
+
o=torch.tensor(np.asarray(o), dtype=torch.float32),
|
|
37
|
+
do=torch.tensor(np.asarray(do), dtype=torch.float32),
|
|
38
|
+
ddq=torch.tensor(np.asarray(ddq), dtype=torch.float32),
|
|
39
|
+
ddk=torch.tensor(np.asarray(ddk), dtype=torch.float32),
|
|
40
|
+
ddv=torch.tensor(np.asarray(ddv), dtype=torch.float32),
|
|
41
|
+
scale=torch.tensor(np.asarray(scale), dtype=torch.float32),
|
|
42
|
+
is_causal=is_causal,
|
|
43
|
+
)
|
|
44
|
+
return jax_out, torch_out
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
if __name__ == "__main__":
|
|
48
|
+
key1, key2, key3, key4, key5, key6, key7, key8, key9 = jrandom.split(
|
|
49
|
+
jrandom.PRNGKey(42), 9
|
|
50
|
+
)
|
|
51
|
+
nq = 128
|
|
52
|
+
nkv = 256
|
|
53
|
+
d_in = 64
|
|
54
|
+
d_out = 64
|
|
55
|
+
dtype = jnp.float32
|
|
56
|
+
scale = jnp.array(1.0 / jnp.sqrt(d_in), dtype=jnp.float32)
|
|
57
|
+
q = jrandom.normal(key1, (nq, d_in), dtype=dtype)
|
|
58
|
+
k = jrandom.normal(key2, (nkv, d_in), dtype=dtype)
|
|
59
|
+
v = jrandom.normal(key3, (nkv, d_out), dtype=dtype)
|
|
60
|
+
o = jrandom.normal(key8, (nq, d_out), dtype=dtype)
|
|
61
|
+
do = jrandom.normal(key4, (nq, d_out), dtype=dtype)
|
|
62
|
+
ddq = jrandom.normal(key5, (nq, d_in), dtype=dtype)
|
|
63
|
+
ddk = jrandom.normal(key6, (nkv, d_in), dtype=dtype)
|
|
64
|
+
ddv = jrandom.normal(key7, (nkv, d_out), dtype=dtype)
|
|
65
|
+
stats = jrandom.normal(key9, (nq,), dtype=dtype)
|
|
66
|
+
is_causal = False
|
|
67
|
+
|
|
68
|
+
jax_out, torch_out = compare_jax_and_torch(
|
|
69
|
+
stats=stats,
|
|
70
|
+
q=q,
|
|
71
|
+
k=k,
|
|
72
|
+
v=v,
|
|
73
|
+
o=o,
|
|
74
|
+
do=do,
|
|
75
|
+
ddq=ddq,
|
|
76
|
+
ddk=ddk,
|
|
77
|
+
ddv=ddv,
|
|
78
|
+
scale=scale,
|
|
79
|
+
is_causal=is_causal,
|
|
80
|
+
)
|
|
81
|
+
print(jax_out)
|
|
82
|
+
print(torch_out)
|
|
83
|
+
|
|
84
|
+
for tensor1, tensor2 in zip(jax_out, torch_out):
|
|
85
|
+
if jnp.allclose(tensor1, jnp.asarray(tensor2.numpy())):
|
|
86
|
+
print("success")
|
|
87
|
+
else:
|
|
88
|
+
print(torch.std(torch.tensor(np.asarray(tensor1)) - tensor2))
|
|
89
|
+
print("failure")
|