cuequivariance-ops-cu12 0.4.0__py3-none-manylinux_2_39_aarch64.whl → 0.5.1__py3-none-manylinux_2_39_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuequivariance-ops-cu12 might be problematic. Click here for more details.

Files changed (28) hide show
  1. cuequivariance_ops/VERSION +1 -1
  2. cuequivariance_ops/__init__.py +3 -2
  3. cuequivariance_ops/equivariance/dtypes.hh +21 -0
  4. cuequivariance_ops/equivariance/indexed_linear.hh +36 -0
  5. cuequivariance_ops/equivariance/run_fmha.h +192 -0
  6. cuequivariance_ops/equivariance/run_fmha_cudafree.h +77 -0
  7. cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +17 -35
  8. cuequivariance_ops/lib/libcue_ops.so +0 -0
  9. cuequivariance_ops/triton/__init__.py +29 -0
  10. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37192 -0
  11. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
  12. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
  13. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
  14. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
  15. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
  16. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
  17. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
  18. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
  19. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
  20. cuequivariance_ops/triton/cache_manager.py +244 -0
  21. cuequivariance_ops/triton/fused_layer_norm_triton.py +324 -0
  22. cuequivariance_ops/triton/gated_gemm_triton.py +340 -0
  23. cuequivariance_ops/triton/tuning_decorator.py +272 -0
  24. {cuequivariance_ops_cu12-0.4.0.dist-info → cuequivariance_ops_cu12-0.5.1.dist-info}/METADATA +5 -1
  25. cuequivariance_ops_cu12-0.5.1.dist-info/RECORD +32 -0
  26. {cuequivariance_ops_cu12-0.4.0.dist-info → cuequivariance_ops_cu12-0.5.1.dist-info}/WHEEL +1 -1
  27. cuequivariance_ops_cu12-0.4.0.dist-info/RECORD +0 -13
  28. {cuequivariance_ops_cu12-0.4.0.dist-info → cuequivariance_ops_cu12-0.5.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,244 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import hashlib
12
+ import json
13
+ import logging
14
+ import math
15
+ import os
16
+ from multiprocessing import Lock
17
+ from pathlib import Path
18
+ from typing import Any
19
+
20
+ import pynvml
21
+ from platformdirs import user_cache_dir
22
+
23
+ # Configure logging
24
+ logger = logging.getLogger(__name__)
25
+
26
+ FILE_LOCK = Lock()
27
+
28
+
29
+ def get_triton_tuning_mode():
30
+ cueq_at = os.getenv("CUEQ_TRITON_TUNING")
31
+ if cueq_at is not None and cueq_at not in ["AOT", "ONDEMAND"]:
32
+ logger.error(f"CUEQ_TRITON_TUNING setting not recognized: {cueq_at}.\n")
33
+ return cueq_at
34
+
35
+
36
+ def is_docker():
37
+ cgroup = Path("/proc/self/cgroup")
38
+ return Path("/.dockerenv").is_file() or (
39
+ cgroup.is_file() and "docker" in cgroup.read_text()
40
+ )
41
+
42
+
43
+ def overridden_cache_dir():
44
+ return os.getenv("CUEQ_TRITON_CACHE_DIR")
45
+
46
+
47
+ def get_triton_cache_dir() -> Path:
48
+ cache_dir = overridden_cache_dir()
49
+ if cache_dir is None:
50
+ cache_dir = user_cache_dir(appname="cuequivariance-triton", ensure_exists=False)
51
+ cache_dir = Path(cache_dir)
52
+ if cache_dir.exists():
53
+ return cache_dir
54
+ cache_dir.mkdir(parents=True, exist_ok=True)
55
+ return cache_dir
56
+
57
+
58
+ def get_gpu_information():
59
+ pynvml.nvmlInit()
60
+ # Note: non-uniform multi-GPU setups are not supported
61
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
62
+ name = pynvml.nvmlDeviceGetName(handle)
63
+ # pci_info = pynvml.nvmlDeviceGetPciInfo(handle)
64
+ # device_id = pci_info.pciDeviceId
65
+ # sub_device_id = pci_info.pciSubSystemId
66
+ power_limit = pynvml.nvmlDeviceGetPowerManagementLimit(handle)
67
+ max_clock_rate = pynvml.nvmlDeviceGetMaxClockInfo(
68
+ handle, pynvml.NVML_CLOCK_GRAPHICS
69
+ )
70
+ mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
71
+ gpu_core_count = pynvml.nvmlDeviceGetNumGpuCores(handle)
72
+ major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
73
+
74
+ pynvml.nvmlShutdown()
75
+ return {
76
+ "name": name,
77
+ # "device_id": device_id,
78
+ # "sub_device_id": sub_device_id,
79
+ "total_memory": math.ceil(mem_info.total / (1024**3)),
80
+ "multi_processor_count": gpu_core_count // 128,
81
+ "power_limit": power_limit // 1000,
82
+ "clock_rate": max_clock_rate,
83
+ "major": major,
84
+ "minor": minor,
85
+ }
86
+
87
+
88
+ def gpu_information_to_key(information: dict) -> str:
89
+ information.pop("name", None)
90
+ key_string = "_".join(f"{value}" for value in information.values()).replace(
91
+ " ", "_"
92
+ )
93
+ hash_object = hashlib.sha256(key_string.encode())
94
+ hash_str = hash_object.hexdigest()
95
+ return hash_str
96
+
97
+
98
+ class CacheManager:
99
+ """Singleton managing the cache"""
100
+
101
+ def __init__(self):
102
+ self.gpu_cache = {}
103
+ self.gpu_information = get_gpu_information()
104
+ self.gpu_key = gpu_information_to_key(self.gpu_information)
105
+ self.site_json_path = os.path.join(os.path.dirname(__file__), "cache")
106
+ self.json_path = get_triton_cache_dir()
107
+ self.aot_mode = get_triton_tuning_mode()
108
+ self.dirty = {}
109
+
110
+ if os.getenv("CUEQ_TRITON_IGNORE_EXISTING_CACHE") == "1":
111
+ logger.warning(
112
+ f"\n!!!!!! CUEQ_TRITON_IGNORE_EXISTING_CACHE is ON - previously saved setting will be ignored !!!!!!\n"
113
+ f"CUEQ_TRITON_TUNING is set to {self.aot_mode}\n"
114
+ f"The tuning changes will be written to {self.json_path}"
115
+ )
116
+
117
+ if (
118
+ self.aot_mode is not None
119
+ and is_docker()
120
+ and os.getenv("HOME") == "/root"
121
+ and not overridden_cache_dir()
122
+ ):
123
+ logger.warning(
124
+ f"\n!!!!!! CUEQ_TRITON_TUNING is set to {self.aot_mode} and you are running as root in a Docker container. !!!!!!\n"
125
+ f"The tuning changes will be written to {self.json_path}"
126
+ "Please remember to commit the container - otherwise any tuning changes will be lost on container restart."
127
+ )
128
+
129
+ def load_cache(self, fn_key: str) -> dict:
130
+ # load the json file and store it in the cache-dict
131
+ # if the file does not exist, create an empty dict for the specified function
132
+ fn_cache = {}
133
+ gpu_cache = {}
134
+ best_key = None
135
+ major, minor = self.gpu_information["major"], self.gpu_information["minor"]
136
+ basename = f"{fn_key}.{major}.{minor}.json"
137
+ json_file = os.path.join(self.json_path, basename)
138
+
139
+ def result(self, gpu_cache):
140
+ # empty cache or fuzzy match, update for possible save
141
+ if best_key or not gpu_cache:
142
+ gpu_cache["gpu_information"] = self.gpu_information
143
+ self.gpu_cache[fn_key] = gpu_cache
144
+ return gpu_cache
145
+
146
+ if os.getenv("CUEQ_TRITON_IGNORE_EXISTING_CACHE"):
147
+ return result(self, gpu_cache)
148
+ try:
149
+ with FILE_LOCK, open(json_file, "rb") as f:
150
+ fn_cache = json.load(f)
151
+ except Exception as e0:
152
+ site_json_file = os.path.join(self.site_json_path, basename)
153
+ try:
154
+ with FILE_LOCK, open(site_json_file, "rb") as f:
155
+ fn_cache = json.load(f)
156
+ except Exception as e:
157
+ logger.warning(
158
+ f"Error reading system-wide triton tuning cache file: {site_json_file}\n{e}\n"
159
+ f"Error reading users triton tuning cache file {json_file}:\n{e0}"
160
+ )
161
+ pass
162
+ if fn_cache:
163
+ gpu_cache = fn_cache.get(self.gpu_key)
164
+ if gpu_cache is None:
165
+ # do a fuzzy match of config:
166
+ def within_10_percent(a, b, key):
167
+ a = int(a[key])
168
+ b = int(b[key])
169
+ return abs(a - b) / (a + b) < 0.2
170
+
171
+ def full_match(a, b):
172
+ # matching clock & memory
173
+ return (
174
+ a["total_memory"] == b["total_memory"]
175
+ and a["clock_rate"] == b["clock_rate"]
176
+ )
177
+
178
+ def partial_match(a, b):
179
+ # matching clk or memory whichever matches
180
+ return within_10_percent(a, b, "total_memory") or within_10_percent(
181
+ a, b, "clock_rate"
182
+ )
183
+
184
+ for key in fn_cache:
185
+ conf = fn_cache[key].get("gpu_information")
186
+ if conf:
187
+ if full_match(conf, self.gpu_information):
188
+ best_key = key
189
+ break
190
+ elif partial_match(conf, self.gpu_information):
191
+ best_key = key
192
+ if best_key is None:
193
+ # just pick the first entry there
194
+ best_key = next(iter(fn_cache))
195
+ gpu_cache = fn_cache[best_key]
196
+
197
+ return result(self, gpu_cache)
198
+
199
+ def save_cache(self, fn_key: str) -> None:
200
+ # save cache-dict to json file
201
+ major, minor = self.gpu_information["major"], self.gpu_information["minor"]
202
+ basename = f"{fn_key}.{major}.{minor}.json"
203
+ json_file = os.path.join(self.json_path, basename)
204
+ # Load existing data from the file if it exists
205
+ if os.path.exists(json_file):
206
+ with FILE_LOCK, open(json_file, "rb") as f:
207
+ existing_data = json.load(f)
208
+ else:
209
+ existing_data = {}
210
+ # Update the entry for our GPU key with our data
211
+ existing_data.setdefault(self.gpu_key, {}).update(self.gpu_cache[fn_key])
212
+ self.gpu_cache[fn_key] = existing_data[self.gpu_key]
213
+ merged_data = existing_data
214
+ temp_file = f"{json_file}.{os.getpid()}.tmp"
215
+ try:
216
+ # Save the merged data back to the file
217
+ with FILE_LOCK:
218
+ with open(temp_file, "w") as f:
219
+ json.dump(merged_data, f, indent=4)
220
+ os.replace(temp_file, json_file)
221
+ except Exception as e:
222
+ logger.warning(f"Warning: Failed to write autotune cache: {e}")
223
+
224
+ # Clear the dirty flag
225
+ del self.dirty[fn_key]
226
+
227
+ def get(self, fn_key: str, inp_key: str) -> Any:
228
+ # get value from cache
229
+ # if necessary, load json first
230
+ gpu_cache = self.gpu_cache.get(fn_key) or self.load_cache(fn_key)
231
+ # check if fn_key and inp_key exist in cache
232
+ return gpu_cache.get(inp_key)
233
+
234
+ def set(self, fn_key: str, inp_key: str, value: Any) -> None:
235
+ # write value to cache-dict
236
+ self.gpu_cache[fn_key][inp_key] = value
237
+ self.dirty[fn_key] = 1
238
+
239
+
240
+ cache_manager = CacheManager()
241
+
242
+
243
+ def get_cache_manager():
244
+ return cache_manager
@@ -0,0 +1,324 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import enum
12
+
13
+ import triton
14
+ import triton.language as tl
15
+
16
+
17
+ class Layout(enum.IntEnum):
18
+ BND_BND = 0
19
+ BDN_BND = 1
20
+ BND_BDN = 2
21
+ DBN_BND = 3
22
+ BND_DBN = 4
23
+
24
+
25
+ @triton.jit
26
+ def layer_norm_transpose_forward_kernel(
27
+ x_ptr,
28
+ out_ptr,
29
+ w_ptr,
30
+ b_ptr,
31
+ mean_ptr,
32
+ rstd_ptr,
33
+ B,
34
+ N,
35
+ D: tl.constexpr,
36
+ EPS: tl.constexpr,
37
+ TILE_N: tl.constexpr,
38
+ TILE_D: tl.constexpr,
39
+ ELEMENTWISE_AFFINE: tl.constexpr,
40
+ LAYOUT: tl.constexpr,
41
+ ):
42
+ pid_n = tl.program_id(0)
43
+ pid_b = tl.program_id(1)
44
+
45
+ offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
46
+ offs_d = tl.arange(0, TILE_D)
47
+
48
+ if LAYOUT == 0: # bnd->bnd
49
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
50
+ elif LAYOUT == 1: # bdn->bnd
51
+ x_ptrs = x_ptr + pid_b * D * N + offs_d[None, :] * N + offs_n[:, None]
52
+ elif LAYOUT == 2: # bnd->bdn
53
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
54
+ elif LAYOUT == 3: # dbn->bnd
55
+ x_ptrs = x_ptr + offs_d[None, :] * B * N + pid_b * N + offs_n[:, None]
56
+ elif LAYOUT == 4: # bnd->dbn
57
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
58
+
59
+ mean_ptrs = mean_ptr + pid_b * N + offs_n
60
+ rstd_ptrs = rstd_ptr + pid_b * N + offs_n
61
+ mask_n = offs_n < N
62
+
63
+ num_tiles = D // TILE_D
64
+
65
+ _mean = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
66
+ for _ in range(0, num_tiles):
67
+ x = tl.load(x_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
68
+ _mean += x
69
+
70
+ if LAYOUT == 0: # bnd->bnd
71
+ x_ptrs += TILE_D
72
+ elif LAYOUT == 1: # bdn->bnd
73
+ x_ptrs += TILE_D * N
74
+ elif LAYOUT == 2: # bnd->bdn
75
+ x_ptrs += TILE_D
76
+ elif LAYOUT == 3: # dbn->bnd
77
+ x_ptrs += TILE_D * B * N
78
+ elif LAYOUT == 4: # bnd->dbn
79
+ x_ptrs += TILE_D
80
+
81
+ mean = tl.sum(_mean, axis=1) / D
82
+ tl.store(mean_ptrs, mean, mask=mask_n)
83
+
84
+ if LAYOUT == 0: # bnd->bnd
85
+ x_ptrs -= D
86
+ elif LAYOUT == 1: # bdn->bnd
87
+ x_ptrs -= D * N
88
+ elif LAYOUT == 2: # bnd->bdn
89
+ x_ptrs -= D
90
+ elif LAYOUT == 3: # dbn->bnd
91
+ x_ptrs -= D * B * N
92
+ elif LAYOUT == 4: # bnd->dbn
93
+ x_ptrs -= D
94
+
95
+ _var = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
96
+ for d in range(0, num_tiles):
97
+ x = tl.load(x_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
98
+ x = x - mean[:, None]
99
+ _var += x * x
100
+
101
+ if LAYOUT == 0: # bnd->bnd
102
+ x_ptrs += TILE_D
103
+ elif LAYOUT == 1: # bdn->bnd
104
+ x_ptrs += TILE_D * N
105
+ elif LAYOUT == 2: # bnd->bdn
106
+ x_ptrs += TILE_D
107
+ elif LAYOUT == 3: # dbn->bnd
108
+ x_ptrs += TILE_D * B * N
109
+ elif LAYOUT == 4: # bnd->dbn
110
+ x_ptrs += TILE_D
111
+
112
+ var = tl.sum(_var, axis=1) / D
113
+ rstd = 1.0 / tl.sqrt(var + EPS)
114
+ tl.store(rstd_ptrs, rstd, mask=mask_n)
115
+
116
+ if LAYOUT == 0: # bnd->bnd
117
+ x_ptrs -= D
118
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
119
+ elif LAYOUT == 1: # bdn->bnd
120
+ x_ptrs -= D * N
121
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
122
+ elif LAYOUT == 2: # bnd->bdn
123
+ x_ptrs -= D
124
+ out_ptrs = out_ptr + pid_b * N * D + offs_d[None, :] * N + offs_n[:, None]
125
+ elif LAYOUT == 3: # dbn->bnd
126
+ x_ptrs -= D * B * N
127
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
128
+ elif LAYOUT == 4: # bnd->dbn
129
+ x_ptrs -= D
130
+ out_ptrs = out_ptr + offs_d[None, :] * B * N + pid_b * N + offs_n[:, None]
131
+
132
+ if ELEMENTWISE_AFFINE:
133
+ w_ptrs = w_ptr + offs_d
134
+ b_ptrs = b_ptr + offs_d
135
+
136
+ for _ in range(0, num_tiles):
137
+ if ELEMENTWISE_AFFINE:
138
+ w = tl.load(w_ptrs)
139
+ b = tl.load(b_ptrs)
140
+ else:
141
+ w = 1.0
142
+ b = 0.0
143
+
144
+ x = tl.load(x_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
145
+ x_hat = (x - mean[:, None]) * rstd[:, None]
146
+ y = x_hat * w[None, :] + b[None, :]
147
+ tl.store(out_ptrs, y, mask=mask_n[:, None])
148
+
149
+ if LAYOUT == 0: # bnd->bnd
150
+ x_ptrs += TILE_D
151
+ out_ptrs += TILE_D
152
+ elif LAYOUT == 1: # bdn->bnd
153
+ x_ptrs += TILE_D * N
154
+ out_ptrs += TILE_D
155
+ elif LAYOUT == 2: # bnd->bdn
156
+ x_ptrs += TILE_D
157
+ out_ptrs += TILE_D * N
158
+ elif LAYOUT == 3: # dbn->bnd
159
+ x_ptrs += TILE_D * B * N
160
+ out_ptrs += TILE_D
161
+ elif LAYOUT == 4: # bnd->dbn
162
+ x_ptrs += TILE_D
163
+ out_ptrs += TILE_D * B * N
164
+
165
+ if ELEMENTWISE_AFFINE:
166
+ w_ptrs += TILE_D
167
+ b_ptrs += TILE_D
168
+
169
+
170
+ @triton.jit
171
+ def layer_norm_transpose_backward_kernel(
172
+ grad_out_ptr,
173
+ grad_x_ptr,
174
+ grad_w_ptr,
175
+ grad_b_ptr,
176
+ x_ptr,
177
+ w_ptr,
178
+ mean_ptr,
179
+ rstd_ptr,
180
+ B,
181
+ N,
182
+ D: tl.constexpr,
183
+ TILE_N: tl.constexpr,
184
+ TILE_D: tl.constexpr,
185
+ ELEMENTWISE_AFFINE: tl.constexpr,
186
+ LAYOUT: tl.constexpr,
187
+ ):
188
+ pid_n = tl.program_id(0)
189
+ pid_b = tl.program_id(1)
190
+
191
+ num_tiles = D // TILE_D
192
+ num_tiles_n = tl.cdiv(N, TILE_N)
193
+
194
+ offs_d = tl.arange(0, TILE_D)
195
+ offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
196
+ mask_n = offs_n < N
197
+
198
+ mean_ptrs = mean_ptr + pid_b * N + offs_n
199
+ rstd_ptrs = rstd_ptr + pid_b * N + offs_n
200
+ mean = tl.load(mean_ptrs, mask=mask_n, other=0.0).to(tl.float32)
201
+ rstd = tl.load(rstd_ptrs, mask=mask_n, other=0.0).to(tl.float32)
202
+
203
+ if LAYOUT == 0: # bnd->bnd
204
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
205
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
206
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
207
+ elif LAYOUT == 1: # bdn->bnd
208
+ x_base_ptrs = x_ptr + pid_b * D * N + offs_n[:, None]
209
+ grad_x_base_ptrs = grad_x_ptr + pid_b * D * N + offs_n[:, None]
210
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
211
+ elif LAYOUT == 2: # bnd->bdn
212
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
213
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
214
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None]
215
+ elif LAYOUT == 3: # dbn->bnd
216
+ x_base_ptrs = x_ptr + pid_b * N + offs_n[:, None]
217
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N + offs_n[:, None]
218
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
219
+ elif LAYOUT == 4: # bnd->dbn
220
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
221
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
222
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N + offs_n[:, None]
223
+
224
+ grad_w_base_ptrs = grad_w_ptr + pid_b * num_tiles_n * D + pid_n * D
225
+ grad_b_base_ptrs = grad_b_ptr + pid_b * num_tiles_n * D + pid_n * D
226
+
227
+ c1 = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
228
+ c2 = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
229
+
230
+ for _ in range(num_tiles):
231
+ if ELEMENTWISE_AFFINE:
232
+ w_ptrs = w_ptr + offs_d
233
+ w = tl.load(w_ptrs).to(tl.float32)
234
+ else:
235
+ w = 1.0
236
+
237
+ if LAYOUT == 0: # bnd->bnd
238
+ x_ptrs = x_base_ptrs + offs_d[None, :]
239
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
240
+ elif LAYOUT == 1: # bdn->bnd
241
+ x_ptrs = x_base_ptrs + offs_d[None, :] * N
242
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
243
+ elif LAYOUT == 2: # bnd->bdn
244
+ x_ptrs = x_base_ptrs + offs_d[None, :]
245
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * N
246
+ elif LAYOUT == 3: # dbn->bnd
247
+ x_ptrs = x_base_ptrs + offs_d[None, :] * B * N
248
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
249
+ elif LAYOUT == 4: # bnd->dbn
250
+ x_ptrs = x_base_ptrs + offs_d[None, :]
251
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * B * N
252
+
253
+ x = tl.load(x_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
254
+ grad_out = tl.load(grad_out_ptrs, mask=mask_n[:, None], other=0.0).to(
255
+ tl.float32
256
+ )
257
+
258
+ xhat = (x - mean[:, None]) * rstd[:, None]
259
+
260
+ if ELEMENTWISE_AFFINE:
261
+ grad_b = grad_out
262
+ grad_w = grad_out * xhat
263
+
264
+ grad_b = tl.sum(grad_b, axis=0)
265
+ grad_w = tl.sum(grad_w, axis=0)
266
+
267
+ grad_w_ptrs = grad_w_base_ptrs + offs_d
268
+ grad_b_ptrs = grad_b_base_ptrs + offs_d
269
+
270
+ tl.store(grad_w_ptrs, grad_w)
271
+ tl.store(grad_b_ptrs, grad_b)
272
+
273
+ wdo = w * grad_out
274
+
275
+ c1 += xhat * wdo
276
+ c2 += wdo
277
+
278
+ offs_d += TILE_D
279
+
280
+ c1_dot = tl.sum(c1, axis=1) / D
281
+ c2_dot = tl.sum(c2, axis=1) / D
282
+
283
+ offs_d -= TILE_D * num_tiles
284
+
285
+ for _ in range(num_tiles):
286
+ if ELEMENTWISE_AFFINE:
287
+ w_ptrs = w_ptr + offs_d
288
+ w = tl.load(w_ptrs).to(tl.float32)
289
+ else:
290
+ w = 1.0
291
+
292
+ if LAYOUT == 0: # bnd->bnd
293
+ x_ptrs = x_base_ptrs + offs_d[None, :]
294
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
295
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
296
+ elif LAYOUT == 1: # bdn->bnd
297
+ x_ptrs = x_base_ptrs + offs_d[None, :] * N
298
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :] * N
299
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
300
+ elif LAYOUT == 2: # bnd->bdn
301
+ x_ptrs = x_base_ptrs + offs_d[None, :]
302
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
303
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * N
304
+ elif LAYOUT == 3: # dbn->bnd
305
+ x_ptrs = x_base_ptrs + offs_d[None, :] * B * N
306
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :] * B * N
307
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
308
+ elif LAYOUT == 4: # bnd->dbn
309
+ x_ptrs = x_base_ptrs + offs_d[None, :]
310
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
311
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * B * N
312
+
313
+ x = tl.load(x_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
314
+ grad_out = tl.load(grad_out_ptrs, mask=mask_n[:, None], other=0.0).to(
315
+ tl.float32
316
+ )
317
+
318
+ xhat = (x - mean[:, None]) * rstd[:, None]
319
+ wdo = w * grad_out
320
+
321
+ dx = (wdo - (xhat * c1_dot[:, None] + c2_dot[:, None])) * rstd[:, None]
322
+ tl.store(grad_x_ptrs, dx, mask=mask_n[:, None])
323
+
324
+ offs_d += TILE_D