warp-lang 1.9.0__py3-none-macosx_10_13_universal2.whl → 1.9.1__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.pyi +1420 -2
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build_dll.py +322 -72
- warp/builtins.py +289 -23
- warp/codegen.py +5 -0
- warp/config.py +1 -1
- warp/context.py +243 -32
- warp/examples/interop/example_jax_kernel.py +2 -1
- warp/jax_experimental/custom_call.py +24 -1
- warp/jax_experimental/ffi.py +20 -0
- warp/jax_experimental/xla_ffi.py +16 -7
- warp/native/builtin.h +4 -4
- warp/native/sort.cu +22 -13
- warp/native/sort.h +2 -0
- warp/native/tile.h +188 -13
- warp/native/vec.h +0 -53
- warp/native/warp.cpp +3 -3
- warp/native/warp.cu +60 -30
- warp/native/warp.h +3 -3
- warp/render/render_opengl.py +14 -12
- warp/render/render_usd.py +1 -0
- warp/tests/geometry/test_hash_grid.py +38 -0
- warp/tests/interop/test_jax.py +608 -28
- warp/tests/test_array.py +2 -0
- warp/tests/test_codegen.py +1 -1
- warp/tests/test_fem.py +4 -4
- warp/tests/test_map.py +14 -0
- warp/tests/test_tuple.py +96 -0
- warp/tests/test_types.py +61 -0
- warp/tests/tile/test_tile.py +61 -0
- warp/types.py +17 -3
- {warp_lang-1.9.0.dist-info → warp_lang-1.9.1.dist-info}/METADATA +5 -8
- {warp_lang-1.9.0.dist-info → warp_lang-1.9.1.dist-info}/RECORD +37 -37
- {warp_lang-1.9.0.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -1244,6 +1244,11 @@ class Adjoint:
|
|
|
1244
1244
|
A line directive for the given statement, or None if no line directive is needed.
|
|
1245
1245
|
"""
|
|
1246
1246
|
|
|
1247
|
+
if adj.filename == "unknown source file" or adj.fun_lineno == 0:
|
|
1248
|
+
# Early return if function is not associated with a source file or is otherwise invalid
|
|
1249
|
+
# TODO: Get line directives working with wp.map() functions
|
|
1250
|
+
return None
|
|
1251
|
+
|
|
1247
1252
|
# lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
|
|
1248
1253
|
# emit line directives in generated code if it's not being compiled with line information
|
|
1249
1254
|
build_mode = val if (val := adj.builder_options.get("mode")) is not None else warp.config.mode
|
warp/config.py
CHANGED
warp/context.py
CHANGED
|
@@ -2244,21 +2244,7 @@ class Module:
|
|
|
2244
2244
|
return self.hashers[block_dim].get_module_hash()
|
|
2245
2245
|
|
|
2246
2246
|
def _use_ptx(self, device) -> bool:
|
|
2247
|
-
|
|
2248
|
-
if device.is_cubin_supported:
|
|
2249
|
-
# get user preference specified either per module or globally
|
|
2250
|
-
preferred_cuda_output = self.options.get("cuda_output") or warp.config.cuda_output
|
|
2251
|
-
if preferred_cuda_output is not None:
|
|
2252
|
-
use_ptx = preferred_cuda_output == "ptx"
|
|
2253
|
-
else:
|
|
2254
|
-
# determine automatically: older drivers may not be able to handle PTX generated using newer
|
|
2255
|
-
# CUDA Toolkits, in which case we fall back on generating CUBIN modules
|
|
2256
|
-
use_ptx = runtime.driver_version >= runtime.toolkit_version
|
|
2257
|
-
else:
|
|
2258
|
-
# CUBIN not an option, must use PTX (e.g. CUDA Toolkit too old)
|
|
2259
|
-
use_ptx = True
|
|
2260
|
-
|
|
2261
|
-
return use_ptx
|
|
2247
|
+
return device.get_cuda_output_format(self.options.get("cuda_output")) == "ptx"
|
|
2262
2248
|
|
|
2263
2249
|
def get_module_identifier(self) -> str:
|
|
2264
2250
|
"""Get an abbreviated module name to use for directories and files in the cache.
|
|
@@ -2278,19 +2264,7 @@ class Module:
|
|
|
2278
2264
|
if device is None:
|
|
2279
2265
|
device = runtime.get_device()
|
|
2280
2266
|
|
|
2281
|
-
|
|
2282
|
-
return None
|
|
2283
|
-
|
|
2284
|
-
if self._use_ptx(device):
|
|
2285
|
-
# use the default PTX arch if the device supports it
|
|
2286
|
-
if warp.config.ptx_target_arch is not None:
|
|
2287
|
-
output_arch = min(device.arch, warp.config.ptx_target_arch)
|
|
2288
|
-
else:
|
|
2289
|
-
output_arch = min(device.arch, runtime.default_ptx_arch)
|
|
2290
|
-
else:
|
|
2291
|
-
output_arch = device.arch
|
|
2292
|
-
|
|
2293
|
-
return output_arch
|
|
2267
|
+
return device.get_cuda_compile_arch()
|
|
2294
2268
|
|
|
2295
2269
|
def get_compile_output_name(
|
|
2296
2270
|
self, device: Device | None, output_arch: int | None = None, use_ptx: bool | None = None
|
|
@@ -3327,6 +3301,78 @@ class Device:
|
|
|
3327
3301
|
else:
|
|
3328
3302
|
return False
|
|
3329
3303
|
|
|
3304
|
+
def get_cuda_output_format(self, preferred_cuda_output: str | None = None) -> str | None:
|
|
3305
|
+
"""Determine the CUDA output format to use for this device.
|
|
3306
|
+
|
|
3307
|
+
This method is intended for internal use by Warp's compilation system.
|
|
3308
|
+
External users should not need to call this method directly.
|
|
3309
|
+
|
|
3310
|
+
It determines whether to use PTX or CUBIN output based on device capabilities,
|
|
3311
|
+
caller preferences, and runtime constraints.
|
|
3312
|
+
|
|
3313
|
+
Args:
|
|
3314
|
+
preferred_cuda_output: Caller's preferred format (``"ptx"``, ``"cubin"``, or ``None``).
|
|
3315
|
+
If ``None``, falls back to global config or automatic determination.
|
|
3316
|
+
|
|
3317
|
+
Returns:
|
|
3318
|
+
The output format to use: ``"ptx"``, ``"cubin"``, or ``None`` for CPU devices.
|
|
3319
|
+
"""
|
|
3320
|
+
|
|
3321
|
+
if self.is_cpu:
|
|
3322
|
+
# CPU devices don't use CUDA compilation
|
|
3323
|
+
return None
|
|
3324
|
+
|
|
3325
|
+
if not self.is_cubin_supported:
|
|
3326
|
+
return "ptx"
|
|
3327
|
+
|
|
3328
|
+
# Use provided preference or fall back to global config
|
|
3329
|
+
if preferred_cuda_output is None:
|
|
3330
|
+
preferred_cuda_output = warp.config.cuda_output
|
|
3331
|
+
|
|
3332
|
+
if preferred_cuda_output is not None:
|
|
3333
|
+
# Caller specified a preference, use it if supported
|
|
3334
|
+
if preferred_cuda_output in ("ptx", "cubin"):
|
|
3335
|
+
return preferred_cuda_output
|
|
3336
|
+
else:
|
|
3337
|
+
# Invalid preference, fall back to automatic determination
|
|
3338
|
+
pass
|
|
3339
|
+
|
|
3340
|
+
# Determine automatically: Older drivers may not be able to handle PTX generated using newer CUDA Toolkits,
|
|
3341
|
+
# in which case we fall back on generating CUBIN modules
|
|
3342
|
+
return "ptx" if self.runtime.driver_version >= self.runtime.toolkit_version else "cubin"
|
|
3343
|
+
|
|
3344
|
+
def get_cuda_compile_arch(self) -> int | None:
|
|
3345
|
+
"""Get the CUDA architecture to use when compiling code for this device.
|
|
3346
|
+
|
|
3347
|
+
This method is intended for internal use by Warp's compilation system.
|
|
3348
|
+
External users should not need to call this method directly.
|
|
3349
|
+
|
|
3350
|
+
Determines the appropriate compute capability version to use when compiling
|
|
3351
|
+
CUDA kernels for this device. The architecture depends on the device's
|
|
3352
|
+
CUDA output format preference and available target architectures.
|
|
3353
|
+
|
|
3354
|
+
For PTX output format, uses the minimum of the device's architecture and
|
|
3355
|
+
the configured PTX target architecture to ensure compatibility.
|
|
3356
|
+
For CUBIN output format, uses the device's exact architecture.
|
|
3357
|
+
|
|
3358
|
+
Returns:
|
|
3359
|
+
The compute capability version (e.g., 75 for ``sm_75``) to use for compilation,
|
|
3360
|
+
or ``None`` for CPU devices which don't use CUDA compilation.
|
|
3361
|
+
"""
|
|
3362
|
+
if self.is_cpu:
|
|
3363
|
+
return None
|
|
3364
|
+
|
|
3365
|
+
if self.get_cuda_output_format() == "ptx":
|
|
3366
|
+
# use the default PTX arch if the device supports it
|
|
3367
|
+
if warp.config.ptx_target_arch is not None:
|
|
3368
|
+
output_arch = min(self.arch, warp.config.ptx_target_arch)
|
|
3369
|
+
else:
|
|
3370
|
+
output_arch = min(self.arch, runtime.default_ptx_arch)
|
|
3371
|
+
else:
|
|
3372
|
+
output_arch = self.arch
|
|
3373
|
+
|
|
3374
|
+
return output_arch
|
|
3375
|
+
|
|
3330
3376
|
|
|
3331
3377
|
""" Meta-type for arguments that can be resolved to a concrete Device.
|
|
3332
3378
|
"""
|
|
@@ -4036,6 +4082,8 @@ class Runtime:
|
|
|
4036
4082
|
self.core.wp_cuda_graph_insert_if_else.argtypes = [
|
|
4037
4083
|
ctypes.c_void_p,
|
|
4038
4084
|
ctypes.c_void_p,
|
|
4085
|
+
ctypes.c_int,
|
|
4086
|
+
ctypes.c_bool,
|
|
4039
4087
|
ctypes.POINTER(ctypes.c_int),
|
|
4040
4088
|
ctypes.POINTER(ctypes.c_void_p),
|
|
4041
4089
|
ctypes.POINTER(ctypes.c_void_p),
|
|
@@ -4045,6 +4093,8 @@ class Runtime:
|
|
|
4045
4093
|
self.core.wp_cuda_graph_insert_while.argtypes = [
|
|
4046
4094
|
ctypes.c_void_p,
|
|
4047
4095
|
ctypes.c_void_p,
|
|
4096
|
+
ctypes.c_int,
|
|
4097
|
+
ctypes.c_bool,
|
|
4048
4098
|
ctypes.POINTER(ctypes.c_int),
|
|
4049
4099
|
ctypes.POINTER(ctypes.c_void_p),
|
|
4050
4100
|
ctypes.POINTER(ctypes.c_uint64),
|
|
@@ -4054,6 +4104,8 @@ class Runtime:
|
|
|
4054
4104
|
self.core.wp_cuda_graph_set_condition.argtypes = [
|
|
4055
4105
|
ctypes.c_void_p,
|
|
4056
4106
|
ctypes.c_void_p,
|
|
4107
|
+
ctypes.c_int,
|
|
4108
|
+
ctypes.c_bool,
|
|
4057
4109
|
ctypes.POINTER(ctypes.c_int),
|
|
4058
4110
|
ctypes.c_uint64,
|
|
4059
4111
|
]
|
|
@@ -7053,6 +7105,8 @@ def capture_if(
|
|
|
7053
7105
|
if not runtime.core.wp_cuda_graph_insert_if_else(
|
|
7054
7106
|
device.context,
|
|
7055
7107
|
stream.cuda_stream,
|
|
7108
|
+
device.get_cuda_compile_arch(),
|
|
7109
|
+
device.get_cuda_output_format() == "ptx",
|
|
7056
7110
|
ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
7057
7111
|
None if on_true is None else ctypes.byref(graph_on_true),
|
|
7058
7112
|
None if on_false is None else ctypes.byref(graph_on_false),
|
|
@@ -7117,7 +7171,9 @@ def capture_if(
|
|
|
7117
7171
|
capture_resume(main_graph, stream=stream)
|
|
7118
7172
|
|
|
7119
7173
|
|
|
7120
|
-
def capture_while(
|
|
7174
|
+
def capture_while(
|
|
7175
|
+
condition: warp.array(dtype=int), while_body: Callable | Graph, stream: Stream | None = None, **kwargs
|
|
7176
|
+
):
|
|
7121
7177
|
"""Create a dynamic loop based on a condition.
|
|
7122
7178
|
|
|
7123
7179
|
The condition value is retrieved from the first element of the ``condition`` array.
|
|
@@ -7185,6 +7241,8 @@ def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph
|
|
|
7185
7241
|
if not runtime.core.wp_cuda_graph_insert_while(
|
|
7186
7242
|
device.context,
|
|
7187
7243
|
stream.cuda_stream,
|
|
7244
|
+
device.get_cuda_compile_arch(),
|
|
7245
|
+
device.get_cuda_output_format() == "ptx",
|
|
7188
7246
|
ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
7189
7247
|
ctypes.byref(body_graph),
|
|
7190
7248
|
ctypes.byref(cond_handle),
|
|
@@ -7218,6 +7276,8 @@ def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph
|
|
|
7218
7276
|
if not runtime.core.wp_cuda_graph_set_condition(
|
|
7219
7277
|
device.context,
|
|
7220
7278
|
stream.cuda_stream,
|
|
7279
|
+
device.get_cuda_compile_arch(),
|
|
7280
|
+
device.get_cuda_output_format() == "ptx",
|
|
7221
7281
|
ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
7222
7282
|
cond_handle,
|
|
7223
7283
|
):
|
|
@@ -7748,6 +7808,7 @@ def export_stubs(file): # pragma: no cover
|
|
|
7748
7808
|
print("from typing import Callable", file=file)
|
|
7749
7809
|
print("from typing import TypeVar", file=file)
|
|
7750
7810
|
print("from typing import Generic", file=file)
|
|
7811
|
+
print("from typing import Sequence", file=file)
|
|
7751
7812
|
print("from typing import overload as over", file=file)
|
|
7752
7813
|
print(file=file)
|
|
7753
7814
|
|
|
@@ -7776,7 +7837,7 @@ def export_stubs(file): # pragma: no cover
|
|
|
7776
7837
|
print(header, file=file)
|
|
7777
7838
|
print(file=file)
|
|
7778
7839
|
|
|
7779
|
-
def
|
|
7840
|
+
def add_builtin_function_stub(f):
|
|
7780
7841
|
args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items())
|
|
7781
7842
|
|
|
7782
7843
|
return_str = ""
|
|
@@ -7796,12 +7857,162 @@ def export_stubs(file): # pragma: no cover
|
|
|
7796
7857
|
print(' """', file=file)
|
|
7797
7858
|
print(" ...\n\n", file=file)
|
|
7798
7859
|
|
|
7860
|
+
def add_vector_type_stub(cls, label):
|
|
7861
|
+
cls_name = cls.__name__
|
|
7862
|
+
scalar_type_name = cls._wp_scalar_type_.__name__
|
|
7863
|
+
|
|
7864
|
+
print(f"class {cls_name}:", file=file)
|
|
7865
|
+
|
|
7866
|
+
print(" @over", file=file)
|
|
7867
|
+
print(" def __init__(self) -> None:", file=file)
|
|
7868
|
+
print(f' """Construct a zero-initialized {label}."""', file=file)
|
|
7869
|
+
print(" ...\n\n", file=file)
|
|
7870
|
+
|
|
7871
|
+
print(" @over", file=file)
|
|
7872
|
+
print(f" def __init__(self, other: {cls_name}) -> None:", file=file)
|
|
7873
|
+
print(f' """Construct a {label} by copy."""', file=file)
|
|
7874
|
+
print(" ...\n\n", file=file)
|
|
7875
|
+
|
|
7876
|
+
args = ", ".join(f"{x}: {scalar_type_name}" for x in "xyzw"[: cls._length_])
|
|
7877
|
+
print(" @over", file=file)
|
|
7878
|
+
print(f" def __init__(self, {args}) -> None:", file=file)
|
|
7879
|
+
print(f' """Construct a {label} from its component values."""', file=file)
|
|
7880
|
+
print(" ...\n\n", file=file)
|
|
7881
|
+
|
|
7882
|
+
print(" @over", file=file)
|
|
7883
|
+
print(f" def __init__(self, args: Sequence[{scalar_type_name}]) -> None:", file=file)
|
|
7884
|
+
print(f' """Construct a {label} from a sequence of values."""', file=file)
|
|
7885
|
+
print(" ...\n\n", file=file)
|
|
7886
|
+
|
|
7887
|
+
print(" @over", file=file)
|
|
7888
|
+
print(f" def __init__(self, value: {scalar_type_name}) -> None:", file=file)
|
|
7889
|
+
print(f' """Construct a {label} filled with a value."""', file=file)
|
|
7890
|
+
print(" ...\n\n", file=file)
|
|
7891
|
+
|
|
7892
|
+
def add_matrix_type_stub(cls, label):
|
|
7893
|
+
cls_name = cls.__name__
|
|
7894
|
+
scalar_type_name = cls._wp_scalar_type_.__name__
|
|
7895
|
+
scalar_short_name = warp.types.scalar_short_name(cls._wp_scalar_type_)
|
|
7896
|
+
|
|
7897
|
+
print(f"class {cls_name}:", file=file)
|
|
7898
|
+
|
|
7899
|
+
print(" @over", file=file)
|
|
7900
|
+
print(" def __init__(self) -> None:", file=file)
|
|
7901
|
+
print(f' """Construct a zero-initialized {label}."""', file=file)
|
|
7902
|
+
print(" ...\n\n", file=file)
|
|
7903
|
+
|
|
7904
|
+
print(" @over", file=file)
|
|
7905
|
+
print(f" def __init__(self, other: {cls_name}) -> None:", file=file)
|
|
7906
|
+
print(f' """Construct a {label} by copy."""', file=file)
|
|
7907
|
+
print(" ...\n\n", file=file)
|
|
7908
|
+
|
|
7909
|
+
args = ", ".join(f"m{i}{j}: {scalar_type_name}" for i in range(cls._shape_[0]) for j in range(cls._shape_[1]))
|
|
7910
|
+
print(" @over", file=file)
|
|
7911
|
+
print(f" def __init__(self, {args}) -> None:", file=file)
|
|
7912
|
+
print(f' """Construct a {label} from its component values."""', file=file)
|
|
7913
|
+
print(" ...\n\n", file=file)
|
|
7914
|
+
|
|
7915
|
+
args = ", ".join(f"v{i}: vec{cls._shape_[0]}{scalar_short_name}" for i in range(cls._shape_[0]))
|
|
7916
|
+
print(" @over", file=file)
|
|
7917
|
+
print(f" def __init__(self, {args}) -> None:", file=file)
|
|
7918
|
+
print(f' """Construct a {label} from its row vectors."""', file=file)
|
|
7919
|
+
print(" ...\n\n", file=file)
|
|
7920
|
+
|
|
7921
|
+
print(" @over", file=file)
|
|
7922
|
+
print(f" def __init__(self, args: Sequence[{scalar_type_name}]) -> None:", file=file)
|
|
7923
|
+
print(f' """Construct a {label} from a sequence of values."""', file=file)
|
|
7924
|
+
print(" ...\n\n", file=file)
|
|
7925
|
+
|
|
7926
|
+
print(" @over", file=file)
|
|
7927
|
+
print(f" def __init__(self, value: {scalar_type_name}) -> None:", file=file)
|
|
7928
|
+
print(f' """Construct a {label} filled with a value."""', file=file)
|
|
7929
|
+
print(" ...\n\n", file=file)
|
|
7930
|
+
|
|
7931
|
+
def add_transform_type_stub(cls, label):
|
|
7932
|
+
cls_name = cls.__name__
|
|
7933
|
+
scalar_type_name = cls._wp_scalar_type_.__name__
|
|
7934
|
+
scalar_short_name = warp.types.scalar_short_name(cls._wp_scalar_type_)
|
|
7935
|
+
|
|
7936
|
+
print(f"class {cls_name}:", file=file)
|
|
7937
|
+
|
|
7938
|
+
print(" @over", file=file)
|
|
7939
|
+
print(" def __init__(self) -> None:", file=file)
|
|
7940
|
+
print(f' """Construct a zero-initialized {label}."""', file=file)
|
|
7941
|
+
print(" ...\n\n", file=file)
|
|
7942
|
+
|
|
7943
|
+
print(" @over", file=file)
|
|
7944
|
+
print(f" def __init__(self, other: {cls_name}) -> None:", file=file)
|
|
7945
|
+
print(f' """Construct a {label} by copy."""', file=file)
|
|
7946
|
+
print(" ...\n\n", file=file)
|
|
7947
|
+
|
|
7948
|
+
print(" @over", file=file)
|
|
7949
|
+
print(f" def __init__(self, p: vec3{scalar_short_name}, q: quat{scalar_short_name}) -> None:", file=file)
|
|
7950
|
+
print(f' """Construct a {label} from its p and q components."""', file=file)
|
|
7951
|
+
print(" ...\n\n", file=file)
|
|
7952
|
+
|
|
7953
|
+
args = ()
|
|
7954
|
+
args += tuple(f"p{x}: {scalar_type_name}" for x in "xyz")
|
|
7955
|
+
args += tuple(f"q{x}: {scalar_type_name}" for x in "xyzw")
|
|
7956
|
+
args = ", ".join(args)
|
|
7957
|
+
print(" @over", file=file)
|
|
7958
|
+
print(f" def __init__(self, {args}) -> None:", file=file)
|
|
7959
|
+
print(f' """Construct a {label} from its component values."""', file=file)
|
|
7960
|
+
print(" ...\n\n", file=file)
|
|
7961
|
+
|
|
7962
|
+
print(" @over", file=file)
|
|
7963
|
+
print(
|
|
7964
|
+
f" def __init__(self, p: Sequence[{scalar_type_name}], q: Sequence[{scalar_type_name}]) -> None:",
|
|
7965
|
+
file=file,
|
|
7966
|
+
)
|
|
7967
|
+
print(f' """Construct a {label} from two sequences of values."""', file=file)
|
|
7968
|
+
print(" ...\n\n", file=file)
|
|
7969
|
+
|
|
7970
|
+
print(" @over", file=file)
|
|
7971
|
+
print(f" def __init__(self, value: {scalar_type_name}) -> None:", file=file)
|
|
7972
|
+
print(f' """Construct a {label} filled with a value."""', file=file)
|
|
7973
|
+
print(" ...\n\n", file=file)
|
|
7974
|
+
|
|
7975
|
+
# Vector types.
|
|
7976
|
+
suffixes = ("h", "f", "d", "b", "ub", "s", "us", "i", "ui", "l", "ul")
|
|
7977
|
+
for length in (2, 3, 4):
|
|
7978
|
+
for suffix in suffixes:
|
|
7979
|
+
cls = getattr(warp.types, f"vec{length}{suffix}")
|
|
7980
|
+
add_vector_type_stub(cls, "vector")
|
|
7981
|
+
|
|
7982
|
+
print(f"vec{length} = vec{length}f", file=file)
|
|
7983
|
+
|
|
7984
|
+
# Matrix types.
|
|
7985
|
+
suffixes = ("h", "f", "d")
|
|
7986
|
+
for length in (2, 3, 4):
|
|
7987
|
+
shape = f"{length}{length}"
|
|
7988
|
+
for suffix in suffixes:
|
|
7989
|
+
cls = getattr(warp.types, f"mat{shape}{suffix}")
|
|
7990
|
+
add_matrix_type_stub(cls, "matrix")
|
|
7991
|
+
|
|
7992
|
+
print(f"mat{shape} = mat{shape}f", file=file)
|
|
7993
|
+
|
|
7994
|
+
# Quaternion types.
|
|
7995
|
+
suffixes = ("h", "f", "d")
|
|
7996
|
+
for suffix in suffixes:
|
|
7997
|
+
cls = getattr(warp.types, f"quat{suffix}")
|
|
7998
|
+
add_vector_type_stub(cls, "quaternion")
|
|
7999
|
+
|
|
8000
|
+
print("quat = quatf", file=file)
|
|
8001
|
+
|
|
8002
|
+
# Transformation types.
|
|
8003
|
+
suffixes = ("h", "f", "d")
|
|
8004
|
+
for suffix in suffixes:
|
|
8005
|
+
cls = getattr(warp.types, f"transform{suffix}")
|
|
8006
|
+
add_transform_type_stub(cls, "transformation")
|
|
8007
|
+
|
|
8008
|
+
print("transform = transformf", file=file)
|
|
8009
|
+
|
|
7799
8010
|
for g in builtin_functions.values():
|
|
7800
8011
|
if hasattr(g, "overloads"):
|
|
7801
8012
|
for f in g.overloads:
|
|
7802
|
-
|
|
8013
|
+
add_builtin_function_stub(f)
|
|
7803
8014
|
elif isinstance(g, Function):
|
|
7804
|
-
|
|
8015
|
+
add_builtin_function_stub(g)
|
|
7805
8016
|
|
|
7806
8017
|
|
|
7807
8018
|
def export_builtins(file: io.TextIOBase): # pragma: no cover
|
|
@@ -45,7 +45,8 @@ def sincos_kernel(angle: wp.array(dtype=float), sin_out: wp.array(dtype=float),
|
|
|
45
45
|
@wp.kernel
|
|
46
46
|
def diagonal_kernel(output: wp.array(dtype=wp.mat33)):
|
|
47
47
|
tid = wp.tid()
|
|
48
|
-
|
|
48
|
+
d = float(tid + 1)
|
|
49
|
+
output[tid] = wp.mat33(d, 0.0, 0.0, 0.0, d * 2.0, 0.0, 0.0, 0.0, d * 3.0)
|
|
49
50
|
|
|
50
51
|
|
|
51
52
|
@wp.kernel
|
|
@@ -19,6 +19,7 @@ import warp as wp
|
|
|
19
19
|
from warp.context import type_str
|
|
20
20
|
from warp.jax import get_jax_device
|
|
21
21
|
from warp.types import array_t, launch_bounds_t, strides_from_shape
|
|
22
|
+
from warp.utils import warn
|
|
22
23
|
|
|
23
24
|
_jax_warp_p = None
|
|
24
25
|
|
|
@@ -28,7 +29,7 @@ _registered_kernels = [None]
|
|
|
28
29
|
_registered_kernel_to_id = {}
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
def jax_kernel(kernel, launch_dims=None):
|
|
32
|
+
def jax_kernel(kernel, launch_dims=None, quiet=False):
|
|
32
33
|
"""Create a Jax primitive from a Warp kernel.
|
|
33
34
|
|
|
34
35
|
NOTE: This is an experimental feature under development.
|
|
@@ -38,6 +39,7 @@ def jax_kernel(kernel, launch_dims=None):
|
|
|
38
39
|
launch_dims: Optional. Specify the kernel launch dimensions. If None,
|
|
39
40
|
dimensions are inferred from the shape of the first argument.
|
|
40
41
|
This option when set will specify the output dimensions.
|
|
42
|
+
quiet: Optional. If True, suppress deprecation warnings with newer JAX versions.
|
|
41
43
|
|
|
42
44
|
Limitations:
|
|
43
45
|
- All kernel arguments must be contiguous arrays.
|
|
@@ -46,6 +48,27 @@ def jax_kernel(kernel, launch_dims=None):
|
|
|
46
48
|
- Only the CUDA backend is supported.
|
|
47
49
|
"""
|
|
48
50
|
|
|
51
|
+
import jax
|
|
52
|
+
|
|
53
|
+
# check if JAX version supports this
|
|
54
|
+
if jax.__version_info__ < (0, 4, 25) or jax.__version_info__ >= (0, 8, 0):
|
|
55
|
+
msg = (
|
|
56
|
+
"This version of jax_kernel() requires JAX version 0.4.25 - 0.7.x, "
|
|
57
|
+
f"but installed JAX version is {jax.__version_info__}."
|
|
58
|
+
)
|
|
59
|
+
if jax.__version_info__ >= (0, 8, 0):
|
|
60
|
+
msg += " Please use warp.jax_experimental.ffi.jax_kernel instead."
|
|
61
|
+
raise RuntimeError(msg)
|
|
62
|
+
|
|
63
|
+
# deprecation warning
|
|
64
|
+
if jax.__version_info__ >= (0, 5, 0) and not quiet:
|
|
65
|
+
warn(
|
|
66
|
+
"This version of jax_kernel() is deprecated and will not be supported with newer JAX versions. "
|
|
67
|
+
"Please use the newer FFI version instead (warp.jax_experimental.ffi.jax_kernel). "
|
|
68
|
+
"In Warp release 1.10, the FFI version will become the default implementation of jax_kernel().",
|
|
69
|
+
DeprecationWarning,
|
|
70
|
+
)
|
|
71
|
+
|
|
49
72
|
if _jax_warp_p is None:
|
|
50
73
|
# Create and register the primitive
|
|
51
74
|
_create_jax_warp_primitive()
|
warp/jax_experimental/ffi.py
CHANGED
|
@@ -29,6 +29,18 @@ from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_war
|
|
|
29
29
|
from .xla_ffi import *
|
|
30
30
|
|
|
31
31
|
|
|
32
|
+
def check_jax_version():
|
|
33
|
+
# check if JAX version supports this
|
|
34
|
+
if jax.__version_info__ < (0, 5, 0):
|
|
35
|
+
msg = (
|
|
36
|
+
"This version of jax_kernel() requires JAX version 0.5.0 or higher, "
|
|
37
|
+
f"but installed JAX version is {jax.__version_info__}."
|
|
38
|
+
)
|
|
39
|
+
if jax.__version_info__ >= (0, 4, 25):
|
|
40
|
+
msg += " Please use warp.jax_experimental.custom_call.jax_kernel instead."
|
|
41
|
+
raise RuntimeError(msg)
|
|
42
|
+
|
|
43
|
+
|
|
32
44
|
class GraphMode(IntEnum):
|
|
33
45
|
NONE = 0 # don't capture a graph
|
|
34
46
|
JAX = 1 # let JAX capture a graph
|
|
@@ -668,8 +680,12 @@ def jax_kernel(
|
|
|
668
680
|
- There must be at least one output or input-output argument.
|
|
669
681
|
- Only the CUDA backend is supported.
|
|
670
682
|
"""
|
|
683
|
+
|
|
684
|
+
check_jax_version()
|
|
685
|
+
|
|
671
686
|
key = (
|
|
672
687
|
kernel.func,
|
|
688
|
+
kernel.sig,
|
|
673
689
|
num_outputs,
|
|
674
690
|
vmap_method,
|
|
675
691
|
tuple(launch_dims) if launch_dims else launch_dims,
|
|
@@ -726,6 +742,8 @@ def jax_callable(
|
|
|
726
742
|
- Only the CUDA backend is supported.
|
|
727
743
|
"""
|
|
728
744
|
|
|
745
|
+
check_jax_version()
|
|
746
|
+
|
|
729
747
|
if graph_compatible is not None:
|
|
730
748
|
wp.utils.warn(
|
|
731
749
|
"The `graph_compatible` argument is deprecated, use `graph_mode` instead.",
|
|
@@ -772,6 +790,8 @@ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = Tr
|
|
|
772
790
|
graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
|
|
773
791
|
"""
|
|
774
792
|
|
|
793
|
+
check_jax_version()
|
|
794
|
+
|
|
775
795
|
# TODO check that the name is not already registered
|
|
776
796
|
|
|
777
797
|
def ffi_callback(call_frame):
|
warp/jax_experimental/xla_ffi.py
CHANGED
|
@@ -475,17 +475,26 @@ _xla_data_type_to_constructor = {
|
|
|
475
475
|
XLA_FFI_DataType.C64: jnp.complex64,
|
|
476
476
|
XLA_FFI_DataType.C128: jnp.complex128,
|
|
477
477
|
# XLA_FFI_DataType.TOKEN
|
|
478
|
-
XLA_FFI_DataType.F8E5M2: jnp.float8_e5m2,
|
|
479
|
-
XLA_FFI_DataType.F8E3M4: jnp.float8_e3m4,
|
|
480
|
-
XLA_FFI_DataType.F8E4M3: jnp.float8_e4m3,
|
|
481
|
-
XLA_FFI_DataType.F8E4M3FN: jnp.float8_e4m3fn,
|
|
482
|
-
XLA_FFI_DataType.F8E4M3B11FNUZ: jnp.float8_e4m3b11fnuz,
|
|
483
|
-
XLA_FFI_DataType.F8E5M2FNUZ: jnp.float8_e5m2fnuz,
|
|
484
|
-
XLA_FFI_DataType.F8E4M3FNUZ: jnp.float8_e4m3fnuz,
|
|
485
478
|
# XLA_FFI_DataType.F4E2M1FN: jnp.float4_e2m1fn.dtype,
|
|
486
479
|
# XLA_FFI_DataType.F8E8M0FNU: jnp.float8_e8m0fnu.dtype,
|
|
487
480
|
}
|
|
488
481
|
|
|
482
|
+
# newer types not supported by older versions
|
|
483
|
+
if hasattr(jnp, "float8_e5m2"):
|
|
484
|
+
_xla_data_type_to_constructor[XLA_FFI_DataType.F8E5M2] = jnp.float8_e5m2
|
|
485
|
+
if hasattr(jnp, "float8_e3m4"):
|
|
486
|
+
_xla_data_type_to_constructor[XLA_FFI_DataType.F8E3M4] = jnp.float8_e3m4
|
|
487
|
+
if hasattr(jnp, "float8_e4m3"):
|
|
488
|
+
_xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3] = jnp.float8_e4m3
|
|
489
|
+
if hasattr(jnp, "float8_e4m3fn"):
|
|
490
|
+
_xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3FN] = jnp.float8_e4m3fn
|
|
491
|
+
if hasattr(jnp, "float8_e4m3b11fnuz"):
|
|
492
|
+
_xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3B11FNUZ] = jnp.float8_e4m3b11fnuz
|
|
493
|
+
if hasattr(jnp, "float8_e5m2fnuz"):
|
|
494
|
+
_xla_data_type_to_constructor[XLA_FFI_DataType.F8E5M2FNUZ] = jnp.float8_e5m2fnuz
|
|
495
|
+
if hasattr(jnp, "float8_e4m3fnuz"):
|
|
496
|
+
_xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3FNUZ] = jnp.float8_e4m3fnuz
|
|
497
|
+
|
|
489
498
|
|
|
490
499
|
########################################################################
|
|
491
500
|
# Helpers for translating between ctypes and python types
|
warp/native/builtin.h
CHANGED
|
@@ -1093,8 +1093,8 @@ CUDA_CALLABLE inline T select(const C& cond, const T& a, const T& b)
|
|
|
1093
1093
|
return (!!cond) ? b : a;
|
|
1094
1094
|
}
|
|
1095
1095
|
|
|
1096
|
-
template <typename C, typename
|
|
1097
|
-
CUDA_CALLABLE inline void adj_select(const C& cond, const
|
|
1096
|
+
template <typename C, typename TA, typename TB, typename TRet>
|
|
1097
|
+
CUDA_CALLABLE inline void adj_select(const C& cond, const TA& a, const TB& b, C& adj_cond, TA& adj_a, TB& adj_b, const TRet& adj_ret)
|
|
1098
1098
|
{
|
|
1099
1099
|
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1100
1100
|
if (!!cond)
|
|
@@ -1110,8 +1110,8 @@ CUDA_CALLABLE inline T where(const C& cond, const T& a, const T& b)
|
|
|
1110
1110
|
return (!!cond) ? a : b;
|
|
1111
1111
|
}
|
|
1112
1112
|
|
|
1113
|
-
template <typename C, typename
|
|
1114
|
-
CUDA_CALLABLE inline void adj_where(const C& cond, const
|
|
1113
|
+
template <typename C, typename TA, typename TB, typename TRet>
|
|
1114
|
+
CUDA_CALLABLE inline void adj_where(const C& cond, const TA& a, const TB& b, C& adj_cond, TA& adj_a, TB& adj_b, const TRet& adj_ret)
|
|
1115
1115
|
{
|
|
1116
1116
|
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1117
1117
|
if (!!cond)
|
warp/native/sort.cu
CHANGED
|
@@ -23,7 +23,7 @@
|
|
|
23
23
|
|
|
24
24
|
#include <cub/cub.cuh>
|
|
25
25
|
|
|
26
|
-
#include <
|
|
26
|
+
#include <unordered_map>
|
|
27
27
|
|
|
28
28
|
// temporary buffer for radix sort
|
|
29
29
|
struct RadixSortTemp
|
|
@@ -32,8 +32,8 @@ struct RadixSortTemp
|
|
|
32
32
|
size_t size = 0;
|
|
33
33
|
};
|
|
34
34
|
|
|
35
|
-
//
|
|
36
|
-
static std::
|
|
35
|
+
// use unique temp buffers per CUDA stream to avoid race conditions
|
|
36
|
+
static std::unordered_map<void*, RadixSortTemp> g_radix_sort_temp_map;
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
template <typename KeyType>
|
|
@@ -44,6 +44,8 @@ void radix_sort_reserve_internal(void* context, int n, void** mem_out, size_t* s
|
|
|
44
44
|
cub::DoubleBuffer<KeyType> d_keys;
|
|
45
45
|
cub::DoubleBuffer<int> d_values;
|
|
46
46
|
|
|
47
|
+
CUstream stream = static_cast<CUstream>(wp_cuda_stream_get_current());
|
|
48
|
+
|
|
47
49
|
// compute temporary memory required
|
|
48
50
|
size_t sort_temp_size;
|
|
49
51
|
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
@@ -52,12 +54,9 @@ void radix_sort_reserve_internal(void* context, int n, void** mem_out, size_t* s
|
|
|
52
54
|
d_keys,
|
|
53
55
|
d_values,
|
|
54
56
|
n, 0, sizeof(KeyType)*8,
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
if (!context)
|
|
58
|
-
context = wp_cuda_context_get_current();
|
|
57
|
+
stream));
|
|
59
58
|
|
|
60
|
-
RadixSortTemp& temp = g_radix_sort_temp_map[
|
|
59
|
+
RadixSortTemp& temp = g_radix_sort_temp_map[stream];
|
|
61
60
|
|
|
62
61
|
if (sort_temp_size > temp.size)
|
|
63
62
|
{
|
|
@@ -77,6 +76,17 @@ void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
|
|
|
77
76
|
radix_sort_reserve_internal<int>(context, n, mem_out, size_out);
|
|
78
77
|
}
|
|
79
78
|
|
|
79
|
+
void radix_sort_release(void* context, void* stream)
|
|
80
|
+
{
|
|
81
|
+
// release temporary buffer for the given stream, if it exists
|
|
82
|
+
auto it = g_radix_sort_temp_map.find(stream);
|
|
83
|
+
if (it != g_radix_sort_temp_map.end())
|
|
84
|
+
{
|
|
85
|
+
wp_free_device(context, it->second.mem);
|
|
86
|
+
g_radix_sort_temp_map.erase(it);
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
80
90
|
template <typename KeyType>
|
|
81
91
|
void radix_sort_pairs_device(void* context, KeyType* keys, int* values, int n)
|
|
82
92
|
{
|
|
@@ -153,6 +163,8 @@ void segmented_sort_reserve(void* context, int n, int num_segments, void** mem_o
|
|
|
153
163
|
int* start_indices = NULL;
|
|
154
164
|
int* end_indices = NULL;
|
|
155
165
|
|
|
166
|
+
CUstream stream = static_cast<CUstream>(wp_cuda_stream_get_current());
|
|
167
|
+
|
|
156
168
|
// compute temporary memory required
|
|
157
169
|
size_t sort_temp_size;
|
|
158
170
|
check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
|
@@ -166,12 +178,9 @@ void segmented_sort_reserve(void* context, int n, int num_segments, void** mem_o
|
|
|
166
178
|
end_indices,
|
|
167
179
|
0,
|
|
168
180
|
32,
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
if (!context)
|
|
172
|
-
context = wp_cuda_context_get_current();
|
|
181
|
+
stream));
|
|
173
182
|
|
|
174
|
-
RadixSortTemp& temp = g_radix_sort_temp_map[
|
|
183
|
+
RadixSortTemp& temp = g_radix_sort_temp_map[stream];
|
|
175
184
|
|
|
176
185
|
if (sort_temp_size > temp.size)
|
|
177
186
|
{
|
warp/native/sort.h
CHANGED
|
@@ -20,6 +20,8 @@
|
|
|
20
20
|
#include <stddef.h>
|
|
21
21
|
|
|
22
22
|
void radix_sort_reserve(void* context, int n, void** mem_out=NULL, size_t* size_out=NULL);
|
|
23
|
+
void radix_sort_release(void* context, void* stream);
|
|
24
|
+
|
|
23
25
|
void radix_sort_pairs_host(int* keys, int* values, int n);
|
|
24
26
|
void radix_sort_pairs_host(float* keys, int* values, int n);
|
|
25
27
|
void radix_sort_pairs_host(int64_t* keys, int* values, int n);
|