triton-windows 3.3.1.post21__cp39-cp39-win_amd64.whl → 3.4.0.post21__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +4 -1
- triton/_filecheck.py +87 -0
- triton/_internal_testing.py +26 -15
- triton/_utils.py +110 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +112 -78
- triton/backends/amd/driver.c +5 -2
- triton/backends/amd/driver.py +143 -46
- triton/backends/compiler.py +7 -21
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +94 -94
- triton/backends/nvidia/driver.c +90 -98
- triton/backends/nvidia/driver.py +296 -125
- triton/compiler/code_generator.py +212 -111
- triton/compiler/compiler.py +110 -25
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +4 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +99 -0
- triton/experimental/gluon/language/__init__.py +18 -0
- triton/experimental/gluon/language/_core.py +312 -0
- triton/experimental/gluon/language/_layouts.py +230 -0
- triton/experimental/gluon/language/_math.py +12 -0
- triton/experimental/gluon/language/_semantic.py +287 -0
- triton/experimental/gluon/language/_standard.py +47 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +40 -0
- triton/knobs.py +481 -0
- triton/language/__init__.py +39 -14
- triton/language/core.py +794 -537
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/libdevice.py +113 -104
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1706 -1770
- triton/language/standard.py +116 -51
- triton/runtime/autotuner.py +117 -59
- triton/runtime/build.py +73 -9
- triton/runtime/cache.py +18 -47
- triton/runtime/driver.py +32 -29
- triton/runtime/interpreter.py +72 -35
- triton/runtime/jit.py +146 -110
- triton/testing.py +16 -12
- triton/tools/disasm.py +3 -4
- triton/tools/tensor_descriptor.py +36 -0
- triton/windows_utils.py +47 -83
- {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/METADATA +7 -2
- {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/RECORD +64 -41
- triton_windows-3.4.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.4.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.4.0.post21.dist-info/top_level.txt +1 -0
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post21.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/WHEEL +0 -0
triton/windows_utils.py
CHANGED
|
@@ -54,14 +54,11 @@ def max_version(
|
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
def check_msvc(msvc_base_path: Path, version: str) -> bool:
|
|
57
|
-
return all(
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
msvc_base_path / version / "lib" / "x64" / "vcruntime.lib",
|
|
63
|
-
]
|
|
64
|
-
)
|
|
57
|
+
return all(x.exists() for x in [
|
|
58
|
+
msvc_base_path / version / "bin" / "Hostx64" / "x64" / "cl.exe",
|
|
59
|
+
msvc_base_path / version / "include" / "vcruntime.h",
|
|
60
|
+
msvc_base_path / version / "lib" / "x64" / "vcruntime.lib",
|
|
61
|
+
])
|
|
65
62
|
|
|
66
63
|
|
|
67
64
|
def find_msvc_env() -> tuple[Optional[Path], Optional[str]]:
|
|
@@ -72,20 +69,16 @@ def find_msvc_env() -> tuple[Optional[Path], Optional[str]]:
|
|
|
72
69
|
|
|
73
70
|
version = os.getenv("VCToolsVersion")
|
|
74
71
|
if not check_msvc(msvc_base_path, version):
|
|
75
|
-
warnings.warn(
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
"but this MSVC installation is incomplete."
|
|
79
|
-
)
|
|
72
|
+
warnings.warn(f"Environment variables VCINSTALLDIR = {os.getenv('VCINSTALLDIR')}, "
|
|
73
|
+
f"VCToolsVersion = {os.getenv('VCToolsVersion')} are set, "
|
|
74
|
+
"but this MSVC installation is incomplete.")
|
|
80
75
|
return None, None
|
|
81
76
|
|
|
82
77
|
return msvc_base_path, version
|
|
83
78
|
|
|
84
79
|
|
|
85
80
|
def find_msvc_vswhere() -> tuple[Optional[Path], Optional[str]]:
|
|
86
|
-
vswhere_path = find_in_program_files(
|
|
87
|
-
r"Microsoft Visual Studio\Installer\vswhere.exe"
|
|
88
|
-
)
|
|
81
|
+
vswhere_path = find_in_program_files(r"Microsoft Visual Studio\Installer\vswhere.exe")
|
|
89
82
|
if vswhere_path is None:
|
|
90
83
|
return None, None
|
|
91
84
|
|
|
@@ -111,9 +104,7 @@ def find_msvc_vswhere() -> tuple[Optional[Path], Optional[str]]:
|
|
|
111
104
|
if not msvc_base_path.exists():
|
|
112
105
|
return None, None
|
|
113
106
|
|
|
114
|
-
version = max_version(
|
|
115
|
-
os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path)
|
|
116
|
-
)
|
|
107
|
+
version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
|
|
117
108
|
if version is None:
|
|
118
109
|
return None, None
|
|
119
110
|
|
|
@@ -132,9 +123,7 @@ def find_msvc_envpath() -> tuple[Optional[Path], Optional[str]]:
|
|
|
132
123
|
if not msvc_base_path.exists():
|
|
133
124
|
continue
|
|
134
125
|
|
|
135
|
-
version = max_version(
|
|
136
|
-
os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path)
|
|
137
|
-
)
|
|
126
|
+
version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
|
|
138
127
|
if version is None:
|
|
139
128
|
continue
|
|
140
129
|
|
|
@@ -153,9 +142,7 @@ def find_msvc_hardcoded() -> tuple[Optional[Path], Optional[str]]:
|
|
|
153
142
|
paths = sorted(paths)[::-1]
|
|
154
143
|
for msvc_base_path in paths:
|
|
155
144
|
msvc_base_path = Path(msvc_base_path)
|
|
156
|
-
version = max_version(
|
|
157
|
-
os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path)
|
|
158
|
-
)
|
|
145
|
+
version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
|
|
159
146
|
if version is None:
|
|
160
147
|
continue
|
|
161
148
|
return msvc_base_path, version
|
|
@@ -188,13 +175,10 @@ def find_msvc(env_only: bool) -> tuple[Optional[str], list[str], list[str]]:
|
|
|
188
175
|
|
|
189
176
|
|
|
190
177
|
def check_winsdk(winsdk_base_path: Path, version: str) -> bool:
|
|
191
|
-
return all(
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
winsdk_base_path / "Lib" / version / "ucrt" / "x64" / "ucrt.lib",
|
|
196
|
-
]
|
|
197
|
-
)
|
|
178
|
+
return all(x.exists() for x in [
|
|
179
|
+
winsdk_base_path / "Include" / version / "ucrt" / "stdlib.h",
|
|
180
|
+
winsdk_base_path / "Lib" / version / "ucrt" / "x64" / "ucrt.lib",
|
|
181
|
+
])
|
|
198
182
|
|
|
199
183
|
|
|
200
184
|
def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
|
|
@@ -207,18 +191,14 @@ def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
|
|
|
207
191
|
if version is None:
|
|
208
192
|
version = os.getenv("WindowsSDKVer")
|
|
209
193
|
if version is None:
|
|
210
|
-
warnings.warn(
|
|
211
|
-
|
|
212
|
-
"but WindowsSDKVersion (or WindowsSDKVer) is not set."
|
|
213
|
-
)
|
|
194
|
+
warnings.warn(f"Environment variable WindowsSdkDir = {winsdk_base_path}, "
|
|
195
|
+
"but WindowsSDKVersion (or WindowsSDKVer) is not set.")
|
|
214
196
|
return None, None
|
|
215
197
|
version = version.rstrip("\\")
|
|
216
198
|
if not check_winsdk(winsdk_base_path, version):
|
|
217
|
-
warnings.warn(
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
"but this Windows SDK installation is incomplete."
|
|
221
|
-
)
|
|
199
|
+
warnings.warn(f"Environment variables WindowsSdkDir = {winsdk_base_path}, "
|
|
200
|
+
f"WindowsSDKVersion (or WindowsSDKVer) = {version} are set, "
|
|
201
|
+
"but this Windows SDK installation is incomplete.")
|
|
222
202
|
return None, None
|
|
223
203
|
|
|
224
204
|
return winsdk_base_path, version
|
|
@@ -227,9 +207,7 @@ def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
|
|
|
227
207
|
def find_winsdk_registry() -> tuple[Optional[Path], Optional[str]]:
|
|
228
208
|
try:
|
|
229
209
|
reg = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
|
|
230
|
-
key = winreg.OpenKeyEx(
|
|
231
|
-
reg, r"SOFTWARE\WOW6432Node\Microsoft\Microsoft SDKs\Windows\v10.0"
|
|
232
|
-
)
|
|
210
|
+
key = winreg.OpenKeyEx(reg, r"SOFTWARE\WOW6432Node\Microsoft\Microsoft SDKs\Windows\v10.0")
|
|
233
211
|
folder = winreg.QueryValueEx(key, "InstallationFolder")[0]
|
|
234
212
|
winreg.CloseKey(key)
|
|
235
213
|
except OSError:
|
|
@@ -295,10 +273,8 @@ def find_winsdk(env_only: bool) -> tuple[list[str], list[str]]:
|
|
|
295
273
|
return [], []
|
|
296
274
|
|
|
297
275
|
|
|
298
|
-
@functools.
|
|
299
|
-
def find_msvc_winsdk(
|
|
300
|
-
env_only: bool = False,
|
|
301
|
-
) -> tuple[Optional[str], list[str], list[str]]:
|
|
276
|
+
@functools.lru_cache
|
|
277
|
+
def find_msvc_winsdk(env_only: bool = False, ) -> tuple[Optional[str], list[str], list[str]]:
|
|
302
278
|
msvc_bin_path, msvc_inc_dirs, msvc_lib_dirs = find_msvc(env_only)
|
|
303
279
|
winsdk_inc_dirs, winsdk_lib_dirs = find_winsdk(env_only)
|
|
304
280
|
return (
|
|
@@ -308,15 +284,18 @@ def find_msvc_winsdk(
|
|
|
308
284
|
)
|
|
309
285
|
|
|
310
286
|
|
|
311
|
-
@functools.
|
|
287
|
+
@functools.lru_cache
|
|
312
288
|
def find_python() -> list[str]:
|
|
289
|
+
version = sysconfig.get_python_version().replace(".", "")
|
|
290
|
+
if sysconfig.get_config_var("Py_GIL_DISABLED"):
|
|
291
|
+
version += "t"
|
|
313
292
|
for python_base_path in [
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
293
|
+
sys.exec_prefix,
|
|
294
|
+
sys.base_exec_prefix,
|
|
295
|
+
os.path.dirname(sys.executable),
|
|
317
296
|
]:
|
|
318
297
|
python_lib_dir = Path(python_base_path) / "libs"
|
|
319
|
-
if (python_lib_dir / "
|
|
298
|
+
if (python_lib_dir / f"python{version}.lib").exists():
|
|
320
299
|
return [str(python_lib_dir)]
|
|
321
300
|
|
|
322
301
|
warnings.warn("Failed to find Python libs.")
|
|
@@ -325,14 +304,11 @@ def find_python() -> list[str]:
|
|
|
325
304
|
|
|
326
305
|
def check_and_find_cuda(base_path: Path) -> tuple[Optional[str], list[str], list[str]]:
|
|
327
306
|
# pip
|
|
328
|
-
if all(
|
|
329
|
-
x.exists()
|
|
330
|
-
for x in [
|
|
307
|
+
if all(x.exists() for x in [
|
|
331
308
|
base_path / "cuda_nvcc" / "bin" / "ptxas.exe",
|
|
332
309
|
base_path / "cuda_runtime" / "include" / "cuda.h",
|
|
333
310
|
base_path / "cuda_runtime" / "lib" / "x64" / "cuda.lib",
|
|
334
|
-
|
|
335
|
-
):
|
|
311
|
+
]):
|
|
336
312
|
return (
|
|
337
313
|
str(base_path / "cuda_nvcc" / "bin"),
|
|
338
314
|
[str(base_path / "cuda_runtime" / "include")],
|
|
@@ -340,14 +316,11 @@ def check_and_find_cuda(base_path: Path) -> tuple[Optional[str], list[str], list
|
|
|
340
316
|
)
|
|
341
317
|
|
|
342
318
|
# conda
|
|
343
|
-
if all(
|
|
344
|
-
x.exists()
|
|
345
|
-
for x in [
|
|
319
|
+
if all(x.exists() for x in [
|
|
346
320
|
base_path / "bin" / "ptxas.exe",
|
|
347
321
|
base_path / "include" / "cuda.h",
|
|
348
322
|
base_path / "lib" / "cuda.lib",
|
|
349
|
-
|
|
350
|
-
):
|
|
323
|
+
]):
|
|
351
324
|
return (
|
|
352
325
|
str(base_path / "bin"),
|
|
353
326
|
[str(base_path / "include")],
|
|
@@ -355,14 +328,11 @@ def check_and_find_cuda(base_path: Path) -> tuple[Optional[str], list[str], list
|
|
|
355
328
|
)
|
|
356
329
|
|
|
357
330
|
# bundled or system-wide
|
|
358
|
-
if all(
|
|
359
|
-
x.exists()
|
|
360
|
-
for x in [
|
|
331
|
+
if all(x.exists() for x in [
|
|
361
332
|
base_path / "bin" / "ptxas.exe",
|
|
362
333
|
base_path / "include" / "cuda.h",
|
|
363
334
|
base_path / "lib" / "x64" / "cuda.lib",
|
|
364
|
-
|
|
365
|
-
):
|
|
335
|
+
]):
|
|
366
336
|
return (
|
|
367
337
|
str(base_path / "bin"),
|
|
368
338
|
[str(base_path / "include")],
|
|
@@ -379,9 +349,7 @@ def find_cuda_env() -> tuple[Optional[str], list[str], list[str]]:
|
|
|
379
349
|
continue
|
|
380
350
|
|
|
381
351
|
cuda_base_path = Path(cuda_base_path)
|
|
382
|
-
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(
|
|
383
|
-
cuda_base_path
|
|
384
|
-
)
|
|
352
|
+
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(cuda_base_path)
|
|
385
353
|
if cuda_bin_path:
|
|
386
354
|
return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
|
|
387
355
|
|
|
@@ -389,9 +357,7 @@ def find_cuda_env() -> tuple[Optional[str], list[str], list[str]]:
|
|
|
389
357
|
|
|
390
358
|
|
|
391
359
|
def find_cuda_bundled() -> tuple[Optional[str], list[str], list[str]]:
|
|
392
|
-
cuda_base_path = (
|
|
393
|
-
Path(sysconfig.get_paths()["platlib"]) / "triton" / "backends" / "nvidia"
|
|
394
|
-
)
|
|
360
|
+
cuda_base_path = (Path(sysconfig.get_paths()["platlib"]) / "triton" / "backends" / "nvidia")
|
|
395
361
|
return check_and_find_cuda(cuda_base_path)
|
|
396
362
|
|
|
397
363
|
|
|
@@ -415,23 +381,21 @@ def find_cuda_hardcoded() -> tuple[Optional[str], list[str], list[str]]:
|
|
|
415
381
|
paths = sorted(paths)[::-1]
|
|
416
382
|
for cuda_base_path in paths:
|
|
417
383
|
cuda_base_path = Path(cuda_base_path)
|
|
418
|
-
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(
|
|
419
|
-
cuda_base_path
|
|
420
|
-
)
|
|
384
|
+
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(cuda_base_path)
|
|
421
385
|
if cuda_bin_path:
|
|
422
386
|
return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
|
|
423
387
|
|
|
424
388
|
return None, [], []
|
|
425
389
|
|
|
426
390
|
|
|
427
|
-
@functools.
|
|
391
|
+
@functools.lru_cache
|
|
428
392
|
def find_cuda() -> tuple[Optional[str], list[str], list[str]]:
|
|
429
393
|
for f in [
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
394
|
+
find_cuda_env,
|
|
395
|
+
find_cuda_bundled,
|
|
396
|
+
find_cuda_pip,
|
|
397
|
+
find_cuda_conda,
|
|
398
|
+
find_cuda_hardcoded,
|
|
435
399
|
]:
|
|
436
400
|
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = f()
|
|
437
401
|
if cuda_bin_path:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: triton-windows
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.4.0.post21
|
|
4
4
|
Summary: A language and compiler for custom Deep Learning operations
|
|
5
5
|
Home-page: https://github.com/woct0rdho/triton-windows
|
|
6
6
|
Author: Philippe Tillet, Dian Wu
|
|
@@ -15,9 +15,12 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.11
|
|
16
16
|
Classifier: Programming Language :: Python :: 3.12
|
|
17
17
|
Classifier: Programming Language :: Python :: 3.13
|
|
18
|
+
Requires-Python: >=3.9,<3.14
|
|
19
|
+
License-File: LICENSE
|
|
18
20
|
Requires-Dist: setuptools>=40.8.0
|
|
21
|
+
Requires-Dist: importlib-metadata; python_version < "3.10"
|
|
19
22
|
Provides-Extra: build
|
|
20
|
-
Requires-Dist: cmake
|
|
23
|
+
Requires-Dist: cmake<4.0,>=3.20; extra == "build"
|
|
21
24
|
Requires-Dist: lit; extra == "build"
|
|
22
25
|
Provides-Extra: tests
|
|
23
26
|
Requires-Dist: autopep8; extra == "tests"
|
|
@@ -37,6 +40,8 @@ Dynamic: author-email
|
|
|
37
40
|
Dynamic: classifier
|
|
38
41
|
Dynamic: home-page
|
|
39
42
|
Dynamic: keywords
|
|
43
|
+
Dynamic: license-file
|
|
40
44
|
Dynamic: provides-extra
|
|
41
45
|
Dynamic: requires-dist
|
|
46
|
+
Dynamic: requires-python
|
|
42
47
|
Dynamic: summary
|
|
@@ -1,56 +1,77 @@
|
|
|
1
|
-
triton/__init__.py,sha256=
|
|
2
|
-
triton/
|
|
3
|
-
triton/
|
|
1
|
+
triton/__init__.py,sha256=CEQdOZ6zLHuEVGoJ7kl2YwcfBBIPBkOQzzDjHr6ibOM,1464
|
|
2
|
+
triton/_filecheck.py,sha256=iWl8uL4LJeV4En4h4mzUbDnrmXB4jXeEgy4_uqRURH8,2845
|
|
3
|
+
triton/_internal_testing.py,sha256=MpXYuQlvJUYtFAmFcfzlU7dyIDEQB_XirJ96IRujMTA,6326
|
|
4
|
+
triton/_utils.py,sha256=XTYb3qDDaVmbhmXbm6ChMTYTW9jeE538jZwJE_eliQg,3539
|
|
4
5
|
triton/errors.py,sha256=8WfnuRKLG578mgY6cBA3ECruVMf9ULEKFNgRcJ6IhWM,89
|
|
5
|
-
triton/
|
|
6
|
-
triton/
|
|
7
|
-
triton/
|
|
8
|
-
triton/
|
|
9
|
-
triton/backends/
|
|
6
|
+
triton/knobs.py,sha256=VOdRM_J0TejBYP2H7QZZzLwcRJq-Eppm9hBxDt1pCgA,14916
|
|
7
|
+
triton/testing.py,sha256=vbEQRNrOnnzRQvVVSaiZrUo8AC0XPV40GJxfvkKYLh0,20276
|
|
8
|
+
triton/windows_utils.py,sha256=JMi6mjOApzh2-cw3Wl_nl6ji7JkwexYI7xgo2Et3ihU,12903
|
|
9
|
+
triton/_C/libtriton.pyd,sha256=DOWk1tvh31NsGgulOpQCJQiAh5k-8tkKuDM1Q4UJrr8,89941504
|
|
10
|
+
triton/backends/__init__.py,sha256=X7290kf96Fk9QnfLScsX4UDG3zPyH_-31E4A7pVOijM,1612
|
|
11
|
+
triton/backends/compiler.py,sha256=MY2_cQG26p68z8VwRv2Nlj_h2DfEhwBbN-30caMgep0,2840
|
|
10
12
|
triton/backends/driver.py,sha256=AN60upJlPgia0JwvZ8vIVgLITNPuI0fdz8zMIIHPpF4,1450
|
|
11
|
-
triton/backends/amd/
|
|
12
|
-
triton/backends/amd/
|
|
13
|
-
triton/backends/amd/driver.
|
|
13
|
+
triton/backends/amd/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
+
triton/backends/amd/compiler.py,sha256=yaF0MqfZ81-fL1Jb5aDxGGGXs2j-pVOVpneWQT6DOUs,19966
|
|
15
|
+
triton/backends/amd/driver.c,sha256=hu5_QLMJVmeyR5zYDWfDAklZckISaAFM7kKOg9MpuWE,8612
|
|
16
|
+
triton/backends/amd/driver.py,sha256=9UU3u5gdqjGl7NYdsOyFk63MjgWW0Vnx4-jJ35ROscs,23718
|
|
14
17
|
triton/backends/amd/lib/asanrtl.bc,sha256=1xv2RlU3WvbdsghHlmhwiHewGM2B5dKts5bERM6S89o,24508
|
|
15
18
|
triton/backends/amd/lib/ockl.bc,sha256=wQKCzkKukIHbu0lyjKUYlhndc7S27xto6L54J0Bn-C0,246124
|
|
16
19
|
triton/backends/amd/lib/ocml.bc,sha256=UPNTXW0gCXUNB-c6orSYwb-mz9_mjUc7zny_vfFza44,205964
|
|
17
20
|
triton/backends/nvidia/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
|
-
triton/backends/nvidia/compiler.py,sha256=
|
|
19
|
-
triton/backends/nvidia/driver.c,sha256=
|
|
20
|
-
triton/backends/nvidia/driver.py,sha256=
|
|
21
|
-
triton/backends/nvidia/bin/ptxas.exe,sha256=
|
|
21
|
+
triton/backends/nvidia/compiler.py,sha256=7Yf79DdocAgYaWLR-MSTAv_QAJjVAZkhbr1-ovsnfnQ,19452
|
|
22
|
+
triton/backends/nvidia/driver.c,sha256=rH8RDtMMv_UHr7qiLnrSdNg3xojOQe_fF1zW67LFjaE,17882
|
|
23
|
+
triton/backends/nvidia/driver.py,sha256=22YT4HhTPI1aJ-mjw31Lvw1mcffj17pGJry3m9W6jio,26391
|
|
24
|
+
triton/backends/nvidia/bin/ptxas.exe,sha256=f28E0l5aerLAfBXk7yagfOwIEE6_6_NkMx-vqPPEQ9Y,24753152
|
|
22
25
|
triton/backends/nvidia/include/cuda.h,sha256=Fn44OjeRImxegJ39apYUspseEfTWNGwpqSGUOnHj5WY,1183268
|
|
23
26
|
triton/backends/nvidia/lib/libdevice.10.bc,sha256=XC-uN8huaMOjhgWpX1EtfRLV89uYYxC-R_VzBKpype4,473728
|
|
24
27
|
triton/backends/nvidia/lib/x64/cuda.lib,sha256=I5DZfR8aQ9wodYo3trskSbJpJd9lHvZXsnEZ3NV30LQ,160840
|
|
25
28
|
triton/compiler/__init__.py,sha256=0NEunzjGCNEVOhYZLDI4pDi_zAaYAgTXNm8U5uxbdL0,242
|
|
26
|
-
triton/compiler/code_generator.py,sha256=
|
|
27
|
-
triton/compiler/compiler.py,sha256=
|
|
29
|
+
triton/compiler/code_generator.py,sha256=Rj5tFwqfjMR1-cr2CxOr3nP3Ez2V4EE3ENKQAp66uXc,66887
|
|
30
|
+
triton/compiler/compiler.py,sha256=Qm-71vUVkafzjYpU3ttBCTzWheIs08Z7zfBusHODEJY,21556
|
|
28
31
|
triton/compiler/errors.py,sha256=I9Y15pDWcL9heY4SWWdLeMDtW6Iiq2pFXzKfJ6dY_C0,1732
|
|
29
32
|
triton/compiler/make_launcher.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
|
-
triton/
|
|
31
|
-
triton/
|
|
32
|
-
triton/
|
|
33
|
-
triton/
|
|
34
|
-
triton/language/
|
|
35
|
-
triton/language/
|
|
36
|
-
triton/language/
|
|
33
|
+
triton/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
34
|
+
triton/experimental/gluon/__init__.py,sha256=e2NX3d9SND2hKGmMDCix8_Sg12BMpK3zR3NWHy-fioQ,76
|
|
35
|
+
triton/experimental/gluon/_compiler.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
36
|
+
triton/experimental/gluon/_runtime.py,sha256=b9KD4D-CuuYxwCwj2YPPpT_JnB9e1G4_ztaESaQjgHw,3395
|
|
37
|
+
triton/experimental/gluon/language/__init__.py,sha256=EEDociOXDeGu9y2FwRKUoGNxJKu7eg2jIi94j7TQxx8,451
|
|
38
|
+
triton/experimental/gluon/language/_core.py,sha256=8QyCvlTYQnqIDDrVa1AG84u7K3zaBAS_81QaJg3RfqE,8911
|
|
39
|
+
triton/experimental/gluon/language/_layouts.py,sha256=H3Zkjmkl-IwJxCC4pnPxsBw-2mGNCNfn9MbY1c1l6s0,8975
|
|
40
|
+
triton/experimental/gluon/language/_math.py,sha256=R9fMFusmpDy3rdA-zwxIIB5nigEg08MeTnSdslA0DQA,329
|
|
41
|
+
triton/experimental/gluon/language/_semantic.py,sha256=SKRJPTtnhfWpuz-deaun845FpHvBvpeW2erGqO-UkdQ,13898
|
|
42
|
+
triton/experimental/gluon/language/_standard.py,sha256=gunBiUjdix_LDVONM-OZ5wuR0i8O5fZLXHmOBEyvbLk,1063
|
|
43
|
+
triton/experimental/gluon/language/nvidia/__init__.py,sha256=SFBuACK5P2XoYcutHEnKjqgRTboU4CPDmJz0hT6dFRQ,80
|
|
44
|
+
triton/experimental/gluon/language/nvidia/blackwell/__init__.py,sha256=KHmHVhgmvhLQXsPfwExi_O9AfjQzpKYmqb3zWSy6W4Q,7599
|
|
45
|
+
triton/experimental/gluon/language/nvidia/blackwell/tma.py,sha256=15xbiY2QmgjVPHvXPJ3MKrAkbqcaZjgx-V3B-kCgo9w,1086
|
|
46
|
+
triton/experimental/gluon/language/nvidia/hopper/__init__.py,sha256=NVReVRoDY3OuXQUrKokuxY86z0QehVSje63H52u-5hs,295
|
|
47
|
+
triton/experimental/gluon/language/nvidia/hopper/mbarrier.py,sha256=MezUtSQr-FzYM1kbuto1xNE4NimSBZIxP25xkn7nnp8,1603
|
|
48
|
+
triton/experimental/gluon/language/nvidia/hopper/tma.py,sha256=Jih4obM2oGcUU7DYg0T7nSindQQd40zgQh0e1MeLh6A,3508
|
|
49
|
+
triton/experimental/gluon/nvidia/__init__.py,sha256=ISXB4RV7RcCLsU-JhcRFeA29gCBDVk8cTwO2j99ivLc,80
|
|
50
|
+
triton/experimental/gluon/nvidia/blackwell.py,sha256=cllwlUCE5_YKWqySQZk7wt7Fierz345E5VwztxNRGMs,69
|
|
51
|
+
triton/experimental/gluon/nvidia/hopper.py,sha256=SKDi2fPCB87vMZAF6Em3gfZgif95U4Omeiexn7c969o,1518
|
|
52
|
+
triton/language/__init__.py,sha256=XJPQq1rq0SoPcfqkw9YVln7XqP0I25nIGzTa-QmMpiY,6418
|
|
53
|
+
triton/language/core.py,sha256=UsbrSmv92MdNW9JfKkyDhDOAZI8drrS1PK7Av2un7A4,116128
|
|
54
|
+
triton/language/math.py,sha256=CKvuIc5iMKhz7Qgx9w-VcLfOOZadv5svKK4aGZLuHMc,7399
|
|
55
|
+
triton/language/random.py,sha256=jkuFmfgZ8yvKuub9EY27zPvsC6nhkJIk05xf4y-7SR8,7102
|
|
56
|
+
triton/language/semantic.py,sha256=ERQftF6yeuTMmPO722GWpfwug7r9ZDWhxThL1kwLfmI,96950
|
|
57
|
+
triton/language/standard.py,sha256=EwpxORaDOiOdT4QgCnxiFXrsC652no7RptVoF9OszN0,16152
|
|
37
58
|
triton/language/extra/__init__.py,sha256=XRXFvr7416pRsh_Rh-X6qV66SiEyVDVbxp4GSAE1mfc,655
|
|
38
59
|
triton/language/extra/libdevice.py,sha256=Dki14elRNmQsz-Ytw9CnOaLCCnte4T6cI8bOzWjN63A,6318
|
|
39
|
-
triton/language/extra/cuda/__init__.py,sha256=
|
|
40
|
-
triton/language/extra/cuda/
|
|
41
|
-
triton/language/extra/cuda/libdevice.py,sha256=
|
|
42
|
-
triton/language/extra/cuda/utils.py,sha256=
|
|
60
|
+
triton/language/extra/cuda/__init__.py,sha256=MBBu2EUYxsp6ygjiwO4Yh1X1EswMstfaiRTMSMGtbcw,407
|
|
61
|
+
triton/language/extra/cuda/gdc.py,sha256=QAqc_E1INKjYlW6ERSnb9uWoEBDAQlnxkn2yiIWHJPQ,2185
|
|
62
|
+
triton/language/extra/cuda/libdevice.py,sha256=J7Kl0ejbAIus7-YBn2OSK71lkm3pC7G1J-5ZdHfS82U,56764
|
|
63
|
+
triton/language/extra/cuda/utils.py,sha256=phDcXCFViaq3p4ThwHrO8-FtU-8A8I3nk4mZZJVvTio,4426
|
|
43
64
|
triton/language/extra/hip/__init__.py,sha256=ieSER4LeX9_0horChGUUVwpuKAprkuka8uGAkEBDyDM,49
|
|
44
|
-
triton/language/extra/hip/libdevice.py,sha256=
|
|
65
|
+
triton/language/extra/hip/libdevice.py,sha256=Rf-AmBzcO6ORVzSxSuLXOy0lpoZTsnRAuTvjSF83r-E,17313
|
|
45
66
|
triton/runtime/__init__.py,sha256=mKL5cqIBDUw2WO80NRCh4s1G8KYaqgM59TTAbTkPPjQ,621
|
|
46
67
|
triton/runtime/_allocation.py,sha256=zaW4B7I7c-2rkVuN7IZaUB6IQSI1t4FvnTPZH-r7DTk,798
|
|
47
|
-
triton/runtime/autotuner.py,sha256=
|
|
48
|
-
triton/runtime/build.py,sha256=
|
|
49
|
-
triton/runtime/cache.py,sha256=
|
|
50
|
-
triton/runtime/driver.py,sha256=
|
|
68
|
+
triton/runtime/autotuner.py,sha256=cfWBuLpL6-eBv-J2tFIbL0gE3ZGHOFSJ0e3n0GfrzLw,20244
|
|
69
|
+
triton/runtime/build.py,sha256=hiFHKRV-fwDoT5lMbWpzA3hLp6LrE6ccw7zTo2AvHo0,6203
|
|
70
|
+
triton/runtime/cache.py,sha256=uMV-CwCaS9cthIzKoLlTHXjhw_RoaIUaVP7zmgsdsIo,9689
|
|
71
|
+
triton/runtime/driver.py,sha256=seGhU4efCFPVN0KVzd4gmZ1x5s0I_sFyM5NC8brXWF8,1798
|
|
51
72
|
triton/runtime/errors.py,sha256=CwfJXciwel_-K3BfQfKUpLPDWrSyTnGsfJkqJojrdfQ,1052
|
|
52
|
-
triton/runtime/interpreter.py,sha256=
|
|
53
|
-
triton/runtime/jit.py,sha256=
|
|
73
|
+
triton/runtime/interpreter.py,sha256=IJ0kLHhnoRd9-lwtPk7l8IDs_reaDzLpIBJaH-h3S8g,60811
|
|
74
|
+
triton/runtime/jit.py,sha256=M1bqBguYaMuztS3hH4-T5CI8UrvADyEC2mVuuwKVrvg,36766
|
|
54
75
|
triton/runtime/tcc/libtcc.dll,sha256=4IVp00uvXFRsmhnF5tC1mT8Zb0Hl6uuxDlTHG1kQkrw,156160
|
|
55
76
|
triton/runtime/tcc/tcc.exe,sha256=6cs-ieIKnv6tg8yeaxADFCdWNML3BQVtpx9CTqmwzfA,23552
|
|
56
77
|
triton/runtime/tcc/include/_mingw.h,sha256=q0vn005_oOci8JSODJCtTZW4oexknC8Ybfo4e1e-eDM,3865
|
|
@@ -151,13 +172,15 @@ triton/runtime/tcc/lib/user32.def,sha256=EcYohyyDgmz9fLBoOR-vszLeJ2YkBUoNGvSnuXr
|
|
|
151
172
|
triton/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
152
173
|
triton/tools/build_extern.py,sha256=jCr-2hu3nLGBIJhCGUQ1jAyzLttughjkiPGEwRFjLR0,13673
|
|
153
174
|
triton/tools/compile.py,sha256=CP_-yqEd55ejkc2_OYVE7q0Eyh9xErk8KJy2BcdCV0Y,7129
|
|
154
|
-
triton/tools/disasm.py,sha256=
|
|
155
|
-
triton/tools/experimental_descriptor.py,sha256=0Wqy96Cc6YLh9o0eTknW-Lfvha6lfRSfe8bswkcPHMs,1260
|
|
175
|
+
triton/tools/disasm.py,sha256=T9jiTkdK_0nI3R_4uea0zvfioYdcR-zIZwTfuucgw6g,5026
|
|
156
176
|
triton/tools/link.py,sha256=u7qtfZRLriZkAMEGNvj8YF-k1cthmLL7BwHYqBgT63E,11871
|
|
157
177
|
triton/tools/mxfp.py,sha256=YQdpBrGkOVNOtnLeRjMCeVFHWkSwUubGeWsItIjO8TU,11737
|
|
178
|
+
triton/tools/tensor_descriptor.py,sha256=mt4iVVRcNg0gjoytb6iCP4l5vt-H2V3MGeAQfJcStJo,1289
|
|
158
179
|
triton/tools/extra/cuda/compile.c,sha256=TdIENsqk6wrvv1C4Mk-sq9keXe3SJuMQcf0UpxmjNZk,2153
|
|
159
180
|
triton/tools/extra/cuda/compile.h,sha256=n9QKIFZTL4RSsiXtAxBP9XGSnxjyaevQQ9bBpwDsvAg,332
|
|
160
|
-
triton_windows-3.
|
|
161
|
-
triton_windows-3.
|
|
162
|
-
triton_windows-3.
|
|
163
|
-
triton_windows-3.
|
|
181
|
+
triton_windows-3.4.0.post21.dist-info/licenses/LICENSE,sha256=kmQPuXIi_Qppj_KM4MN4LBcmI_jWxgm1V2NqgPKPuUY,1132
|
|
182
|
+
triton_windows-3.4.0.post21.dist-info/METADATA,sha256=IHOZbw7LyV2c8dx0dvpgjzw8LmdA99209iNKB1UC01o,1794
|
|
183
|
+
triton_windows-3.4.0.post21.dist-info/WHEEL,sha256=XkFE14KmFh7mutkkb-qn_ueuH2lwfT8rLdfc5xpQ7wE,99
|
|
184
|
+
triton_windows-3.4.0.post21.dist-info/entry_points.txt,sha256=cztF9ZYXxoMhibI_OttiKCl1EBP2LQaV8naJ-BcuES4,76
|
|
185
|
+
triton_windows-3.4.0.post21.dist-info/top_level.txt,sha256=WBiIZyv6n9Y7MIh-HPHSv2w1RDk7EFL__7ZgQRrmHYs,7
|
|
186
|
+
triton_windows-3.4.0.post21.dist-info/RECORD,,
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright 2018-2020 Philippe Tillet
|
|
3
|
+
* Copyright 2020-2022 OpenAI
|
|
4
|
+
*
|
|
5
|
+
* Permission is hereby granted, free of charge, to any person obtaining
|
|
6
|
+
* a copy of this software and associated documentation files
|
|
7
|
+
* (the "Software"), to deal in the Software without restriction,
|
|
8
|
+
* including without limitation the rights to use, copy, modify, merge,
|
|
9
|
+
* publish, distribute, sublicense, and/or sell copies of the Software,
|
|
10
|
+
* and to permit persons to whom the Software is furnished to do so,
|
|
11
|
+
* subject to the following conditions:
|
|
12
|
+
*
|
|
13
|
+
* The above copyright notice and this permission notice shall be
|
|
14
|
+
* included in all copies or substantial portions of the Software.
|
|
15
|
+
*
|
|
16
|
+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
17
|
+
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
18
|
+
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
19
|
+
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
20
|
+
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
21
|
+
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
22
|
+
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
23
|
+
*/
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
triton
|
triton/language/_utils.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
from typing import List
|
|
2
|
-
|
|
3
|
-
TRITON_MAX_TENSOR_NUMEL = 1048576
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def is_power_of_two(x):
|
|
7
|
-
return (x & (x - 1)) == 0
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def validate_block_shape(shape: List[int]):
|
|
11
|
-
numel = 1
|
|
12
|
-
for i, d in enumerate(shape):
|
|
13
|
-
if not isinstance(d, int):
|
|
14
|
-
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]")
|
|
15
|
-
if not is_power_of_two(d):
|
|
16
|
-
raise ValueError(f"Shape element {i} must be a power of 2")
|
|
17
|
-
numel *= d
|
|
18
|
-
|
|
19
|
-
if numel > TRITON_MAX_TENSOR_NUMEL:
|
|
20
|
-
raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
|
|
21
|
-
return numel
|
|
@@ -1,106 +0,0 @@
|
|
|
1
|
-
from typing import Sequence
|
|
2
|
-
|
|
3
|
-
from triton.language import core
|
|
4
|
-
from triton.language import semantic
|
|
5
|
-
from triton._C.libtriton import ir
|
|
6
|
-
|
|
7
|
-
__all__ = [
|
|
8
|
-
"experimental_device_tensormap_create1d",
|
|
9
|
-
"experimental_device_tensormap_create2d",
|
|
10
|
-
"experimental_tensormap_fenceproxy_acquire",
|
|
11
|
-
]
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def _determine_elem_type(element_ty: core.dtype):
|
|
15
|
-
if element_ty.primitive_bitwidth == 8:
|
|
16
|
-
return 0
|
|
17
|
-
elif element_ty.primitive_bitwidth == 16:
|
|
18
|
-
return 1
|
|
19
|
-
elif element_ty.primitive_bitwidth == 32:
|
|
20
|
-
return 2
|
|
21
|
-
else:
|
|
22
|
-
raise ValueError("element_ty must be a primitive of size 1, 2, or 4 bytes but got")
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
@core.builtin
|
|
26
|
-
def experimental_device_tensormap_create1d(
|
|
27
|
-
desc_ptr: core.tensor,
|
|
28
|
-
global_address: core.tensor,
|
|
29
|
-
load_size: core.tensor,
|
|
30
|
-
global_size: core.tensor,
|
|
31
|
-
element_ty: core.dtype,
|
|
32
|
-
_builder: ir.builder = None,
|
|
33
|
-
):
|
|
34
|
-
load_size = core._constexpr_to_value(load_size)
|
|
35
|
-
global_size = semantic.to_tensor(global_size, _builder)
|
|
36
|
-
element_ty = core._constexpr_to_value(element_ty)
|
|
37
|
-
element_stride = [core.full([], 1, core.int32, _builder=_builder)]
|
|
38
|
-
|
|
39
|
-
semantic.tensormap_create(
|
|
40
|
-
desc_ptr=desc_ptr,
|
|
41
|
-
global_address=global_address,
|
|
42
|
-
box_dim=[semantic.to_tensor(load_size, _builder)],
|
|
43
|
-
global_dim=[global_size],
|
|
44
|
-
global_stride=[],
|
|
45
|
-
element_stride=element_stride,
|
|
46
|
-
elem_type=_determine_elem_type(element_ty),
|
|
47
|
-
interleave_layout=0,
|
|
48
|
-
swizzle_mode=0,
|
|
49
|
-
fill_mode=0,
|
|
50
|
-
builder=_builder,
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
@core.builtin
|
|
55
|
-
def experimental_device_tensormap_create2d(
|
|
56
|
-
desc_ptr: core.tensor,
|
|
57
|
-
global_address: core.tensor,
|
|
58
|
-
load_size: Sequence[core.constexpr],
|
|
59
|
-
global_size: Sequence[core.tensor],
|
|
60
|
-
element_ty: core.dtype,
|
|
61
|
-
_builder: ir.builder = None,
|
|
62
|
-
):
|
|
63
|
-
assert len(load_size) == 2
|
|
64
|
-
assert len(global_size) == 2
|
|
65
|
-
load_size = [core._constexpr_to_value(x) for x in load_size]
|
|
66
|
-
global_size = [semantic.to_tensor(x, _builder) for x in global_size]
|
|
67
|
-
|
|
68
|
-
element_size = element_ty.primitive_bitwidth // 8
|
|
69
|
-
element_size_t = core.full([], element_size, core.int64, _builder=_builder)
|
|
70
|
-
global_stride = semantic.mul(element_size_t, global_size[-1], True, _builder)
|
|
71
|
-
|
|
72
|
-
contig_dim_size_in_bytes = element_size * load_size[-1]
|
|
73
|
-
if contig_dim_size_in_bytes > 128:
|
|
74
|
-
load_size[-1] = 128 // element_size
|
|
75
|
-
|
|
76
|
-
elem_stride = core.full([], 1, core.int32, _builder=_builder)
|
|
77
|
-
|
|
78
|
-
semantic.tensormap_create(
|
|
79
|
-
desc_ptr=desc_ptr,
|
|
80
|
-
global_address=global_address,
|
|
81
|
-
box_dim=[semantic.to_tensor(x, _builder) for x in load_size[::-1]],
|
|
82
|
-
global_dim=global_size[::-1],
|
|
83
|
-
global_stride=[global_stride],
|
|
84
|
-
element_stride=[elem_stride, elem_stride],
|
|
85
|
-
elem_type=_determine_elem_type(element_ty),
|
|
86
|
-
interleave_layout=0,
|
|
87
|
-
swizzle_mode=_determine_swizzle_mode_2d(contig_dim_size_in_bytes, load_size),
|
|
88
|
-
fill_mode=0,
|
|
89
|
-
builder=_builder,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def _determine_swizzle_mode_2d(contig_dim_size_in_bytes, load_size):
|
|
94
|
-
if contig_dim_size_in_bytes >= 128:
|
|
95
|
-
return 3
|
|
96
|
-
elif contig_dim_size_in_bytes >= 64:
|
|
97
|
-
return 2
|
|
98
|
-
elif contig_dim_size_in_bytes >= 32:
|
|
99
|
-
return 1
|
|
100
|
-
else:
|
|
101
|
-
raise ValueError("block size too small")
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
@core.builtin
|
|
105
|
-
def experimental_tensormap_fenceproxy_acquire(desc_ptr: core.tensor, _builder: ir.builder = None):
|
|
106
|
-
semantic.tensormap_fenceproxy_acquire(desc_ptr, _builder)
|
|
@@ -1,32 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
import triton
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class TmaDescKernelParam:
|
|
7
|
-
TMA_DESC_SIZE = 128
|
|
8
|
-
|
|
9
|
-
def __init__(self, ptr, dims, block_dims, element_size):
|
|
10
|
-
self.desc = torch.empty(self.TMA_DESC_SIZE, dtype=torch.uint8, device="cpu")
|
|
11
|
-
assert len(dims) == len(block_dims)
|
|
12
|
-
assert 1 <= len(dims) <= 2
|
|
13
|
-
assert self.desc.data_ptr() % 64 == 0
|
|
14
|
-
|
|
15
|
-
if len(dims) == 1:
|
|
16
|
-
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size,
|
|
17
|
-
self.desc.data_ptr())
|
|
18
|
-
else:
|
|
19
|
-
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0],
|
|
20
|
-
block_dims[1], element_size, self.desc.data_ptr())
|
|
21
|
-
|
|
22
|
-
# Return a CUtensorMap* pointer in host memory
|
|
23
|
-
def tma_desc_cpu_ptr(self):
|
|
24
|
-
return self.desc.data_ptr()
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def create_1d_tma_descriptor(ptr, dim, block_dim, element_size):
|
|
28
|
-
return TmaDescKernelParam(ptr, [dim], [block_dim], element_size)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def create_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size):
|
|
32
|
-
return TmaDescKernelParam(ptr, [dim1, dim0], [block_dim1, block_dim0], element_size)
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
triton
|
|
2
|
-
triton/_C
|
|
3
|
-
triton/backends
|
|
4
|
-
triton/backends/amd
|
|
5
|
-
triton/backends/nvidia
|
|
6
|
-
triton/compiler
|
|
7
|
-
triton/language
|
|
8
|
-
triton/language/extra
|
|
9
|
-
triton/language/extra\cuda
|
|
10
|
-
triton/language/extra\hip
|
|
11
|
-
triton/runtime
|
|
12
|
-
triton/tools
|
|
13
|
-
triton/tools/extra
|
|
14
|
-
triton/tools/extra\cuda
|
|
File without changes
|