warp-lang 1.9.0__py3-none-win_amd64.whl → 1.9.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

warp/codegen.py CHANGED
@@ -1244,6 +1244,11 @@ class Adjoint:
1244
1244
  A line directive for the given statement, or None if no line directive is needed.
1245
1245
  """
1246
1246
 
1247
+ if adj.filename == "unknown source file" or adj.fun_lineno == 0:
1248
+ # Early return if function is not associated with a source file or is otherwise invalid
1249
+ # TODO: Get line directives working with wp.map() functions
1250
+ return None
1251
+
1247
1252
  # lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
1248
1253
  # emit line directives in generated code if it's not being compiled with line information
1249
1254
  build_mode = val if (val := adj.builder_options.get("mode")) is not None else warp.config.mode
warp/config.py CHANGED
@@ -15,7 +15,7 @@
15
15
 
16
16
  from typing import Optional
17
17
 
18
- version: str = "1.9.0"
18
+ version: str = "1.9.1"
19
19
  """Warp version string"""
20
20
 
21
21
  verify_fp: bool = False
warp/context.py CHANGED
@@ -2244,21 +2244,7 @@ class Module:
2244
2244
  return self.hashers[block_dim].get_module_hash()
2245
2245
 
2246
2246
  def _use_ptx(self, device) -> bool:
2247
- # determine whether to use PTX or CUBIN
2248
- if device.is_cubin_supported:
2249
- # get user preference specified either per module or globally
2250
- preferred_cuda_output = self.options.get("cuda_output") or warp.config.cuda_output
2251
- if preferred_cuda_output is not None:
2252
- use_ptx = preferred_cuda_output == "ptx"
2253
- else:
2254
- # determine automatically: older drivers may not be able to handle PTX generated using newer
2255
- # CUDA Toolkits, in which case we fall back on generating CUBIN modules
2256
- use_ptx = runtime.driver_version >= runtime.toolkit_version
2257
- else:
2258
- # CUBIN not an option, must use PTX (e.g. CUDA Toolkit too old)
2259
- use_ptx = True
2260
-
2261
- return use_ptx
2247
+ return device.get_cuda_output_format(self.options.get("cuda_output")) == "ptx"
2262
2248
 
2263
2249
  def get_module_identifier(self) -> str:
2264
2250
  """Get an abbreviated module name to use for directories and files in the cache.
@@ -2278,19 +2264,7 @@ class Module:
2278
2264
  if device is None:
2279
2265
  device = runtime.get_device()
2280
2266
 
2281
- if device.is_cpu:
2282
- return None
2283
-
2284
- if self._use_ptx(device):
2285
- # use the default PTX arch if the device supports it
2286
- if warp.config.ptx_target_arch is not None:
2287
- output_arch = min(device.arch, warp.config.ptx_target_arch)
2288
- else:
2289
- output_arch = min(device.arch, runtime.default_ptx_arch)
2290
- else:
2291
- output_arch = device.arch
2292
-
2293
- return output_arch
2267
+ return device.get_cuda_compile_arch()
2294
2268
 
2295
2269
  def get_compile_output_name(
2296
2270
  self, device: Device | None, output_arch: int | None = None, use_ptx: bool | None = None
@@ -3327,6 +3301,78 @@ class Device:
3327
3301
  else:
3328
3302
  return False
3329
3303
 
3304
+ def get_cuda_output_format(self, preferred_cuda_output: str | None = None) -> str | None:
3305
+ """Determine the CUDA output format to use for this device.
3306
+
3307
+ This method is intended for internal use by Warp's compilation system.
3308
+ External users should not need to call this method directly.
3309
+
3310
+ It determines whether to use PTX or CUBIN output based on device capabilities,
3311
+ caller preferences, and runtime constraints.
3312
+
3313
+ Args:
3314
+ preferred_cuda_output: Caller's preferred format (``"ptx"``, ``"cubin"``, or ``None``).
3315
+ If ``None``, falls back to global config or automatic determination.
3316
+
3317
+ Returns:
3318
+ The output format to use: ``"ptx"``, ``"cubin"``, or ``None`` for CPU devices.
3319
+ """
3320
+
3321
+ if self.is_cpu:
3322
+ # CPU devices don't use CUDA compilation
3323
+ return None
3324
+
3325
+ if not self.is_cubin_supported:
3326
+ return "ptx"
3327
+
3328
+ # Use provided preference or fall back to global config
3329
+ if preferred_cuda_output is None:
3330
+ preferred_cuda_output = warp.config.cuda_output
3331
+
3332
+ if preferred_cuda_output is not None:
3333
+ # Caller specified a preference, use it if supported
3334
+ if preferred_cuda_output in ("ptx", "cubin"):
3335
+ return preferred_cuda_output
3336
+ else:
3337
+ # Invalid preference, fall back to automatic determination
3338
+ pass
3339
+
3340
+ # Determine automatically: Older drivers may not be able to handle PTX generated using newer CUDA Toolkits,
3341
+ # in which case we fall back on generating CUBIN modules
3342
+ return "ptx" if self.runtime.driver_version >= self.runtime.toolkit_version else "cubin"
3343
+
3344
+ def get_cuda_compile_arch(self) -> int | None:
3345
+ """Get the CUDA architecture to use when compiling code for this device.
3346
+
3347
+ This method is intended for internal use by Warp's compilation system.
3348
+ External users should not need to call this method directly.
3349
+
3350
+ Determines the appropriate compute capability version to use when compiling
3351
+ CUDA kernels for this device. The architecture depends on the device's
3352
+ CUDA output format preference and available target architectures.
3353
+
3354
+ For PTX output format, uses the minimum of the device's architecture and
3355
+ the configured PTX target architecture to ensure compatibility.
3356
+ For CUBIN output format, uses the device's exact architecture.
3357
+
3358
+ Returns:
3359
+ The compute capability version (e.g., 75 for ``sm_75``) to use for compilation,
3360
+ or ``None`` for CPU devices which don't use CUDA compilation.
3361
+ """
3362
+ if self.is_cpu:
3363
+ return None
3364
+
3365
+ if self.get_cuda_output_format() == "ptx":
3366
+ # use the default PTX arch if the device supports it
3367
+ if warp.config.ptx_target_arch is not None:
3368
+ output_arch = min(self.arch, warp.config.ptx_target_arch)
3369
+ else:
3370
+ output_arch = min(self.arch, runtime.default_ptx_arch)
3371
+ else:
3372
+ output_arch = self.arch
3373
+
3374
+ return output_arch
3375
+
3330
3376
 
3331
3377
  """ Meta-type for arguments that can be resolved to a concrete Device.
3332
3378
  """
@@ -4036,6 +4082,8 @@ class Runtime:
4036
4082
  self.core.wp_cuda_graph_insert_if_else.argtypes = [
4037
4083
  ctypes.c_void_p,
4038
4084
  ctypes.c_void_p,
4085
+ ctypes.c_int,
4086
+ ctypes.c_bool,
4039
4087
  ctypes.POINTER(ctypes.c_int),
4040
4088
  ctypes.POINTER(ctypes.c_void_p),
4041
4089
  ctypes.POINTER(ctypes.c_void_p),
@@ -4045,6 +4093,8 @@ class Runtime:
4045
4093
  self.core.wp_cuda_graph_insert_while.argtypes = [
4046
4094
  ctypes.c_void_p,
4047
4095
  ctypes.c_void_p,
4096
+ ctypes.c_int,
4097
+ ctypes.c_bool,
4048
4098
  ctypes.POINTER(ctypes.c_int),
4049
4099
  ctypes.POINTER(ctypes.c_void_p),
4050
4100
  ctypes.POINTER(ctypes.c_uint64),
@@ -4054,6 +4104,8 @@ class Runtime:
4054
4104
  self.core.wp_cuda_graph_set_condition.argtypes = [
4055
4105
  ctypes.c_void_p,
4056
4106
  ctypes.c_void_p,
4107
+ ctypes.c_int,
4108
+ ctypes.c_bool,
4057
4109
  ctypes.POINTER(ctypes.c_int),
4058
4110
  ctypes.c_uint64,
4059
4111
  ]
@@ -7053,6 +7105,8 @@ def capture_if(
7053
7105
  if not runtime.core.wp_cuda_graph_insert_if_else(
7054
7106
  device.context,
7055
7107
  stream.cuda_stream,
7108
+ device.get_cuda_compile_arch(),
7109
+ device.get_cuda_output_format() == "ptx",
7056
7110
  ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
7057
7111
  None if on_true is None else ctypes.byref(graph_on_true),
7058
7112
  None if on_false is None else ctypes.byref(graph_on_false),
@@ -7117,7 +7171,9 @@ def capture_if(
7117
7171
  capture_resume(main_graph, stream=stream)
7118
7172
 
7119
7173
 
7120
- def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph, stream: Stream = None, **kwargs):
7174
+ def capture_while(
7175
+ condition: warp.array(dtype=int), while_body: Callable | Graph, stream: Stream | None = None, **kwargs
7176
+ ):
7121
7177
  """Create a dynamic loop based on a condition.
7122
7178
 
7123
7179
  The condition value is retrieved from the first element of the ``condition`` array.
@@ -7185,6 +7241,8 @@ def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph
7185
7241
  if not runtime.core.wp_cuda_graph_insert_while(
7186
7242
  device.context,
7187
7243
  stream.cuda_stream,
7244
+ device.get_cuda_compile_arch(),
7245
+ device.get_cuda_output_format() == "ptx",
7188
7246
  ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
7189
7247
  ctypes.byref(body_graph),
7190
7248
  ctypes.byref(cond_handle),
@@ -7218,6 +7276,8 @@ def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph
7218
7276
  if not runtime.core.wp_cuda_graph_set_condition(
7219
7277
  device.context,
7220
7278
  stream.cuda_stream,
7279
+ device.get_cuda_compile_arch(),
7280
+ device.get_cuda_output_format() == "ptx",
7221
7281
  ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
7222
7282
  cond_handle,
7223
7283
  ):
@@ -7748,6 +7808,7 @@ def export_stubs(file): # pragma: no cover
7748
7808
  print("from typing import Callable", file=file)
7749
7809
  print("from typing import TypeVar", file=file)
7750
7810
  print("from typing import Generic", file=file)
7811
+ print("from typing import Sequence", file=file)
7751
7812
  print("from typing import overload as over", file=file)
7752
7813
  print(file=file)
7753
7814
 
@@ -7776,7 +7837,7 @@ def export_stubs(file): # pragma: no cover
7776
7837
  print(header, file=file)
7777
7838
  print(file=file)
7778
7839
 
7779
- def add_stub(f):
7840
+ def add_builtin_function_stub(f):
7780
7841
  args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items())
7781
7842
 
7782
7843
  return_str = ""
@@ -7796,12 +7857,162 @@ def export_stubs(file): # pragma: no cover
7796
7857
  print(' """', file=file)
7797
7858
  print(" ...\n\n", file=file)
7798
7859
 
7860
+ def add_vector_type_stub(cls, label):
7861
+ cls_name = cls.__name__
7862
+ scalar_type_name = cls._wp_scalar_type_.__name__
7863
+
7864
+ print(f"class {cls_name}:", file=file)
7865
+
7866
+ print(" @over", file=file)
7867
+ print(" def __init__(self) -> None:", file=file)
7868
+ print(f' """Construct a zero-initialized {label}."""', file=file)
7869
+ print(" ...\n\n", file=file)
7870
+
7871
+ print(" @over", file=file)
7872
+ print(f" def __init__(self, other: {cls_name}) -> None:", file=file)
7873
+ print(f' """Construct a {label} by copy."""', file=file)
7874
+ print(" ...\n\n", file=file)
7875
+
7876
+ args = ", ".join(f"{x}: {scalar_type_name}" for x in "xyzw"[: cls._length_])
7877
+ print(" @over", file=file)
7878
+ print(f" def __init__(self, {args}) -> None:", file=file)
7879
+ print(f' """Construct a {label} from its component values."""', file=file)
7880
+ print(" ...\n\n", file=file)
7881
+
7882
+ print(" @over", file=file)
7883
+ print(f" def __init__(self, args: Sequence[{scalar_type_name}]) -> None:", file=file)
7884
+ print(f' """Construct a {label} from a sequence of values."""', file=file)
7885
+ print(" ...\n\n", file=file)
7886
+
7887
+ print(" @over", file=file)
7888
+ print(f" def __init__(self, value: {scalar_type_name}) -> None:", file=file)
7889
+ print(f' """Construct a {label} filled with a value."""', file=file)
7890
+ print(" ...\n\n", file=file)
7891
+
7892
+ def add_matrix_type_stub(cls, label):
7893
+ cls_name = cls.__name__
7894
+ scalar_type_name = cls._wp_scalar_type_.__name__
7895
+ scalar_short_name = warp.types.scalar_short_name(cls._wp_scalar_type_)
7896
+
7897
+ print(f"class {cls_name}:", file=file)
7898
+
7899
+ print(" @over", file=file)
7900
+ print(" def __init__(self) -> None:", file=file)
7901
+ print(f' """Construct a zero-initialized {label}."""', file=file)
7902
+ print(" ...\n\n", file=file)
7903
+
7904
+ print(" @over", file=file)
7905
+ print(f" def __init__(self, other: {cls_name}) -> None:", file=file)
7906
+ print(f' """Construct a {label} by copy."""', file=file)
7907
+ print(" ...\n\n", file=file)
7908
+
7909
+ args = ", ".join(f"m{i}{j}: {scalar_type_name}" for i in range(cls._shape_[0]) for j in range(cls._shape_[1]))
7910
+ print(" @over", file=file)
7911
+ print(f" def __init__(self, {args}) -> None:", file=file)
7912
+ print(f' """Construct a {label} from its component values."""', file=file)
7913
+ print(" ...\n\n", file=file)
7914
+
7915
+ args = ", ".join(f"v{i}: vec{cls._shape_[0]}{scalar_short_name}" for i in range(cls._shape_[0]))
7916
+ print(" @over", file=file)
7917
+ print(f" def __init__(self, {args}) -> None:", file=file)
7918
+ print(f' """Construct a {label} from its row vectors."""', file=file)
7919
+ print(" ...\n\n", file=file)
7920
+
7921
+ print(" @over", file=file)
7922
+ print(f" def __init__(self, args: Sequence[{scalar_type_name}]) -> None:", file=file)
7923
+ print(f' """Construct a {label} from a sequence of values."""', file=file)
7924
+ print(" ...\n\n", file=file)
7925
+
7926
+ print(" @over", file=file)
7927
+ print(f" def __init__(self, value: {scalar_type_name}) -> None:", file=file)
7928
+ print(f' """Construct a {label} filled with a value."""', file=file)
7929
+ print(" ...\n\n", file=file)
7930
+
7931
+ def add_transform_type_stub(cls, label):
7932
+ cls_name = cls.__name__
7933
+ scalar_type_name = cls._wp_scalar_type_.__name__
7934
+ scalar_short_name = warp.types.scalar_short_name(cls._wp_scalar_type_)
7935
+
7936
+ print(f"class {cls_name}:", file=file)
7937
+
7938
+ print(" @over", file=file)
7939
+ print(" def __init__(self) -> None:", file=file)
7940
+ print(f' """Construct a zero-initialized {label}."""', file=file)
7941
+ print(" ...\n\n", file=file)
7942
+
7943
+ print(" @over", file=file)
7944
+ print(f" def __init__(self, other: {cls_name}) -> None:", file=file)
7945
+ print(f' """Construct a {label} by copy."""', file=file)
7946
+ print(" ...\n\n", file=file)
7947
+
7948
+ print(" @over", file=file)
7949
+ print(f" def __init__(self, p: vec3{scalar_short_name}, q: quat{scalar_short_name}) -> None:", file=file)
7950
+ print(f' """Construct a {label} from its p and q components."""', file=file)
7951
+ print(" ...\n\n", file=file)
7952
+
7953
+ args = ()
7954
+ args += tuple(f"p{x}: {scalar_type_name}" for x in "xyz")
7955
+ args += tuple(f"q{x}: {scalar_type_name}" for x in "xyzw")
7956
+ args = ", ".join(args)
7957
+ print(" @over", file=file)
7958
+ print(f" def __init__(self, {args}) -> None:", file=file)
7959
+ print(f' """Construct a {label} from its component values."""', file=file)
7960
+ print(" ...\n\n", file=file)
7961
+
7962
+ print(" @over", file=file)
7963
+ print(
7964
+ f" def __init__(self, p: Sequence[{scalar_type_name}], q: Sequence[{scalar_type_name}]) -> None:",
7965
+ file=file,
7966
+ )
7967
+ print(f' """Construct a {label} from two sequences of values."""', file=file)
7968
+ print(" ...\n\n", file=file)
7969
+
7970
+ print(" @over", file=file)
7971
+ print(f" def __init__(self, value: {scalar_type_name}) -> None:", file=file)
7972
+ print(f' """Construct a {label} filled with a value."""', file=file)
7973
+ print(" ...\n\n", file=file)
7974
+
7975
+ # Vector types.
7976
+ suffixes = ("h", "f", "d", "b", "ub", "s", "us", "i", "ui", "l", "ul")
7977
+ for length in (2, 3, 4):
7978
+ for suffix in suffixes:
7979
+ cls = getattr(warp.types, f"vec{length}{suffix}")
7980
+ add_vector_type_stub(cls, "vector")
7981
+
7982
+ print(f"vec{length} = vec{length}f", file=file)
7983
+
7984
+ # Matrix types.
7985
+ suffixes = ("h", "f", "d")
7986
+ for length in (2, 3, 4):
7987
+ shape = f"{length}{length}"
7988
+ for suffix in suffixes:
7989
+ cls = getattr(warp.types, f"mat{shape}{suffix}")
7990
+ add_matrix_type_stub(cls, "matrix")
7991
+
7992
+ print(f"mat{shape} = mat{shape}f", file=file)
7993
+
7994
+ # Quaternion types.
7995
+ suffixes = ("h", "f", "d")
7996
+ for suffix in suffixes:
7997
+ cls = getattr(warp.types, f"quat{suffix}")
7998
+ add_vector_type_stub(cls, "quaternion")
7999
+
8000
+ print("quat = quatf", file=file)
8001
+
8002
+ # Transformation types.
8003
+ suffixes = ("h", "f", "d")
8004
+ for suffix in suffixes:
8005
+ cls = getattr(warp.types, f"transform{suffix}")
8006
+ add_transform_type_stub(cls, "transformation")
8007
+
8008
+ print("transform = transformf", file=file)
8009
+
7799
8010
  for g in builtin_functions.values():
7800
8011
  if hasattr(g, "overloads"):
7801
8012
  for f in g.overloads:
7802
- add_stub(f)
8013
+ add_builtin_function_stub(f)
7803
8014
  elif isinstance(g, Function):
7804
- add_stub(g)
8015
+ add_builtin_function_stub(g)
7805
8016
 
7806
8017
 
7807
8018
  def export_builtins(file: io.TextIOBase): # pragma: no cover
@@ -45,7 +45,8 @@ def sincos_kernel(angle: wp.array(dtype=float), sin_out: wp.array(dtype=float),
45
45
  @wp.kernel
46
46
  def diagonal_kernel(output: wp.array(dtype=wp.mat33)):
47
47
  tid = wp.tid()
48
- output[tid] = wp.mat33(1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0)
48
+ d = float(tid + 1)
49
+ output[tid] = wp.mat33(d, 0.0, 0.0, 0.0, d * 2.0, 0.0, 0.0, 0.0, d * 3.0)
49
50
 
50
51
 
51
52
  @wp.kernel
@@ -19,6 +19,7 @@ import warp as wp
19
19
  from warp.context import type_str
20
20
  from warp.jax import get_jax_device
21
21
  from warp.types import array_t, launch_bounds_t, strides_from_shape
22
+ from warp.utils import warn
22
23
 
23
24
  _jax_warp_p = None
24
25
 
@@ -28,7 +29,7 @@ _registered_kernels = [None]
28
29
  _registered_kernel_to_id = {}
29
30
 
30
31
 
31
- def jax_kernel(kernel, launch_dims=None):
32
+ def jax_kernel(kernel, launch_dims=None, quiet=False):
32
33
  """Create a Jax primitive from a Warp kernel.
33
34
 
34
35
  NOTE: This is an experimental feature under development.
@@ -38,6 +39,7 @@ def jax_kernel(kernel, launch_dims=None):
38
39
  launch_dims: Optional. Specify the kernel launch dimensions. If None,
39
40
  dimensions are inferred from the shape of the first argument.
40
41
  This option when set will specify the output dimensions.
42
+ quiet: Optional. If True, suppress deprecation warnings with newer JAX versions.
41
43
 
42
44
  Limitations:
43
45
  - All kernel arguments must be contiguous arrays.
@@ -46,6 +48,27 @@ def jax_kernel(kernel, launch_dims=None):
46
48
  - Only the CUDA backend is supported.
47
49
  """
48
50
 
51
+ import jax
52
+
53
+ # check if JAX version supports this
54
+ if jax.__version_info__ < (0, 4, 25) or jax.__version_info__ >= (0, 8, 0):
55
+ msg = (
56
+ "This version of jax_kernel() requires JAX version 0.4.25 - 0.7.x, "
57
+ f"but installed JAX version is {jax.__version_info__}."
58
+ )
59
+ if jax.__version_info__ >= (0, 8, 0):
60
+ msg += " Please use warp.jax_experimental.ffi.jax_kernel instead."
61
+ raise RuntimeError(msg)
62
+
63
+ # deprecation warning
64
+ if jax.__version_info__ >= (0, 5, 0) and not quiet:
65
+ warn(
66
+ "This version of jax_kernel() is deprecated and will not be supported with newer JAX versions. "
67
+ "Please use the newer FFI version instead (warp.jax_experimental.ffi.jax_kernel). "
68
+ "In Warp release 1.10, the FFI version will become the default implementation of jax_kernel().",
69
+ DeprecationWarning,
70
+ )
71
+
49
72
  if _jax_warp_p is None:
50
73
  # Create and register the primitive
51
74
  _create_jax_warp_primitive()
@@ -29,6 +29,18 @@ from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_war
29
29
  from .xla_ffi import *
30
30
 
31
31
 
32
+ def check_jax_version():
33
+ # check if JAX version supports this
34
+ if jax.__version_info__ < (0, 5, 0):
35
+ msg = (
36
+ "This version of jax_kernel() requires JAX version 0.5.0 or higher, "
37
+ f"but installed JAX version is {jax.__version_info__}."
38
+ )
39
+ if jax.__version_info__ >= (0, 4, 25):
40
+ msg += " Please use warp.jax_experimental.custom_call.jax_kernel instead."
41
+ raise RuntimeError(msg)
42
+
43
+
32
44
  class GraphMode(IntEnum):
33
45
  NONE = 0 # don't capture a graph
34
46
  JAX = 1 # let JAX capture a graph
@@ -668,8 +680,12 @@ def jax_kernel(
668
680
  - There must be at least one output or input-output argument.
669
681
  - Only the CUDA backend is supported.
670
682
  """
683
+
684
+ check_jax_version()
685
+
671
686
  key = (
672
687
  kernel.func,
688
+ kernel.sig,
673
689
  num_outputs,
674
690
  vmap_method,
675
691
  tuple(launch_dims) if launch_dims else launch_dims,
@@ -726,6 +742,8 @@ def jax_callable(
726
742
  - Only the CUDA backend is supported.
727
743
  """
728
744
 
745
+ check_jax_version()
746
+
729
747
  if graph_compatible is not None:
730
748
  wp.utils.warn(
731
749
  "The `graph_compatible` argument is deprecated, use `graph_mode` instead.",
@@ -772,6 +790,8 @@ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = Tr
772
790
  graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
773
791
  """
774
792
 
793
+ check_jax_version()
794
+
775
795
  # TODO check that the name is not already registered
776
796
 
777
797
  def ffi_callback(call_frame):
@@ -475,17 +475,26 @@ _xla_data_type_to_constructor = {
475
475
  XLA_FFI_DataType.C64: jnp.complex64,
476
476
  XLA_FFI_DataType.C128: jnp.complex128,
477
477
  # XLA_FFI_DataType.TOKEN
478
- XLA_FFI_DataType.F8E5M2: jnp.float8_e5m2,
479
- XLA_FFI_DataType.F8E3M4: jnp.float8_e3m4,
480
- XLA_FFI_DataType.F8E4M3: jnp.float8_e4m3,
481
- XLA_FFI_DataType.F8E4M3FN: jnp.float8_e4m3fn,
482
- XLA_FFI_DataType.F8E4M3B11FNUZ: jnp.float8_e4m3b11fnuz,
483
- XLA_FFI_DataType.F8E5M2FNUZ: jnp.float8_e5m2fnuz,
484
- XLA_FFI_DataType.F8E4M3FNUZ: jnp.float8_e4m3fnuz,
485
478
  # XLA_FFI_DataType.F4E2M1FN: jnp.float4_e2m1fn.dtype,
486
479
  # XLA_FFI_DataType.F8E8M0FNU: jnp.float8_e8m0fnu.dtype,
487
480
  }
488
481
 
482
+ # newer types not supported by older versions
483
+ if hasattr(jnp, "float8_e5m2"):
484
+ _xla_data_type_to_constructor[XLA_FFI_DataType.F8E5M2] = jnp.float8_e5m2
485
+ if hasattr(jnp, "float8_e3m4"):
486
+ _xla_data_type_to_constructor[XLA_FFI_DataType.F8E3M4] = jnp.float8_e3m4
487
+ if hasattr(jnp, "float8_e4m3"):
488
+ _xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3] = jnp.float8_e4m3
489
+ if hasattr(jnp, "float8_e4m3fn"):
490
+ _xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3FN] = jnp.float8_e4m3fn
491
+ if hasattr(jnp, "float8_e4m3b11fnuz"):
492
+ _xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3B11FNUZ] = jnp.float8_e4m3b11fnuz
493
+ if hasattr(jnp, "float8_e5m2fnuz"):
494
+ _xla_data_type_to_constructor[XLA_FFI_DataType.F8E5M2FNUZ] = jnp.float8_e5m2fnuz
495
+ if hasattr(jnp, "float8_e4m3fnuz"):
496
+ _xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3FNUZ] = jnp.float8_e4m3fnuz
497
+
489
498
 
490
499
  ########################################################################
491
500
  # Helpers for translating between ctypes and python types
warp/native/builtin.h CHANGED
@@ -1093,8 +1093,8 @@ CUDA_CALLABLE inline T select(const C& cond, const T& a, const T& b)
1093
1093
  return (!!cond) ? b : a;
1094
1094
  }
1095
1095
 
1096
- template <typename C, typename T>
1097
- CUDA_CALLABLE inline void adj_select(const C& cond, const T& a, const T& b, C& adj_cond, T& adj_a, T& adj_b, const T& adj_ret)
1096
+ template <typename C, typename TA, typename TB, typename TRet>
1097
+ CUDA_CALLABLE inline void adj_select(const C& cond, const TA& a, const TB& b, C& adj_cond, TA& adj_a, TB& adj_b, const TRet& adj_ret)
1098
1098
  {
1099
1099
  // The double NOT operator !! casts to bool without compiler warnings.
1100
1100
  if (!!cond)
@@ -1110,8 +1110,8 @@ CUDA_CALLABLE inline T where(const C& cond, const T& a, const T& b)
1110
1110
  return (!!cond) ? a : b;
1111
1111
  }
1112
1112
 
1113
- template <typename C, typename T>
1114
- CUDA_CALLABLE inline void adj_where(const C& cond, const T& a, const T& b, C& adj_cond, T& adj_a, T& adj_b, const T& adj_ret)
1113
+ template <typename C, typename TA, typename TB, typename TRet>
1114
+ CUDA_CALLABLE inline void adj_where(const C& cond, const TA& a, const TB& b, C& adj_cond, TA& adj_a, TB& adj_b, const TRet& adj_ret)
1115
1115
  {
1116
1116
  // The double NOT operator !! casts to bool without compiler warnings.
1117
1117
  if (!!cond)
warp/native/sort.cu CHANGED
@@ -23,7 +23,7 @@
23
23
 
24
24
  #include <cub/cub.cuh>
25
25
 
26
- #include <map>
26
+ #include <unordered_map>
27
27
 
28
28
  // temporary buffer for radix sort
29
29
  struct RadixSortTemp
@@ -32,8 +32,8 @@ struct RadixSortTemp
32
32
  size_t size = 0;
33
33
  };
34
34
 
35
- // map temp buffers to CUDA contexts
36
- static std::map<void*, RadixSortTemp> g_radix_sort_temp_map;
35
+ // use unique temp buffers per CUDA stream to avoid race conditions
36
+ static std::unordered_map<void*, RadixSortTemp> g_radix_sort_temp_map;
37
37
 
38
38
 
39
39
  template <typename KeyType>
@@ -44,6 +44,8 @@ void radix_sort_reserve_internal(void* context, int n, void** mem_out, size_t* s
44
44
  cub::DoubleBuffer<KeyType> d_keys;
45
45
  cub::DoubleBuffer<int> d_values;
46
46
 
47
+ CUstream stream = static_cast<CUstream>(wp_cuda_stream_get_current());
48
+
47
49
  // compute temporary memory required
48
50
  size_t sort_temp_size;
49
51
  check_cuda(cub::DeviceRadixSort::SortPairs(
@@ -52,12 +54,9 @@ void radix_sort_reserve_internal(void* context, int n, void** mem_out, size_t* s
52
54
  d_keys,
53
55
  d_values,
54
56
  n, 0, sizeof(KeyType)*8,
55
- (cudaStream_t)wp_cuda_stream_get_current()));
56
-
57
- if (!context)
58
- context = wp_cuda_context_get_current();
57
+ stream));
59
58
 
60
- RadixSortTemp& temp = g_radix_sort_temp_map[context];
59
+ RadixSortTemp& temp = g_radix_sort_temp_map[stream];
61
60
 
62
61
  if (sort_temp_size > temp.size)
63
62
  {
@@ -77,6 +76,17 @@ void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
77
76
  radix_sort_reserve_internal<int>(context, n, mem_out, size_out);
78
77
  }
79
78
 
79
+ void radix_sort_release(void* context, void* stream)
80
+ {
81
+ // release temporary buffer for the given stream, if it exists
82
+ auto it = g_radix_sort_temp_map.find(stream);
83
+ if (it != g_radix_sort_temp_map.end())
84
+ {
85
+ wp_free_device(context, it->second.mem);
86
+ g_radix_sort_temp_map.erase(it);
87
+ }
88
+ }
89
+
80
90
  template <typename KeyType>
81
91
  void radix_sort_pairs_device(void* context, KeyType* keys, int* values, int n)
82
92
  {
@@ -153,6 +163,8 @@ void segmented_sort_reserve(void* context, int n, int num_segments, void** mem_o
153
163
  int* start_indices = NULL;
154
164
  int* end_indices = NULL;
155
165
 
166
+ CUstream stream = static_cast<CUstream>(wp_cuda_stream_get_current());
167
+
156
168
  // compute temporary memory required
157
169
  size_t sort_temp_size;
158
170
  check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
@@ -166,12 +178,9 @@ void segmented_sort_reserve(void* context, int n, int num_segments, void** mem_o
166
178
  end_indices,
167
179
  0,
168
180
  32,
169
- (cudaStream_t)wp_cuda_stream_get_current()));
170
-
171
- if (!context)
172
- context = wp_cuda_context_get_current();
181
+ stream));
173
182
 
174
- RadixSortTemp& temp = g_radix_sort_temp_map[context];
183
+ RadixSortTemp& temp = g_radix_sort_temp_map[stream];
175
184
 
176
185
  if (sort_temp_size > temp.size)
177
186
  {
warp/native/sort.h CHANGED
@@ -20,6 +20,8 @@
20
20
  #include <stddef.h>
21
21
 
22
22
  void radix_sort_reserve(void* context, int n, void** mem_out=NULL, size_t* size_out=NULL);
23
+ void radix_sort_release(void* context, void* stream);
24
+
23
25
  void radix_sort_pairs_host(int* keys, int* values, int n);
24
26
  void radix_sort_pairs_host(float* keys, int* values, int n);
25
27
  void radix_sort_pairs_host(int64_t* keys, int* values, int n);