triton-windows 3.2.0.post11__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (154) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +85 -0
  3. triton/_internal_testing.py +123 -0
  4. triton/backends/__init__.py +50 -0
  5. triton/backends/amd/compiler.py +368 -0
  6. triton/backends/amd/driver.c +211 -0
  7. triton/backends/amd/driver.py +512 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  25. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  26. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  27. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  28. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  31. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  32. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  40. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  41. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  42. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  43. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  44. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  45. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  46. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  48. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  49. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  50. triton/backends/amd/include/hip/device_functions.h +38 -0
  51. triton/backends/amd/include/hip/driver_types.h +468 -0
  52. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  53. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  54. triton/backends/amd/include/hip/hip_common.h +100 -0
  55. triton/backends/amd/include/hip/hip_complex.h +38 -0
  56. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  57. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  58. triton/backends/amd/include/hip/hip_ext.h +159 -0
  59. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  60. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  61. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  62. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  63. triton/backends/amd/include/hip/hip_profile.h +27 -0
  64. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  65. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  66. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  67. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  68. triton/backends/amd/include/hip/hip_version.h +17 -0
  69. triton/backends/amd/include/hip/hiprtc.h +421 -0
  70. triton/backends/amd/include/hip/library_types.h +78 -0
  71. triton/backends/amd/include/hip/math_functions.h +42 -0
  72. triton/backends/amd/include/hip/surface_types.h +63 -0
  73. triton/backends/amd/include/hip/texture_types.h +194 -0
  74. triton/backends/amd/include/hsa/Brig.h +1131 -0
  75. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  76. triton/backends/amd/include/hsa/amd_hsa_elf.h +436 -0
  77. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  78. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  79. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  80. triton/backends/amd/include/hsa/hsa.h +5729 -0
  81. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  82. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  83. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  84. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  85. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  87. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  88. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  89. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  90. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  91. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  92. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  93. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  94. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  95. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  96. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  97. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  98. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  99. triton/backends/amd/include/roctracer/roctx.h +229 -0
  100. triton/backends/amd/lib/ockl.bc +0 -0
  101. triton/backends/amd/lib/ocml.bc +0 -0
  102. triton/backends/compiler.py +304 -0
  103. triton/backends/driver.py +48 -0
  104. triton/backends/nvidia/__init__.py +0 -0
  105. triton/backends/nvidia/bin/ptxas.exe +0 -0
  106. triton/backends/nvidia/compiler.py +410 -0
  107. triton/backends/nvidia/driver.c +451 -0
  108. triton/backends/nvidia/driver.py +524 -0
  109. triton/backends/nvidia/include/cuda.h +24359 -0
  110. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  111. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  112. triton/compiler/__init__.py +4 -0
  113. triton/compiler/code_generator.py +1303 -0
  114. triton/compiler/compiler.py +430 -0
  115. triton/compiler/errors.py +51 -0
  116. triton/compiler/make_launcher.py +0 -0
  117. triton/errors.py +5 -0
  118. triton/language/__init__.py +294 -0
  119. triton/language/_utils.py +21 -0
  120. triton/language/core.py +2694 -0
  121. triton/language/extra/__init__.py +26 -0
  122. triton/language/extra/cuda/__init__.py +13 -0
  123. triton/language/extra/cuda/_experimental_tma.py +108 -0
  124. triton/language/extra/cuda/libdevice.py +1629 -0
  125. triton/language/extra/cuda/utils.py +109 -0
  126. triton/language/extra/hip/__init__.py +3 -0
  127. triton/language/extra/hip/libdevice.py +475 -0
  128. triton/language/extra/libdevice.py +786 -0
  129. triton/language/math.py +250 -0
  130. triton/language/random.py +207 -0
  131. triton/language/semantic.py +1796 -0
  132. triton/language/standard.py +452 -0
  133. triton/runtime/__init__.py +23 -0
  134. triton/runtime/autotuner.py +408 -0
  135. triton/runtime/build.py +111 -0
  136. triton/runtime/cache.py +295 -0
  137. triton/runtime/driver.py +60 -0
  138. triton/runtime/errors.py +26 -0
  139. triton/runtime/interpreter.py +1235 -0
  140. triton/runtime/jit.py +951 -0
  141. triton/testing.py +511 -0
  142. triton/tools/__init__.py +0 -0
  143. triton/tools/build_extern.py +365 -0
  144. triton/tools/compile.c +67 -0
  145. triton/tools/compile.h +14 -0
  146. triton/tools/compile.py +155 -0
  147. triton/tools/disasm.py +144 -0
  148. triton/tools/experimental_descriptor.py +32 -0
  149. triton/tools/link.py +322 -0
  150. triton/windows_utils.py +375 -0
  151. triton_windows-3.2.0.post11.dist-info/METADATA +39 -0
  152. triton_windows-3.2.0.post11.dist-info/RECORD +154 -0
  153. triton_windows-3.2.0.post11.dist-info/WHEEL +5 -0
  154. triton_windows-3.2.0.post11.dist-info/top_level.txt +12 -0
@@ -0,0 +1,1235 @@
1
+ import ast
2
+ import textwrap
3
+ import inspect
4
+ from typing import Tuple
5
+
6
+ import math
7
+ import numpy as np
8
+
9
+ import triton
10
+ import triton.language as tl
11
+ from dataclasses import dataclass
12
+ from .errors import InterpreterError
13
+ from functools import partial
14
+ from .._C.libtriton import interpreter as _interpreter
15
+ from .._C.libtriton import ir as _ir
16
+
17
+
18
+ class TensorHandle:
19
+
20
+ def __init__(self, data, dtype):
21
+ '''
22
+ data: numpy array
23
+ dtype: triton type, either pointer_type or scalar_type.
24
+ we don't store block_type here because the shape information is already availale in the data field
25
+ attr: a dictionary of attributes
26
+ '''
27
+ self.data = data
28
+ self.dtype = dtype
29
+ self.attr = {}
30
+
31
+ def __bool__(self):
32
+ return bool(self.data.all())
33
+
34
+ def get_element_ty(self):
35
+ dtype = self.dtype
36
+ while hasattr(dtype, "element_ty"):
37
+ dtype = dtype.element_ty
38
+ return dtype
39
+
40
+ def clone(self):
41
+ return TensorHandle(self.data.copy(), self.dtype)
42
+
43
+ def set_attr(self, key, value):
44
+ self.attr[key] = value
45
+
46
+
47
+ class BlockPointerHandle:
48
+
49
+ def __init__(self, base, shape, strides, offsets, tensor_shape, order):
50
+ self.base = base
51
+ self.shape = shape
52
+ self.strides = strides
53
+ self.offsets = offsets
54
+ self.tensor_shape = tensor_shape
55
+ self.order = order
56
+
57
+ def materialize_pointers(self, boundary_check):
58
+ dtype_tt = self.base.get_element_ty()
59
+ n_bytes = dtype_tt.primitive_bitwidth // 8
60
+ tensor_shape = self.tensor_shape
61
+ ptrs = np.broadcast_to(self.base.data, self.tensor_shape)
62
+ masks = np.ones(self.tensor_shape, dtype=bool)
63
+ for dim in range(len(tensor_shape)):
64
+ bcast_dims = [1] * len(tensor_shape)
65
+ bcast_dims[dim] = tensor_shape[dim]
66
+ off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims)
67
+ ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
68
+ if dim in boundary_check:
69
+ masks = np.logical_and(masks, off < self.shape[dim].data)
70
+ ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
71
+ return ptrs, masks
72
+
73
+
74
+ @dataclass(frozen=True)
75
+ class InterpreterOptions:
76
+ extern_libs: dict = None
77
+ debug: bool = False
78
+ sanitize_overflow: bool = True
79
+ arch: str = None
80
+ supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15")
81
+ deprecated_fp8_dtypes: Tuple[str] = ()
82
+ default_dot_input_precision: str = "tf32"
83
+ allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
84
+ max_num_imprecise_acc_default: int = 0
85
+ backend_name: str = "interpreter"
86
+
87
+
88
+ def _get_signed_np_dtype(dtype):
89
+ if dtype == np.uint8:
90
+ return np.int8
91
+ if dtype == np.uint16:
92
+ return np.int16
93
+ if dtype == np.uint32:
94
+ return np.int32
95
+ if dtype == np.uint64:
96
+ return np.int64
97
+ return dtype
98
+
99
+
100
+ def _get_np_dtype(tt_dtype):
101
+ if isinstance(tt_dtype, tl.pointer_type):
102
+ return np.dtype(np.uint64)
103
+ np_types = {
104
+ tl.int1: np.dtype(bool),
105
+ tl.float16: np.dtype(np.float16),
106
+ tl.float32: np.dtype(np.float32),
107
+ tl.float64: np.dtype(np.float64),
108
+ tl.int8: np.dtype(np.int8),
109
+ tl.uint8: np.dtype(np.uint8),
110
+ tl.int16: np.dtype(np.int16),
111
+ tl.uint16: np.dtype(np.uint16),
112
+ tl.int32: np.dtype(np.int32),
113
+ tl.uint32: np.dtype(np.uint32),
114
+ tl.int64: np.dtype(np.int64),
115
+ tl.uint64: np.dtype(np.uint64),
116
+ # bfloat16 types are stored as uint16
117
+ tl.bfloat16: np.dtype(np.uint16),
118
+ # float8 types are stored as uint8
119
+ tl.float8e5: np.dtype(np.uint8),
120
+ tl.float8e5b16: np.dtype(np.uint8),
121
+ tl.float8e4nv: np.dtype(np.uint8),
122
+ tl.float8e4b8: np.dtype(np.uint8),
123
+ tl.float8e4b15: np.dtype(np.uint8),
124
+ }
125
+ if isinstance(tt_dtype, tl.block_type):
126
+ if isinstance(tt_dtype.element_ty, tl.pointer_type):
127
+ return np.dtype(np.uint64)
128
+ return np_types[tt_dtype.element_ty]
129
+ return np_types[tt_dtype]
130
+
131
+
132
+ def _convert_float(input, input_dtype, output_dtype, rounding_mode):
133
+ input_uint_dtype = getattr(np, f"uint{input_dtype.primitive_bitwidth}")
134
+ output_unint_dtype = getattr(np, f"uint{output_dtype.primitive_bitwidth}")
135
+ input_bin = np.frombuffer(input.tobytes(), dtype=input_uint_dtype)
136
+ sign = (input_bin >> (input_dtype.primitive_bitwidth - 1)) & 0x01
137
+ input_exponent_width = input_dtype.primitive_bitwidth - input_dtype.fp_mantissa_width - 1
138
+ output_exponent_width = output_dtype.primitive_bitwidth - output_dtype.fp_mantissa_width - 1
139
+ significand = input_bin & ((1 << input_dtype.fp_mantissa_width) - 1)
140
+ bias_input = input_dtype.exponent_bias
141
+ bias_output = output_dtype.exponent_bias
142
+ exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32)
143
+ subnormal_index = exponent == 0
144
+ if np.any(subnormal_index):
145
+ # Credit to Phil: phil@openai.com
146
+ # subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (2^(m0) + 2^(m1) + ... + 2^(mn))
147
+ # where m0, m1, ..., mn are the 1-bit of the mantissa
148
+ # convert it to normal repr: ((-1.0)**sign) * (2.0**(1 + m0 - exp_bias)) * (1 + 2^(m1 - m0) + ... + 2^(mn - m0))
149
+ bit_pos = np.zeros_like(input_bin, dtype=np.int32)
150
+ # Find the most significant bit of the mantissa in the significand
151
+ for i in range(input_dtype.fp_mantissa_width):
152
+ bit_index = ((significand >> i) & 0x01)
153
+ # pos should be >= 1
154
+ bit_pos[bit_index == 1] = input_dtype.fp_mantissa_width - i
155
+ zero_significand_index = significand == 0
156
+ exponent[subnormal_index] = 1 - bit_pos[subnormal_index]
157
+ # 0 significand and subnormal should be treated as 0
158
+ exponent[zero_significand_index & subnormal_index] = bias_input - bias_output
159
+ significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & (
160
+ (1 << input_dtype.fp_mantissa_width) - 1)
161
+ # Prevent overflow and underflow
162
+ exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1))
163
+ exponent_output = exponent_output.astype(output_unint_dtype)
164
+ sign_output = sign.astype(output_unint_dtype)
165
+ if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast
166
+ significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & (
167
+ (1 << output_dtype.fp_mantissa_width) - 1)
168
+ if rounding_mode == _ir.ROUNDING_MODE.RTNE: # Round to nearst even
169
+ # find the cut-off bit
170
+ cut_off = significand & (1 << (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width - 1))
171
+ significand_output = significand_output + (cut_off > 0)
172
+ significand_output = significand_output.astype(output_unint_dtype)
173
+ else: # Upcast
174
+ significand_output = (significand.astype(output_unint_dtype) <<
175
+ (output_dtype.fp_mantissa_width - input_dtype.fp_mantissa_width)) & (
176
+ (1 << output_dtype.fp_mantissa_width) - 1)
177
+ subnormal_index = exponent_output == 0
178
+ if np.any(subnormal_index): # underflow
179
+ # normal repr: ((-1.0)**sign) * (2.0**(exp - exp_bias_input)) * (1 + 2^(m0) + 2^(m1) + ... + 2^(mn))
180
+ # where m0, m1, ..., mn are the 1-bit of the mantissa
181
+ # shift = (1 - exp_bias_output) - (exp - exp_bias_input)
182
+ # convert it to subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias_output)) * (2^(-shift) + 2^(m0 - shift) + 2^(m1 - shift) + ... + 2^(mn - shift))
183
+ exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32)
184
+ non_zero_exponent_index = exponent != 0
185
+ # If the original exponent is not zero, we still need to shift the significand and consider the 1.0 part in mantissa
186
+ subnormal_index = subnormal_index & non_zero_exponent_index
187
+ shift = np.zeros_like(input_bin, dtype=np.int32)
188
+ shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input)
189
+ significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | (
190
+ 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index]))
191
+ output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | (
192
+ exponent_output << output_dtype.fp_mantissa_width) | significand_output
193
+ return output.reshape(input.shape)
194
+
195
+
196
+ def _erf(x):
197
+ # Numpy does not support erf
198
+ return math.erf(x)
199
+
200
+
201
+ def _umulhi_64(a, b):
202
+ # Numpy does not support 128-bit multiplication
203
+ # So we have to implement it manually
204
+ return (int(a) * int(b)) >> 64
205
+
206
+
207
+ np_erf_fp32 = np.vectorize(_erf, otypes=[np.float32])
208
+ np_erf_fp64 = np.vectorize(_erf, otypes=[np.float64])
209
+ np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64])
210
+
211
+
212
+ class ExtraFunctions:
213
+
214
+ @staticmethod
215
+ def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _builder):
216
+ return tl.tensor(_builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty)
217
+
218
+
219
+ class InterpreterBuilder:
220
+ ir_sem_to_interpreter_sem = {
221
+ _ir.MEM_SEMANTIC.ACQUIRE: _interpreter.MEM_SEMANTIC.ACQUIRE,
222
+ _ir.MEM_SEMANTIC.RELEASE: _interpreter.MEM_SEMANTIC.RELEASE,
223
+ _ir.MEM_SEMANTIC.RELAXED: _interpreter.MEM_SEMANTIC.RELAXED,
224
+ _ir.MEM_SEMANTIC.ACQUIRE_RELEASE: _interpreter.MEM_SEMANTIC.ACQUIRE_RELEASE,
225
+ }
226
+
227
+ ir_rmw_op_to_interpreter_rmw_op = {
228
+ _ir.ATOMIC_OP.ADD: _interpreter.RMW_OP.ADD,
229
+ _ir.ATOMIC_OP.FADD: _interpreter.RMW_OP.FADD,
230
+ _ir.ATOMIC_OP.MIN: _interpreter.RMW_OP.MIN,
231
+ _ir.ATOMIC_OP.UMIN: _interpreter.RMW_OP.UMIN,
232
+ _ir.ATOMIC_OP.MAX: _interpreter.RMW_OP.MAX,
233
+ _ir.ATOMIC_OP.UMAX: _interpreter.RMW_OP.UMAX,
234
+ _ir.ATOMIC_OP.AND: _interpreter.RMW_OP.AND,
235
+ _ir.ATOMIC_OP.OR: _interpreter.RMW_OP.OR,
236
+ _ir.ATOMIC_OP.XOR: _interpreter.RMW_OP.XOR,
237
+ _ir.ATOMIC_OP.XCHG: _interpreter.RMW_OP.XCHG,
238
+ }
239
+
240
+ def __init__(self) -> None:
241
+ self.arch = None
242
+ self.options = InterpreterOptions()
243
+ self.codegen_fns = {}
244
+ self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types
245
+ self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (16, 16, 16)
246
+
247
+ def set_grid_idx(self, x, y, z):
248
+ if not x < self.grid_dim[0]:
249
+ raise ValueError("x >= grid_dim[0]")
250
+ if not y < self.grid_dim[1]:
251
+ raise ValueError("y >= grid_dim[1]")
252
+ if not z < self.grid_dim[2]:
253
+ raise ValueError("z >= grid_dim[2]")
254
+ self.grid_idx = (x, y, z)
255
+
256
+ def set_grid_dim(self, nx, ny, nz):
257
+ self.grid_dim = (nx, ny, nz)
258
+
259
+ # constants
260
+
261
+ def get_half_ty(self):
262
+ return tl.float16
263
+
264
+ def get_bf16_ty(self):
265
+ return tl.bfloat16
266
+
267
+ def get_float_ty(self):
268
+ return tl.float32
269
+
270
+ def get_double_ty(self):
271
+ return tl.float64
272
+
273
+ def get_int8_ty(self):
274
+ return tl.int8
275
+
276
+ def get_uint8_ty(self):
277
+ return tl.uint8
278
+
279
+ def get_int16_ty(self):
280
+ return tl.int16
281
+
282
+ def get_uint16_ty(self):
283
+ return tl.uint16
284
+
285
+ def get_int32_ty(self):
286
+ return tl.int32
287
+
288
+ def get_uint32_ty(self):
289
+ return tl.uint32
290
+
291
+ def get_int64_ty(self):
292
+ return tl.int64
293
+
294
+ def get_uint64_ty(self):
295
+ return tl.uint64
296
+
297
+ def get_fp8e4nv_ty(self):
298
+ return tl.float8e4nv
299
+
300
+ def get_fp8e4b15_ty(self):
301
+ return tl.float8e4b15
302
+
303
+ def get_fp8e4b8_ty(self):
304
+ return tl.float8e4b8
305
+
306
+ def get_fp8e5_ty(self):
307
+ return tl.float8e5
308
+
309
+ def get_fp8e5b16_ty(self):
310
+ return tl.float8e5b16
311
+
312
+ def get_ptr_ty(self, elt_ty, addr_space):
313
+ return tl.pointer_type(elt_ty, addr_space)
314
+
315
+ def get_block_ty(self, dtype, shape):
316
+ return tl.block_type(dtype, shape)
317
+
318
+ def get_int1(self, value):
319
+ return TensorHandle(np.array([value], dtype=np.bool_), tl.int1)
320
+
321
+ def get_uint8(self, value):
322
+ return TensorHandle(np.array([value], dtype=np.uint8), tl.uint8)
323
+
324
+ def get_int8(self, value):
325
+ return TensorHandle(np.array([value], dtype=np.int8), tl.int8)
326
+
327
+ def get_uint16(self, value):
328
+ return TensorHandle(np.array([value], dtype=np.uint16), tl.uint16)
329
+
330
+ def get_int16(self, value):
331
+ return TensorHandle(np.array([value], dtype=np.int16), tl.int16)
332
+
333
+ def get_uint32(self, value):
334
+ return TensorHandle(np.array([value], dtype=np.uint32), tl.uint32)
335
+
336
+ def get_int32(self, value):
337
+ return TensorHandle(np.array([value], dtype=np.int32), tl.int32)
338
+
339
+ def get_uint64(self, value):
340
+ return TensorHandle(np.array([value], dtype=np.uint64), tl.uint64)
341
+
342
+ def get_int64(self, value):
343
+ return TensorHandle(np.array([value], dtype=np.int64), tl.int64)
344
+
345
+ def get_fp16(self, value):
346
+ return TensorHandle(np.array([value], dtype=np.float16), tl.float16)
347
+
348
+ def get_fp32(self, value):
349
+ return TensorHandle(np.array([value], dtype=np.float32), tl.float32)
350
+
351
+ def get_fp64(self, value):
352
+ return TensorHandle(np.array([value], dtype=np.float64), tl.float64)
353
+
354
+ def get_null_value(self, type):
355
+ return TensorHandle(np.array([0], dtype=_get_np_dtype(type)), type)
356
+
357
+ # programming model
358
+ def create_get_program_id(self, axis):
359
+ if self.grid_idx is None:
360
+ raise ValueError("grid_idx is None")
361
+ return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32)
362
+
363
+ def create_get_num_programs(self, axis):
364
+ return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32)
365
+
366
+ # memory ops
367
+ def create_load(self, ptr, _0, _1, is_volatile):
368
+ mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
369
+ other = None
370
+ return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile)
371
+
372
+ def create_store(self, ptr, val, _0, _1):
373
+ mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
374
+ return self.create_masked_store(ptr, val, mask, None, None)
375
+
376
+ def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile):
377
+ dtype_tt = ptrs.get_element_ty()
378
+ dtype_np = _get_np_dtype(dtype_tt)
379
+ if other is None:
380
+ other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
381
+ ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np)
382
+ return TensorHandle(ret, dtype_tt)
383
+
384
+ def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy):
385
+ return _interpreter.store(ptrs.data, value.data, mask.data)
386
+
387
+ # casting ops
388
+ def cast_impl(self, src, dst_type):
389
+ src_element_type = src.dtype.scalar
390
+ dst_element_type = dst_type.scalar
391
+ if (src_element_type == tl.bfloat16 and dst_element_type == tl.float32) or \
392
+ (src_element_type == tl.float32 and dst_element_type == tl.bfloat16):
393
+ data = _convert_float(src.data, src_element_type, dst_element_type, None).view(_get_np_dtype(dst_type))
394
+ return TensorHandle(data, dst_type.scalar)
395
+ else:
396
+ return TensorHandle(src.data.astype(_get_np_dtype(dst_type)), dst_type.scalar)
397
+
398
+ create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
399
+ create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
400
+ create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type)
401
+ create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type)
402
+ create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type)
403
+ create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type)
404
+ create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type)
405
+
406
+ def create_fp_to_fp(self, src, dst_type, rounding_mode):
407
+ src_element_type = src.dtype.scalar
408
+ dst_element_type = dst_type.scalar
409
+ data = _convert_float(src.data, src_element_type, dst_element_type, rounding_mode).view(_get_np_dtype(dst_type))
410
+ return TensorHandle(data, dst_type.scalar)
411
+
412
+ def create_bitcast(self, src, dst_type):
413
+ return TensorHandle(src.data.view(_get_np_dtype(dst_type)), dst_type.scalar)
414
+
415
+ # binary operators
416
+ def binary_op(self, lhs, rhs, op):
417
+ return TensorHandle(op(lhs.data, rhs.data), lhs.dtype.scalar)
418
+
419
+ create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
420
+ create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
421
+ create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
422
+ create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder)
423
+ create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
424
+ create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
425
+ create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
426
+ create_sdiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs)
427
+ create_udiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs)
428
+ # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
429
+ create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
430
+ create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
431
+ create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
432
+ create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
433
+ create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift)
434
+ create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift)
435
+ create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
436
+ create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
437
+ create_minimumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
438
+ create_minnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
439
+ create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
440
+ create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
441
+ create_maximumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
442
+ create_maxnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
443
+ create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
444
+ create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
445
+ create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
446
+ create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
447
+ create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
448
+ create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
449
+ create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
450
+ create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
451
+ create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
452
+ create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
453
+ create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
454
+ create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
455
+ create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
456
+ create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
457
+ create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
458
+ create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
459
+ create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
460
+ create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
461
+ create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
462
+ create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
463
+ create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
464
+ create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
465
+ create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and)
466
+ create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor)
467
+ create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or)
468
+ create_int_to_ptr = create_bitcast
469
+ create_ptr_to_int = create_bitcast
470
+
471
+ def create_idiv(self, lhs, rhs):
472
+ # Triton has IEEE, not numpy/torch, semantics for %, and those carry
473
+ # through to //, so we have to use a nonstandard expression to get a
474
+ # reference result for //.
475
+ return TensorHandle((lhs.data - np.fmod(lhs.data, rhs.data)) // rhs.data, lhs.dtype.scalar)
476
+
477
+ def create_ashr(self, lhs, rhs):
478
+ # Triton's rshift operator depends on the signedness of the left operand
479
+ lhs_dtype = _get_signed_np_dtype(lhs.data.dtype)
480
+ rhs_dtype = _get_signed_np_dtype(rhs.data.dtype)
481
+ lhs.data = lhs.data.astype(lhs_dtype)
482
+ rhs.data = rhs.data.astype(rhs_dtype)
483
+ return self.binary_op(lhs, rhs, np.right_shift)
484
+
485
+ def create_umulhi(self, lhs, rhs):
486
+ dtype = lhs.data.dtype
487
+ if dtype == np.int64 or dtype == np.uint64:
488
+ return TensorHandle(np_umulhi_u64(lhs.data, rhs.data), lhs.dtype.scalar)
489
+ else:
490
+ compute_dtype = getattr(np, f"uint{dtype.itemsize * 8 * 2}")
491
+ lhs_data = lhs.data.astype(compute_dtype)
492
+ rhs_data = rhs.data.astype(compute_dtype)
493
+ ret_data = np.multiply(lhs_data, rhs_data) >> (dtype.itemsize * 8)
494
+ return TensorHandle(ret_data.astype(dtype), lhs.dtype.scalar)
495
+
496
+ # ternary functions
497
+ def ternary_op(self, lhs, rhs, other, op):
498
+ return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype.scalar)
499
+
500
+ create_clampf = lambda self, arg, lo, hi, propagate_nans: self.ternary_op(arg, lo, hi, np.clip)
501
+ create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
502
+
503
+ def create_fma(self, x, y, z):
504
+ return TensorHandle(x.data * y.data + z.data, z.dtype.scalar)
505
+
506
+ # unary functions
507
+ def unary_op(self, arg, op):
508
+ return TensorHandle(op(arg.data), arg.dtype.scalar)
509
+
510
+ def create_fabs(self, arg):
511
+ # Mask out the sign bit based on the primitive length
512
+ dtype_tt = arg.dtype
513
+ mask_bitwidth = dtype_tt.primitive_bitwidth - 1
514
+ np_uint_dtype = getattr(np, f"uint{dtype_tt.primitive_bitwidth}")
515
+ data = arg.data.view(np_uint_dtype)
516
+ mask = (1 << mask_bitwidth) - 1
517
+ ret = (data & mask).view(_get_np_dtype(dtype_tt))
518
+ return TensorHandle(ret, arg.dtype.scalar)
519
+
520
+ create_cos = lambda self, arg: self.unary_op(arg, np.cos)
521
+ create_exp = lambda self, arg: self.unary_op(arg, np.exp)
522
+ create_exp2 = lambda self, arg: self.unary_op(arg, np.exp2)
523
+ create_iabs = lambda self, arg: self.unary_op(arg, np.abs)
524
+ create_floor = lambda self, arg: self.unary_op(arg, np.floor)
525
+ create_ceil = lambda self, arg: self.unary_op(arg, np.ceil)
526
+ create_log = lambda self, arg: self.unary_op(arg, np.log)
527
+ create_log2 = lambda self, arg: self.unary_op(arg, np.log2)
528
+ create_precise_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt)
529
+ create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt)
530
+ create_sin = lambda self, arg: self.unary_op(arg, np.sin)
531
+
532
+ def create_erf(self, arg):
533
+ ret = np_erf_fp32(arg.data) if arg.data.dtype == np.float32 else np_erf_fp64(arg.data)
534
+ return TensorHandle(ret, arg.dtype.scalar)
535
+
536
+ def create_rsqrt(self, arg):
537
+ return TensorHandle(1 / np.sqrt(arg.data), arg.dtype.scalar)
538
+
539
+ # tensor operators
540
+ create_reshape = lambda self, arg, shape, allow_reorder: TensorHandle(arg.data.reshape(shape), arg.dtype.scalar)
541
+
542
+ def create_trans(self, arg, perm):
543
+ return TensorHandle(np.transpose(arg.data, perm), arg.dtype.scalar)
544
+
545
+ def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc):
546
+ a_data = a.data
547
+ b_data = b.data
548
+ if (a.dtype.primitive_bitwidth == 8 and a.dtype.is_floating()) or \
549
+ (b.dtype.primitive_bitwidth == 8 and b.dtype.is_floating()):
550
+ a_data = _convert_float(a_data, a.dtype, tl.float16, None).view(np.float16)
551
+ b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16)
552
+ return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar)
553
+
554
+ def create_make_range(self, start, stop):
555
+ return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32)
556
+
557
+ def create_histogram(self, data, bins):
558
+ return TensorHandle(np.histogram(data.data, bins=bins, range=(0, bins))[0], tl.int32)
559
+
560
+ # pointer arithmetic
561
+
562
+ def create_addptr(self, ptr, offset):
563
+ dtype_tt = ptr.get_element_ty()
564
+ element_bitwidth = dtype_tt.primitive_bitwidth
565
+ # int1's bitwidth is 1, but we need to use 8 for pointer arithmetic
566
+ element_bytewidth = max(1, element_bitwidth // 8)
567
+ return TensorHandle(ptr.data + element_bytewidth * offset.data.astype(np.uint64), ptr.dtype)
568
+
569
+ def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy,
570
+ is_volatile):
571
+ ptrs, masks = ptr.materialize_pointers(boundary_check)
572
+ dtype_tt = ptrs.get_element_ty()
573
+ dtype_np = _get_np_dtype(dtype_tt)
574
+ if padding_option is None:
575
+ other = None
576
+ elif padding_option == _ir.PADDING_OPTION.PAD_ZERO:
577
+ other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
578
+ elif padding_option == _ir.PADDING_OPTION.PAD_NAN:
579
+ other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt)
580
+ else:
581
+ raise ValueError(f"unsupported padding option {padding_option}")
582
+ return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile)
583
+
584
+ def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy):
585
+ ptrs, masks = ptr.materialize_pointers(boundary_check)
586
+ return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy)
587
+
588
+ def create_expand_dims(self, arg, axis):
589
+ return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype.scalar)
590
+
591
+ def create_broadcast(self, arg, shape):
592
+ return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype.scalar)
593
+
594
+ def create_cat(self, lhs, rhs):
595
+ return TensorHandle(np.concatenate([lhs.data, rhs.data]), lhs.dtype.scalar)
596
+
597
+ def create_join(self, lhs, rhs):
598
+ # Triton only supports joining two original tensors into a new one along the last axis
599
+ return TensorHandle(np.stack([lhs.data, rhs.data], axis=-1), lhs.dtype.scalar)
600
+
601
+ def create_split(self, val):
602
+ # Triton only supports splitting the original tensor into two along the last axis
603
+ return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar))
604
+
605
+ def create_splat(self, arg, shape):
606
+ if isinstance(arg.dtype, tl.block_type):
607
+ return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
608
+ else: # scalar
609
+ return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
610
+
611
+ def create_atomic_cas(self, ptr, cmp, val, sem, scope):
612
+ if sem not in self.ir_sem_to_interpreter_sem:
613
+ raise ValueError(f"unsupported semantic {sem}")
614
+ sem = self.ir_sem_to_interpreter_sem[sem]
615
+ return TensorHandle(_interpreter.atomic_cas(ptr.data, cmp.data, val.data, sem), cmp.dtype.scalar)
616
+
617
+ def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem, scope):
618
+ if rmwOp not in self.ir_rmw_op_to_interpreter_rmw_op:
619
+ raise ValueError(f"unsupported rmwOp {rmwOp}")
620
+ if sem not in self.ir_sem_to_interpreter_sem:
621
+ raise ValueError(f"unsupported semantic {sem}")
622
+ rmwOp = self.ir_rmw_op_to_interpreter_rmw_op[rmwOp]
623
+ sem = self.ir_sem_to_interpreter_sem[sem]
624
+ return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar)
625
+
626
+ def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure):
627
+ raise NotImplementedError("extern_elementwise not supported in interpreter mode")
628
+
629
+ def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack):
630
+ raise NotImplementedError("inline_asm not supported in interpreter mode")
631
+
632
+ def create_print(self, prefix, hex, values, isSigned):
633
+ # NOTE: the `isSigned` variable is not really used here; because Signness is already known
634
+ # by `values` themselves in python interpreter, thus not really needed here;
635
+ # it is only used for triton PrintOpToLLVM to correctly construct the format specifier.
636
+ # Interpreter's device_print function has a different format than Triton's device_print
637
+ msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})"
638
+ if prefix:
639
+ msg += f" {prefix}"
640
+ if hex:
641
+ np.set_printoptions(formatter={'all': lambda x: f"0x{x:02x}"})
642
+ for value in values:
643
+ print(msg + f" {value.data}")
644
+ if hex:
645
+ np.set_printoptions(formatter=None)
646
+
647
+ def create_assert(self, condition, message):
648
+ # Interpreter's device_assert function has a different format than Triton's device_assert
649
+ assert condition, f"{message}"
650
+
651
+ def create_assume(self, condition):
652
+ assert condition, "Assume failed"
653
+
654
+ def create_barrier(self):
655
+ # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter
656
+ pass
657
+
658
+ def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order):
659
+ # Create new offsets to avoid modifying the original
660
+ new_offsets = [offset.clone() for offset in offsets]
661
+ return BlockPointerHandle(base, shape, strides, new_offsets, tensor_shape, order)
662
+
663
+ def create_advance(self, ptr, offsets):
664
+ if len(ptr.offsets) != len(offsets):
665
+ raise ValueError("len(ptr.offsets) != len(offsets)")
666
+ # Create new offsets to avoid modifying the original
667
+ new_offsets = [offset.clone() for offset in ptr.offsets]
668
+ ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.tensor_shape, ptr.order)
669
+ for i in range(len(offsets)):
670
+ ret.offsets[i].data += offsets[i].data
671
+ return ret
672
+
673
+ def get_all_ones_value(self, type):
674
+ np_type = _get_np_dtype(type)
675
+ if "int" in np_type.name:
676
+ return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar)
677
+ else:
678
+ raise TypeError(f"unsupported type {type}")
679
+
680
+
681
+ def _patch_attr(obj, name, member, builder):
682
+ new_member = lambda *args, member=member, **kwargs: (member(*args, **
683
+ {k: v
684
+ for k, v in kwargs.items()
685
+ if k != "_builder"}, _builder=builder))
686
+ setattr(obj, name, new_member)
687
+
688
+
689
+ def _patch_builtin(pkg, builder):
690
+ for name, member in inspect.getmembers(pkg):
691
+ if tl.core.is_builtin(member):
692
+ _patch_attr(pkg, name, member, builder)
693
+
694
+
695
+ def _patch_lang_tensor(tensor):
696
+
697
+ def _get_bool(self):
698
+ data = self.handle.data
699
+ # in triton, only scalars can be converted to booleans
700
+ # here we need this hack because all scalars are tensors
701
+ return bool(data) if data.size == 1 else True
702
+
703
+ def _get_transpose(self):
704
+ return tl.core.tensor(TensorHandle(np.transpose(self.handle.data), self.handle.dtype), self.dtype.scalar)
705
+
706
+ tensor.__index__ = lambda self: int(self.handle.data)
707
+ tensor.__bool__ = lambda self: _get_bool(self)
708
+ tensor.__repr__ = lambda self: repr(self.handle.data)
709
+ tensor.__str__ = lambda self: str(self.handle.data)
710
+ tensor.T = property(_get_transpose)
711
+
712
+
713
+ class ReduceScanOpIneterface:
714
+
715
+ def __init__(self, axis, combine_fn):
716
+ self.axis = axis
717
+ self.combine_fn = combine_fn
718
+
719
+ def check_axis(self, shape, axis):
720
+ if axis is not None and axis >= len(shape):
721
+ raise ValueError(f"axis {axis} out of bounds for shape {shape}")
722
+
723
+ def check_tensor(self, input):
724
+ for arg in input:
725
+ if not isinstance(arg, tl.core.tensor):
726
+ raise ValueError(f"input must be a tensor, got {type(arg)}")
727
+ self.check_axis(arg.shape, self.axis)
728
+
729
+ def to_tensor(self, ret, dtype):
730
+ if hasattr(ret, "shape") and ret.shape:
731
+ ret_type = tl.block_type(dtype, ret.shape)
732
+ else:
733
+ ret = np.array([ret]).astype(_get_np_dtype(dtype))
734
+ ret_type = dtype
735
+ return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type)
736
+
737
+ def apply(self, input):
738
+ if not isinstance(input, tuple):
739
+ input = (input, )
740
+ self.check_tensor(input)
741
+ return self.apply_impl(input)
742
+
743
+ def apply_impl(self, input):
744
+ raise NotImplementedError("apply_impl not implemented")
745
+
746
+
747
+ class ReduceOps(ReduceScanOpIneterface):
748
+
749
+ def __init__(self, axis, combine_fn, keep_dims):
750
+ super().__init__(axis, combine_fn)
751
+ self.keep_dims = keep_dims
752
+
753
+ def unravel(self, input, axis):
754
+ ret = []
755
+ for data in input:
756
+ if axis is not None:
757
+ ret.append(data)
758
+ else:
759
+ axis = 0
760
+ ret.append(self.to_tensor(data.handle.data.flatten(), data.dtype))
761
+ return tuple(ret), axis
762
+
763
+ def generic_reduce(self, input):
764
+ original_axis = self.axis
765
+ input, axis = self.unravel(input, self.axis)
766
+ input_data = []
767
+ output_data = []
768
+ input_shape = input[0].handle.data.shape
769
+ output_shape = input_shape[0:axis] + input_shape[axis + 1:]
770
+ for arg in input:
771
+ input_data.append(arg.handle.data)
772
+ output_data.append(np.zeros(output_shape, dtype=arg.handle.data.dtype))
773
+ # Reduce on axis
774
+ for i in range(input_data[0].size):
775
+ # Recover input_index from i using input_shape
776
+ input_index = np.unravel_index(i, input_shape)
777
+ output_index = input_index[0:axis] + input_index[axis + 1:]
778
+ input_tuple = tuple(self.to_tensor(d[input_index], input[ii].dtype) for ii, d in enumerate(input_data))
779
+ if input_index[axis] == 0:
780
+ # First element
781
+ for j in range(len(output_data)):
782
+ output_data[j][output_index] = input_tuple[j].handle.data.item()
783
+ else:
784
+ acc_tuple = tuple(self.to_tensor(o[output_index], input[oi].dtype) for oi, o in enumerate(output_data))
785
+ combine_fn_ret = self.combine_fn.fn(*acc_tuple, *input_tuple)
786
+ acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret
787
+ for j in range(len(output_data)):
788
+ output_data[j][output_index] = acc_tuple[j].handle.data.item() if isinstance(
789
+ acc_tuple[j], tl.core.tensor) else acc_tuple[j]
790
+ # Pack output
791
+ ret = []
792
+ for i, data in enumerate(output_data):
793
+ if self.keep_dims:
794
+ if original_axis is not None:
795
+ data = np.expand_dims(data, axis)
796
+ else:
797
+ for _ in range(len(input_shape)):
798
+ data = np.expand_dims(data, 0)
799
+
800
+ elif original_axis is None:
801
+ # Take a scalar
802
+ data = data.item()
803
+ ret.append(self.to_tensor(data, input[i].dtype))
804
+ return ret[0] if len(ret) == 1 else tuple(ret)
805
+
806
+ def min_max(self, input, val_reduce_op, idx_reduce_op=None):
807
+ # If input is a tuple, it must be (val, index), and we only take val
808
+ input = input[0] if isinstance(input, tuple) else input
809
+ val = None
810
+ idx = None
811
+ if val_reduce_op:
812
+ val = self.to_tensor(val_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype)
813
+ if idx_reduce_op:
814
+ idx = self.to_tensor(idx_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), tl.int32)
815
+ if val is not None and idx is not None:
816
+ return val, idx
817
+ elif val is not None:
818
+ return val
819
+ elif idx is not None:
820
+ return idx
821
+ else:
822
+ raise ValueError("val_reduce_op and idx_reduce_op are both None")
823
+
824
+ def sum(self, input):
825
+ return self.to_tensor(np.sum(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype)
826
+
827
+ def apply_impl(self, input):
828
+ if self.combine_fn == tl.standard._argmin_combine_tie_break_left:
829
+ return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=np.argmin)
830
+ elif self.combine_fn == tl.standard._argmax_combine_tie_break_left:
831
+ return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax)
832
+ elif self.combine_fn == tl.standard._elementwise_max:
833
+ return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=None)
834
+ elif self.combine_fn == tl.standard._elementwise_min:
835
+ return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=None)
836
+ elif self.combine_fn == tl.standard._sum_combine:
837
+ return self.sum(input[0])
838
+ else:
839
+ # Fall back to the slow mode
840
+ return self.generic_reduce(input)
841
+
842
+
843
+ class ScanOps(ReduceScanOpIneterface):
844
+
845
+ def __init__(self, axis, combine_fn, reverse):
846
+ super().__init__(axis, combine_fn)
847
+ self.reverse = reverse
848
+
849
+ def cumsum(self, input):
850
+ return [self.to_tensor(np.cumsum(input.handle.data, axis=self.axis), dtype=input.dtype)]
851
+
852
+ def cumprod(self, input):
853
+ return [self.to_tensor(np.cumprod(input.handle.data, axis=self.axis), dtype=input.dtype)]
854
+
855
+ def generic_scan(self, input):
856
+ input_data = []
857
+ output_data = []
858
+ shape = input[0].handle.data.shape
859
+ for arg in input:
860
+ input_data.append(arg.handle.data)
861
+ output_data.append(np.zeros(shape, dtype=arg.handle.data.dtype))
862
+ # Scan on axis
863
+ for i in range(input_data[0].size):
864
+ # Recover index from i using shape
865
+ index = np.unravel_index(i, shape)
866
+ data = tuple(self.to_tensor(d[index], input[ii].dtype) for ii, d in enumerate(input_data))
867
+ if index[self.axis] == 0:
868
+ # First element
869
+ for j in range(len(output_data)):
870
+ output_data[j][index] = data[j].handle.data.item()
871
+ else:
872
+ prev_index = tuple(index[i] - 1 if i == self.axis else index[i] for i in range(len(index)))
873
+ acc_tuple = tuple(self.to_tensor(o[prev_index], input[oi].dtype) for oi, o in enumerate(output_data))
874
+ combine_fn_ret = self.combine_fn.fn(*acc_tuple, *data)
875
+ acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret
876
+ for j in range(len(output_data)):
877
+ output_data[j][index] = acc_tuple[j].handle.data.item() if isinstance(
878
+ acc_tuple[j], tl.core.tensor) else acc_tuple[j]
879
+ # Pack output
880
+ ret = []
881
+ for i, data in enumerate(output_data):
882
+ ret.append(self.to_tensor(data, input[i].dtype))
883
+ return ret
884
+
885
+ def apply_impl(self, input):
886
+ new_input = []
887
+ if self.reverse:
888
+ for arg in input:
889
+ new_input.append(self.to_tensor(np.flip(arg.handle.data, axis=self.axis), arg.dtype))
890
+ else:
891
+ new_input = input
892
+ if self.combine_fn == tl.standard._sum_combine:
893
+ ret = self.cumsum(new_input[0])
894
+ elif self.combine_fn == tl.standard._prod_combine:
895
+ ret = self.cumprod(new_input[0])
896
+ else:
897
+ # Fall back to the slow mode
898
+ ret = self.generic_scan(new_input)
899
+ if self.reverse:
900
+ for arg in ret:
901
+ arg.handle.data = np.flip(arg.handle.data, axis=self.axis)
902
+ return len(ret) == 1 and ret[0] or tuple(ret)
903
+
904
+
905
+ def _patch_reduce_scan():
906
+ # Because interpreter doesn't support region_builder_fn, we cannot patch the builder
907
+ # to use the new reduce and scan functions.
908
+ # Instead, we need to patch reduce and reduce functions in tl and tl.core
909
+ def _new_reduce(input, axis, combine_fn, keep_dims=False, **kwargs):
910
+ return ReduceOps(axis, combine_fn, keep_dims).apply(input)
911
+
912
+ def _new_scan(input, axis, combine_fn, reverse=False, **kwargs):
913
+ return ScanOps(axis, combine_fn, reverse).apply(input)
914
+
915
+ tl.reduce = _new_reduce
916
+ tl.associative_scan = _new_scan
917
+ tl.core.reduce = _new_reduce
918
+ tl.core.associative_scan = _new_scan
919
+
920
+
921
+ def _patch_lang_core(lang):
922
+
923
+ def _new_to_ir(self, builder):
924
+ # We need to specify signedness for integer types in the numpy mode
925
+ if self.name == 'void':
926
+ return builder.get_void_ty()
927
+ elif self.name == 'int1':
928
+ return builder.get_int1_ty()
929
+ elif self.name == 'int8':
930
+ return builder.get_int8_ty()
931
+ elif self.name == 'uint8':
932
+ return builder.get_uint8_ty()
933
+ elif self.name == 'int16':
934
+ return builder.get_int16_ty()
935
+ elif self.name == 'uint16':
936
+ return builder.get_uint16_ty()
937
+ elif self.name == 'int32':
938
+ return builder.get_int32_ty()
939
+ elif self.name == 'uint32':
940
+ return builder.get_uint32_ty()
941
+ elif self.name == 'int64':
942
+ return builder.get_int64_ty()
943
+ elif self.name == 'uint64':
944
+ return builder.get_uint64_ty()
945
+ elif self.name == 'fp8e5':
946
+ return builder.get_fp8e5_ty()
947
+ elif self.name == 'fp8e4nv':
948
+ return builder.get_fp8e4nv_ty()
949
+ elif self.name == 'fp8e4b15':
950
+ return builder.get_fp8e4b15_ty()
951
+ elif self.name == 'fp16':
952
+ return builder.get_half_ty()
953
+ elif self.name == 'bf16':
954
+ return builder.get_bf16_ty()
955
+ elif self.name == 'fp32':
956
+ return builder.get_float_ty()
957
+ elif self.name == 'fp64':
958
+ return builder.get_double_ty()
959
+ raise ValueError(f'fail to convert {self} to ir type')
960
+
961
+ # can't just map lang.static_range to `range`, because `tl.static_range`
962
+ # can get `step` passed by keyword
963
+ def _new_range(arg1, arg2=None, step=None, **kwargs):
964
+ if step is None:
965
+ step = 1
966
+ if arg2 is None:
967
+ start, end = 0, arg1
968
+ else:
969
+ start, end = arg1, arg2
970
+ return range(start, end, step)
971
+
972
+ def _new_static_assert(cond, msg=""):
973
+ assert cond, msg
974
+
975
+ def _set_attr(input, values, name):
976
+ # skip non tensor types. This may happen for induction variables.
977
+ if not isinstance(input, tl.tensor):
978
+ return input
979
+ # Unwrap constexpr
980
+ values = [values] if not isinstance(values, (list, tuple)) else values
981
+ values = [v.value if isinstance(v, tl.constexpr) else v for v in values]
982
+ if len(values) != max(1, len(input.shape)):
983
+ raise ValueError(f"len(values) != len(input.shape) for {name}")
984
+ input.handle.set_attr(name, values)
985
+ return input
986
+
987
+ lang.range = _new_range
988
+ lang.static_range = _new_range
989
+ lang.static_assert = _new_static_assert
990
+ lang.static_print = print
991
+ lang.dtype.to_ir = _new_to_ir
992
+ lang.multiple_of = partial(_set_attr, name="tt.divisiblity")
993
+ lang.max_contiguous = partial(_set_attr, name="tt.contiguity")
994
+ lang.max_constancy = partial(_set_attr, name="tt.constancy")
995
+
996
+ _patch_reduce_scan()
997
+
998
+
999
+ def _patch_lang(fn):
1000
+ langs = [value for _, value in fn.__globals__.items() if value in [tl, tl.core]]
1001
+ assert len(langs) >= 1, "triton.language must be visible from within jit'd function"
1002
+ for lang in langs:
1003
+ _patch_builtin(lang, interpreter_builder)
1004
+ _patch_builtin(lang.tensor, interpreter_builder)
1005
+ if lang == tl:
1006
+ _patch_builtin(lang.math, interpreter_builder)
1007
+ _patch_lang_tensor(lang.tensor)
1008
+ _patch_lang_core(lang)
1009
+
1010
+
1011
+ # TODO: wrap everything in triton tensors
1012
+ def _implicit_cvt(arg):
1013
+ if isinstance(arg, int):
1014
+ ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
1015
+ dtype = np.int32
1016
+ if -2**31 <= arg < 2**31:
1017
+ dtype = np.int32
1018
+ elif 2**31 <= arg < 2**32:
1019
+ dtype = np.uint32
1020
+ elif -2**63 <= arg < 2**63:
1021
+ dtype = np.int64
1022
+ elif 2**63 <= arg < 2**64:
1023
+ dtype = np.uint64
1024
+ else:
1025
+ raise ValueError(f"Unsupported integer value {arg}")
1026
+ handle = TensorHandle(np.array([arg], dtype=dtype), ty)
1027
+ return tl.tensor(handle, ty)
1028
+ if hasattr(arg, "data_ptr"):
1029
+ ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
1030
+ handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
1031
+ return tl.tensor(handle, ty)
1032
+ return arg
1033
+
1034
+
1035
+ interpreter_builder = InterpreterBuilder()
1036
+
1037
+ # These keywords are not supported by the interpreter
1038
+ RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg"]
1039
+
1040
+
1041
+ class GridExecutor:
1042
+
1043
+ def __init__(self, fn, arg_names, grid):
1044
+ from .jit import _normalize_ty # TODO: modularize
1045
+
1046
+ self.fn = fn
1047
+ self.arg_names = arg_names
1048
+ self.grid = grid
1049
+ __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
1050
+ self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
1051
+
1052
+ def _init_args_hst(self, args_dev, kwargs):
1053
+ args_hst = []
1054
+ for arg in args_dev:
1055
+ if hasattr(arg, "data_ptr"):
1056
+ args_hst.append(arg.cpu())
1057
+ else:
1058
+ args_hst.append(arg)
1059
+ # Process keyword arguments
1060
+ kwargs_hst = {}
1061
+ for key, value in kwargs.items():
1062
+ if hasattr(value, "data_ptr"):
1063
+ kwargs_hst[key] = value.cpu()
1064
+ else:
1065
+ kwargs_hst[key] = value
1066
+ return args_hst, kwargs_hst
1067
+
1068
+ def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
1069
+ for arg_dev, arg_hst in zip(args_dev, args_hst):
1070
+ if hasattr(arg_dev, "data_ptr"):
1071
+ arg_dev.data.copy_(arg_hst.to(arg_dev.device).data)
1072
+
1073
+ # Restore keyword arguments
1074
+ for key, kwarg_dev in kwargs.items():
1075
+ kwarg_hst = kwargs_hst[key]
1076
+ if hasattr(kwarg_dev, "data_ptr"):
1077
+ kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data)
1078
+
1079
+ def __call__(self, *args_dev, **kwargs):
1080
+ # removes reserved keywords from kwargs
1081
+ kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
1082
+ if kwargs.pop("warmup", False):
1083
+ return
1084
+ # copy arguments to the host
1085
+ args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
1086
+ # remaps core language functions to interpreted ones
1087
+ _patch_lang(self.fn)
1088
+ # we need to copy arguments to the host for the interpreter
1089
+ # implicitly convert tensor arguments to their base pointers
1090
+ args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst)
1091
+ args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()}
1092
+ # iterate through grid
1093
+ grid = self.grid(args) if callable(self.grid) else self.grid
1094
+ assert len(grid) <= 3, "grid must have at most 3 dimensions"
1095
+ grid = grid + (1, ) * (3 - len(grid))
1096
+ interpreter_builder.set_grid_dim(*grid)
1097
+ try:
1098
+ for x in range(grid[0]):
1099
+ for y in range(grid[1]):
1100
+ for z in range(grid[2]):
1101
+ interpreter_builder.set_grid_idx(x, y, z)
1102
+ self.fn(**args)
1103
+ except Exception as e:
1104
+ raise InterpreterError(repr(e)) from e
1105
+ # copy arguments back to propagate side-effects
1106
+ self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst)
1107
+
1108
+
1109
+ class ASTTransformer(ast.NodeTransformer):
1110
+
1111
+ def visit_Assign(self, node):
1112
+ names = []
1113
+ for target in node.targets:
1114
+ names += [self.visit(target)]
1115
+ if len(names) > 1:
1116
+ raise ValueError("Multiple assignments are not supported")
1117
+ # Modify the assignment x = value to
1118
+ # triton.language.semantic.to_tensor(value, interpreter_builder, False)
1119
+ node.value = ast.Call(
1120
+ func=ast.Attribute(
1121
+ value=ast.Attribute(
1122
+ value=ast.Attribute(value=ast.Name(id='triton', ctx=ast.Load()), attr='language', ctx=ast.Load()),
1123
+ attr='semantic', ctx=ast.Load()), attr='to_tensor', ctx=ast.Load()),
1124
+ args=[node.value, ast.Name(id='interpreter_builder', ctx=ast.Load()),
1125
+ ast.Constant(value=False)], keywords=[])
1126
+ return node
1127
+
1128
+
1129
+ class FunctionRewriter:
1130
+ ast_transformer = ASTTransformer()
1131
+
1132
+ def __init__(self, fn, **kwargs):
1133
+ self.fn = fn
1134
+ self.kwargs = kwargs
1135
+ self.filename: str = ""
1136
+ # Absolute line number in the file
1137
+ self.def_file_lineno: int = 0
1138
+
1139
+ def rewrite_ast(self):
1140
+ # If exception is raise, it means the function does not have source code available,
1141
+ # e.g., dynamically generated functions, we cannot rewrite it so just return the original function
1142
+ try:
1143
+ lines, _ = inspect.getsourcelines(self.fn)
1144
+ except Exception:
1145
+ return self.fn
1146
+
1147
+ # truncate lines before def
1148
+ # @triton.autotune(...)
1149
+ # ...
1150
+ # @triton.jit
1151
+ # ...
1152
+ # def foo(...): <- this line is the function definition
1153
+ self.filename, self.def_file_lineno = self._get_jit_fn_file_line()
1154
+ self.def_lineno = self._find_def(lines)
1155
+ src = self._prepare_source(lines)
1156
+ transformed_ast = self._transform_ast(src)
1157
+ return self._compile_and_exec(transformed_ast)
1158
+
1159
+ def _get_jit_fn_file_line(self):
1160
+ from .jit import get_jit_fn_file_line, JITFunction
1161
+ return get_jit_fn_file_line(JITFunction(self.fn))
1162
+
1163
+ def _find_def(self, lines):
1164
+ def_lineno = 0
1165
+ # Line numbers start from 1
1166
+ for i, line in enumerate(lines):
1167
+ if line.strip().startswith("def "):
1168
+ def_lineno = i + 1
1169
+ return def_lineno
1170
+
1171
+ def _prepare_source(self, lines):
1172
+ lines = lines[self.def_lineno - 1:]
1173
+ src = ''.join(lines)
1174
+ return textwrap.dedent(src)
1175
+
1176
+ def _transform_ast(self, src):
1177
+ # src is like:
1178
+ # 1: def foo(...):
1179
+ # 2: ...
1180
+ parsed_ast = ast.parse(src)
1181
+ transformed_ast = self.ast_transformer.visit(parsed_ast)
1182
+ ast.fix_missing_locations(transformed_ast)
1183
+ inc_lineno = self.def_file_lineno - 1
1184
+ ast.increment_lineno(transformed_ast, inc_lineno)
1185
+ return transformed_ast
1186
+
1187
+ def _compile_and_exec(self, transformed_ast):
1188
+ compiled_code = compile(transformed_ast, filename=self.filename, mode='exec')
1189
+ local_namespace = {**self.kwargs}
1190
+ fn_globals = self.fn.__globals__
1191
+ for key, value in globals().items():
1192
+ if key not in fn_globals:
1193
+ fn_globals[key] = value
1194
+ exec(compiled_code, fn_globals, local_namespace)
1195
+ return local_namespace[self.fn.__name__]
1196
+
1197
+
1198
+ class InterpretedFunction:
1199
+ # Cache all rewritten functions
1200
+ rewritten_fn = {}
1201
+
1202
+ def __init__(self, fn, **kwargs) -> None:
1203
+ self.fn = fn
1204
+ self.rewriter = FunctionRewriter(fn, **kwargs)
1205
+
1206
+ def run(*args, **kwargs):
1207
+ grid = kwargs["grid"]
1208
+ fn = self.rewrite()
1209
+ return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs)
1210
+
1211
+ self.run = run
1212
+ signature = inspect.signature(fn)
1213
+ self.arg_names = [v.name for v in signature.parameters.values()]
1214
+
1215
+ def rewrite(self):
1216
+ if self.fn not in self.rewritten_fn:
1217
+ self.rewritten_fn[self.fn] = self.rewriter.rewrite_ast()
1218
+ return self.rewritten_fn[self.fn]
1219
+
1220
+ @property
1221
+ def __name__(self):
1222
+ return self.fn.__name__
1223
+
1224
+ def __getitem__(self, grid):
1225
+ fn = self.rewrite()
1226
+ return GridExecutor(fn, self.arg_names, grid)
1227
+
1228
+ def __call__(self, *args, **kwargs):
1229
+ # This is a device function call
1230
+ _patch_lang(self.fn)
1231
+ fn = self.rewrite()
1232
+ try:
1233
+ return fn(*args, **kwargs)
1234
+ except Exception as e:
1235
+ raise InterpreterError(repr(e)) from e