triton-windows 3.2.0.post11__cp39-cp39-win_amd64.whl → 3.3.0a0.post11__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +3 -3
  3. triton/_internal_testing.py +59 -4
  4. triton/_utils.py +35 -0
  5. triton/backends/amd/compiler.py +121 -74
  6. triton/backends/amd/driver.py +77 -43
  7. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
  8. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
  13. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
  15. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
  16. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
  17. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
  18. triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
  19. triton/backends/amd/include/hip/hip_ext.h +4 -2
  20. triton/backends/amd/include/hip/hip_fp8.h +33 -0
  21. triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
  22. triton/backends/amd/include/hip/hip_version.h +3 -3
  23. triton/backends/amd/include/hip/hiprtc.h +25 -25
  24. triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
  25. triton/backends/amd/include/hsa/hsa.h +11 -2
  26. triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
  27. triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
  28. triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
  29. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
  30. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
  31. triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
  32. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
  33. triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
  34. triton/backends/amd/lib/asanrtl.bc +0 -0
  35. triton/backends/compiler.py +25 -225
  36. triton/backends/driver.py +7 -2
  37. triton/backends/nvidia/bin/ptxas.exe +0 -0
  38. triton/backends/nvidia/compiler.py +135 -90
  39. triton/backends/nvidia/driver.c +0 -1
  40. triton/backends/nvidia/driver.py +135 -49
  41. triton/backends/nvidia/include/cuda.h +2162 -241
  42. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  43. triton/compiler/__init__.py +2 -2
  44. triton/compiler/code_generator.py +334 -231
  45. triton/compiler/compiler.py +77 -66
  46. triton/language/__init__.py +22 -5
  47. triton/language/core.py +448 -74
  48. triton/language/extra/cuda/_experimental_tma.py +3 -5
  49. triton/language/math.py +1 -1
  50. triton/language/random.py +2 -1
  51. triton/language/semantic.py +206 -52
  52. triton/language/standard.py +35 -18
  53. triton/runtime/_allocation.py +32 -0
  54. triton/runtime/autotuner.py +27 -32
  55. triton/runtime/build.py +1 -48
  56. triton/runtime/cache.py +6 -6
  57. triton/runtime/errors.py +10 -0
  58. triton/runtime/interpreter.py +179 -45
  59. triton/runtime/jit.py +149 -190
  60. triton/testing.py +39 -11
  61. triton/tools/compile.py +27 -20
  62. triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
  63. triton/tools/mxfp.py +301 -0
  64. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/METADATA +5 -2
  65. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/RECORD +68 -59
  66. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/top_level.txt +2 -0
  67. /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
  68. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/WHEEL +0 -0
@@ -37,7 +37,7 @@ def _find_already_mmapped_dylib_on_linux(lib_name):
37
37
  # Load libc and get the dl_iterate_phdr symbol.
38
38
  try:
39
39
  dl_iterate_phdr = ctypes.CDLL('libc.so.6').dl_iterate_phdr
40
- except:
40
+ except Exception:
41
41
  return None
42
42
  # argtypes must use c_char_p to accept create_string_buffer.
43
43
  dl_iterate_phdr.argtypes = [callback_t, c_char_p]
@@ -185,35 +185,32 @@ def ty_to_cpp(ty):
185
185
  }[ty]
186
186
 
187
187
 
188
- def make_launcher(constants, signature, ids, warp_size):
189
- start_desc = len(signature)
190
- #signature = generate_cu_signature(constants, signature, ids)
191
- arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
188
+ def make_launcher(constants, signature, warp_size):
189
+
190
+ def _serialize_signature(sig):
191
+ if isinstance(sig, tuple):
192
+ return ','.join(map(_serialize_signature, sig))
193
+ return sig
192
194
 
193
195
  def _extracted_type(ty):
196
+ if isinstance(ty, tuple):
197
+ val = ','.join(map(_extracted_type, ty))
198
+ return f"[{val}]"
194
199
  if ty[0] == '*':
195
200
  return "PyObject*"
196
- return {
197
- 'i1': 'int32_t',
198
- 'i8': 'int8_t',
199
- 'i16': 'int16_t',
200
- 'i32': 'int32_t',
201
- 'i64': 'int64_t',
202
- 'u1': 'uint32_t',
203
- 'u8': 'uint8_t',
204
- 'u16': 'uint16_t',
205
- 'u32': 'uint32_t',
206
- 'u64': 'uint64_t',
207
- 'fp16': 'float',
208
- 'bf16': 'float',
209
- 'fp32': 'float',
210
- 'f32': 'float',
211
- 'fp64': 'double',
212
- }[ty]
201
+ if ty in ("constexpr"):
202
+ return "PyObject*"
203
+ return ty_to_cpp(ty)
213
204
 
214
205
  def format_of(ty):
206
+ if isinstance(ty, tuple):
207
+ val = ''.join(map(format_of, ty))
208
+ return f"({val})"
209
+ if ty[0] == '*':
210
+ return "O"
211
+ if ty in ("constexpr"):
212
+ return "O"
215
213
  return {
216
- "PyObject*": "O",
217
214
  "float": "f",
218
215
  "double": "d",
219
216
  "long": "l",
@@ -225,16 +222,29 @@ def make_launcher(constants, signature, ids, warp_size):
225
222
  "uint16_t": "H",
226
223
  "uint32_t": "I",
227
224
  "uint64_t": "K",
228
- }[ty]
225
+ }[ty_to_cpp(ty)]
229
226
 
230
- args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
231
- format = "iiiKKOOOO" + args_format
227
+ args_format = ''.join([format_of(ty) for ty in signature.values()])
228
+ format = "piiiKKOOOO" + args_format
229
+ signature = ','.join(map(_serialize_signature, signature.values()))
230
+ signature = list(filter(bool, signature.split(',')))
231
+ signature = {i: s for i, s in enumerate(signature)}
232
232
  args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
233
-
233
+ # Record the end of regular arguments;
234
+ # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
235
+ arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr")
236
+ internal_args_list = []
237
+ for i, ty in signature.items():
238
+ if ty[0] == "*":
239
+ internal_args_list.append(f"ptr_info{i}.dev_ptr")
240
+ elif ty != "constexpr":
241
+ internal_args_list.append(f"_arg{i}")
234
242
  libhip_path = _get_path_to_hip_runtime_dylib()
235
243
 
236
244
  # generate glue code
237
- params = [i for i in signature.keys() if i not in constants]
245
+ params = list(range(len(signature)))
246
+ params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
247
+ params.append("&global_scratch")
238
248
  src = f"""
239
249
  #define __HIP_PLATFORM_AMD__
240
250
  #include <hip/hip_runtime.h>
@@ -257,6 +267,12 @@ static const char *hipLibSearchPaths[] = {{"{libhip_path}"}};
257
267
  unsigned int blockDimY, unsigned int blockDimZ, \\
258
268
  unsigned int sharedMemBytes, hipStream_t stream, \\
259
269
  void **kernelParams, void **extra) \\
270
+ FOR_EACH_ERR_FN(hipModuleLaunchCooperativeKernel, hipFunction_t f, \\
271
+ unsigned int gridDimX, unsigned int gridDimY, \\
272
+ unsigned int gridDimZ, unsigned int blockDimX, \\
273
+ unsigned int blockDimY, unsigned int blockDimZ, \\
274
+ unsigned int sharedMemBytes, hipStream_t stream, \\
275
+ void **kernelParams, void **extra) \\
260
276
  FOR_EACH_ERR_FN(hipPointerGetAttribute, void *data, \\
261
277
  hipPointer_attribute attribute, hipDeviceptr_t ptr)
262
278
 
@@ -328,13 +344,18 @@ static inline void gpuAssert(hipError_t code, const char *file, int line)
328
344
 
329
345
  #define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
330
346
 
331
- static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
347
+ static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
332
348
  // printf("_launch hip kernel\\n");
333
- void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};
349
+ hipDeviceptr_t global_scratch = 0;
350
+ void *params[] = {{ {', '.join(params)} }};
351
+ if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
352
+ HIP_CHECK(hipSymbolTable.hipModuleLaunchCooperativeKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
353
+ return;
354
+ }}
334
355
  if (gridX*gridY*gridZ > 0) {{
335
- HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
336
- }}
356
+ HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
337
357
  }}
358
+ }}
338
359
 
339
360
  typedef struct _DevicePtrInfo {{
340
361
  hipDeviceptr_t dev_ptr;
@@ -387,12 +408,14 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
387
408
  int gridX, gridY, gridZ;
388
409
  uint64_t _stream;
389
410
  uint64_t _function;
411
+ int launch_cooperative_grid;
390
412
  PyObject *launch_enter_hook = NULL;
391
413
  PyObject *launch_exit_hook = NULL;
392
414
  PyObject *kernel_metadata = NULL;
393
415
  PyObject *launch_metadata = NULL;
394
416
  {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
395
- if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function,
417
+ if(!PyArg_ParseTuple(args, \"{format}\", &launch_cooperative_grid,
418
+ &gridX, &gridY, &gridZ, &_stream, &_function,
396
419
  &kernel_metadata, &launch_metadata,
397
420
  &launch_enter_hook, &launch_exit_hook {args_list})) {{
398
421
  return NULL;
@@ -415,7 +438,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
415
438
 
416
439
  // raise exception asap
417
440
  {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
418
- _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''});
441
+ _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
419
442
 
420
443
  if(launch_exit_hook != Py_None){{
421
444
  PyObject* args = Py_BuildValue("(O)", launch_metadata);
@@ -464,17 +487,17 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
464
487
  class HIPLauncher(object):
465
488
 
466
489
  def __init__(self, src, metadata):
467
- ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
468
490
  constants = src.constants if hasattr(src, "constants") else dict()
469
- cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
470
- constants = {cst_key(key): value for key, value in constants.items()}
471
- signature = {cst_key(key): value for key, value in src.signature.items()}
472
- src = make_launcher(constants, signature, ids, metadata.warp_size)
491
+ arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
492
+ constants = {arg_idx(idx): value for idx, value in constants.items()}
493
+ signature = {idx: value for idx, value in src.signature.items()}
494
+ src = make_launcher(constants, signature, metadata.warp_size)
473
495
  mod = compile_module_from_src(src, "__triton_launcher")
474
496
  self.launch = mod.launch
497
+ self.launch_cooperative_grid = metadata.launch_cooperative_grid
475
498
 
476
- def __call__(self, *args, **kwargs):
477
- self.launch(*args, **kwargs)
499
+ def __call__(self, *args):
500
+ self.launch(self.launch_cooperative_grid, *args)
478
501
 
479
502
 
480
503
  class HIPDriver(GPUDriver):
@@ -490,8 +513,11 @@ class HIPDriver(GPUDriver):
490
513
 
491
514
  @staticmethod
492
515
  def is_active():
493
- import torch
494
- return torch.version.hip is not None
516
+ try:
517
+ import torch
518
+ return torch.version.hip is not None
519
+ except ImportError:
520
+ return False
495
521
 
496
522
  def get_current_target(self):
497
523
  device = self.get_current_device()
@@ -500,6 +526,11 @@ class HIPDriver(GPUDriver):
500
526
  warp_size = device_properties['warpSize']
501
527
  return GPUTarget("hip", arch.split(':')[0], warp_size)
502
528
 
529
+ def get_active_torch_device(self):
530
+ import torch
531
+ # when using hip devices, the device string in pytorch is "cuda"
532
+ return torch.device("cuda", self.get_current_device())
533
+
503
534
  def get_benchmarker(self):
504
535
  from triton.testing import do_bench
505
536
  return do_bench
@@ -510,3 +541,6 @@ class HIPDriver(GPUDriver):
510
541
  # It's the same as the Nvidia backend.
511
542
  cache_size = 256 * 1024 * 1024
512
543
  return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
544
+
545
+ def clear_cache(self, cache):
546
+ cache.zero_()
@@ -266,14 +266,14 @@ __device__ static inline int __mul24(int x, int y) {
266
266
  }
267
267
 
268
268
  __device__ static inline long long __mul64hi(long long int x, long long int y) {
269
- ulong x0 = (ulong)x & 0xffffffffUL;
270
- long x1 = x >> 32;
271
- ulong y0 = (ulong)y & 0xffffffffUL;
272
- long y1 = y >> 32;
273
- ulong z0 = x0*y0;
274
- long t = x1*y0 + (z0 >> 32);
275
- long z1 = t & 0xffffffffL;
276
- long z2 = t >> 32;
269
+ unsigned long long x0 = (unsigned long long)x & 0xffffffffUL;
270
+ long long x1 = x >> 32;
271
+ unsigned long long y0 = (unsigned long long)y & 0xffffffffUL;
272
+ long long y1 = y >> 32;
273
+ unsigned long long z0 = x0*y0;
274
+ long long t = x1*y0 + (z0 >> 32);
275
+ long long z1 = t & 0xffffffffL;
276
+ long long z2 = t >> 32;
277
277
  z1 = x0*y1 + z1;
278
278
  return x1*y1 + z2 + (z1 >> 32);
279
279
  }
@@ -300,14 +300,14 @@ __device__ static inline int __umul24(unsigned int x, unsigned int y) {
300
300
 
301
301
  __device__
302
302
  static inline unsigned long long __umul64hi(unsigned long long int x, unsigned long long int y) {
303
- ulong x0 = x & 0xffffffffUL;
304
- ulong x1 = x >> 32;
305
- ulong y0 = y & 0xffffffffUL;
306
- ulong y1 = y >> 32;
307
- ulong z0 = x0*y0;
308
- ulong t = x1*y0 + (z0 >> 32);
309
- ulong z1 = t & 0xffffffffUL;
310
- ulong z2 = t >> 32;
303
+ unsigned long long x0 = x & 0xffffffffUL;
304
+ unsigned long long x1 = x >> 32;
305
+ unsigned long long y0 = y & 0xffffffffUL;
306
+ unsigned long long y1 = y >> 32;
307
+ unsigned long long z0 = x0*y0;
308
+ unsigned long long t = x1*y0 + (z0 >> 32);
309
+ unsigned long long z1 = t & 0xffffffffUL;
310
+ unsigned long long z2 = t >> 32;
311
311
  z1 = x0*y1 + z1;
312
312
  return x1*y1 + z2 + (z1 >> 32);
313
313
  }
@@ -322,11 +322,6 @@ __device__ static inline unsigned int __usad(unsigned int x, unsigned int y, uns
322
322
  return __ockl_sadd_u32(x, y, z);
323
323
  }
324
324
 
325
- __device__ static inline unsigned int __lane_id() {
326
- return __builtin_amdgcn_mbcnt_hi(
327
- -1, __builtin_amdgcn_mbcnt_lo(-1, 0));
328
- }
329
-
330
325
  __device__
331
326
  static inline unsigned int __mbcnt_lo(unsigned int x, unsigned int y) {return __builtin_amdgcn_mbcnt_lo(x,y);};
332
327
 
@@ -339,6 +334,7 @@ HIP specific device functions
339
334
 
340
335
  #if !defined(__HIPCC_RTC__)
341
336
  #include "amd_warp_functions.h"
337
+ #include "amd_warp_sync_functions.h"
342
338
  #endif
343
339
 
344
340
  #define MASK1 0x00ff00ff
@@ -687,34 +683,6 @@ void __named_sync() { __builtin_amdgcn_s_barrier(); }
687
683
 
688
684
  #endif // __HIP_DEVICE_COMPILE__
689
685
 
690
- // warp vote function __all __any __ballot
691
- __device__
692
- inline
693
- int __all(int predicate) {
694
- return __ockl_wfall_i32(predicate);
695
- }
696
-
697
- __device__
698
- inline
699
- int __any(int predicate) {
700
- return __ockl_wfany_i32(predicate);
701
- }
702
-
703
- // XXX from llvm/include/llvm/IR/InstrTypes.h
704
- #define ICMP_NE 33
705
-
706
- __device__
707
- inline
708
- unsigned long long int __ballot(int predicate) {
709
- return __builtin_amdgcn_uicmp(predicate, 0, ICMP_NE);
710
- }
711
-
712
- __device__
713
- inline
714
- unsigned long long int __ballot64(int predicate) {
715
- return __builtin_amdgcn_uicmp(predicate, 0, ICMP_NE);
716
- }
717
-
718
686
  // hip.amdgcn.bc - lanemask
719
687
  __device__
720
688
  inline
@@ -877,6 +845,10 @@ int __syncthreads_or(int predicate)
877
845
  #if (defined(__GFX10__) || defined(__GFX11__))
878
846
  #define HW_ID_WGP_ID_SIZE 4
879
847
  #define HW_ID_WGP_ID_OFFSET 10
848
+ #if (defined(__AMDGCN_CUMODE__))
849
+ #define HW_ID_CU_ID_SIZE 1
850
+ #define HW_ID_CU_ID_OFFSET 8
851
+ #endif
880
852
  #else
881
853
  #define HW_ID_CU_ID_SIZE 4
882
854
  #define HW_ID_CU_ID_OFFSET 8
@@ -933,6 +905,10 @@ unsigned __smid(void)
933
905
  GETREG_IMMED(HW_ID_WGP_ID_SIZE - 1, HW_ID_WGP_ID_OFFSET, HW_ID));
934
906
  unsigned sa_id = __builtin_amdgcn_s_getreg(
935
907
  GETREG_IMMED(HW_ID_SA_ID_SIZE - 1, HW_ID_SA_ID_OFFSET, HW_ID));
908
+ #if (defined(__AMDGCN_CUMODE__))
909
+ unsigned cu_id = __builtin_amdgcn_s_getreg(
910
+ GETREG_IMMED(HW_ID_CU_ID_SIZE - 1, HW_ID_CU_ID_OFFSET, HW_ID));
911
+ #endif
936
912
  #else
937
913
  #if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
938
914
  unsigned xcc_id = __builtin_amdgcn_s_getreg(
@@ -945,6 +921,9 @@ unsigned __smid(void)
945
921
  unsigned temp = se_id;
946
922
  temp = (temp << HW_ID_SA_ID_SIZE) | sa_id;
947
923
  temp = (temp << HW_ID_WGP_ID_SIZE) | wgp_id;
924
+ #if (defined(__AMDGCN_CUMODE__))
925
+ temp = (temp << HW_ID_CU_ID_SIZE) | cu_id;
926
+ #endif
948
927
  return temp;
949
928
  //TODO : CU Mode impl
950
929
  #elif (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
@@ -612,11 +612,17 @@ float atomicMin(float* addr, float val) {
612
612
  #if defined(__AMDGCN_UNSAFE_FP_ATOMICS__)
613
613
  return unsafeAtomicMin(addr, val);
614
614
  #else
615
+ typedef union u_hold {
616
+ float a;
617
+ unsigned int b;
618
+ } u_hold_t;
619
+ u_hold_t u{val};
620
+ bool neg_zero = 0x80000000U == u.b;
615
621
  #if __has_builtin(__hip_atomic_load) && \
616
622
  __has_builtin(__hip_atomic_compare_exchange_strong)
617
623
  float value = __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
618
624
  bool done = false;
619
- while (!done && value > val) {
625
+ while (!done && (value > val || (neg_zero && value == 0.0f))) {
620
626
  done = __hip_atomic_compare_exchange_strong(addr, &value, val,
621
627
  __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
622
628
  }
@@ -625,7 +631,7 @@ float atomicMin(float* addr, float val) {
625
631
  unsigned int *uaddr = (unsigned int *)addr;
626
632
  unsigned int value = __atomic_load_n(uaddr, __ATOMIC_RELAXED);
627
633
  bool done = false;
628
- while (!done && __uint_as_float(value) > val) {
634
+ while (!done && (__uint_as_float(value) > val || (neg_zero && __uint_as_float(value) == 0.0f))) {
629
635
  done = __atomic_compare_exchange_n(uaddr, &value, __float_as_uint(val), false,
630
636
  __ATOMIC_RELAXED, __ATOMIC_RELAXED);
631
637
  }
@@ -658,11 +664,17 @@ double atomicMin(double* addr, double val) {
658
664
  #if defined(__AMDGCN_UNSAFE_FP_ATOMICS__)
659
665
  return unsafeAtomicMin(addr, val);
660
666
  #else
667
+ typedef union u_hold {
668
+ double a;
669
+ unsigned long long b;
670
+ } u_hold_t;
671
+ u_hold_t u{val};
672
+ bool neg_zero = 0x8000000000000000ULL == u.b;
661
673
  #if __has_builtin(__hip_atomic_load) && \
662
674
  __has_builtin(__hip_atomic_compare_exchange_strong)
663
675
  double value = __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
664
676
  bool done = false;
665
- while (!done && value > val) {
677
+ while (!done && (value > val || (neg_zero && value == 0.0))) {
666
678
  done = __hip_atomic_compare_exchange_strong(addr, &value, val,
667
679
  __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
668
680
  }
@@ -671,7 +683,8 @@ double atomicMin(double* addr, double val) {
671
683
  unsigned long long *uaddr = (unsigned long long *)addr;
672
684
  unsigned long long value = __atomic_load_n(uaddr, __ATOMIC_RELAXED);
673
685
  bool done = false;
674
- while (!done && __longlong_as_double(value) > val) {
686
+ while (!done &&
687
+ (__longlong_as_double(value) > val || (neg_zero && __longlong_as_double(value) == 0.0))) {
675
688
  done = __atomic_compare_exchange_n(uaddr, &value, __double_as_longlong(val), false,
676
689
  __ATOMIC_RELAXED, __ATOMIC_RELAXED);
677
690
  }
@@ -856,11 +869,17 @@ float atomicMax(float* addr, float val) {
856
869
  #if defined(__AMDGCN_UNSAFE_FP_ATOMICS__)
857
870
  return unsafeAtomicMax(addr, val);
858
871
  #else
872
+ typedef union u_hold {
873
+ float a;
874
+ unsigned int b;
875
+ } u_hold_t;
876
+ u_hold_t u{val};
877
+ bool neg_zero = 0x80000000U == u.b;
859
878
  #if __has_builtin(__hip_atomic_load) && \
860
879
  __has_builtin(__hip_atomic_compare_exchange_strong)
861
880
  float value = __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
862
881
  bool done = false;
863
- while (!done && value < val) {
882
+ while (!done && (value < val || (neg_zero && value == 0.0f))) {
864
883
  done = __hip_atomic_compare_exchange_strong(addr, &value, val,
865
884
  __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
866
885
  }
@@ -869,7 +888,7 @@ float atomicMax(float* addr, float val) {
869
888
  unsigned int *uaddr = (unsigned int *)addr;
870
889
  unsigned int value = __atomic_load_n(uaddr, __ATOMIC_RELAXED);
871
890
  bool done = false;
872
- while (!done && __uint_as_float(value) < val) {
891
+ while (!done && (__uint_as_float(value) < val || (neg_zero && __uint_as_float(value) == 0.0f))) {
873
892
  done = __atomic_compare_exchange_n(uaddr, &value, __float_as_uint(val), false,
874
893
  __ATOMIC_RELAXED, __ATOMIC_RELAXED);
875
894
  }
@@ -902,11 +921,17 @@ double atomicMax(double* addr, double val) {
902
921
  #if defined(__AMDGCN_UNSAFE_FP_ATOMICS__)
903
922
  return unsafeAtomicMax(addr, val);
904
923
  #else
924
+ typedef union u_hold {
925
+ double a;
926
+ unsigned long long b;
927
+ } u_hold_t;
928
+ u_hold_t u{val};
929
+ bool neg_zero = 0x8000000000000000ULL == u.b;
905
930
  #if __has_builtin(__hip_atomic_load) && \
906
931
  __has_builtin(__hip_atomic_compare_exchange_strong)
907
932
  double value = __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
908
933
  bool done = false;
909
- while (!done && value < val) {
934
+ while (!done && (value < val || (neg_zero && value == 0.0))) {
910
935
  done = __hip_atomic_compare_exchange_strong(addr, &value, val,
911
936
  __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
912
937
  }
@@ -915,7 +940,8 @@ double atomicMax(double* addr, double val) {
915
940
  unsigned long long *uaddr = (unsigned long long *)addr;
916
941
  unsigned long long value = __atomic_load_n(uaddr, __ATOMIC_RELAXED);
917
942
  bool done = false;
918
- while (!done && __longlong_as_double(value) < val) {
943
+ while (!done &&
944
+ (__longlong_as_double(value) < val || (neg_zero && __longlong_as_double(value) == 0.0))) {
919
945
  done = __atomic_compare_exchange_n(uaddr, &value, __double_as_longlong(val), false,
920
946
  __ATOMIC_RELAXED, __ATOMIC_RELAXED);
921
947
  }
@@ -977,7 +1003,7 @@ unsigned int atomicDec(unsigned int* address, unsigned int val)
977
1003
  #else
978
1004
  return __builtin_amdgcn_atomic_dec32(address, val, __ATOMIC_RELAXED, "agent");
979
1005
  #endif // __gfx941__
980
-
1006
+
981
1007
  }
982
1008
 
983
1009
  __device__