triton-windows 3.3.1.post19__cp310-cp310-win_amd64.whl → 3.4.0.post20__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (166) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/testing.py +16 -12
  56. triton/tools/disasm.py +3 -4
  57. triton/tools/tensor_descriptor.py +36 -0
  58. triton/windows_utils.py +14 -6
  59. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  60. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  61. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  62. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  63. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  64. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  65. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  66. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  67. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  68. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  69. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  70. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  71. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  72. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  73. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  80. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  81. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  82. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  83. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  84. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  85. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  86. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  87. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  88. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  89. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  90. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  91. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  92. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  93. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  94. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  95. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  96. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  97. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  98. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  99. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  100. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  101. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  102. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  103. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  104. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  105. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  106. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  107. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  108. triton/backends/amd/include/hip/device_functions.h +0 -38
  109. triton/backends/amd/include/hip/driver_types.h +0 -468
  110. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  111. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  112. triton/backends/amd/include/hip/hip_common.h +0 -100
  113. triton/backends/amd/include/hip/hip_complex.h +0 -38
  114. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  115. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  116. triton/backends/amd/include/hip/hip_ext.h +0 -161
  117. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  118. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  119. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  120. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  121. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  122. triton/backends/amd/include/hip/hip_profile.h +0 -27
  123. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  124. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  125. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  126. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  127. triton/backends/amd/include/hip/hip_version.h +0 -17
  128. triton/backends/amd/include/hip/hiprtc.h +0 -421
  129. triton/backends/amd/include/hip/library_types.h +0 -78
  130. triton/backends/amd/include/hip/math_functions.h +0 -42
  131. triton/backends/amd/include/hip/surface_types.h +0 -63
  132. triton/backends/amd/include/hip/texture_types.h +0 -194
  133. triton/backends/amd/include/hsa/Brig.h +0 -1131
  134. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  135. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  136. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  137. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  138. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  139. triton/backends/amd/include/hsa/hsa.h +0 -5738
  140. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  141. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  142. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  143. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  144. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  145. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  146. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  147. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  148. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  149. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  150. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  151. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  152. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  153. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  154. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  155. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  156. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  157. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  158. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  159. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  160. triton/backends/amd/include/roctracer/roctx.h +0 -229
  161. triton/language/_utils.py +0 -21
  162. triton/language/extra/cuda/_experimental_tma.py +0 -106
  163. triton/tools/experimental_descriptor.py +0 -32
  164. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  165. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  166. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +0 -0
@@ -1,18 +1,19 @@
1
1
  import ast
2
+ import copy
2
3
  import inspect
3
4
  import re
4
5
  import warnings
5
- import os
6
6
  import textwrap
7
7
  import itertools
8
+ from dataclasses import dataclass
8
9
  from types import ModuleType
9
10
  from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable, List
10
11
 
11
- from .. import language
12
- from .._C.libtriton import ir
13
- from ..language import constexpr, semantic, str_to_ty, tensor
14
- from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, base_value, base_type
15
- from ..runtime.jit import get_jit_fn_file_line
12
+ from .. import knobs, language
13
+ from .._C.libtriton import ir, gluon_ir
14
+ from ..language import constexpr, str_to_ty, tensor
15
+ from ..language.core import _unwrap_if_constexpr, base_value, base_type
16
+ from ..runtime.jit import get_jit_fn_file_line, get_full_name
16
17
  # ideally we wouldn't need any runtime component
17
18
  from ..runtime import JITFunction
18
19
  from .._utils import find_paths_if, get_iterable_path, set_iterable_path
@@ -27,29 +28,9 @@ def check_identifier_legality(name, type):
27
28
  return name
28
29
 
29
30
 
30
- def mangle_ty(ty):
31
- if ty.is_tuple():
32
- return 'T' + '_'.join(map(mangle_ty, ty.types)) + 'T'
33
- if ty.is_ptr():
34
- return 'P' + mangle_ty(ty.element_ty)
35
- if ty.is_int():
36
- SIGNED = language.dtype.SIGNEDNESS.SIGNED
37
- prefix = 'i' if ty.int_signedness == SIGNED else 'u'
38
- return prefix + str(ty.int_bitwidth)
39
- if ty.is_floating():
40
- return str(ty)
41
- if ty.is_block():
42
- elt = mangle_ty(ty.scalar)
43
- shape = '_'.join(map(str, ty.shape))
44
- return f'{elt}S{shape}S'
45
- if ty.is_void():
46
- return 'V'
47
- raise TypeError(f'Unsupported type {ty}')
48
-
49
-
50
31
  def mangle_fn(name, arg_tys, constants):
51
32
  # doesn't mangle ret type, which must be a function of arg tys
52
- mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
33
+ mangled_arg_names = '_'.join([ty.mangle() for ty in arg_tys])
53
34
  mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
54
35
  mangled_constants = mangled_constants.replace('.', '_d_')
55
36
  mangled_constants = mangled_constants.replace("'", '_sq_')
@@ -68,11 +49,11 @@ def _is_triton_tensor(o: Any) -> bool:
68
49
 
69
50
 
70
51
  def _is_constexpr(o: Any) -> bool:
71
- return o is None or isinstance(o, (constexpr, language.core.dtype))
52
+ return o is None or isinstance(o, (constexpr, language.core.dtype, JITFunction))
72
53
 
73
54
 
74
- def _is_triton_scalar(o: Any) -> bool:
75
- return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1)
55
+ def _is_non_scalar_tensor(o: Any) -> bool:
56
+ return _is_triton_tensor(o) and (o.type.is_block() and o.type.numel != 1)
76
57
 
77
58
 
78
59
  def _is_list_like(o: Any) -> bool:
@@ -82,7 +63,7 @@ def _is_list_like(o: Any) -> bool:
82
63
  def _check_fn_args(node, fn, args):
83
64
  if fn.noinline:
84
65
  for idx, arg in enumerate(args):
85
- if not _is_constexpr(arg) and not _is_triton_scalar(arg):
66
+ if not _is_constexpr(arg) and _is_non_scalar_tensor(arg):
86
67
  raise UnsupportedLanguageConstruct(
87
68
  fn.src, node,
88
69
  f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
@@ -102,6 +83,7 @@ def _apply_to_tuple_values(value, fn):
102
83
  assert False, f"Unsupported type {type(value)}"
103
84
 
104
85
  vals = [fn(v) for v in value]
86
+ vals = [constexpr(v) if v is None else v for v in vals]
105
87
  types = [v.type for v in vals]
106
88
  return language.tuple(vals, language.tuple_type(types, fields))
107
89
 
@@ -154,10 +136,9 @@ class ContainsReturnChecker(ast.NodeVisitor):
154
136
  return any(self.visit(s) for s in body)
155
137
 
156
138
  def _visit_function(self, fn) -> bool:
157
- # Currently we only support JITFunctions defined in the global scope
158
- if isinstance(fn, JITFunction) and not fn.noinline:
159
- fn_node = fn.parse()
160
- return ContainsReturnChecker(self.gscope).visit(fn_node)
139
+ # no need to check within the function as it won't cause an early return.
140
+ # If the function itself has unstructured control flow we may not be able to inline it causing poor performance.
141
+ # We should check for this and fail or emit a warning.
161
142
  return False
162
143
 
163
144
  def generic_visit(self, node) -> bool:
@@ -241,26 +222,26 @@ class ASTFunction:
241
222
  self.constants = constants
242
223
  self.attrs = attrs
243
224
 
244
- def return_types_ir(self, builder: ir.builder):
245
- ret_types = []
246
- for ret_ty in self.ret_types:
247
- if ret_ty is None:
225
+ def flatten_ir_types(self, builder: ir.builder, types: List[base_type]) -> List[ir.type]:
226
+ ir_types = []
227
+ for ty in types:
228
+ if ty is None:
248
229
  continue
249
- ir_ty = ret_ty.to_ir(builder)
250
- if isinstance(ir_ty, list):
251
- ret_types.extend(ir_ty)
252
- else:
253
- ret_types.append(ir_ty)
254
- return ret_types
230
+ ty._flatten_ir_types(builder, ir_types)
231
+ return ir_types
232
+
233
+ def return_types_ir(self, builder: ir.builder) -> List[ir.type]:
234
+ return self.flatten_ir_types(builder, self.ret_types)
255
235
 
256
236
  def serialize(self, builder: ir.builder):
257
237
  # fill up IR values in template
258
238
  # > build function
259
239
  is_val = lambda path, _: path not in self.constants and _ is not None
260
240
  val_paths = list(find_paths_if(self.arg_types, is_val))
261
- arg_types = [get_iterable_path(self.arg_types, path).to_ir(builder) for path in val_paths]
262
- ret_types = self.return_types_ir(builder)
263
- return builder.get_function_ty(arg_types, ret_types)
241
+ arg_types = [get_iterable_path(self.arg_types, path) for path in val_paths]
242
+ arg_types_ir = self.flatten_ir_types(builder, arg_types)
243
+ ret_types_ir = self.return_types_ir(builder)
244
+ return builder.get_function_ty(arg_types_ir, ret_types_ir)
264
245
 
265
246
  def deserialize(self, fn):
266
247
  # create "template"
@@ -272,19 +253,18 @@ class ASTFunction:
272
253
  vals = make_template(self.arg_types)
273
254
  is_val = lambda path, _: path not in self.constants and _ is not None
274
255
  val_paths = list(find_paths_if(self.arg_types, is_val))
275
- # > set attributes
276
- for attr_path, attr_specs in self.attrs.items():
277
- for attr_name, attr_val in attr_specs:
278
- if attr_path in val_paths:
279
- fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val)
280
- for i, path in enumerate(val_paths):
281
- ty = get_iterable_path(self.arg_types, path)
282
- if isinstance(ty, nv_tma_desc_type):
283
- fn.set_arg_attr(i, "tt.nv_tma_desc", 1)
284
256
  # > add IR values to the template
285
- for i, path in enumerate(val_paths):
257
+ cursor = 0
258
+ handles = [fn.args(i) for i in range(fn.get_num_args())]
259
+ for path in val_paths:
286
260
  ty = get_iterable_path(self.arg_types, path)
287
- set_iterable_path(vals, path, language.tensor(fn.args(i), ty))
261
+ # > set attributes
262
+ attr_specs = self.attrs.get(path, [])
263
+ for attr_name, attr_val in attr_specs:
264
+ fn.set_arg_attr(cursor, attr_name, attr_val)
265
+ # > build frontend value
266
+ val, cursor = ty._unflatten_ir(handles, cursor)
267
+ set_iterable_path(vals, path, val)
288
268
  # > add constexpr values to the template
289
269
  constants = self.constants
290
270
  for path, val in constants.items():
@@ -292,13 +272,26 @@ class ASTFunction:
292
272
  return vals
293
273
 
294
274
 
275
+ @dataclass(frozen=True)
276
+ class BoundJITMethod:
277
+ __self__: base_value
278
+ __func__: JITFunction
279
+
280
+
295
281
  class CodeGenerator(ast.NodeVisitor):
296
282
 
297
283
  def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map,
298
284
  module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False,
299
285
  file_name: Optional[str] = None, begin_line=0):
300
286
  self.context = context
301
- self.builder = ir.builder(context)
287
+ if jit_fn.is_gluon():
288
+ from triton.experimental.gluon.language._semantic import GluonSemantic
289
+ self.builder = gluon_ir.GluonOpBuilder(context)
290
+ self.semantic = GluonSemantic(self.builder)
291
+ else:
292
+ from triton.language.semantic import TritonSemantic
293
+ self.builder = ir.builder(context)
294
+ self.semantic = TritonSemantic(self.builder)
302
295
  self.file_name = file_name
303
296
  # node.lineno starts from 1, so we need to subtract 1
304
297
  self.begin_line = begin_line - 1
@@ -306,7 +299,7 @@ class CodeGenerator(ast.NodeVisitor):
306
299
  self.builder.options = options
307
300
  # dict of functions provided by the backend. Below are the list of possible functions:
308
301
  # Convert custom types not natively supported on HW.
309
- # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None)
302
+ # convert_custom_types(input_tensor, dtype, fp_downcast_rounding=None, _builder=None)
310
303
  self.builder.codegen_fns = codegen_fns
311
304
  self.builder.module_map = {} if module_map is None else module_map
312
305
  self.module = self.builder.create_module() if module is None else module
@@ -329,6 +322,7 @@ class CodeGenerator(ast.NodeVisitor):
329
322
  self.jit_fn = jit_fn
330
323
  # TODO: we currently generate illegal names for non-kernel functions involving constexprs!
331
324
  if is_kernel:
325
+ function_name = function_name[function_name.rfind('.') + 1:]
332
326
  function_name = check_identifier_legality(function_name, "function")
333
327
  self.function_name = function_name
334
328
  self.is_kernel = is_kernel
@@ -345,7 +339,10 @@ class CodeGenerator(ast.NodeVisitor):
345
339
  # special handling.
346
340
  self.visiting_arg_default_value = False
347
341
 
348
- builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)}
342
+ builtin_namespace: Dict[str, Any] = {
343
+ _.__name__: _
344
+ for _ in (len, list, range, float, int, isinstance, getattr, hasattr)
345
+ }
349
346
  builtin_namespace.update((
350
347
  ('print', language.core.device_print),
351
348
  ('min', language.minimum),
@@ -378,11 +375,14 @@ class CodeGenerator(ast.NodeVisitor):
378
375
  # But actually a bunch of other things, such as module imports, are
379
376
  # technically Python globals. We have to allow these too!
380
377
  if any([
381
- val is absent, name in self.builtin_namespace, #
378
+ val is absent,
379
+ name in self.builtin_namespace, #
382
380
  type(val) is ModuleType, #
383
381
  isinstance(val, JITFunction), #
384
382
  getattr(val, "__triton_builtin__", False), #
383
+ getattr(val, "__triton_aggregate__", False), #
385
384
  getattr(val, "__module__", "").startswith("triton.language"), #
385
+ getattr(val, "__module__", "").startswith("triton.experimental.gluon.language"), #
386
386
  isinstance(val, language.dtype), #
387
387
  _is_namedtuple(val),
388
388
  self._is_constexpr_global(name), #
@@ -390,7 +390,7 @@ class CodeGenerator(ast.NodeVisitor):
390
390
  # because you should be able to do
391
391
  # @triton.jit def fn(x: tl.constexpr = GLOBAL): ...
392
392
  self.visiting_arg_default_value, #
393
- os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1"
393
+ knobs.compilation.allow_non_constexpr_globals,
394
394
  ]):
395
395
  return val
396
396
  raise NameError(
@@ -467,7 +467,7 @@ class CodeGenerator(ast.NodeVisitor):
467
467
  if isinstance(value, language.tuple):
468
468
  return _apply_to_tuple_values(value, decay)
469
469
  elif isinstance(value, (language.constexpr, int, float)):
470
- return semantic.to_tensor(value, self.builder)
470
+ return self.semantic.to_tensor(value)
471
471
  return value
472
472
 
473
473
  ret_value = decay(ret_value)
@@ -575,13 +575,16 @@ class CodeGenerator(ast.NodeVisitor):
575
575
  return self.visit_Assign(node)
576
576
 
577
577
  def assignTarget(self, target, value):
578
+ assert isinstance(target.ctx, ast.Store)
578
579
  if isinstance(target, ast.Subscript):
579
- assert target.ctx.__class__.__name__ == "Store"
580
580
  return self.visit_Subscript_Store(target, value)
581
581
  if isinstance(target, ast.Tuple):
582
- assert target.ctx.__class__.__name__ == "Store"
583
- for i, name in enumerate(target.elts):
584
- self.set_value(self.visit(name), value.values[i])
582
+ for i, target in enumerate(target.elts):
583
+ self.assignTarget(target, value.values[i])
584
+ return
585
+ if isinstance(target, ast.Attribute):
586
+ base = self.visit(target.value)
587
+ setattr(base, target.attr, value)
585
588
  return
586
589
  assert isinstance(target, ast.Name)
587
590
  self.set_value(self.visit(target), value)
@@ -596,7 +599,7 @@ class CodeGenerator(ast.NodeVisitor):
596
599
  if value is not None and \
597
600
  not _is_triton_value(value) and \
598
601
  not isinstance(value, native_nontensor_types):
599
- value = semantic.to_tensor(value, self.builder)
602
+ value = self.semantic.to_tensor(value)
600
603
  return value
601
604
 
602
605
  values = _sanitize_value(self.visit(node.value))
@@ -605,12 +608,12 @@ class CodeGenerator(ast.NodeVisitor):
605
608
  self.assignTarget(targets[0], values)
606
609
 
607
610
  def visit_AugAssign(self, node):
608
- name = node.target.id
609
- lhs = ast.Name(id=name, ctx=ast.Load())
611
+ lhs = copy.deepcopy(node.target)
612
+ lhs.ctx = ast.Load()
610
613
  rhs = ast.BinOp(lhs, node.op, node.value)
611
614
  assign = ast.Assign(targets=[node.target], value=rhs)
612
615
  self.visit(assign)
613
- return self.dereference_name(name)
616
+ return self.visit(lhs)
614
617
 
615
618
  def visit_Name(self, node):
616
619
  if type(node.ctx) is ast.Store:
@@ -630,10 +633,12 @@ class CodeGenerator(ast.NodeVisitor):
630
633
  def _apply_binary_method(self, method_name, lhs, rhs):
631
634
  # TODO: raise something meaningful if getattr fails below, esp for reverse method
632
635
  if _is_triton_tensor(lhs):
633
- return getattr(lhs, method_name)(rhs, _builder=self.builder)
636
+ return getattr(lhs, method_name)(rhs, _semantic=self.semantic)
634
637
  if _is_triton_tensor(rhs):
635
638
  reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name)
636
- return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder)
639
+ return getattr(rhs, reverse_method_name)(lhs, _semantic=self.semantic)
640
+ if not isinstance(lhs, (constexpr, language.tuple)) and isinstance(rhs, constexpr):
641
+ lhs = constexpr(lhs)
637
642
  return getattr(lhs, method_name)(rhs)
638
643
 
639
644
  def visit_BinOp(self, node):
@@ -786,7 +791,14 @@ class CodeGenerator(ast.NodeVisitor):
786
791
  cond = self.visit(node.test)
787
792
 
788
793
  if _is_triton_tensor(cond):
789
- cond = cond.to(language.int1, _builder=self.builder)
794
+ if _is_non_scalar_tensor(cond):
795
+ raise self._unsupported(node, "Boolean value of Tensor with more than one value is ambiguous")
796
+ if cond.type.is_block():
797
+ warnings.warn(
798
+ "If conditional called with multidimensional Tensor instead of scalar; please use \"if (%s).item()\" instead"
799
+ % ast.unparse(node.test))
800
+ cond = language.core._unsplat(cond, _semantic=self.semantic, _generator=self)
801
+ cond = cond.to(language.int1, _semantic=self.semantic)
790
802
  contains_return = ContainsReturnChecker(self.gscope).visit(node)
791
803
  if contains_return:
792
804
  if self.scf_stack:
@@ -812,21 +824,21 @@ class CodeGenerator(ast.NodeVisitor):
812
824
  def visit_IfExp(self, node):
813
825
  cond = self.visit(node.test)
814
826
  if _is_triton_tensor(cond):
815
- cond = cond.to(language.int1, _builder=self.builder)
827
+ cond = cond.to(language.int1, _semantic=self.semantic)
816
828
  # TODO: Deal w/ more complicated return types (e.g tuple)
817
829
  with enter_sub_region(self):
818
830
  ip, last_loc = self._get_insertion_point_and_loc()
819
831
 
820
832
  then_block = self.builder.create_block()
821
833
  self.builder.set_insertion_point_to_start(then_block)
822
- then_val = semantic.to_tensor(self.visit(node.body), self.builder)
834
+ then_val = self.semantic.to_tensor(self.visit(node.body))
823
835
  then_block = self.builder.get_insertion_block()
824
836
 
825
837
  else_block = self.builder.create_block()
826
838
  self.builder.set_insertion_point_to_start(else_block)
827
839
  # do not need to reset lscope since
828
840
  # ternary expressions cannot define new variables
829
- else_val = semantic.to_tensor(self.visit(node.orelse), self.builder)
841
+ else_val = self.semantic.to_tensor(self.visit(node.orelse))
830
842
  else_block = self.builder.get_insertion_block()
831
843
 
832
844
  self._set_insertion_point_and_loc(ip, last_loc)
@@ -892,10 +904,12 @@ class CodeGenerator(ast.NodeVisitor):
892
904
  if fn is None:
893
905
  raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.")
894
906
  if _is_triton_tensor(operand):
895
- return getattr(operand, fn)(_builder=self.builder)
907
+ return getattr(operand, fn)(_semantic=self.semantic)
896
908
  try:
897
909
  return getattr(operand, fn)()
898
910
  except AttributeError:
911
+ if fn == "__not__":
912
+ return constexpr(not operand)
899
913
  raise self._unsupported(
900
914
  node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}")
901
915
 
@@ -912,6 +926,20 @@ class CodeGenerator(ast.NodeVisitor):
912
926
  f'but is re-assigned to {loop_val.type} in loop! '\
913
927
  f'Please make sure that the type stays consistent.'
914
928
 
929
+ def visit_withitem(self, node):
930
+ return self.visit(node.context_expr)
931
+
932
+ def visit_With(self, node):
933
+ assert len(node.items) == 1
934
+ context = node.items[0].context_expr
935
+ withitemClass = self.visit(context.func)
936
+ if withitemClass == language.async_task:
937
+ args = [self.visit(arg) for arg in context.args]
938
+ with withitemClass(*args, _builder=self.builder):
939
+ self.visit_compound_statement(node.body)
940
+ else:
941
+ self.visit_compound_statement(node.body)
942
+
915
943
  def visit_While(self, node):
916
944
  with enter_sub_region(self) as sr:
917
945
  liveins, insert_block = sr
@@ -991,15 +1019,15 @@ class CodeGenerator(ast.NodeVisitor):
991
1019
  ast.NodeVisitor.generic_visit(self, stmt)
992
1020
 
993
1021
  def visit_Subscript_Load(self, node):
994
- assert node.ctx.__class__.__name__ == "Load"
1022
+ assert isinstance(node.ctx, ast.Load)
995
1023
  lhs = self.visit(node.value)
996
1024
  slices = self.visit(node.slice)
997
1025
  if _is_triton_tensor(lhs):
998
- return lhs.__getitem__(slices, _builder=self.builder)
1026
+ return lhs.__getitem__(slices, _semantic=self.semantic)
999
1027
  return lhs[slices]
1000
1028
 
1001
1029
  def visit_Subscript_Store(self, node, value):
1002
- assert node.ctx.__class__.__name__ == "Store"
1030
+ assert isinstance(node.ctx, ast.Store)
1003
1031
  lhs = self.visit(node.value)
1004
1032
  slices = self.visit(node.slice)
1005
1033
  assert isinstance(lhs, language.tuple)
@@ -1028,6 +1056,7 @@ class CodeGenerator(ast.NodeVisitor):
1028
1056
  loop_unroll_factor = None
1029
1057
  disallow_acc_multi_buffer = False
1030
1058
  flatten = False
1059
+ warp_specialize = False
1031
1060
  if IteratorClass is language.range:
1032
1061
  iterator = IteratorClass(*iter_args, **iter_kwargs)
1033
1062
  # visit iterator arguments
@@ -1040,6 +1069,7 @@ class CodeGenerator(ast.NodeVisitor):
1040
1069
  loop_unroll_factor = iterator.loop_unroll_factor
1041
1070
  disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer
1042
1071
  flatten = iterator.flatten
1072
+ warp_specialize = iterator.warp_specialize
1043
1073
  elif IteratorClass is range:
1044
1074
  # visit iterator arguments
1045
1075
  # note: only `range` iterator is supported now
@@ -1055,14 +1085,14 @@ class CodeGenerator(ast.NodeVisitor):
1055
1085
  step = constexpr(-step.value)
1056
1086
  negative_step = True
1057
1087
  lb, ub = ub, lb
1058
- lb = semantic.to_tensor(lb, self.builder)
1059
- ub = semantic.to_tensor(ub, self.builder)
1060
- step = semantic.to_tensor(step, self.builder)
1088
+ lb = self.semantic.to_tensor(lb)
1089
+ ub = self.semantic.to_tensor(ub)
1090
+ step = self.semantic.to_tensor(step)
1061
1091
  # induction variable type
1062
1092
  if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int():
1063
1093
  raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})")
1064
- iv_type = semantic.integer_promote_impl(lb.dtype, ub.dtype)
1065
- iv_type = semantic.integer_promote_impl(iv_type, step.dtype)
1094
+ iv_type = self.semantic.integer_promote_impl(lb.dtype, ub.dtype)
1095
+ iv_type = self.semantic.integer_promote_impl(iv_type, step.dtype)
1066
1096
  iv_ir_type = iv_type.to_ir(self.builder)
1067
1097
  iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED
1068
1098
  # lb/ub/step might be constexpr, we need to cast them to tensor
@@ -1118,6 +1148,8 @@ class CodeGenerator(ast.NodeVisitor):
1118
1148
  for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr())
1119
1149
  if flatten:
1120
1150
  for_op.set_attr("tt.flatten", self.builder.get_unit_attr())
1151
+ if warp_specialize:
1152
+ for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr())
1121
1153
 
1122
1154
  self.scf_stack.append(node)
1123
1155
  for_op_body = for_op.get_body(0)
@@ -1136,7 +1168,7 @@ class CodeGenerator(ast.NodeVisitor):
1136
1168
  if name in liveins:
1137
1169
  local = self.local_defs[name]
1138
1170
  if isinstance(local, constexpr):
1139
- local = semantic.to_tensor(local, self.builder)
1171
+ local = self.semantic.to_tensor(local)
1140
1172
  yields.append(local)
1141
1173
 
1142
1174
  # create YieldOp
@@ -1180,7 +1212,7 @@ class CodeGenerator(ast.NodeVisitor):
1180
1212
  def visit_Assert(self, node) -> Any:
1181
1213
  test = self.visit(node.test)
1182
1214
  msg = self.visit(node.msg) if node.msg is not None else ""
1183
- return language.core.device_assert(test, msg, _builder=self.builder)
1215
+ return language.core.device_assert(test, msg, _semantic=self.semantic)
1184
1216
 
1185
1217
  def call_JitFunction(self, fn: JITFunction, args, kwargs):
1186
1218
  args = inspect.getcallargs(fn.fn, *args, **kwargs)
@@ -1193,10 +1225,9 @@ class CodeGenerator(ast.NodeVisitor):
1193
1225
  args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
1194
1226
  args_val = [get_iterable_path(args, path) for path in args_path]
1195
1227
  # mangle
1196
- fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst)
1228
+ fn_name = mangle_fn(get_full_name(fn), [arg.type for arg in args_val], args_cst)
1197
1229
  # generate function def if necessary
1198
1230
  if not self.module.has_function(fn_name):
1199
- gscope = fn.__globals__
1200
1231
  # If the callee is not set, we use the same debug setting as the caller
1201
1232
  file_name, begin_line = get_jit_fn_file_line(fn)
1202
1233
  arg_types = [
@@ -1205,7 +1236,7 @@ class CodeGenerator(ast.NodeVisitor):
1205
1236
  for arg in args
1206
1237
  ]
1207
1238
  prototype = ASTFunction([], arg_types, args_cst, dict())
1208
- generator = CodeGenerator(self.context, prototype, gscope, module=self.module, jit_fn=fn,
1239
+ generator = CodeGenerator(self.context, prototype, fn.get_capture_scope(), module=self.module, jit_fn=fn,
1209
1240
  function_name=fn_name, function_types=self.function_ret_types,
1210
1241
  noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
1211
1242
  options=self.builder.options, codegen_fns=self.builder.codegen_fns,
@@ -1214,6 +1245,8 @@ class CodeGenerator(ast.NodeVisitor):
1214
1245
  generator.visit(fn.parse())
1215
1246
  except Exception as e:
1216
1247
  # Wrap the error in the callee with the location of the call.
1248
+ if knobs.compilation.front_end_debugging:
1249
+ raise
1217
1250
  raise CompilationError(self.jit_fn.src, self.cur_node, None) from e
1218
1251
 
1219
1252
  callee_ret_type = generator.ret_type
@@ -1221,7 +1254,7 @@ class CodeGenerator(ast.NodeVisitor):
1221
1254
  else:
1222
1255
  callee_ret_type = self.function_ret_types[fn_name]
1223
1256
  symbol = self.module.get_function(fn_name)
1224
- args_val = [arg.handle for arg in args_val]
1257
+ args_val = flatten_values_to_ir(args_val)
1225
1258
  call_op = self.builder.call(symbol, args_val)
1226
1259
  if callee_ret_type == language.void:
1227
1260
  return None
@@ -1230,18 +1263,29 @@ class CodeGenerator(ast.NodeVisitor):
1230
1263
 
1231
1264
  def visit_Call(self, node):
1232
1265
  fn = _unwrap_if_constexpr(self.visit(node.func))
1233
- static_implementation = self.statically_implemented_functions.get(fn)
1234
- if static_implementation is not None:
1235
- return static_implementation(self, node)
1266
+ if not isinstance(fn, BoundJITMethod):
1267
+ static_implementation = self.statically_implemented_functions.get(fn)
1268
+ if static_implementation is not None:
1269
+ return static_implementation(self, node)
1270
+
1271
+ mur = getattr(fn, '_must_use_result', False)
1272
+ if mur and getattr(node, '_is_unused', False):
1273
+ error_message = ["The result of %s is not being used." % ast.unparse(node.func)]
1274
+ if isinstance(mur, str):
1275
+ error_message.append(mur)
1276
+ raise CompilationError(self.jit_fn.src, node, " ".join(error_message))
1236
1277
 
1237
1278
  kws = dict(self.visit(keyword) for keyword in node.keywords)
1238
1279
  args = [self.visit(arg) for arg in node.args]
1239
1280
  args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
1281
+ if isinstance(fn, BoundJITMethod):
1282
+ args.insert(0, fn.__self__)
1283
+ fn = fn.__func__
1240
1284
  if isinstance(fn, JITFunction):
1241
1285
  _check_fn_args(node, fn, args)
1242
1286
  return self.call_JitFunction(fn, args, kws)
1243
1287
  if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn):
1244
- extra_kwargs = {"_builder": self.builder}
1288
+ extra_kwargs = {"_semantic": self.semantic}
1245
1289
  sig = inspect.signature(fn)
1246
1290
  if '_generator' in sig.parameters:
1247
1291
  extra_kwargs['_generator'] = self
@@ -1252,6 +1296,8 @@ class CodeGenerator(ast.NodeVisitor):
1252
1296
  ret = language.tuple(ret)
1253
1297
  return ret
1254
1298
  except Exception as e:
1299
+ if knobs.compilation.front_end_debugging:
1300
+ raise
1255
1301
  # Normally when we raise a CompilationError, we raise it as
1256
1302
  # `from None`, because the original fileline from the exception
1257
1303
  # is not relevant (and often points into code_generator.py
@@ -1269,26 +1315,73 @@ class CodeGenerator(ast.NodeVisitor):
1269
1315
  return constexpr(node.value)
1270
1316
 
1271
1317
  def visit_BoolOp(self, node: ast.BoolOp):
1272
- if len(node.values) != 2:
1273
- raise self._unsupported(
1274
- node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
1275
- lhs = self.visit(node.values[0])
1276
- rhs = self.visit(node.values[1])
1277
1318
  method_name = self._method_name_for_bool_op.get(type(node.op))
1278
1319
  if method_name is None:
1279
1320
  raise self._unsupported(
1280
1321
  node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
1281
- return self._apply_binary_method(method_name, lhs, rhs)
1322
+
1323
+ nontrivial_values = []
1324
+
1325
+ for subnode in node.values:
1326
+ # we visit the values in order, executing their side-effects
1327
+ # and possibly early-exiting:
1328
+ value = self.visit(subnode)
1329
+ if not _is_triton_tensor(value):
1330
+ # this is a constexpr, so we might be able to short-circuit:
1331
+ bv = bool(value)
1332
+ if (bv is False) and (method_name == "logical_and"):
1333
+ # value is falsey so return that:
1334
+ return value
1335
+ if (bv is True) and (method_name == "logical_or"):
1336
+ # value is truthy so return that:
1337
+ return value
1338
+ # otherwise, our constexpr has no effect on the output of the
1339
+ # expression so we do not append it to nontrivial_values.
1340
+ else:
1341
+ if value.type.is_block():
1342
+ lineno = getattr(node, "lineno", None)
1343
+ if lineno is not None:
1344
+ lineno += self.begin_line
1345
+ warnings.warn_explicit(
1346
+ "Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead",
1347
+ category=UserWarning,
1348
+ filename=self.file_name,
1349
+ lineno=lineno,
1350
+ source=ast.unparse(node),
1351
+ )
1352
+ # not a constexpr so we must append it:
1353
+ nontrivial_values.append(value)
1354
+
1355
+ if len(nontrivial_values) == 0:
1356
+ # the semantics of a disjunction of falsey values or conjunction
1357
+ # of truthy values is to return the final value:
1358
+ nontrivial_values.append(value)
1359
+
1360
+ while len(nontrivial_values) >= 2:
1361
+ rhs = nontrivial_values.pop()
1362
+ lhs = nontrivial_values.pop()
1363
+ res = self._apply_binary_method(method_name, lhs, rhs)
1364
+ nontrivial_values.append(res)
1365
+
1366
+ assert len(nontrivial_values) == 1
1367
+ return nontrivial_values[0]
1282
1368
 
1283
1369
  _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
1284
1370
 
1285
1371
  def visit_Attribute(self, node):
1286
1372
  lhs = self.visit(node.value)
1287
1373
  if _is_triton_tensor(lhs) and node.attr == "T":
1288
- return semantic.permute(lhs, (1, 0), builder=self.builder)
1289
- return getattr(lhs, node.attr)
1374
+ return self.semantic.permute(lhs, (1, 0))
1375
+ # NOTE: special case ".value" for BC
1376
+ if isinstance(lhs, constexpr) and node.attr != "value":
1377
+ lhs = lhs.value
1378
+ attr = getattr(lhs, node.attr)
1379
+ if _is_triton_value(lhs) and isinstance(attr, JITFunction):
1380
+ return BoundJITMethod(lhs, attr)
1381
+ return attr
1290
1382
 
1291
1383
  def visit_Expr(self, node):
1384
+ node.value._is_unused = True
1292
1385
  ast.NodeVisitor.generic_visit(self, node)
1293
1386
 
1294
1387
  def visit_NoneType(self, node):
@@ -1331,6 +1424,8 @@ class CodeGenerator(ast.NodeVisitor):
1331
1424
  except CompilationError:
1332
1425
  raise
1333
1426
  except Exception as e:
1427
+ if knobs.compilation.front_end_debugging:
1428
+ raise
1334
1429
  # Wrap the error in a CompilationError which contains the source
1335
1430
  # of the @jit function.
1336
1431
  raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None
@@ -1378,16 +1473,22 @@ class CodeGenerator(ast.NodeVisitor):
1378
1473
 
1379
1474
  return ret
1380
1475
 
1476
+ from ..experimental.gluon import language as ttgl
1381
1477
  statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {
1382
1478
  language.core.static_assert: execute_static_assert,
1383
1479
  language.core.static_print: static_executor(print),
1480
+ ttgl.static_assert: execute_static_assert,
1481
+ ttgl.static_print: static_executor(print),
1384
1482
  int: static_executor(int),
1385
1483
  len: static_executor(len),
1386
1484
  }
1387
1485
 
1388
1486
 
1389
- def ast_to_ttir(fn, src, context, options, codegen_fns, module_map):
1390
- arg_types = list(map(str_to_ty, src.signature.values()))
1487
+ def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None):
1488
+ arg_types = [None] * len(fn.arg_names)
1489
+ for k, v in src.signature.items():
1490
+ idx = fn.arg_names.index(k)
1491
+ arg_types[idx] = str_to_ty(v)
1391
1492
  prototype = ASTFunction([], arg_types, src.constants, src.attrs)
1392
1493
  file_name, begin_line = get_jit_fn_file_line(fn)
1393
1494
  # query function representation
@@ -1396,9 +1497,9 @@ def ast_to_ttir(fn, src, context, options, codegen_fns, module_map):
1396
1497
  constants = {fn.arg_names[i[0]]: src.constants[i] for i in leaves}
1397
1498
  signature = src.signature
1398
1499
  proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature)
1399
- generator = CodeGenerator(context, prototype, gscope=fn.__globals__.copy(), function_name=fn.repr(proxy), jit_fn=fn,
1400
- is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
1401
- codegen_fns=codegen_fns, module_map=module_map)
1500
+ generator = CodeGenerator(context, prototype, gscope=fn.get_capture_scope(), function_name=fn.repr(proxy),
1501
+ jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
1502
+ codegen_fns=codegen_fns, module_map=module_map, module=module)
1402
1503
  generator.visit(fn.parse())
1403
1504
  ret = generator.module
1404
1505
  # module takes ownership of the context