triton-windows 3.4.0.post20__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
@@ -7,15 +7,15 @@ from ..backends.compiler import Language
7
7
  from ..backends.compiler import BaseBackend, GPUTarget
8
8
  from .. import __version__, knobs
9
9
  from ..runtime.autotuner import OutOfResources
10
- from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
10
+ from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager, get_cache_key
11
11
  from ..runtime.driver import driver
12
12
  from ..tools.disasm import get_sass
13
13
  from pathlib import Path
14
14
  import re
15
15
  import functools
16
16
  import os
17
- import sysconfig
18
17
  import time
18
+ import copy
19
19
 
20
20
  # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
21
21
  # and any following whitespace
@@ -64,12 +64,9 @@ class ASTSource:
64
64
  assert isinstance(k, tuple)
65
65
  self.constants[k] = v
66
66
  self.attrs = attrs or dict()
67
- if isinstance(self.signature, str):
68
- self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))}
69
- else:
70
- for k in self.signature.keys():
71
- if not isinstance(k, str):
72
- raise TypeError("Signature keys must be string")
67
+ for k in self.signature.keys():
68
+ if not isinstance(k, str):
69
+ raise TypeError("Signature keys must be string")
73
70
 
74
71
  def hash(self):
75
72
  sorted_sig = [v for k, v in sorted(self.signature.items())]
@@ -78,7 +75,7 @@ class ASTSource:
78
75
  key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}"
79
76
  return hashlib.sha256(key.encode("utf-8")).hexdigest()
80
77
 
81
- def make_ir(self, options, codegen_fns, module_map, context):
78
+ def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
82
79
  from .code_generator import ast_to_ttir
83
80
  return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
84
81
  module_map=module_map)
@@ -117,7 +114,7 @@ class IRSource:
117
114
  def hash(self):
118
115
  return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
119
116
 
120
- def make_ir(self, options, codegen_fns, module_map, context):
117
+ def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
121
118
  self.module.context = context
122
119
  return self.module
123
120
 
@@ -129,42 +126,6 @@ class IRSource:
129
126
  return dict()
130
127
 
131
128
 
132
- @functools.lru_cache()
133
- def triton_key():
134
- import pkgutil
135
- TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
136
- contents = []
137
- # frontend
138
- with open(__file__, "rb") as f:
139
- contents += [hashlib.sha256(f.read()).hexdigest()]
140
- # compiler
141
- path_prefixes = [
142
- (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
143
- (os.path.join(TRITON_PATH, "backends"), "triton.backends."),
144
- ]
145
- for path, prefix in path_prefixes:
146
- for lib in pkgutil.walk_packages([path], prefix=prefix):
147
- with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
148
- contents += [hashlib.sha256(f.read()).hexdigest()]
149
-
150
- # backend
151
- libtriton_hash = hashlib.sha256()
152
- ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
153
- with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
154
- while True:
155
- chunk = f.read(1024**2)
156
- if not chunk:
157
- break
158
- libtriton_hash.update(chunk)
159
- contents.append(libtriton_hash.hexdigest())
160
- # language
161
- language_path = os.path.join(TRITON_PATH, 'language')
162
- for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
163
- with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
164
- contents += [hashlib.sha256(f.read()).hexdigest()]
165
- return f'{__version__}' + '-'.join(contents)
166
-
167
-
168
129
  @functools.lru_cache()
169
130
  def max_shared_mem(device):
170
131
  return driver.active.utils.get_device_properties(device)["max_shared_mem"]
@@ -258,7 +219,7 @@ class CompileTimer:
258
219
  )
259
220
 
260
221
 
261
- def compile(src, target=None, options=None):
222
+ def compile(src, target=None, options=None, _env_vars=None):
262
223
  compilation_listener = knobs.compilation.listener
263
224
  if compilation_listener:
264
225
  timer = CompileTimer()
@@ -277,8 +238,8 @@ def compile(src, target=None, options=None):
277
238
  extra_options = src.parse_options()
278
239
  options = backend.parse_options(dict(options or dict(), **extra_options))
279
240
  # create cache manager
280
- env_vars = get_cache_invalidating_env_vars()
281
- key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}"
241
+ env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars
242
+ key = get_cache_key(src, backend, options, env_vars=env_vars)
282
243
  hash = hashlib.sha256(key.encode("utf-8")).hexdigest()
283
244
  fn_cache_manager = get_cache_manager(hash)
284
245
  # For dumping/overriding only hash the source as we want it to be independent of triton
@@ -336,7 +297,7 @@ def compile(src, target=None, options=None):
336
297
  codegen_fns = backend.get_codegen_implementation(options)
337
298
  module_map = backend.get_module_map()
338
299
  try:
339
- module = src.make_ir(options, codegen_fns, module_map, context)
300
+ module = src.make_ir(target, options, codegen_fns, module_map, context)
340
301
  except Exception as e:
341
302
  filter_traceback(e)
342
303
  raise
@@ -371,6 +332,9 @@ def compile(src, target=None, options=None):
371
332
  metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
372
333
  if fn_dump_manager is not None:
373
334
  fn_dump_manager.put(next_module, ir_filename)
335
+ if ext == "cubin":
336
+ sass = get_sass(next_module)
337
+ fn_dump_manager.put(sass, file_name + ".sass")
374
338
  # use an env variable to parse ir from file
375
339
  if use_ir_loc == ext:
376
340
  ir_full_name = fn_cache_manager.get_file(ir_filename)
@@ -440,6 +404,10 @@ class AsmDict(dict):
440
404
  return value
441
405
 
442
406
 
407
+ def _raise_error(err, *args, **kwargs):
408
+ raise copy.deepcopy(err)
409
+
410
+
443
411
  class CompiledKernel:
444
412
 
445
413
  def __init__(self, src, metadata_group, hash):
@@ -464,51 +432,66 @@ class CompiledKernel:
464
432
  file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
465
433
  for file in asm_files
466
434
  })
435
+ self.metadata_group = metadata_group
467
436
  self.kernel = self.asm[binary_ext]
468
437
  # binaries are lazily initialized
469
438
  # because it involves doing runtime things
470
439
  # (e.g., checking amount of shared memory on current device)
471
440
  self.module = None
472
441
  self.function = None
442
+ self._run = None
473
443
 
474
444
  def _init_handles(self):
475
445
  if self.module is not None:
476
446
  return
447
+
448
+ def raise_(err):
449
+ # clone the exception object so that the one saved in the closure
450
+ # of the partial function below doesn't get assigned a stack trace
451
+ # after the subsequent raise. otherwise, the CompiledKernel instance
452
+ # saved in the (global) kernel cache will keep references to all the
453
+ # locals in the traceback via the exception instance in the closure.
454
+ cloned_err = copy.deepcopy(err)
455
+ self._run = functools.partial(_raise_error, cloned_err)
456
+ raise err
457
+
477
458
  device = driver.active.get_current_device()
478
459
  # create launcher
479
- self.run = driver.active.launcher_cls(self.src, self.metadata)
460
+ self._run = driver.active.launcher_cls(self.src, self.metadata)
480
461
  # not enough shared memory to run the kernel
481
462
  max_shared = max_shared_mem(device)
482
463
  if self.metadata.shared > max_shared:
483
- raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
464
+ raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
484
465
  if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:
485
466
  # Use blackwell max tmem size for now, this should be moved in device properties
486
467
  max_tmem_size = 512 # tmem size in number of columns
487
468
  if self.metadata.tmem_size > max_tmem_size:
488
- raise OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory")
469
+ raise_(OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory"))
470
+ if knobs.runtime.kernel_load_start_hook is not None:
471
+ knobs.runtime.kernel_load_start_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
489
472
  # TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
490
473
  self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
491
474
  self.name, self.kernel, self.metadata.shared, device)
492
475
  warp_size = driver.active.get_current_target().warp_size
493
476
  if self.metadata.num_warps * warp_size > self.n_max_threads:
494
- raise OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads")
477
+ raise_(OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads"))
478
+ if knobs.runtime.kernel_load_end_hook is not None:
479
+ knobs.runtime.kernel_load_end_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
495
480
 
496
- def __getattribute__(self, name):
497
- if name == 'run':
481
+ @property
482
+ def run(self):
483
+ if self._run is None:
498
484
  self._init_handles()
499
- return super().__getattribute__(name)
485
+ return self._run
500
486
 
501
487
  def launch_metadata(self, grid, stream, *args):
502
488
  if knobs.runtime.launch_enter_hook is None:
503
489
  return None
490
+ self._init_handles()
504
491
  ret = LazyDict({"name": self.name, "function": self.function, "stream": stream})
505
492
  if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None:
506
493
  return ret
507
- arg_dict = {}
508
- arg_idx = 0
509
- for i, arg_name in enumerate(self.src.fn.arg_names):
510
- arg_dict[arg_name] = args[arg_idx]
511
- arg_idx += 1
494
+ arg_dict = {name: arg for name, arg in zip(self.src.fn.arg_names, args)}
512
495
  ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict))
513
496
  return ret
514
497
 
@@ -1,4 +1,5 @@
1
1
  from . import nvidia
2
- from ._runtime import jit
2
+ from ._runtime import constexpr_function, jit
3
+ from triton.language.core import must_use_result
3
4
 
4
- __all__ = ["jit", "nvidia"]
5
+ __all__ = ["constexpr_function", "jit", "must_use_result", "nvidia"]
@@ -1,13 +1,14 @@
1
1
  from __future__ import annotations
2
- import triton
3
2
  from triton.compiler.compiler import ASTSource
4
3
  from triton.backends.compiler import Language
5
- from triton.runtime.jit import JITFunction
4
+ from triton.runtime.jit import JITFunction, constexpr_function
6
5
  from typing import TypeVar, Optional, Callable, Iterable, Union
7
6
  from triton._C.libtriton import ir
8
7
 
9
8
  T = TypeVar("T")
10
9
 
10
+ __all__ = ["constexpr_function", "jit"]
11
+
11
12
 
12
13
  class GluonASTSource(ASTSource):
13
14
 
@@ -16,7 +17,7 @@ class GluonASTSource(ASTSource):
16
17
  self.language = Language.GLUON
17
18
  self.ext = "ttgir"
18
19
 
19
- def make_ir(self, options, codegen_fns, module_map, context):
20
+ def make_ir(self, target, options, codegen_fns, module_map, context):
20
21
  from triton.compiler.compiler import make_backend
21
22
  from triton.compiler.code_generator import ast_to_ttir
22
23
 
@@ -24,14 +25,16 @@ class GluonASTSource(ASTSource):
24
25
  module = builder.create_module()
25
26
 
26
27
  # Assign module attributes eagerly, as they are needed to verify layouts
27
- target = triton.runtime.driver.active.get_current_target()
28
28
  backend = make_backend(target)
29
29
  target = backend.get_target_name(options)
30
+
30
31
  module.set_attr("ttg.target", builder.get_string_attr(target))
31
32
  module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps))
32
33
  module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas))
33
- module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(32))
34
- if options.maxnreg is not None:
34
+ module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(options.warp_size))
35
+
36
+ is_cuda = options.backend_name == "cuda"
37
+ if is_cuda and options.maxnreg is not None:
35
38
  module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg))
36
39
 
37
40
  module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
@@ -1,18 +1,119 @@
1
- from ._core import * # NOQA: F403
2
- from ._core import __all__ as __core_all
3
- from ._layouts import * # NOQA: F403
4
- from ._layouts import __all__ as __layouts_all
5
- from ._math import * # NOQA: F403
6
- from ._math import __all__ as __math_all
7
- from ._standard import * # NOQA: F403
8
- from ._standard import __all__ as __standard_all
1
+ from ._core import (
2
+ base_value,
3
+ base_type,
4
+ block_type,
5
+ broadcast,
6
+ constexpr,
7
+ dtype,
8
+ void,
9
+ int1,
10
+ int8,
11
+ int16,
12
+ int32,
13
+ int64,
14
+ uint8,
15
+ uint16,
16
+ uint32,
17
+ uint64,
18
+ float8e5,
19
+ float8e5b16,
20
+ float8e4nv,
21
+ float8e4b8,
22
+ float8e4b15,
23
+ float16,
24
+ bfloat16,
25
+ float32,
26
+ float64,
27
+ pointer_type,
28
+ shared_memory_descriptor,
29
+ tensor,
30
+ tuple,
31
+ tuple_type,
32
+ _unwrap_if_constexpr,
33
+ # API Functions
34
+ allocate_shared_memory,
35
+ arange,
36
+ associative_scan,
37
+ atomic_add,
38
+ atomic_and,
39
+ atomic_cas,
40
+ atomic_max,
41
+ atomic_min,
42
+ atomic_or,
43
+ atomic_xchg,
44
+ atomic_xor,
45
+ convert_layout,
46
+ device_assert,
47
+ expand_dims,
48
+ full,
49
+ histogram,
50
+ inline_asm_elementwise,
51
+ join,
52
+ load,
53
+ map_elementwise,
54
+ max_constancy,
55
+ max_contiguous,
56
+ maximum,
57
+ minimum,
58
+ multiple_of,
59
+ num_programs,
60
+ permute,
61
+ program_id,
62
+ reduce,
63
+ reshape,
64
+ set_auto_layout,
65
+ split,
66
+ static_assert,
67
+ static_print,
68
+ static_range,
69
+ store,
70
+ thread_barrier,
71
+ to_tensor,
72
+ warp_specialize,
73
+ where,
74
+ )
75
+ from ._layouts import (
76
+ AutoLayout,
77
+ BlockedLayout,
78
+ SliceLayout,
79
+ DistributedLinearLayout,
80
+ DotOperandLayout,
81
+ NVMMADistributedLayout,
82
+ NVMMASharedLayout,
83
+ SwizzledSharedLayout,
84
+ PaddedSharedLayout,
85
+ )
86
+ from ._math import (
87
+ umulhi,
88
+ exp,
89
+ exp2,
90
+ fma,
91
+ log,
92
+ log2,
93
+ cos,
94
+ rsqrt,
95
+ sin,
96
+ sqrt,
97
+ sqrt_rn,
98
+ abs,
99
+ fdiv,
100
+ div_rn,
101
+ erf,
102
+ floor,
103
+ ceil,
104
+ )
105
+ from ._standard import (
106
+ cdiv,
107
+ full_like,
108
+ max,
109
+ min,
110
+ reduce_or,
111
+ sum,
112
+ xor_sum,
113
+ zeros,
114
+ zeros_like,
115
+ )
9
116
 
10
117
  from . import nvidia
11
-
12
- __all__ = [
13
- *__core_all,
14
- *__layouts_all,
15
- *__math_all,
16
- *__standard_all,
17
- "nvidia",
18
- ]
118
+ from . import amd
119
+ from . import extra