triton-windows 3.3.1.post19__cp39-cp39-win_amd64.whl → 3.4.0.post20__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (166) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/testing.py +16 -12
  56. triton/tools/disasm.py +3 -4
  57. triton/tools/tensor_descriptor.py +36 -0
  58. triton/windows_utils.py +14 -6
  59. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  60. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  61. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  62. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  63. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  64. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  65. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  66. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  67. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  68. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  69. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  70. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  71. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  72. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  73. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  80. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  81. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  82. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  83. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  84. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  85. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  86. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  87. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  88. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  89. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  90. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  91. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  92. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  93. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  94. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  95. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  96. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  97. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  98. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  99. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  100. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  101. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  102. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  103. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  104. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  105. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  106. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  107. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  108. triton/backends/amd/include/hip/device_functions.h +0 -38
  109. triton/backends/amd/include/hip/driver_types.h +0 -468
  110. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  111. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  112. triton/backends/amd/include/hip/hip_common.h +0 -100
  113. triton/backends/amd/include/hip/hip_complex.h +0 -38
  114. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  115. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  116. triton/backends/amd/include/hip/hip_ext.h +0 -161
  117. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  118. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  119. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  120. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  121. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  122. triton/backends/amd/include/hip/hip_profile.h +0 -27
  123. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  124. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  125. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  126. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  127. triton/backends/amd/include/hip/hip_version.h +0 -17
  128. triton/backends/amd/include/hip/hiprtc.h +0 -421
  129. triton/backends/amd/include/hip/library_types.h +0 -78
  130. triton/backends/amd/include/hip/math_functions.h +0 -42
  131. triton/backends/amd/include/hip/surface_types.h +0 -63
  132. triton/backends/amd/include/hip/texture_types.h +0 -194
  133. triton/backends/amd/include/hsa/Brig.h +0 -1131
  134. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  135. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  136. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  137. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  138. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  139. triton/backends/amd/include/hsa/hsa.h +0 -5738
  140. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  141. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  142. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  143. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  144. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  145. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  146. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  147. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  148. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  149. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  150. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  151. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  152. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  153. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  154. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  155. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  156. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  157. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  158. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  159. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  160. triton/backends/amd/include/roctracer/roctx.h +0 -229
  161. triton/language/_utils.py +0 -21
  162. triton/language/extra/cuda/_experimental_tma.py +0 -106
  163. triton/tools/experimental_descriptor.py +0 -32
  164. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  165. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  166. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +0 -0
triton/language/math.py CHANGED
@@ -1,5 +1,4 @@
1
1
  from . import core
2
- from . import semantic
3
2
  from functools import wraps
4
3
  from typing import List
5
4
 
@@ -85,107 +84,107 @@ def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]:
85
84
  @core.builtin
86
85
  @_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"])
87
86
  @_add_math_2arg_docstr("most significant N bits of the 2N-bit product")
88
- def umulhi(x, y, _builder=None):
89
- x = semantic.to_tensor(x, _builder)
90
- y = semantic.to_tensor(y, _builder)
91
- x, y = core.binary_op_type_legalization(x, y, _builder)
92
- return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type)
87
+ def umulhi(x, y, _semantic=None):
88
+ x = _semantic.to_tensor(x)
89
+ y = _semantic.to_tensor(y)
90
+ x, y = core.binary_op_type_legalization(x, y, _semantic)
91
+ return core.tensor(_semantic.builder.create_umulhi(x.handle, y.handle), x.type)
93
92
 
94
93
 
95
94
  @core.builtin
96
95
  @_check_dtype(dtypes=["fp32", "fp64"])
97
96
  @_add_math_1arg_docstr("exponential")
98
97
  @core._tensor_member_fn
99
- def exp(x, _builder=None):
100
- x = semantic.to_tensor(x, _builder)
101
- return core.tensor(_builder.create_exp(x.handle), x.type)
98
+ def exp(x, _semantic=None):
99
+ x = _semantic.to_tensor(x)
100
+ return core.tensor(_semantic.builder.create_exp(x.handle), x.type)
102
101
 
103
102
 
104
103
  @core.builtin
105
104
  @_check_dtype(dtypes=["fp32", "fp64"])
106
105
  @_add_math_1arg_docstr("exponential (base 2)")
107
106
  @core._tensor_member_fn
108
- def exp2(x, _builder=None):
109
- x = semantic.to_tensor(x, _builder)
110
- return core.tensor(_builder.create_exp2(x.handle), x.type)
107
+ def exp2(x, _semantic=None):
108
+ x = _semantic.to_tensor(x)
109
+ return core.tensor(_semantic.builder.create_exp2(x.handle), x.type)
111
110
 
112
111
 
113
112
  @core.builtin
114
113
  @_check_dtype(dtypes=["fp32", "fp64"])
115
114
  @_add_math_1arg_docstr("natural logarithm")
116
115
  @core._tensor_member_fn
117
- def log(x, _builder=None):
118
- x = semantic.to_tensor(x, _builder)
119
- return core.tensor(_builder.create_log(x.handle), x.type)
116
+ def log(x, _semantic=None):
117
+ x = _semantic.to_tensor(x)
118
+ return core.tensor(_semantic.builder.create_log(x.handle), x.type)
120
119
 
121
120
 
122
121
  @core.builtin
123
122
  @_check_dtype(dtypes=["fp32", "fp64"])
124
123
  @_add_math_1arg_docstr("logarithm (base 2)")
125
124
  @core._tensor_member_fn
126
- def log2(x, _builder=None):
127
- x = semantic.to_tensor(x, _builder)
128
- return core.tensor(_builder.create_log2(x.handle), x.type)
125
+ def log2(x, _semantic=None):
126
+ x = _semantic.to_tensor(x)
127
+ return core.tensor(_semantic.builder.create_log2(x.handle), x.type)
129
128
 
130
129
 
131
130
  @core.builtin
132
131
  @_check_dtype(dtypes=["fp32", "fp64"])
133
132
  @_add_math_1arg_docstr("cosine")
134
133
  @core._tensor_member_fn
135
- def cos(x, _builder=None):
136
- x = semantic.to_tensor(x, _builder)
137
- return core.tensor(_builder.create_cos(x.handle), x.type)
134
+ def cos(x, _semantic=None):
135
+ x = _semantic.to_tensor(x)
136
+ return core.tensor(_semantic.builder.create_cos(x.handle), x.type)
138
137
 
139
138
 
140
139
  @core.builtin
141
140
  @_check_dtype(dtypes=["fp32", "fp64"])
142
141
  @_add_math_1arg_docstr("sine")
143
142
  @core._tensor_member_fn
144
- def sin(x, _builder=None):
145
- x = semantic.to_tensor(x, _builder)
146
- return core.tensor(_builder.create_sin(x.handle), x.type)
143
+ def sin(x, _semantic=None):
144
+ x = _semantic.to_tensor(x)
145
+ return core.tensor(_semantic.builder.create_sin(x.handle), x.type)
147
146
 
148
147
 
149
148
  @core.builtin
150
149
  @_check_dtype(dtypes=["fp32", "fp64"])
151
150
  @_add_math_1arg_docstr("fast square root")
152
151
  @core._tensor_member_fn
153
- def sqrt(x, _builder=None):
154
- x = semantic.to_tensor(x, _builder)
155
- return core.tensor(_builder.create_sqrt(x.handle), x.type)
152
+ def sqrt(x, _semantic=None):
153
+ x = _semantic.to_tensor(x)
154
+ return core.tensor(_semantic.builder.create_sqrt(x.handle), x.type)
156
155
 
157
156
 
158
157
  @core.builtin
159
158
  @_check_dtype(dtypes=["fp32"])
160
159
  @_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)")
161
160
  @core._tensor_member_fn
162
- def sqrt_rn(x, _builder=None):
163
- x = semantic.to_tensor(x, _builder)
164
- return core.tensor(_builder.create_precise_sqrt(x.handle), x.type)
161
+ def sqrt_rn(x, _semantic=None):
162
+ x = _semantic.to_tensor(x)
163
+ return core.tensor(_semantic.builder.create_precise_sqrt(x.handle), x.type)
165
164
 
166
165
 
167
166
  @core.builtin
168
167
  @_check_dtype(dtypes=["fp32", "fp64"])
169
168
  @_add_math_1arg_docstr("inverse square root")
170
169
  @core._tensor_member_fn
171
- def rsqrt(x, _builder=None):
172
- x = semantic.to_tensor(x, _builder)
173
- return core.tensor(_builder.create_rsqrt(x.handle), x.type)
170
+ def rsqrt(x, _semantic=None):
171
+ x = _semantic.to_tensor(x)
172
+ return core.tensor(_semantic.builder.create_rsqrt(x.handle), x.type)
174
173
 
175
174
 
176
175
  @core._tensor_member_fn
177
176
  @core.builtin
178
177
  @_add_math_1arg_docstr("absolute value")
179
- def abs(x, _builder=None):
180
- x = semantic.to_tensor(x, _builder)
178
+ def abs(x, _semantic=None):
179
+ x = _semantic.to_tensor(x)
181
180
  dtype = x.dtype
182
181
  if dtype.is_fp8e4b15():
183
- mask = core.full(x.shape, 0x7F, core.int8, _builder=_builder)
184
- return core.tensor(_builder.create_and(x.handle, mask.handle), x.type)
182
+ mask = core.full(x.shape, 0x7F, core.int8, _semantic=_semantic)
183
+ return core.tensor(_semantic.builder.create_and(x.handle, mask.handle), x.type)
185
184
  elif dtype.is_floating():
186
- return core.tensor(_builder.create_fabs(x.handle), x.type)
185
+ return core.tensor(_semantic.builder.create_fabs(x.handle), x.type)
187
186
  elif dtype.is_int_signed():
188
- return core.tensor(_builder.create_iabs(x.handle), x.type)
187
+ return core.tensor(_semantic.builder.create_iabs(x.handle), x.type)
189
188
  elif dtype.is_int_unsigned():
190
189
  return x # no-op
191
190
  else:
@@ -194,57 +193,57 @@ def abs(x, _builder=None):
194
193
 
195
194
  @core.builtin
196
195
  @_add_math_2arg_docstr("fast division")
197
- def fdiv(x, y, ieee_rounding=False, _builder=None):
198
- ieee_rounding = core._constexpr_to_value(ieee_rounding)
199
- x = semantic.to_tensor(x, _builder)
200
- y = semantic.to_tensor(y, _builder)
201
- return semantic.fdiv(x, y, ieee_rounding, _builder)
196
+ def fdiv(x, y, ieee_rounding=False, _semantic=None):
197
+ ieee_rounding = core._unwrap_if_constexpr(ieee_rounding)
198
+ x = _semantic.to_tensor(x)
199
+ y = _semantic.to_tensor(y)
200
+ return _semantic.fdiv(x, y, ieee_rounding)
202
201
 
203
202
 
204
203
  @core.builtin
205
204
  @_check_dtype(dtypes=["fp32"])
206
205
  @_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)")
207
- def div_rn(x, y, _builder=None):
208
- x = semantic.to_tensor(x, _builder)
209
- y = semantic.to_tensor(y, _builder)
210
- x, y = core.binary_op_type_legalization(x, y, _builder)
211
- return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type)
206
+ def div_rn(x, y, _semantic=None):
207
+ x = _semantic.to_tensor(x)
208
+ y = _semantic.to_tensor(y)
209
+ x, y = core.binary_op_type_legalization(x, y, _semantic)
210
+ return core.tensor(_semantic.builder.create_precise_divf(x.handle, y.handle), x.type)
212
211
 
213
212
 
214
213
  @core.builtin
215
214
  @_check_dtype(dtypes=["fp32", "fp64"])
216
215
  @_add_math_1arg_docstr("error function")
217
216
  @core._tensor_member_fn
218
- def erf(x, _builder=None):
219
- x = semantic.to_tensor(x, _builder)
220
- return core.tensor(_builder.create_erf(x.handle), x.type)
217
+ def erf(x, _semantic=None):
218
+ x = _semantic.to_tensor(x)
219
+ return core.tensor(_semantic.builder.create_erf(x.handle), x.type)
221
220
 
222
221
 
223
222
  @core.builtin
224
223
  @_check_dtype(dtypes=["fp32", "fp64"])
225
224
  @_add_math_1arg_docstr("floor")
226
225
  @core._tensor_member_fn
227
- def floor(x, _builder=None):
228
- x = semantic.to_tensor(x, _builder)
229
- return core.tensor(_builder.create_floor(x.handle), x.type)
226
+ def floor(x, _semantic=None):
227
+ x = _semantic.to_tensor(x)
228
+ return core.tensor(_semantic.builder.create_floor(x.handle), x.type)
230
229
 
231
230
 
232
231
  @core.builtin
233
232
  @_check_dtype(dtypes=["fp32", "fp64"])
234
233
  @_add_math_1arg_docstr("ceil")
235
234
  @core._tensor_member_fn
236
- def ceil(x, _builder=None):
237
- x = semantic.to_tensor(x, _builder)
238
- return core.tensor(_builder.create_ceil(x.handle), x.type)
235
+ def ceil(x, _semantic=None):
236
+ x = _semantic.to_tensor(x)
237
+ return core.tensor(_semantic.builder.create_ceil(x.handle), x.type)
239
238
 
240
239
 
241
240
  @core.builtin
242
241
  @_add_math_3arg_docstr("fused multiply-add")
243
- def fma(x, y, z, _builder=None):
244
- x = semantic.to_tensor(x, _builder)
245
- y = semantic.to_tensor(y, _builder)
246
- z = semantic.to_tensor(z, _builder)
247
- x, y = core.binary_op_type_legalization(x, y, _builder)
248
- z, x = core.binary_op_type_legalization(z, x, _builder)
249
- z, y = core.binary_op_type_legalization(z, y, _builder)
250
- return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type)
242
+ def fma(x, y, z, _semantic=None):
243
+ x = _semantic.to_tensor(x)
244
+ y = _semantic.to_tensor(y)
245
+ z = _semantic.to_tensor(z)
246
+ x, y = core.binary_op_type_legalization(x, y, _semantic)
247
+ z, x = core.binary_op_type_legalization(z, x, _semantic)
248
+ z, y = core.binary_op_type_legalization(z, y, _semantic)
249
+ return core.tensor(_semantic.builder.create_fma(x.handle, y.handle, z.handle), x.type)
triton/language/random.py CHANGED
@@ -51,6 +51,7 @@ def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
51
51
  c1 = tl.to_tensor(c1)
52
52
  c2 = tl.to_tensor(c2)
53
53
  c3 = tl.to_tensor(c3)
54
+
54
55
  if tl.constexpr(c0.dtype.primitive_bitwidth) == 32:
55
56
  int_dtype = tl.uint32
56
57
  seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
@@ -60,6 +61,7 @@ def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
60
61
  int_dtype = tl.uint64
61
62
  seed_hi = tl.full((1, ), 0, dtype=int_dtype)
62
63
  seed_lo = seed
64
+
63
65
  c0 = c0.to(int_dtype, bitcast=True)
64
66
  c1 = c1.to(int_dtype, bitcast=True)
65
67
  c2 = c2.to(int_dtype, bitcast=True)
@@ -96,8 +98,16 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
96
98
  :param offsets: The offsets to generate random numbers for.
97
99
  """
98
100
  # _0 = tl.zeros(offset.shape, offset.dtype)
99
- _0 = offset * 0
100
- return philox(seed, offset, _0, _0, _0, n_rounds)
101
+
102
+ offset_lo = offset.to(tl.uint32)
103
+ _0 = offset_lo * 0
104
+
105
+ if tl.constexpr(offset.dtype.primitive_bitwidth) > 32:
106
+ offset_hi = (offset >> 32).to(tl.uint32)
107
+ else:
108
+ offset_hi = _0
109
+
110
+ return philox(seed, offset_lo, offset_hi, _0, _0, n_rounds)
101
111
 
102
112
 
103
113
  # -------------------