triton-windows 3.3.0.post19__cp39-cp39-win_amd64.whl → 3.4.0.post20__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (173) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/runtime/tcc/lib/python310.def +1610 -0
  56. triton/runtime/tcc/lib/python311.def +1633 -0
  57. triton/runtime/tcc/lib/python312.def +1703 -0
  58. triton/runtime/tcc/lib/python313.def +1651 -0
  59. triton/runtime/tcc/lib/python313t.def +1656 -0
  60. triton/runtime/tcc/lib/python39.def +1644 -0
  61. triton/runtime/tcc/lib/python3t.def +905 -0
  62. triton/testing.py +16 -12
  63. triton/tools/disasm.py +3 -4
  64. triton/tools/tensor_descriptor.py +36 -0
  65. triton/windows_utils.py +14 -6
  66. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  67. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  68. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +1 -1
  69. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  70. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  71. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  72. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  73. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  80. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  81. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  82. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  83. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  84. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  85. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  86. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  87. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  88. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  89. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  90. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  91. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  92. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  93. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  94. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  95. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  96. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  97. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  98. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  99. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  100. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  101. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  102. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  103. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  104. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  105. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  106. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  107. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  108. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  109. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  110. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  111. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  112. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  113. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  114. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  115. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  116. triton/backends/amd/include/hip/device_functions.h +0 -38
  117. triton/backends/amd/include/hip/driver_types.h +0 -468
  118. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  119. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  120. triton/backends/amd/include/hip/hip_common.h +0 -100
  121. triton/backends/amd/include/hip/hip_complex.h +0 -38
  122. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  123. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  124. triton/backends/amd/include/hip/hip_ext.h +0 -161
  125. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  126. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  127. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  128. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  129. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  130. triton/backends/amd/include/hip/hip_profile.h +0 -27
  131. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  132. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  133. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  134. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  135. triton/backends/amd/include/hip/hip_version.h +0 -17
  136. triton/backends/amd/include/hip/hiprtc.h +0 -421
  137. triton/backends/amd/include/hip/library_types.h +0 -78
  138. triton/backends/amd/include/hip/math_functions.h +0 -42
  139. triton/backends/amd/include/hip/surface_types.h +0 -63
  140. triton/backends/amd/include/hip/texture_types.h +0 -194
  141. triton/backends/amd/include/hsa/Brig.h +0 -1131
  142. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  143. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  144. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  145. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  146. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  147. triton/backends/amd/include/hsa/hsa.h +0 -5738
  148. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  149. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  150. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  151. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  152. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  153. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  154. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  155. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  156. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  157. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  158. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  159. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  160. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  161. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  162. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  163. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  164. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  165. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  166. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  167. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  168. triton/backends/amd/include/roctracer/roctx.h +0 -229
  169. triton/language/_utils.py +0 -21
  170. triton/language/extra/cuda/_experimental_tma.py +0 -106
  171. triton/tools/experimental_descriptor.py +0 -32
  172. triton_windows-3.3.0.post19.dist-info/RECORD +0 -253
  173. triton_windows-3.3.0.post19.dist-info/top_level.txt +0 -14
triton/language/math.py CHANGED
@@ -1,5 +1,4 @@
1
1
  from . import core
2
- from . import semantic
3
2
  from functools import wraps
4
3
  from typing import List
5
4
 
@@ -85,107 +84,107 @@ def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]:
85
84
  @core.builtin
86
85
  @_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"])
87
86
  @_add_math_2arg_docstr("most significant N bits of the 2N-bit product")
88
- def umulhi(x, y, _builder=None):
89
- x = semantic.to_tensor(x, _builder)
90
- y = semantic.to_tensor(y, _builder)
91
- x, y = core.binary_op_type_legalization(x, y, _builder)
92
- return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type)
87
+ def umulhi(x, y, _semantic=None):
88
+ x = _semantic.to_tensor(x)
89
+ y = _semantic.to_tensor(y)
90
+ x, y = core.binary_op_type_legalization(x, y, _semantic)
91
+ return core.tensor(_semantic.builder.create_umulhi(x.handle, y.handle), x.type)
93
92
 
94
93
 
95
94
  @core.builtin
96
95
  @_check_dtype(dtypes=["fp32", "fp64"])
97
96
  @_add_math_1arg_docstr("exponential")
98
97
  @core._tensor_member_fn
99
- def exp(x, _builder=None):
100
- x = semantic.to_tensor(x, _builder)
101
- return core.tensor(_builder.create_exp(x.handle), x.type)
98
+ def exp(x, _semantic=None):
99
+ x = _semantic.to_tensor(x)
100
+ return core.tensor(_semantic.builder.create_exp(x.handle), x.type)
102
101
 
103
102
 
104
103
  @core.builtin
105
104
  @_check_dtype(dtypes=["fp32", "fp64"])
106
105
  @_add_math_1arg_docstr("exponential (base 2)")
107
106
  @core._tensor_member_fn
108
- def exp2(x, _builder=None):
109
- x = semantic.to_tensor(x, _builder)
110
- return core.tensor(_builder.create_exp2(x.handle), x.type)
107
+ def exp2(x, _semantic=None):
108
+ x = _semantic.to_tensor(x)
109
+ return core.tensor(_semantic.builder.create_exp2(x.handle), x.type)
111
110
 
112
111
 
113
112
  @core.builtin
114
113
  @_check_dtype(dtypes=["fp32", "fp64"])
115
114
  @_add_math_1arg_docstr("natural logarithm")
116
115
  @core._tensor_member_fn
117
- def log(x, _builder=None):
118
- x = semantic.to_tensor(x, _builder)
119
- return core.tensor(_builder.create_log(x.handle), x.type)
116
+ def log(x, _semantic=None):
117
+ x = _semantic.to_tensor(x)
118
+ return core.tensor(_semantic.builder.create_log(x.handle), x.type)
120
119
 
121
120
 
122
121
  @core.builtin
123
122
  @_check_dtype(dtypes=["fp32", "fp64"])
124
123
  @_add_math_1arg_docstr("logarithm (base 2)")
125
124
  @core._tensor_member_fn
126
- def log2(x, _builder=None):
127
- x = semantic.to_tensor(x, _builder)
128
- return core.tensor(_builder.create_log2(x.handle), x.type)
125
+ def log2(x, _semantic=None):
126
+ x = _semantic.to_tensor(x)
127
+ return core.tensor(_semantic.builder.create_log2(x.handle), x.type)
129
128
 
130
129
 
131
130
  @core.builtin
132
131
  @_check_dtype(dtypes=["fp32", "fp64"])
133
132
  @_add_math_1arg_docstr("cosine")
134
133
  @core._tensor_member_fn
135
- def cos(x, _builder=None):
136
- x = semantic.to_tensor(x, _builder)
137
- return core.tensor(_builder.create_cos(x.handle), x.type)
134
+ def cos(x, _semantic=None):
135
+ x = _semantic.to_tensor(x)
136
+ return core.tensor(_semantic.builder.create_cos(x.handle), x.type)
138
137
 
139
138
 
140
139
  @core.builtin
141
140
  @_check_dtype(dtypes=["fp32", "fp64"])
142
141
  @_add_math_1arg_docstr("sine")
143
142
  @core._tensor_member_fn
144
- def sin(x, _builder=None):
145
- x = semantic.to_tensor(x, _builder)
146
- return core.tensor(_builder.create_sin(x.handle), x.type)
143
+ def sin(x, _semantic=None):
144
+ x = _semantic.to_tensor(x)
145
+ return core.tensor(_semantic.builder.create_sin(x.handle), x.type)
147
146
 
148
147
 
149
148
  @core.builtin
150
149
  @_check_dtype(dtypes=["fp32", "fp64"])
151
150
  @_add_math_1arg_docstr("fast square root")
152
151
  @core._tensor_member_fn
153
- def sqrt(x, _builder=None):
154
- x = semantic.to_tensor(x, _builder)
155
- return core.tensor(_builder.create_sqrt(x.handle), x.type)
152
+ def sqrt(x, _semantic=None):
153
+ x = _semantic.to_tensor(x)
154
+ return core.tensor(_semantic.builder.create_sqrt(x.handle), x.type)
156
155
 
157
156
 
158
157
  @core.builtin
159
158
  @_check_dtype(dtypes=["fp32"])
160
159
  @_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)")
161
160
  @core._tensor_member_fn
162
- def sqrt_rn(x, _builder=None):
163
- x = semantic.to_tensor(x, _builder)
164
- return core.tensor(_builder.create_precise_sqrt(x.handle), x.type)
161
+ def sqrt_rn(x, _semantic=None):
162
+ x = _semantic.to_tensor(x)
163
+ return core.tensor(_semantic.builder.create_precise_sqrt(x.handle), x.type)
165
164
 
166
165
 
167
166
  @core.builtin
168
167
  @_check_dtype(dtypes=["fp32", "fp64"])
169
168
  @_add_math_1arg_docstr("inverse square root")
170
169
  @core._tensor_member_fn
171
- def rsqrt(x, _builder=None):
172
- x = semantic.to_tensor(x, _builder)
173
- return core.tensor(_builder.create_rsqrt(x.handle), x.type)
170
+ def rsqrt(x, _semantic=None):
171
+ x = _semantic.to_tensor(x)
172
+ return core.tensor(_semantic.builder.create_rsqrt(x.handle), x.type)
174
173
 
175
174
 
176
175
  @core._tensor_member_fn
177
176
  @core.builtin
178
177
  @_add_math_1arg_docstr("absolute value")
179
- def abs(x, _builder=None):
180
- x = semantic.to_tensor(x, _builder)
178
+ def abs(x, _semantic=None):
179
+ x = _semantic.to_tensor(x)
181
180
  dtype = x.dtype
182
181
  if dtype.is_fp8e4b15():
183
- mask = core.full(x.shape, 0x7F, core.int8, _builder=_builder)
184
- return core.tensor(_builder.create_and(x.handle, mask.handle), x.type)
182
+ mask = core.full(x.shape, 0x7F, core.int8, _semantic=_semantic)
183
+ return core.tensor(_semantic.builder.create_and(x.handle, mask.handle), x.type)
185
184
  elif dtype.is_floating():
186
- return core.tensor(_builder.create_fabs(x.handle), x.type)
185
+ return core.tensor(_semantic.builder.create_fabs(x.handle), x.type)
187
186
  elif dtype.is_int_signed():
188
- return core.tensor(_builder.create_iabs(x.handle), x.type)
187
+ return core.tensor(_semantic.builder.create_iabs(x.handle), x.type)
189
188
  elif dtype.is_int_unsigned():
190
189
  return x # no-op
191
190
  else:
@@ -194,57 +193,57 @@ def abs(x, _builder=None):
194
193
 
195
194
  @core.builtin
196
195
  @_add_math_2arg_docstr("fast division")
197
- def fdiv(x, y, ieee_rounding=False, _builder=None):
198
- ieee_rounding = core._constexpr_to_value(ieee_rounding)
199
- x = semantic.to_tensor(x, _builder)
200
- y = semantic.to_tensor(y, _builder)
201
- return semantic.fdiv(x, y, ieee_rounding, _builder)
196
+ def fdiv(x, y, ieee_rounding=False, _semantic=None):
197
+ ieee_rounding = core._unwrap_if_constexpr(ieee_rounding)
198
+ x = _semantic.to_tensor(x)
199
+ y = _semantic.to_tensor(y)
200
+ return _semantic.fdiv(x, y, ieee_rounding)
202
201
 
203
202
 
204
203
  @core.builtin
205
204
  @_check_dtype(dtypes=["fp32"])
206
205
  @_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)")
207
- def div_rn(x, y, _builder=None):
208
- x = semantic.to_tensor(x, _builder)
209
- y = semantic.to_tensor(y, _builder)
210
- x, y = core.binary_op_type_legalization(x, y, _builder)
211
- return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type)
206
+ def div_rn(x, y, _semantic=None):
207
+ x = _semantic.to_tensor(x)
208
+ y = _semantic.to_tensor(y)
209
+ x, y = core.binary_op_type_legalization(x, y, _semantic)
210
+ return core.tensor(_semantic.builder.create_precise_divf(x.handle, y.handle), x.type)
212
211
 
213
212
 
214
213
  @core.builtin
215
214
  @_check_dtype(dtypes=["fp32", "fp64"])
216
215
  @_add_math_1arg_docstr("error function")
217
216
  @core._tensor_member_fn
218
- def erf(x, _builder=None):
219
- x = semantic.to_tensor(x, _builder)
220
- return core.tensor(_builder.create_erf(x.handle), x.type)
217
+ def erf(x, _semantic=None):
218
+ x = _semantic.to_tensor(x)
219
+ return core.tensor(_semantic.builder.create_erf(x.handle), x.type)
221
220
 
222
221
 
223
222
  @core.builtin
224
223
  @_check_dtype(dtypes=["fp32", "fp64"])
225
224
  @_add_math_1arg_docstr("floor")
226
225
  @core._tensor_member_fn
227
- def floor(x, _builder=None):
228
- x = semantic.to_tensor(x, _builder)
229
- return core.tensor(_builder.create_floor(x.handle), x.type)
226
+ def floor(x, _semantic=None):
227
+ x = _semantic.to_tensor(x)
228
+ return core.tensor(_semantic.builder.create_floor(x.handle), x.type)
230
229
 
231
230
 
232
231
  @core.builtin
233
232
  @_check_dtype(dtypes=["fp32", "fp64"])
234
233
  @_add_math_1arg_docstr("ceil")
235
234
  @core._tensor_member_fn
236
- def ceil(x, _builder=None):
237
- x = semantic.to_tensor(x, _builder)
238
- return core.tensor(_builder.create_ceil(x.handle), x.type)
235
+ def ceil(x, _semantic=None):
236
+ x = _semantic.to_tensor(x)
237
+ return core.tensor(_semantic.builder.create_ceil(x.handle), x.type)
239
238
 
240
239
 
241
240
  @core.builtin
242
241
  @_add_math_3arg_docstr("fused multiply-add")
243
- def fma(x, y, z, _builder=None):
244
- x = semantic.to_tensor(x, _builder)
245
- y = semantic.to_tensor(y, _builder)
246
- z = semantic.to_tensor(z, _builder)
247
- x, y = core.binary_op_type_legalization(x, y, _builder)
248
- z, x = core.binary_op_type_legalization(z, x, _builder)
249
- z, y = core.binary_op_type_legalization(z, y, _builder)
250
- return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type)
242
+ def fma(x, y, z, _semantic=None):
243
+ x = _semantic.to_tensor(x)
244
+ y = _semantic.to_tensor(y)
245
+ z = _semantic.to_tensor(z)
246
+ x, y = core.binary_op_type_legalization(x, y, _semantic)
247
+ z, x = core.binary_op_type_legalization(z, x, _semantic)
248
+ z, y = core.binary_op_type_legalization(z, y, _semantic)
249
+ return core.tensor(_semantic.builder.create_fma(x.handle, y.handle, z.handle), x.type)
triton/language/random.py CHANGED
@@ -51,6 +51,7 @@ def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
51
51
  c1 = tl.to_tensor(c1)
52
52
  c2 = tl.to_tensor(c2)
53
53
  c3 = tl.to_tensor(c3)
54
+
54
55
  if tl.constexpr(c0.dtype.primitive_bitwidth) == 32:
55
56
  int_dtype = tl.uint32
56
57
  seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
@@ -60,6 +61,7 @@ def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
60
61
  int_dtype = tl.uint64
61
62
  seed_hi = tl.full((1, ), 0, dtype=int_dtype)
62
63
  seed_lo = seed
64
+
63
65
  c0 = c0.to(int_dtype, bitcast=True)
64
66
  c1 = c1.to(int_dtype, bitcast=True)
65
67
  c2 = c2.to(int_dtype, bitcast=True)
@@ -96,8 +98,16 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
96
98
  :param offsets: The offsets to generate random numbers for.
97
99
  """
98
100
  # _0 = tl.zeros(offset.shape, offset.dtype)
99
- _0 = offset * 0
100
- return philox(seed, offset, _0, _0, _0, n_rounds)
101
+
102
+ offset_lo = offset.to(tl.uint32)
103
+ _0 = offset_lo * 0
104
+
105
+ if tl.constexpr(offset.dtype.primitive_bitwidth) > 32:
106
+ offset_hi = (offset >> 32).to(tl.uint32)
107
+ else:
108
+ offset_hi = _0
109
+
110
+ return philox(seed, offset_lo, offset_hi, _0, _0, n_rounds)
101
111
 
102
112
 
103
113
  # -------------------