cuequivariance-ops-cu12 0.8.1__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. cuequivariance_ops/VERSION +1 -0
  2. cuequivariance_ops/__init__.py +42 -0
  3. cuequivariance_ops/_version.py +20 -0
  4. cuequivariance_ops/common/common.hpp +98 -0
  5. cuequivariance_ops/common/cudart.hpp +286 -0
  6. cuequivariance_ops/common/error.hpp +66 -0
  7. cuequivariance_ops/common/error_raft.hpp +323 -0
  8. cuequivariance_ops/common/nvtx.hpp +29 -0
  9. cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
  10. cuequivariance_ops/equivariance/dtypes.hh +65 -0
  11. cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
  12. cuequivariance_ops/equivariance/indexed_linear.hh +41 -0
  13. cuequivariance_ops/equivariance/run_fmha.h +192 -0
  14. cuequivariance_ops/equivariance/run_fmha_cudafree.h +176 -0
  15. cuequivariance_ops/equivariance/run_fmha_sm100.h +135 -0
  16. cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
  17. cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
  18. cuequivariance_ops/gpu_timing_kernels.hh +42 -0
  19. cuequivariance_ops/lib/libcue_ops.so +0 -0
  20. cuequivariance_ops/sleep.hh +40 -0
  21. cuequivariance_ops/triton/__init__.py +66 -0
  22. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37142 -0
  23. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.12.0.json +37132 -0
  24. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
  25. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
  26. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
  27. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
  28. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
  29. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.12.0.json +55692 -0
  30. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
  31. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
  32. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
  33. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
  34. cuequivariance_ops/triton/cache_manager.py +336 -0
  35. cuequivariance_ops/triton/fused_layer_norm_triton.py +546 -0
  36. cuequivariance_ops/triton/gated_gemm_triton.py +394 -0
  37. cuequivariance_ops/triton/pair_bias.py +365 -0
  38. cuequivariance_ops/triton/tuning_decorator.py +188 -0
  39. cuequivariance_ops/triton/utils.py +29 -0
  40. cuequivariance_ops_cu12-0.8.1.dist-info/METADATA +182 -0
  41. cuequivariance_ops_cu12-0.8.1.dist-info/RECORD +46 -0
  42. cuequivariance_ops_cu12-0.8.1.dist-info/WHEEL +6 -0
  43. cuequivariance_ops_cu12-0.8.1.dist-info/licenses/LICENSE +142 -0
  44. cuequivariance_ops_cu12-0.8.1.dist-info/licenses/Third_party_attr.txt +24 -0
  45. cuequivariance_ops_cu12-0.8.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  46. cuequivariance_ops_cu12.libs/libnvfatbin-b51d3b3f.so.12.8.90 +0 -0
@@ -0,0 +1,336 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import hashlib
12
+ import json
13
+ import logging
14
+ import math
15
+ import os
16
+ import subprocess
17
+ from multiprocessing import Lock
18
+ from pathlib import Path
19
+ from typing import Any
20
+
21
+ import pynvml
22
+ from platformdirs import user_cache_dir
23
+
24
+ # Configure logging
25
+ logger = logging.getLogger(__name__)
26
+
27
+ FILE_LOCK = Lock()
28
+
29
+
30
+ def get_triton_tuning_mode():
31
+ cueq_at = os.getenv("CUEQ_TRITON_TUNING")
32
+ if cueq_at is not None and cueq_at not in ["AOT", "ONDEMAND"]:
33
+ logger.error(f"CUEQ_TRITON_TUNING setting not recognized: {cueq_at}.\n")
34
+ return cueq_at
35
+
36
+
37
+ def is_docker():
38
+ cgroup = Path("/proc/self/cgroup")
39
+ return Path("/.dockerenv").is_file() or (
40
+ cgroup.is_file() and "docker" in cgroup.read_text()
41
+ )
42
+
43
+
44
+ def overridden_cache_dir():
45
+ return os.getenv("CUEQ_TRITON_CACHE_DIR")
46
+
47
+
48
+ def get_triton_cache_dir() -> Path:
49
+ cache_dir = overridden_cache_dir()
50
+ if cache_dir is None:
51
+ cache_dir = user_cache_dir(appname="cuequivariance-triton", ensure_exists=False)
52
+ cache_dir = Path(cache_dir)
53
+ if cache_dir.exists():
54
+ return cache_dir
55
+ cache_dir.mkdir(parents=True, exist_ok=True)
56
+ return cache_dir
57
+
58
+
59
+ def get_gpu_name():
60
+ """Get GPU name from nvidia-smi."""
61
+ try:
62
+ result = subprocess.run(
63
+ ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"],
64
+ capture_output=True,
65
+ text=True,
66
+ )
67
+ return result.stdout.strip()
68
+ except Exception as e:
69
+ print(f"Error getting GPU memory: {e}")
70
+ return "NVIDIA RTX A6000" # default
71
+
72
+
73
+ def get_gpu_information():
74
+ # Default values (NVIDIA RTX A6000 specs)
75
+ default_map = {
76
+ "NVIDIA RTX A6000": {
77
+ "name": "NVIDIA RTX A6000",
78
+ "major": 8,
79
+ "minor": 6,
80
+ "total_memory": 45,
81
+ "multi_processor_count": 84,
82
+ "power_limit": 300,
83
+ "clock_rate": 2100,
84
+ },
85
+ "NVIDIA GB10": {
86
+ "name": "NVIDIA GB10",
87
+ "major": 12,
88
+ "minor": 0, # 1 actually, changed here to point to 12.0 spec
89
+ "total_memory": 96, # unified 128G, using part of it to match 12.0 spec
90
+ "multi_processor_count": 48,
91
+ "power_limit": 140,
92
+ "clock_rate": 1700,
93
+ },
94
+ }
95
+ try:
96
+ pynvml.nvmlInit()
97
+ # Note: non-uniform multi-GPU setups are not supported
98
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
99
+ name = pynvml.nvmlDeviceGetName(handle)
100
+ # pci_info = pynvml.nvmlDeviceGetPciInfo(handle)
101
+ # device_id = pci_info.pciDeviceId
102
+ # sub_device_id = pci_info.pciSubSystemId
103
+ power_limit = pynvml.nvmlDeviceGetPowerManagementLimit(handle)
104
+ max_clock_rate = pynvml.nvmlDeviceGetMaxClockInfo(
105
+ handle, pynvml.NVML_CLOCK_GRAPHICS
106
+ )
107
+ mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
108
+ gpu_core_count = pynvml.nvmlDeviceGetNumGpuCores(handle)
109
+ major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
110
+
111
+ total_memory = math.ceil(mem_info.total / (1024**3))
112
+ multi_processor_count = gpu_core_count // 128
113
+ power_limit_kw = power_limit // 1000
114
+
115
+ # Validate all values - if any are invalid, raise exception to use defaults
116
+ if not name or not isinstance(name, str):
117
+ raise ValueError("Invalid GPU name")
118
+ if not isinstance(major, int) or major <= 0:
119
+ raise ValueError("Invalid compute capability major version")
120
+ if not isinstance(minor, int) or minor < 0:
121
+ raise ValueError("Invalid compute capability minor version")
122
+ if not isinstance(total_memory, int) or total_memory <= 0:
123
+ raise ValueError("Invalid total memory")
124
+ if not isinstance(multi_processor_count, int) or multi_processor_count <= 0:
125
+ raise ValueError("Invalid multi-processor count")
126
+ if not isinstance(power_limit_kw, int) or power_limit_kw <= 0:
127
+ raise ValueError("Invalid power limit")
128
+ if not isinstance(max_clock_rate, int) or max_clock_rate <= 0:
129
+ raise ValueError("Invalid clock rate")
130
+
131
+ logger.debug(
132
+ f"GPU information: {name}, {major}, {minor}, {total_memory}, {multi_processor_count}, {power_limit_kw}, {max_clock_rate}"
133
+ )
134
+ return {
135
+ "name": name,
136
+ # "device_id": device_id,
137
+ # "sub_device_id": sub_device_id,
138
+ "total_memory": total_memory,
139
+ "multi_processor_count": multi_processor_count,
140
+ "power_limit": power_limit_kw,
141
+ "clock_rate": max_clock_rate,
142
+ "major": major,
143
+ "minor": minor,
144
+ }
145
+ except Exception as e:
146
+ gpu_name = get_gpu_name()
147
+ logger.warning(
148
+ f"Failed to get GPU information from pynvml: {e}. Using default values for {gpu_name}."
149
+ )
150
+ if gpu_name in default_map:
151
+ defaults = default_map[gpu_name]
152
+ else:
153
+ defaults = default_map["NVIDIA RTX A6000"]
154
+ print(
155
+ f"GPU information: {defaults['name']}, {defaults['major']}, {defaults['minor']}, {defaults['total_memory']}, {defaults['multi_processor_count']}, {defaults['power_limit']}, {defaults['clock_rate']}"
156
+ )
157
+ return defaults
158
+ finally:
159
+ try:
160
+ pynvml.nvmlShutdown()
161
+ except Exception:
162
+ pass
163
+
164
+
165
+ def gpu_information_to_key(information: dict) -> str:
166
+ information.pop("name", None)
167
+ key_string = "_".join(f"{value}" for value in information.values()).replace(
168
+ " ", "_"
169
+ )
170
+ hash_object = hashlib.sha256(key_string.encode())
171
+ hash_str = hash_object.hexdigest()
172
+ return hash_str
173
+
174
+
175
+ def load_json(json_file):
176
+ with FILE_LOCK:
177
+ with open(json_file, "rb") as f:
178
+ fn_cache = json.load(f)
179
+ return fn_cache
180
+
181
+
182
+ class CacheManager:
183
+ """Singleton managing the cache"""
184
+
185
+ def __init__(self):
186
+ self.gpu_cache = {}
187
+ self.gpu_information = get_gpu_information()
188
+ self.gpu_key = gpu_information_to_key(self.gpu_information)
189
+ self.site_json_path = str(os.path.join(os.path.dirname(__file__), "cache"))
190
+ self.json_path = str(get_triton_cache_dir())
191
+ self.dirty = {}
192
+
193
+ if os.getenv("CUEQ_TRITON_IGNORE_EXISTING_CACHE") == "1":
194
+ logger.warning(
195
+ f"\n!!!!!! CUEQ_TRITON_IGNORE_EXISTING_CACHE is ON - previously saved setting will be ignored !!!!!!\n"
196
+ f"CUEQ_TRITON_TUNING is set to {self.aot_mode}\n"
197
+ f"The tuning changes will be written to {self.json_path}"
198
+ )
199
+
200
+ if (
201
+ self.aot_mode is not None
202
+ and is_docker()
203
+ and os.getenv("HOME") == "/root"
204
+ and not overridden_cache_dir()
205
+ ):
206
+ logger.warning(
207
+ f"\n!!!!!! CUEQ_TRITON_TUNING is set to {self.aot_mode} and you are running as root in a Docker container. !!!!!!\n"
208
+ f"The tuning changes will be written to {self.json_path}"
209
+ "\nPlease remember to commit the container - otherwise any tuning changes will be lost on container restart."
210
+ )
211
+
212
+ # define aot_mode as a property to allow the environment variable to change during runtime
213
+ @property
214
+ def aot_mode(self):
215
+ return get_triton_tuning_mode()
216
+
217
+ def load_cache(self, fn_key: str) -> dict:
218
+ # load the json file and store it in the cache-dict
219
+ # if the file does not exist, create an empty dict for the specified function
220
+ fn_cache = {}
221
+ gpu_cache = {}
222
+ best_key = None
223
+
224
+ major, minor = self.gpu_information["major"], self.gpu_information["minor"]
225
+ basename = f"{fn_key}.{major}.{minor}.json"
226
+ json_file = f"{self.json_path}/{basename}"
227
+
228
+ def result(self, gpu_cache):
229
+ # empty cache or fuzzy match, update for possible save
230
+ if best_key or not gpu_cache:
231
+ gpu_cache["gpu_information"] = self.gpu_information
232
+ self.gpu_cache[fn_key] = gpu_cache
233
+ return gpu_cache
234
+
235
+ if os.getenv("CUEQ_TRITON_IGNORE_EXISTING_CACHE"):
236
+ return result(self, gpu_cache)
237
+
238
+ try:
239
+ fn_cache = load_json(json_file)
240
+
241
+ except Exception as e0:
242
+ site_json_file = f"{self.site_json_path}/{basename}"
243
+ try:
244
+ fn_cache = load_json(site_json_file)
245
+ except Exception as e:
246
+ logger.warning(
247
+ f"Error reading system-wide triton tuning cache file: {site_json_file}\n{e}\n"
248
+ f"Error reading users triton tuning cache file {json_file}:\n{e0}"
249
+ )
250
+ pass
251
+ if fn_cache:
252
+ gpu_cache = fn_cache.get(self.gpu_key)
253
+ if gpu_cache is None:
254
+ # do a fuzzy match of config:
255
+ def within_10_percent(a, b, key):
256
+ a = int(a[key])
257
+ b = int(b[key])
258
+ return abs(a - b) / (a + b) < 0.2
259
+
260
+ def full_match(a, b):
261
+ # matching clock & memory
262
+ return (
263
+ a["total_memory"] == b["total_memory"]
264
+ and a["clock_rate"] == b["clock_rate"]
265
+ )
266
+
267
+ def partial_match(a, b):
268
+ # matching clk or memory whichever matches
269
+ return within_10_percent(a, b, "total_memory") or within_10_percent(
270
+ a, b, "clock_rate"
271
+ )
272
+
273
+ for key in fn_cache:
274
+ conf = fn_cache[key].get("gpu_information")
275
+ if conf:
276
+ if full_match(conf, self.gpu_information):
277
+ best_key = key
278
+ break
279
+ elif partial_match(conf, self.gpu_information):
280
+ best_key = key
281
+ if best_key is None:
282
+ # just pick the first entry there
283
+ best_key = next(iter(fn_cache))
284
+ gpu_cache = fn_cache[best_key]
285
+
286
+ return result(self, gpu_cache)
287
+
288
+ def save_cache(self, fn_key: str) -> None:
289
+ # save cache-dict to json file
290
+ major, minor = self.gpu_information["major"], self.gpu_information["minor"]
291
+ basename = f"{fn_key}.{major}.{minor}.json"
292
+ json_file = os.path.join(self.json_path, basename)
293
+
294
+ # Load existing data from the file if it exists
295
+ if os.path.exists(json_file):
296
+ with FILE_LOCK, open(json_file, "rb") as f:
297
+ existing_data = json.load(f)
298
+ else:
299
+ existing_data = {}
300
+ # Update the entry for our GPU key with our data
301
+ existing_data.setdefault(self.gpu_key, {}).update(self.gpu_cache[fn_key])
302
+ self.gpu_cache[fn_key] = existing_data[self.gpu_key]
303
+ merged_data = existing_data
304
+ temp_file = f"{json_file}.{os.getpid()}.tmp"
305
+ try:
306
+ # Save the merged data back to the file
307
+ with FILE_LOCK:
308
+ with open(temp_file, "w") as f:
309
+ json.dump(merged_data, f, indent=4)
310
+ os.replace(temp_file, json_file)
311
+ except Exception as e:
312
+ logger.warning(f"Warning: Failed to write autotune cache: {e}")
313
+
314
+ # Clear the dirty flag
315
+ del self.dirty[fn_key]
316
+
317
+ def get(self, fn_key: str, inp_key: str) -> Any:
318
+ # get value from cache
319
+ # if necessary, load json first
320
+ gpu_cache = self.gpu_cache.get(fn_key)
321
+ if gpu_cache is None:
322
+ gpu_cache = self.load_cache(fn_key)
323
+ # check if fn_key and inp_key exist in cache
324
+ return gpu_cache.get(inp_key)
325
+
326
+ def set(self, fn_key: str, inp_key: str, value: Any) -> None:
327
+ # write value to cache-dict
328
+ self.gpu_cache[fn_key][inp_key] = value
329
+ self.dirty[fn_key] = 1
330
+
331
+
332
+ cache_manager = CacheManager()
333
+
334
+
335
+ def get_cache_manager():
336
+ return cache_manager