triton-windows 3.3.1.post19__cp311-cp311-win_amd64.whl → 3.4.0.post20__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +4 -1
- triton/_filecheck.py +87 -0
- triton/_internal_testing.py +26 -15
- triton/_utils.py +110 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +112 -78
- triton/backends/amd/driver.c +5 -2
- triton/backends/amd/driver.py +149 -47
- triton/backends/compiler.py +7 -21
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +92 -93
- triton/backends/nvidia/driver.c +90 -98
- triton/backends/nvidia/driver.py +303 -128
- triton/compiler/code_generator.py +212 -111
- triton/compiler/compiler.py +110 -25
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +4 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +99 -0
- triton/experimental/gluon/language/__init__.py +18 -0
- triton/experimental/gluon/language/_core.py +312 -0
- triton/experimental/gluon/language/_layouts.py +230 -0
- triton/experimental/gluon/language/_math.py +12 -0
- triton/experimental/gluon/language/_semantic.py +287 -0
- triton/experimental/gluon/language/_standard.py +47 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +40 -0
- triton/knobs.py +481 -0
- triton/language/__init__.py +39 -14
- triton/language/core.py +794 -537
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/libdevice.py +113 -104
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1706 -1770
- triton/language/standard.py +116 -51
- triton/runtime/autotuner.py +117 -59
- triton/runtime/build.py +76 -12
- triton/runtime/cache.py +18 -47
- triton/runtime/driver.py +32 -29
- triton/runtime/interpreter.py +72 -35
- triton/runtime/jit.py +146 -110
- triton/testing.py +16 -12
- triton/tools/disasm.py +3 -4
- triton/tools/tensor_descriptor.py +36 -0
- triton/windows_utils.py +14 -6
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
- triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
- triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
- triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +0 -0
triton/compiler/compiler.py
CHANGED
|
@@ -3,19 +3,19 @@ import hashlib
|
|
|
3
3
|
import json
|
|
4
4
|
from .._C.libtriton import get_cache_invalidating_env_vars, ir
|
|
5
5
|
from ..backends import backends
|
|
6
|
-
from ..backends.compiler import
|
|
7
|
-
from .. import
|
|
6
|
+
from ..backends.compiler import Language
|
|
7
|
+
from ..backends.compiler import BaseBackend, GPUTarget
|
|
8
|
+
from .. import __version__, knobs
|
|
8
9
|
from ..runtime.autotuner import OutOfResources
|
|
9
10
|
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
|
|
10
11
|
from ..runtime.driver import driver
|
|
11
12
|
from ..tools.disasm import get_sass
|
|
12
|
-
# TODO: this shouldn't be here
|
|
13
|
-
from .code_generator import ast_to_ttir
|
|
14
13
|
from pathlib import Path
|
|
15
14
|
import re
|
|
16
15
|
import functools
|
|
17
16
|
import os
|
|
18
17
|
import sysconfig
|
|
18
|
+
import time
|
|
19
19
|
|
|
20
20
|
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
|
21
21
|
# and any following whitespace
|
|
@@ -53,6 +53,7 @@ class ASTSource:
|
|
|
53
53
|
|
|
54
54
|
def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
|
|
55
55
|
self.fn = fn
|
|
56
|
+
self.language = Language.TRITON
|
|
56
57
|
self.ext = "ttir"
|
|
57
58
|
self.name = fn.__name__
|
|
58
59
|
self.signature = signature
|
|
@@ -78,6 +79,7 @@ class ASTSource:
|
|
|
78
79
|
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
79
80
|
|
|
80
81
|
def make_ir(self, options, codegen_fns, module_map, context):
|
|
82
|
+
from .code_generator import ast_to_ttir
|
|
81
83
|
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
|
|
82
84
|
module_map=module_map)
|
|
83
85
|
|
|
@@ -91,6 +93,7 @@ class IRSource:
|
|
|
91
93
|
self.path = path
|
|
92
94
|
path = Path(path)
|
|
93
95
|
self.ext = path.suffix[1:]
|
|
96
|
+
self.language = Language.TRITON
|
|
94
97
|
self.src = path.read_text()
|
|
95
98
|
ir.load_dialects(context)
|
|
96
99
|
backend.load_dialects(context)
|
|
@@ -162,6 +165,11 @@ def triton_key():
|
|
|
162
165
|
return f'{__version__}' + '-'.join(contents)
|
|
163
166
|
|
|
164
167
|
|
|
168
|
+
@functools.lru_cache()
|
|
169
|
+
def max_shared_mem(device):
|
|
170
|
+
return driver.active.utils.get_device_properties(device)["max_shared_mem"]
|
|
171
|
+
|
|
172
|
+
|
|
165
173
|
def parse(full_name, ext, context):
|
|
166
174
|
if ext == "ttir" or ext == "ttgir":
|
|
167
175
|
module = ir.parse_mlir_module(full_name, context)
|
|
@@ -179,7 +187,7 @@ def filter_traceback(e: BaseException):
|
|
|
179
187
|
|
|
180
188
|
These are uninteresting to the user -- "just show me *my* code!"
|
|
181
189
|
"""
|
|
182
|
-
if
|
|
190
|
+
if knobs.compilation.front_end_debugging:
|
|
183
191
|
return
|
|
184
192
|
|
|
185
193
|
if e.__cause__ is not None:
|
|
@@ -211,7 +219,50 @@ def filter_traceback(e: BaseException):
|
|
|
211
219
|
e.__traceback__ = frames[0]
|
|
212
220
|
|
|
213
221
|
|
|
222
|
+
class CompileTimer:
|
|
223
|
+
|
|
224
|
+
def __init__(self) -> None:
|
|
225
|
+
self.start: float = time.perf_counter()
|
|
226
|
+
self.ir_initialization_end: float | None = None
|
|
227
|
+
self.lowering_stage_ends: list[tuple[str, float]] = []
|
|
228
|
+
self.store_results_end: float | None = None
|
|
229
|
+
|
|
230
|
+
def finished_ir_initialization(self) -> None:
|
|
231
|
+
self.ir_initialization_end = time.perf_counter()
|
|
232
|
+
|
|
233
|
+
def stage_finished(self, stage_name: str) -> None:
|
|
234
|
+
self.lowering_stage_ends.append((stage_name, time.perf_counter()))
|
|
235
|
+
|
|
236
|
+
def end(self) -> knobs.CompileTimes:
|
|
237
|
+
timestamp = time.perf_counter()
|
|
238
|
+
if self.ir_initialization_end is None:
|
|
239
|
+
self.ir_initialization_end = timestamp
|
|
240
|
+
else:
|
|
241
|
+
self.store_results_end = timestamp
|
|
242
|
+
|
|
243
|
+
def delta(start: float, end: float | None) -> int:
|
|
244
|
+
if end is None:
|
|
245
|
+
return 0
|
|
246
|
+
return int((end - start) * 1000000)
|
|
247
|
+
|
|
248
|
+
lowering_stage_durations = []
|
|
249
|
+
stage_start = self.ir_initialization_end
|
|
250
|
+
for stage_name, stage_end in self.lowering_stage_ends:
|
|
251
|
+
lowering_stage_durations.append((stage_name, delta(stage_start, stage_end)))
|
|
252
|
+
stage_start = stage_end
|
|
253
|
+
|
|
254
|
+
return knobs.CompileTimes(
|
|
255
|
+
ir_initialization=delta(self.start, self.ir_initialization_end),
|
|
256
|
+
lowering_stages=lowering_stage_durations,
|
|
257
|
+
store_results=delta(stage_start, self.store_results_end),
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
214
261
|
def compile(src, target=None, options=None):
|
|
262
|
+
compilation_listener = knobs.compilation.listener
|
|
263
|
+
if compilation_listener:
|
|
264
|
+
timer = CompileTimer()
|
|
265
|
+
|
|
215
266
|
if target is None:
|
|
216
267
|
target = driver.active.get_current_target()
|
|
217
268
|
assert isinstance(target, GPUTarget), "target must be of GPUTarget type"
|
|
@@ -232,9 +283,9 @@ def compile(src, target=None, options=None):
|
|
|
232
283
|
fn_cache_manager = get_cache_manager(hash)
|
|
233
284
|
# For dumping/overriding only hash the source as we want it to be independent of triton
|
|
234
285
|
# core changes to make it easier to track kernels by hash.
|
|
235
|
-
enable_override =
|
|
236
|
-
enable_ir_dump =
|
|
237
|
-
store_only_binary =
|
|
286
|
+
enable_override = knobs.compilation.override
|
|
287
|
+
enable_ir_dump = knobs.compilation.dump_ir
|
|
288
|
+
store_only_binary = knobs.compilation.store_binary_only
|
|
238
289
|
fn_override_manager = get_override_manager(src.hash()) if enable_override else None
|
|
239
290
|
fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None
|
|
240
291
|
# Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms.
|
|
@@ -245,10 +296,20 @@ def compile(src, target=None, options=None):
|
|
|
245
296
|
metadata_filename = f"{file_name}.json"
|
|
246
297
|
metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
|
|
247
298
|
metadata_path = metadata_group.get(metadata_filename)
|
|
248
|
-
always_compile =
|
|
299
|
+
always_compile = knobs.compilation.always_compile
|
|
249
300
|
if not always_compile and metadata_path is not None:
|
|
250
301
|
# cache hit!
|
|
251
|
-
|
|
302
|
+
res = CompiledKernel(src, metadata_group, hash)
|
|
303
|
+
if compilation_listener:
|
|
304
|
+
compilation_listener(
|
|
305
|
+
src=src,
|
|
306
|
+
metadata=res.metadata._asdict(),
|
|
307
|
+
metadata_group=metadata_group,
|
|
308
|
+
times=timer.end(),
|
|
309
|
+
cache_hit=True,
|
|
310
|
+
)
|
|
311
|
+
return res
|
|
312
|
+
|
|
252
313
|
# initialize metadata
|
|
253
314
|
metadata = {
|
|
254
315
|
"hash": hash,
|
|
@@ -259,7 +320,7 @@ def compile(src, target=None, options=None):
|
|
|
259
320
|
metadata["triton_version"] = __version__
|
|
260
321
|
# run compilation pipeline and populate metadata
|
|
261
322
|
stages = dict()
|
|
262
|
-
backend.add_stages(stages, options)
|
|
323
|
+
backend.add_stages(stages, options, src.language)
|
|
263
324
|
first_stage = list(stages.keys()).index(src.ext)
|
|
264
325
|
# when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
|
|
265
326
|
if ir_source:
|
|
@@ -279,11 +340,30 @@ def compile(src, target=None, options=None):
|
|
|
279
340
|
except Exception as e:
|
|
280
341
|
filter_traceback(e)
|
|
281
342
|
raise
|
|
282
|
-
|
|
343
|
+
|
|
344
|
+
if ir_source:
|
|
345
|
+
ir_filename = f"{file_name}.{src.ext}"
|
|
346
|
+
metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename)
|
|
347
|
+
else:
|
|
348
|
+
ir_filename = f"{file_name}.source"
|
|
349
|
+
metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename)
|
|
350
|
+
|
|
351
|
+
use_ir_loc = knobs.compilation.use_ir_loc
|
|
352
|
+
if ir_source and use_ir_loc:
|
|
353
|
+
module.create_location_snapshot(src.path)
|
|
354
|
+
print(f"Creating new locations for {src.path}")
|
|
355
|
+
|
|
356
|
+
if compilation_listener:
|
|
357
|
+
timer.finished_ir_initialization()
|
|
283
358
|
for ext, compile_ir in list(stages.items())[first_stage:]:
|
|
284
359
|
next_module = compile_ir(module, metadata)
|
|
285
360
|
ir_filename = f"{file_name}.{ext}"
|
|
286
|
-
if
|
|
361
|
+
if fn_override_manager is None:
|
|
362
|
+
# Users can override kernels at scale by setting `ir_override` in autotune config
|
|
363
|
+
# without TRITON_KERNEL_OVERRIDE
|
|
364
|
+
if (ir_override := metadata.get("ir_override", None)) and ir_override.endswith(f".{ext}"):
|
|
365
|
+
next_module = parse(ir_override, ext, context)
|
|
366
|
+
elif full_name := fn_override_manager.get_file(ir_filename):
|
|
287
367
|
print(f"\nOverriding kernel with file {full_name}")
|
|
288
368
|
next_module = parse(full_name, ext, context)
|
|
289
369
|
# If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json
|
|
@@ -297,6 +377,8 @@ def compile(src, target=None, options=None):
|
|
|
297
377
|
next_module.create_location_snapshot(ir_full_name)
|
|
298
378
|
print(f"Creating new locations for {ir_full_name}")
|
|
299
379
|
module = next_module
|
|
380
|
+
if compilation_listener:
|
|
381
|
+
timer.stage_finished(ext)
|
|
300
382
|
# write-back metadata
|
|
301
383
|
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
|
|
302
384
|
binary=False)
|
|
@@ -310,13 +392,18 @@ def compile(src, target=None, options=None):
|
|
|
310
392
|
# this is likely due to the llvm-symbolizer forking a process
|
|
311
393
|
# TODO: Reconcile the difference here between the ASAN and non-ASAN path with enabling
|
|
312
394
|
# multithreading in the MLIR context
|
|
313
|
-
if not
|
|
395
|
+
if not knobs.compilation.enable_asan:
|
|
314
396
|
context.disable_multithreading()
|
|
397
|
+
|
|
398
|
+
# notify any listener
|
|
399
|
+
if compilation_listener:
|
|
400
|
+
compilation_listener(src=src, metadata=metadata, metadata_group=metadata_group, times=timer.end(),
|
|
401
|
+
cache_hit=False)
|
|
315
402
|
# return handle to compiled kernel
|
|
316
403
|
return CompiledKernel(src, metadata_group, hash)
|
|
317
404
|
|
|
318
405
|
|
|
319
|
-
def make_backend(target):
|
|
406
|
+
def make_backend(target: GPUTarget) -> BaseBackend:
|
|
320
407
|
actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)]
|
|
321
408
|
if len(actives) != 1:
|
|
322
409
|
raise RuntimeError(
|
|
@@ -330,7 +417,7 @@ class LazyDict:
|
|
|
330
417
|
self.data = data
|
|
331
418
|
self.extras = []
|
|
332
419
|
|
|
333
|
-
def get(self)
|
|
420
|
+
def get(self):
|
|
334
421
|
for func, args in self.extras:
|
|
335
422
|
self.data = self.data | func(*args)
|
|
336
423
|
self.extras.clear()
|
|
@@ -355,11 +442,6 @@ class AsmDict(dict):
|
|
|
355
442
|
|
|
356
443
|
class CompiledKernel:
|
|
357
444
|
|
|
358
|
-
# Hooks for external tools to monitor the execution of triton kernels
|
|
359
|
-
# TODO: move out of this namespace since it's a runtime thing
|
|
360
|
-
launch_enter_hook = None
|
|
361
|
-
launch_exit_hook = None
|
|
362
|
-
|
|
363
445
|
def __init__(self, src, metadata_group, hash):
|
|
364
446
|
from collections import namedtuple
|
|
365
447
|
metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
|
|
@@ -396,7 +478,7 @@ class CompiledKernel:
|
|
|
396
478
|
# create launcher
|
|
397
479
|
self.run = driver.active.launcher_cls(self.src, self.metadata)
|
|
398
480
|
# not enough shared memory to run the kernel
|
|
399
|
-
max_shared =
|
|
481
|
+
max_shared = max_shared_mem(device)
|
|
400
482
|
if self.metadata.shared > max_shared:
|
|
401
483
|
raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
|
|
402
484
|
if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:
|
|
@@ -405,8 +487,11 @@ class CompiledKernel:
|
|
|
405
487
|
if self.metadata.tmem_size > max_tmem_size:
|
|
406
488
|
raise OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory")
|
|
407
489
|
# TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
|
|
408
|
-
self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
|
|
490
|
+
self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
|
|
409
491
|
self.name, self.kernel, self.metadata.shared, device)
|
|
492
|
+
warp_size = driver.active.get_current_target().warp_size
|
|
493
|
+
if self.metadata.num_warps * warp_size > self.n_max_threads:
|
|
494
|
+
raise OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads")
|
|
410
495
|
|
|
411
496
|
def __getattribute__(self, name):
|
|
412
497
|
if name == 'run':
|
|
@@ -414,7 +499,7 @@ class CompiledKernel:
|
|
|
414
499
|
return super().__getattribute__(name)
|
|
415
500
|
|
|
416
501
|
def launch_metadata(self, grid, stream, *args):
|
|
417
|
-
if
|
|
502
|
+
if knobs.runtime.launch_enter_hook is None:
|
|
418
503
|
return None
|
|
419
504
|
ret = LazyDict({"name": self.name, "function": self.function, "stream": stream})
|
|
420
505
|
if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None:
|
|
@@ -436,6 +521,6 @@ class CompiledKernel:
|
|
|
436
521
|
stream = driver.active.get_current_stream(device)
|
|
437
522
|
launch_metadata = self.launch_metadata(grid, stream, *args)
|
|
438
523
|
self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata,
|
|
439
|
-
|
|
524
|
+
knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *args)
|
|
440
525
|
|
|
441
526
|
return runner
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import triton
|
|
3
|
+
from triton.compiler.compiler import ASTSource
|
|
4
|
+
from triton.backends.compiler import Language
|
|
5
|
+
from triton.runtime.jit import JITFunction
|
|
6
|
+
from typing import TypeVar, Optional, Callable, Iterable, Union
|
|
7
|
+
from triton._C.libtriton import ir
|
|
8
|
+
|
|
9
|
+
T = TypeVar("T")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GluonASTSource(ASTSource):
|
|
13
|
+
|
|
14
|
+
def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
|
|
15
|
+
super().__init__(fn, signature, constexprs, attrs)
|
|
16
|
+
self.language = Language.GLUON
|
|
17
|
+
self.ext = "ttgir"
|
|
18
|
+
|
|
19
|
+
def make_ir(self, options, codegen_fns, module_map, context):
|
|
20
|
+
from triton.compiler.compiler import make_backend
|
|
21
|
+
from triton.compiler.code_generator import ast_to_ttir
|
|
22
|
+
|
|
23
|
+
builder = ir.builder(context)
|
|
24
|
+
module = builder.create_module()
|
|
25
|
+
|
|
26
|
+
# Assign module attributes eagerly, as they are needed to verify layouts
|
|
27
|
+
target = triton.runtime.driver.active.get_current_target()
|
|
28
|
+
backend = make_backend(target)
|
|
29
|
+
target = backend.get_target_name(options)
|
|
30
|
+
module.set_attr("ttg.target", builder.get_string_attr(target))
|
|
31
|
+
module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps))
|
|
32
|
+
module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas))
|
|
33
|
+
module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(32))
|
|
34
|
+
if options.maxnreg is not None:
|
|
35
|
+
module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg))
|
|
36
|
+
|
|
37
|
+
module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
|
|
38
|
+
module_map=module_map, module=module)
|
|
39
|
+
return module
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class GluonJITFunction(JITFunction[T]):
|
|
43
|
+
|
|
44
|
+
def create_binder(self):
|
|
45
|
+
result = super().create_binder()
|
|
46
|
+
self.ASTSource = GluonASTSource
|
|
47
|
+
return result
|
|
48
|
+
|
|
49
|
+
def is_gluon(self):
|
|
50
|
+
return True
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def jit(
|
|
54
|
+
fn: Optional[T] = None,
|
|
55
|
+
*,
|
|
56
|
+
version=None,
|
|
57
|
+
repr: Optional[Callable] = None,
|
|
58
|
+
launch_metadata: Optional[Callable] = None,
|
|
59
|
+
do_not_specialize: Optional[Iterable[int | str]] = None,
|
|
60
|
+
do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
|
|
61
|
+
debug: Optional[bool] = None,
|
|
62
|
+
noinline: Optional[bool] = None,
|
|
63
|
+
) -> Union[GluonJITFunction[T], Callable[[T], JITFunction[T]]]:
|
|
64
|
+
"""
|
|
65
|
+
Decorator for JIT-compiling a function using the Triton compiler.
|
|
66
|
+
|
|
67
|
+
:note: When a jit'd function is called, arguments are
|
|
68
|
+
implicitly converted to pointers if they have a :code:`.data_ptr()` method
|
|
69
|
+
and a `.dtype` attribute.
|
|
70
|
+
|
|
71
|
+
:note: This function will be compiled and run on the GPU. It will only have access to:
|
|
72
|
+
|
|
73
|
+
* python primitives,
|
|
74
|
+
* builtins within the triton package,
|
|
75
|
+
* arguments to this function,
|
|
76
|
+
* other jit'd functions
|
|
77
|
+
|
|
78
|
+
:param fn: the function to be jit-compiled
|
|
79
|
+
:type fn: Callable
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def decorator(fn: T) -> JITFunction[T]:
|
|
83
|
+
assert callable(fn)
|
|
84
|
+
return GluonJITFunction(
|
|
85
|
+
fn,
|
|
86
|
+
version=version,
|
|
87
|
+
do_not_specialize=do_not_specialize,
|
|
88
|
+
do_not_specialize_on_alignment=do_not_specialize_on_alignment,
|
|
89
|
+
debug=debug,
|
|
90
|
+
noinline=noinline,
|
|
91
|
+
repr=repr,
|
|
92
|
+
launch_metadata=launch_metadata,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if fn is not None:
|
|
96
|
+
return decorator(fn)
|
|
97
|
+
|
|
98
|
+
else:
|
|
99
|
+
return decorator
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from ._core import * # NOQA: F403
|
|
2
|
+
from ._core import __all__ as __core_all
|
|
3
|
+
from ._layouts import * # NOQA: F403
|
|
4
|
+
from ._layouts import __all__ as __layouts_all
|
|
5
|
+
from ._math import * # NOQA: F403
|
|
6
|
+
from ._math import __all__ as __math_all
|
|
7
|
+
from ._standard import * # NOQA: F403
|
|
8
|
+
from ._standard import __all__ as __standard_all
|
|
9
|
+
|
|
10
|
+
from . import nvidia
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
*__core_all,
|
|
14
|
+
*__layouts_all,
|
|
15
|
+
*__math_all,
|
|
16
|
+
*__standard_all,
|
|
17
|
+
"nvidia",
|
|
18
|
+
]
|