triton-windows 3.3.1.post21__cp313-cp313-win_amd64.whl → 3.4.0.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +143 -46
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +94 -94
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +296 -125
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +73 -9
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/testing.py +16 -12
  56. triton/tools/disasm.py +3 -4
  57. triton/tools/tensor_descriptor.py +36 -0
  58. triton/windows_utils.py +47 -83
  59. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/METADATA +7 -2
  60. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/RECORD +64 -41
  61. triton_windows-3.4.0.post21.dist-info/entry_points.txt +3 -0
  62. triton_windows-3.4.0.post21.dist-info/licenses/LICENSE +23 -0
  63. triton_windows-3.4.0.post21.dist-info/top_level.txt +1 -0
  64. triton/language/_utils.py +0 -21
  65. triton/language/extra/cuda/_experimental_tma.py +0 -106
  66. triton/tools/experimental_descriptor.py +0 -32
  67. triton_windows-3.3.1.post21.dist-info/top_level.txt +0 -14
  68. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/WHEEL +0 -0
@@ -3,19 +3,19 @@ import hashlib
3
3
  import json
4
4
  from .._C.libtriton import get_cache_invalidating_env_vars, ir
5
5
  from ..backends import backends
6
- from ..backends.compiler import GPUTarget
7
- from .. import __version__
6
+ from ..backends.compiler import Language
7
+ from ..backends.compiler import BaseBackend, GPUTarget
8
+ from .. import __version__, knobs
8
9
  from ..runtime.autotuner import OutOfResources
9
10
  from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
10
11
  from ..runtime.driver import driver
11
12
  from ..tools.disasm import get_sass
12
- # TODO: this shouldn't be here
13
- from .code_generator import ast_to_ttir
14
13
  from pathlib import Path
15
14
  import re
16
15
  import functools
17
16
  import os
18
17
  import sysconfig
18
+ import time
19
19
 
20
20
  # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
21
21
  # and any following whitespace
@@ -53,6 +53,7 @@ class ASTSource:
53
53
 
54
54
  def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
55
55
  self.fn = fn
56
+ self.language = Language.TRITON
56
57
  self.ext = "ttir"
57
58
  self.name = fn.__name__
58
59
  self.signature = signature
@@ -78,6 +79,7 @@ class ASTSource:
78
79
  return hashlib.sha256(key.encode("utf-8")).hexdigest()
79
80
 
80
81
  def make_ir(self, options, codegen_fns, module_map, context):
82
+ from .code_generator import ast_to_ttir
81
83
  return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
82
84
  module_map=module_map)
83
85
 
@@ -91,6 +93,7 @@ class IRSource:
91
93
  self.path = path
92
94
  path = Path(path)
93
95
  self.ext = path.suffix[1:]
96
+ self.language = Language.TRITON
94
97
  self.src = path.read_text()
95
98
  ir.load_dialects(context)
96
99
  backend.load_dialects(context)
@@ -162,6 +165,11 @@ def triton_key():
162
165
  return f'{__version__}' + '-'.join(contents)
163
166
 
164
167
 
168
+ @functools.lru_cache()
169
+ def max_shared_mem(device):
170
+ return driver.active.utils.get_device_properties(device)["max_shared_mem"]
171
+
172
+
165
173
  def parse(full_name, ext, context):
166
174
  if ext == "ttir" or ext == "ttgir":
167
175
  module = ir.parse_mlir_module(full_name, context)
@@ -179,7 +187,7 @@ def filter_traceback(e: BaseException):
179
187
 
180
188
  These are uninteresting to the user -- "just show me *my* code!"
181
189
  """
182
- if os.getenv("TRITON_FRONT_END_DEBUGGING", "0") == "1":
190
+ if knobs.compilation.front_end_debugging:
183
191
  return
184
192
 
185
193
  if e.__cause__ is not None:
@@ -211,7 +219,50 @@ def filter_traceback(e: BaseException):
211
219
  e.__traceback__ = frames[0]
212
220
 
213
221
 
222
+ class CompileTimer:
223
+
224
+ def __init__(self) -> None:
225
+ self.start: float = time.perf_counter()
226
+ self.ir_initialization_end: float | None = None
227
+ self.lowering_stage_ends: list[tuple[str, float]] = []
228
+ self.store_results_end: float | None = None
229
+
230
+ def finished_ir_initialization(self) -> None:
231
+ self.ir_initialization_end = time.perf_counter()
232
+
233
+ def stage_finished(self, stage_name: str) -> None:
234
+ self.lowering_stage_ends.append((stage_name, time.perf_counter()))
235
+
236
+ def end(self) -> knobs.CompileTimes:
237
+ timestamp = time.perf_counter()
238
+ if self.ir_initialization_end is None:
239
+ self.ir_initialization_end = timestamp
240
+ else:
241
+ self.store_results_end = timestamp
242
+
243
+ def delta(start: float, end: float | None) -> int:
244
+ if end is None:
245
+ return 0
246
+ return int((end - start) * 1000000)
247
+
248
+ lowering_stage_durations = []
249
+ stage_start = self.ir_initialization_end
250
+ for stage_name, stage_end in self.lowering_stage_ends:
251
+ lowering_stage_durations.append((stage_name, delta(stage_start, stage_end)))
252
+ stage_start = stage_end
253
+
254
+ return knobs.CompileTimes(
255
+ ir_initialization=delta(self.start, self.ir_initialization_end),
256
+ lowering_stages=lowering_stage_durations,
257
+ store_results=delta(stage_start, self.store_results_end),
258
+ )
259
+
260
+
214
261
  def compile(src, target=None, options=None):
262
+ compilation_listener = knobs.compilation.listener
263
+ if compilation_listener:
264
+ timer = CompileTimer()
265
+
215
266
  if target is None:
216
267
  target = driver.active.get_current_target()
217
268
  assert isinstance(target, GPUTarget), "target must be of GPUTarget type"
@@ -232,9 +283,9 @@ def compile(src, target=None, options=None):
232
283
  fn_cache_manager = get_cache_manager(hash)
233
284
  # For dumping/overriding only hash the source as we want it to be independent of triton
234
285
  # core changes to make it easier to track kernels by hash.
235
- enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
236
- enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1"
237
- store_only_binary = os.environ.get("TRITON_STORE_BINARY_ONLY", "0") == "1"
286
+ enable_override = knobs.compilation.override
287
+ enable_ir_dump = knobs.compilation.dump_ir
288
+ store_only_binary = knobs.compilation.store_binary_only
238
289
  fn_override_manager = get_override_manager(src.hash()) if enable_override else None
239
290
  fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None
240
291
  # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms.
@@ -245,10 +296,20 @@ def compile(src, target=None, options=None):
245
296
  metadata_filename = f"{file_name}.json"
246
297
  metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
247
298
  metadata_path = metadata_group.get(metadata_filename)
248
- always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1"
299
+ always_compile = knobs.compilation.always_compile
249
300
  if not always_compile and metadata_path is not None:
250
301
  # cache hit!
251
- return CompiledKernel(src, metadata_group, hash)
302
+ res = CompiledKernel(src, metadata_group, hash)
303
+ if compilation_listener:
304
+ compilation_listener(
305
+ src=src,
306
+ metadata=res.metadata._asdict(),
307
+ metadata_group=metadata_group,
308
+ times=timer.end(),
309
+ cache_hit=True,
310
+ )
311
+ return res
312
+
252
313
  # initialize metadata
253
314
  metadata = {
254
315
  "hash": hash,
@@ -259,7 +320,7 @@ def compile(src, target=None, options=None):
259
320
  metadata["triton_version"] = __version__
260
321
  # run compilation pipeline and populate metadata
261
322
  stages = dict()
262
- backend.add_stages(stages, options)
323
+ backend.add_stages(stages, options, src.language)
263
324
  first_stage = list(stages.keys()).index(src.ext)
264
325
  # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
265
326
  if ir_source:
@@ -279,11 +340,30 @@ def compile(src, target=None, options=None):
279
340
  except Exception as e:
280
341
  filter_traceback(e)
281
342
  raise
282
- use_ir_loc = os.environ.get("USE_IR_LOC", None)
343
+
344
+ if ir_source:
345
+ ir_filename = f"{file_name}.{src.ext}"
346
+ metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename)
347
+ else:
348
+ ir_filename = f"{file_name}.source"
349
+ metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename)
350
+
351
+ use_ir_loc = knobs.compilation.use_ir_loc
352
+ if ir_source and use_ir_loc:
353
+ module.create_location_snapshot(src.path)
354
+ print(f"Creating new locations for {src.path}")
355
+
356
+ if compilation_listener:
357
+ timer.finished_ir_initialization()
283
358
  for ext, compile_ir in list(stages.items())[first_stage:]:
284
359
  next_module = compile_ir(module, metadata)
285
360
  ir_filename = f"{file_name}.{ext}"
286
- if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None):
361
+ if fn_override_manager is None:
362
+ # Users can override kernels at scale by setting `ir_override` in autotune config
363
+ # without TRITON_KERNEL_OVERRIDE
364
+ if (ir_override := metadata.get("ir_override", None)) and ir_override.endswith(f".{ext}"):
365
+ next_module = parse(ir_override, ext, context)
366
+ elif full_name := fn_override_manager.get_file(ir_filename):
287
367
  print(f"\nOverriding kernel with file {full_name}")
288
368
  next_module = parse(full_name, ext, context)
289
369
  # If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json
@@ -297,6 +377,8 @@ def compile(src, target=None, options=None):
297
377
  next_module.create_location_snapshot(ir_full_name)
298
378
  print(f"Creating new locations for {ir_full_name}")
299
379
  module = next_module
380
+ if compilation_listener:
381
+ timer.stage_finished(ext)
300
382
  # write-back metadata
301
383
  metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
302
384
  binary=False)
@@ -310,13 +392,18 @@ def compile(src, target=None, options=None):
310
392
  # this is likely due to the llvm-symbolizer forking a process
311
393
  # TODO: Reconcile the difference here between the ASAN and non-ASAN path with enabling
312
394
  # multithreading in the MLIR context
313
- if not os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
395
+ if not knobs.compilation.enable_asan:
314
396
  context.disable_multithreading()
397
+
398
+ # notify any listener
399
+ if compilation_listener:
400
+ compilation_listener(src=src, metadata=metadata, metadata_group=metadata_group, times=timer.end(),
401
+ cache_hit=False)
315
402
  # return handle to compiled kernel
316
403
  return CompiledKernel(src, metadata_group, hash)
317
404
 
318
405
 
319
- def make_backend(target):
406
+ def make_backend(target: GPUTarget) -> BaseBackend:
320
407
  actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)]
321
408
  if len(actives) != 1:
322
409
  raise RuntimeError(
@@ -330,7 +417,7 @@ class LazyDict:
330
417
  self.data = data
331
418
  self.extras = []
332
419
 
333
- def get(self) -> None:
420
+ def get(self):
334
421
  for func, args in self.extras:
335
422
  self.data = self.data | func(*args)
336
423
  self.extras.clear()
@@ -355,11 +442,6 @@ class AsmDict(dict):
355
442
 
356
443
  class CompiledKernel:
357
444
 
358
- # Hooks for external tools to monitor the execution of triton kernels
359
- # TODO: move out of this namespace since it's a runtime thing
360
- launch_enter_hook = None
361
- launch_exit_hook = None
362
-
363
445
  def __init__(self, src, metadata_group, hash):
364
446
  from collections import namedtuple
365
447
  metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
@@ -396,7 +478,7 @@ class CompiledKernel:
396
478
  # create launcher
397
479
  self.run = driver.active.launcher_cls(self.src, self.metadata)
398
480
  # not enough shared memory to run the kernel
399
- max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
481
+ max_shared = max_shared_mem(device)
400
482
  if self.metadata.shared > max_shared:
401
483
  raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
402
484
  if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:
@@ -405,8 +487,11 @@ class CompiledKernel:
405
487
  if self.metadata.tmem_size > max_tmem_size:
406
488
  raise OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory")
407
489
  # TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
408
- self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
490
+ self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
409
491
  self.name, self.kernel, self.metadata.shared, device)
492
+ warp_size = driver.active.get_current_target().warp_size
493
+ if self.metadata.num_warps * warp_size > self.n_max_threads:
494
+ raise OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads")
410
495
 
411
496
  def __getattribute__(self, name):
412
497
  if name == 'run':
@@ -414,7 +499,7 @@ class CompiledKernel:
414
499
  return super().__getattribute__(name)
415
500
 
416
501
  def launch_metadata(self, grid, stream, *args):
417
- if CompiledKernel.launch_enter_hook is None:
502
+ if knobs.runtime.launch_enter_hook is None:
418
503
  return None
419
504
  ret = LazyDict({"name": self.name, "function": self.function, "stream": stream})
420
505
  if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None:
@@ -436,6 +521,6 @@ class CompiledKernel:
436
521
  stream = driver.active.get_current_stream(device)
437
522
  launch_metadata = self.launch_metadata(grid, stream, *args)
438
523
  self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata,
439
- CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args)
524
+ knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *args)
440
525
 
441
526
  return runner
File without changes
@@ -0,0 +1,4 @@
1
+ from . import nvidia
2
+ from ._runtime import jit
3
+
4
+ __all__ = ["jit", "nvidia"]
File without changes
@@ -0,0 +1,99 @@
1
+ from __future__ import annotations
2
+ import triton
3
+ from triton.compiler.compiler import ASTSource
4
+ from triton.backends.compiler import Language
5
+ from triton.runtime.jit import JITFunction
6
+ from typing import TypeVar, Optional, Callable, Iterable, Union
7
+ from triton._C.libtriton import ir
8
+
9
+ T = TypeVar("T")
10
+
11
+
12
+ class GluonASTSource(ASTSource):
13
+
14
+ def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
15
+ super().__init__(fn, signature, constexprs, attrs)
16
+ self.language = Language.GLUON
17
+ self.ext = "ttgir"
18
+
19
+ def make_ir(self, options, codegen_fns, module_map, context):
20
+ from triton.compiler.compiler import make_backend
21
+ from triton.compiler.code_generator import ast_to_ttir
22
+
23
+ builder = ir.builder(context)
24
+ module = builder.create_module()
25
+
26
+ # Assign module attributes eagerly, as they are needed to verify layouts
27
+ target = triton.runtime.driver.active.get_current_target()
28
+ backend = make_backend(target)
29
+ target = backend.get_target_name(options)
30
+ module.set_attr("ttg.target", builder.get_string_attr(target))
31
+ module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps))
32
+ module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas))
33
+ module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(32))
34
+ if options.maxnreg is not None:
35
+ module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg))
36
+
37
+ module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
38
+ module_map=module_map, module=module)
39
+ return module
40
+
41
+
42
+ class GluonJITFunction(JITFunction[T]):
43
+
44
+ def create_binder(self):
45
+ result = super().create_binder()
46
+ self.ASTSource = GluonASTSource
47
+ return result
48
+
49
+ def is_gluon(self):
50
+ return True
51
+
52
+
53
+ def jit(
54
+ fn: Optional[T] = None,
55
+ *,
56
+ version=None,
57
+ repr: Optional[Callable] = None,
58
+ launch_metadata: Optional[Callable] = None,
59
+ do_not_specialize: Optional[Iterable[int | str]] = None,
60
+ do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
61
+ debug: Optional[bool] = None,
62
+ noinline: Optional[bool] = None,
63
+ ) -> Union[GluonJITFunction[T], Callable[[T], JITFunction[T]]]:
64
+ """
65
+ Decorator for JIT-compiling a function using the Triton compiler.
66
+
67
+ :note: When a jit'd function is called, arguments are
68
+ implicitly converted to pointers if they have a :code:`.data_ptr()` method
69
+ and a `.dtype` attribute.
70
+
71
+ :note: This function will be compiled and run on the GPU. It will only have access to:
72
+
73
+ * python primitives,
74
+ * builtins within the triton package,
75
+ * arguments to this function,
76
+ * other jit'd functions
77
+
78
+ :param fn: the function to be jit-compiled
79
+ :type fn: Callable
80
+ """
81
+
82
+ def decorator(fn: T) -> JITFunction[T]:
83
+ assert callable(fn)
84
+ return GluonJITFunction(
85
+ fn,
86
+ version=version,
87
+ do_not_specialize=do_not_specialize,
88
+ do_not_specialize_on_alignment=do_not_specialize_on_alignment,
89
+ debug=debug,
90
+ noinline=noinline,
91
+ repr=repr,
92
+ launch_metadata=launch_metadata,
93
+ )
94
+
95
+ if fn is not None:
96
+ return decorator(fn)
97
+
98
+ else:
99
+ return decorator
@@ -0,0 +1,18 @@
1
+ from ._core import * # NOQA: F403
2
+ from ._core import __all__ as __core_all
3
+ from ._layouts import * # NOQA: F403
4
+ from ._layouts import __all__ as __layouts_all
5
+ from ._math import * # NOQA: F403
6
+ from ._math import __all__ as __math_all
7
+ from ._standard import * # NOQA: F403
8
+ from ._standard import __all__ as __standard_all
9
+
10
+ from . import nvidia
11
+
12
+ __all__ = [
13
+ *__core_all,
14
+ *__layouts_all,
15
+ *__math_all,
16
+ *__standard_all,
17
+ "nvidia",
18
+ ]