warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl → 1.7.2__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/autograd.py +12 -2
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +1 -1
- warp/builtins.py +103 -66
- warp/codegen.py +48 -27
- warp/config.py +1 -1
- warp/context.py +112 -49
- warp/examples/benchmarks/benchmark_cloth.py +1 -1
- warp/examples/distributed/example_jacobi_mpi.py +507 -0
- warp/fem/cache.py +1 -1
- warp/fem/field/field.py +11 -1
- warp/fem/field/nodal_field.py +36 -22
- warp/fem/geometry/adaptive_nanogrid.py +7 -3
- warp/fem/geometry/trimesh.py +4 -12
- warp/jax_experimental/custom_call.py +14 -2
- warp/jax_experimental/ffi.py +100 -67
- warp/native/builtin.h +91 -65
- warp/native/svd.h +59 -49
- warp/native/tile.h +55 -26
- warp/native/volume.cpp +2 -2
- warp/native/volume_builder.cu +33 -22
- warp/native/warp.cu +1 -1
- warp/render/render_opengl.py +41 -34
- warp/render/render_usd.py +96 -6
- warp/sim/collide.py +11 -9
- warp/sim/inertia.py +189 -156
- warp/sim/integrator_euler.py +3 -0
- warp/sim/integrator_xpbd.py +3 -0
- warp/sim/model.py +56 -31
- warp/sim/render.py +4 -0
- warp/sparse.py +1 -1
- warp/stubs.py +73 -25
- warp/tests/assets/torus.usda +1 -1
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/sim/test_collision.py +237 -206
- warp/tests/sim/test_inertia.py +161 -0
- warp/tests/sim/test_model.py +5 -3
- warp/tests/sim/{flaky_test_sim_grad.py → test_sim_grad.py} +1 -4
- warp/tests/sim/test_xpbd.py +399 -0
- warp/tests/test_array.py +8 -7
- warp/tests/test_atomic.py +181 -2
- warp/tests/test_builtins_resolution.py +38 -38
- warp/tests/test_codegen.py +24 -3
- warp/tests/test_examples.py +16 -6
- warp/tests/test_fem.py +93 -14
- warp/tests/test_func.py +1 -1
- warp/tests/test_mat.py +416 -119
- warp/tests/test_quat.py +321 -137
- warp/tests/test_struct.py +116 -0
- warp/tests/test_vec.py +320 -174
- warp/tests/tile/test_tile.py +27 -0
- warp/tests/tile/test_tile_load.py +124 -0
- warp/tests/unittest_suites.py +2 -5
- warp/types.py +107 -9
- {warp_lang-1.7.0.dist-info → warp_lang-1.7.2.dist-info}/METADATA +41 -19
- {warp_lang-1.7.0.dist-info → warp_lang-1.7.2.dist-info}/RECORD +60 -57
- {warp_lang-1.7.0.dist-info → warp_lang-1.7.2.dist-info}/WHEEL +1 -1
- {warp_lang-1.7.0.dist-info → warp_lang-1.7.2.dist-info}/licenses/LICENSE.md +0 -26
- {warp_lang-1.7.0.dist-info → warp_lang-1.7.2.dist-info}/top_level.txt +0 -0
warp/fem/geometry/trimesh.py
CHANGED
|
@@ -190,7 +190,7 @@ class Trimesh(Geometry):
|
|
|
190
190
|
return args
|
|
191
191
|
|
|
192
192
|
def _bvh_id(self, device):
|
|
193
|
-
if self._tri_bvh is None or self._tri_bvh.device != device:
|
|
193
|
+
if self._tri_bvh is None or self._tri_bvh.device != wp.get_device(device):
|
|
194
194
|
return _NULL_BVH
|
|
195
195
|
return self._tri_bvh.id
|
|
196
196
|
|
|
@@ -519,7 +519,7 @@ class Trimesh(Geometry):
|
|
|
519
519
|
@wp.kernel
|
|
520
520
|
def _compute_tri_bounds(
|
|
521
521
|
tri_vertex_indices: wp.array2d(dtype=int),
|
|
522
|
-
positions: wp.array(dtype=
|
|
522
|
+
positions: wp.array(dtype=Any),
|
|
523
523
|
lowers: wp.array(dtype=wp.vec3),
|
|
524
524
|
uppers: wp.array(dtype=wp.vec3),
|
|
525
525
|
):
|
|
@@ -528,16 +528,8 @@ class Trimesh(Geometry):
|
|
|
528
528
|
p1 = _bvh_vec(positions[tri_vertex_indices[t, 1]])
|
|
529
529
|
p2 = _bvh_vec(positions[tri_vertex_indices[t, 2]])
|
|
530
530
|
|
|
531
|
-
lowers[t] = wp.
|
|
532
|
-
|
|
533
|
-
wp.min(wp.min(p0[1], p1[1]), p2[1]),
|
|
534
|
-
wp.min(wp.min(p0[2], p1[2]), p2[2]),
|
|
535
|
-
)
|
|
536
|
-
uppers[t] = wp.vec3(
|
|
537
|
-
wp.max(wp.max(p0[0], p1[0]), p2[0]),
|
|
538
|
-
wp.max(wp.max(p0[1], p1[1]), p2[1]),
|
|
539
|
-
wp.max(wp.max(p0[2], p1[2]), p2[2]),
|
|
540
|
-
)
|
|
531
|
+
lowers[t] = wp.min(wp.min(p0, p1), p2)
|
|
532
|
+
uppers[t] = wp.max(wp.max(p0, p1), p2)
|
|
541
533
|
|
|
542
534
|
|
|
543
535
|
@wp.struct
|
|
@@ -126,7 +126,14 @@ def _create_jax_warp_primitive():
|
|
|
126
126
|
|
|
127
127
|
# Create and register the primitive.
|
|
128
128
|
# TODO add default implementation that calls the kernel via warp.
|
|
129
|
-
|
|
129
|
+
try:
|
|
130
|
+
# newer JAX versions
|
|
131
|
+
import jax.extend
|
|
132
|
+
|
|
133
|
+
_jax_warp_p = jax.extend.core.Primitive("jax_warp")
|
|
134
|
+
except (ImportError, AttributeError):
|
|
135
|
+
# older JAX versions
|
|
136
|
+
_jax_warp_p = jax.core.Primitive("jax_warp")
|
|
130
137
|
_jax_warp_p.multiple_results = True
|
|
131
138
|
|
|
132
139
|
# TODO Just launch the kernel directly, but make sure the argument
|
|
@@ -262,7 +269,12 @@ def _create_jax_warp_primitive():
|
|
|
262
269
|
capsule = PyCapsule_New(ccall_address.value, b"xla._CUSTOM_CALL_TARGET", PyCapsule_Destructor(0))
|
|
263
270
|
|
|
264
271
|
# Register the callback in XLA.
|
|
265
|
-
|
|
272
|
+
try:
|
|
273
|
+
# newer JAX versions
|
|
274
|
+
jax.ffi.register_ffi_target("warp_call", capsule, platform="gpu", api_version=0)
|
|
275
|
+
except AttributeError:
|
|
276
|
+
# older JAX versions
|
|
277
|
+
jax.lib.xla_client.register_custom_call_target("warp_call", capsule, platform="gpu")
|
|
266
278
|
|
|
267
279
|
def default_layout(shape):
|
|
268
280
|
return range(len(shape) - 1, -1, -1)
|
warp/jax_experimental/ffi.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import ctypes
|
|
17
|
+
import threading
|
|
17
18
|
import traceback
|
|
18
19
|
from typing import Callable
|
|
19
20
|
|
|
@@ -27,68 +28,6 @@ from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_war
|
|
|
27
28
|
from .xla_ffi import *
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None):
|
|
31
|
-
"""Create a JAX callback from a Warp kernel.
|
|
32
|
-
|
|
33
|
-
NOTE: This is an experimental feature under development.
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
kernel: The Warp kernel to launch.
|
|
37
|
-
num_outputs: Optional. Specify the number of output arguments if greater than 1.
|
|
38
|
-
vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
|
|
39
|
-
This argument can also be specified for individual calls.
|
|
40
|
-
launch_dims: Optional. Specify the default kernel launch dimensions. If None, launch
|
|
41
|
-
dimensions are inferred from the shape of the first array argument.
|
|
42
|
-
This argument can also be specified for individual calls.
|
|
43
|
-
output_dims: Optional. Specify the default dimensions of output arrays. If None, output
|
|
44
|
-
dimensions are inferred from the launch dimensions.
|
|
45
|
-
This argument can also be specified for individual calls.
|
|
46
|
-
|
|
47
|
-
Limitations:
|
|
48
|
-
- All kernel arguments must be contiguous arrays or scalars.
|
|
49
|
-
- Scalars must be static arguments in JAX.
|
|
50
|
-
- Input arguments are followed by output arguments in the Warp kernel definition.
|
|
51
|
-
- There must be at least one output argument.
|
|
52
|
-
- Only the CUDA backend is supported.
|
|
53
|
-
"""
|
|
54
|
-
|
|
55
|
-
return FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def jax_callable(
|
|
59
|
-
func: Callable,
|
|
60
|
-
num_outputs: int = 1,
|
|
61
|
-
graph_compatible: bool = True,
|
|
62
|
-
vmap_method: str = "broadcast_all",
|
|
63
|
-
output_dims=None,
|
|
64
|
-
):
|
|
65
|
-
"""Create a JAX callback from an annotated Python function.
|
|
66
|
-
|
|
67
|
-
The Python function arguments must have type annotations like Warp kernels.
|
|
68
|
-
|
|
69
|
-
NOTE: This is an experimental feature under development.
|
|
70
|
-
|
|
71
|
-
Args:
|
|
72
|
-
func: The Python function to call.
|
|
73
|
-
num_outputs: Optional. Specify the number of output arguments if greater than 1.
|
|
74
|
-
graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
|
|
75
|
-
vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
|
|
76
|
-
This argument can also be specified for individual calls.
|
|
77
|
-
output_dims: Optional. Specify the default dimensions of output arrays.
|
|
78
|
-
If ``None``, output dimensions are inferred from the launch dimensions.
|
|
79
|
-
This argument can also be specified for individual calls.
|
|
80
|
-
|
|
81
|
-
Limitations:
|
|
82
|
-
- All kernel arguments must be contiguous arrays or scalars.
|
|
83
|
-
- Scalars must be static arguments in JAX.
|
|
84
|
-
- Input arguments are followed by output arguments in the Warp kernel definition.
|
|
85
|
-
- There must be at least one output argument.
|
|
86
|
-
- Only the CUDA backend is supported.
|
|
87
|
-
"""
|
|
88
|
-
|
|
89
|
-
return FfiCallable(func, num_outputs, graph_compatible, vmap_method, output_dims)
|
|
90
|
-
|
|
91
|
-
|
|
92
31
|
class FfiArg:
|
|
93
32
|
def __init__(self, name, type):
|
|
94
33
|
self.name = name
|
|
@@ -560,7 +499,11 @@ class FfiCallable:
|
|
|
560
499
|
|
|
561
500
|
# call the Python function with reconstructed arguments
|
|
562
501
|
with wp.ScopedStream(stream, sync_enter=False):
|
|
563
|
-
|
|
502
|
+
if stream.is_capturing:
|
|
503
|
+
with wp.ScopedCapture(stream=stream, external=True):
|
|
504
|
+
self.func(*arg_list)
|
|
505
|
+
else:
|
|
506
|
+
self.func(*arg_list)
|
|
564
507
|
|
|
565
508
|
except Exception as e:
|
|
566
509
|
print(traceback.format_exc())
|
|
@@ -571,6 +514,98 @@ class FfiCallable:
|
|
|
571
514
|
return None
|
|
572
515
|
|
|
573
516
|
|
|
517
|
+
# Holders for the custom callbacks to keep them alive.
|
|
518
|
+
_FFI_CALLABLE_REGISTRY: dict[str, FfiCallable] = {}
|
|
519
|
+
_FFI_KERNEL_REGISTRY: dict[str, FfiKernel] = {}
|
|
520
|
+
_FFI_REGISTRY_LOCK = threading.Lock()
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None):
|
|
524
|
+
"""Create a JAX callback from a Warp kernel.
|
|
525
|
+
|
|
526
|
+
NOTE: This is an experimental feature under development.
|
|
527
|
+
|
|
528
|
+
Args:
|
|
529
|
+
kernel: The Warp kernel to launch.
|
|
530
|
+
num_outputs: Optional. Specify the number of output arguments if greater than 1.
|
|
531
|
+
vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
|
|
532
|
+
This argument can also be specified for individual calls.
|
|
533
|
+
launch_dims: Optional. Specify the default kernel launch dimensions. If None, launch
|
|
534
|
+
dimensions are inferred from the shape of the first array argument.
|
|
535
|
+
This argument can also be specified for individual calls.
|
|
536
|
+
output_dims: Optional. Specify the default dimensions of output arrays. If None, output
|
|
537
|
+
dimensions are inferred from the launch dimensions.
|
|
538
|
+
This argument can also be specified for individual calls.
|
|
539
|
+
|
|
540
|
+
Limitations:
|
|
541
|
+
- All kernel arguments must be contiguous arrays or scalars.
|
|
542
|
+
- Scalars must be static arguments in JAX.
|
|
543
|
+
- Input arguments are followed by output arguments in the Warp kernel definition.
|
|
544
|
+
- There must be at least one output argument.
|
|
545
|
+
- Only the CUDA backend is supported.
|
|
546
|
+
"""
|
|
547
|
+
key = (
|
|
548
|
+
kernel.func,
|
|
549
|
+
num_outputs,
|
|
550
|
+
vmap_method,
|
|
551
|
+
tuple(launch_dims) if launch_dims else launch_dims,
|
|
552
|
+
tuple(sorted(output_dims.items())) if output_dims else output_dims,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
with _FFI_REGISTRY_LOCK:
|
|
556
|
+
if key not in _FFI_KERNEL_REGISTRY:
|
|
557
|
+
new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims)
|
|
558
|
+
_FFI_KERNEL_REGISTRY[key] = new_kernel
|
|
559
|
+
|
|
560
|
+
return _FFI_KERNEL_REGISTRY[key]
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
def jax_callable(
|
|
564
|
+
func: Callable,
|
|
565
|
+
num_outputs: int = 1,
|
|
566
|
+
graph_compatible: bool = True,
|
|
567
|
+
vmap_method: str = "broadcast_all",
|
|
568
|
+
output_dims=None,
|
|
569
|
+
):
|
|
570
|
+
"""Create a JAX callback from an annotated Python function.
|
|
571
|
+
|
|
572
|
+
The Python function arguments must have type annotations like Warp kernels.
|
|
573
|
+
|
|
574
|
+
NOTE: This is an experimental feature under development.
|
|
575
|
+
|
|
576
|
+
Args:
|
|
577
|
+
func: The Python function to call.
|
|
578
|
+
num_outputs: Optional. Specify the number of output arguments if greater than 1.
|
|
579
|
+
graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
|
|
580
|
+
vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
|
|
581
|
+
This argument can also be specified for individual calls.
|
|
582
|
+
output_dims: Optional. Specify the default dimensions of output arrays.
|
|
583
|
+
If ``None``, output dimensions are inferred from the launch dimensions.
|
|
584
|
+
This argument can also be specified for individual calls.
|
|
585
|
+
|
|
586
|
+
Limitations:
|
|
587
|
+
- All kernel arguments must be contiguous arrays or scalars.
|
|
588
|
+
- Scalars must be static arguments in JAX.
|
|
589
|
+
- Input arguments are followed by output arguments in the Warp kernel definition.
|
|
590
|
+
- There must be at least one output argument.
|
|
591
|
+
- Only the CUDA backend is supported.
|
|
592
|
+
"""
|
|
593
|
+
key = (
|
|
594
|
+
func,
|
|
595
|
+
num_outputs,
|
|
596
|
+
graph_compatible,
|
|
597
|
+
vmap_method,
|
|
598
|
+
tuple(sorted(output_dims.items())) if output_dims else output_dims,
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
with _FFI_REGISTRY_LOCK:
|
|
602
|
+
if key not in _FFI_CALLABLE_REGISTRY:
|
|
603
|
+
new_callable = FfiCallable(func, num_outputs, graph_compatible, vmap_method, output_dims)
|
|
604
|
+
_FFI_CALLABLE_REGISTRY[key] = new_callable
|
|
605
|
+
|
|
606
|
+
return _FFI_CALLABLE_REGISTRY[key]
|
|
607
|
+
|
|
608
|
+
|
|
574
609
|
###############################################################################
|
|
575
610
|
#
|
|
576
611
|
# Generic FFI callbacks for Python functions of the form
|
|
@@ -578,9 +613,6 @@ class FfiCallable:
|
|
|
578
613
|
#
|
|
579
614
|
###############################################################################
|
|
580
615
|
|
|
581
|
-
# Holder for the custom callbacks to keep them alive.
|
|
582
|
-
ffi_callbacks = {}
|
|
583
|
-
|
|
584
616
|
|
|
585
617
|
def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = True) -> None:
|
|
586
618
|
"""Create a JAX callback from a Python function.
|
|
@@ -640,7 +672,8 @@ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = Tr
|
|
|
640
672
|
|
|
641
673
|
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
642
674
|
callback_func = FFI_CCALLFUNC(ffi_callback)
|
|
643
|
-
|
|
675
|
+
with _FFI_REGISTRY_LOCK:
|
|
676
|
+
_FFI_CALLABLE_REGISTRY[name] = callback_func
|
|
644
677
|
ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p)
|
|
645
678
|
ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
|
|
646
679
|
jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA")
|
warp/native/builtin.h
CHANGED
|
@@ -1271,6 +1271,29 @@ inline CUDA_CALLABLE T atomic_add(T* buf, T value)
|
|
|
1271
1271
|
#endif
|
|
1272
1272
|
}
|
|
1273
1273
|
|
|
1274
|
+
// emulate atomic int64 add with atomicCAS()
|
|
1275
|
+
template <>
|
|
1276
|
+
inline CUDA_CALLABLE int64 atomic_add(int64* address, int64 val)
|
|
1277
|
+
{
|
|
1278
|
+
#if defined(__CUDA_ARCH__)
|
|
1279
|
+
unsigned long long int *address_as_ull = (unsigned long long int*)address;
|
|
1280
|
+
unsigned long long int old = *address_as_ull, assumed;
|
|
1281
|
+
|
|
1282
|
+
while (val < (int64)old)
|
|
1283
|
+
{
|
|
1284
|
+
assumed = old;
|
|
1285
|
+
old = atomicCAS(address_as_ull, assumed, (int64)val);
|
|
1286
|
+
}
|
|
1287
|
+
|
|
1288
|
+
return (int64)old;
|
|
1289
|
+
|
|
1290
|
+
#else
|
|
1291
|
+
int64 old = *address;
|
|
1292
|
+
*address = min(old, val);
|
|
1293
|
+
return old;
|
|
1294
|
+
#endif
|
|
1295
|
+
}
|
|
1296
|
+
|
|
1274
1297
|
template<>
|
|
1275
1298
|
inline CUDA_CALLABLE float16 atomic_add(float16* buf, float16 value)
|
|
1276
1299
|
{
|
|
@@ -1306,53 +1329,6 @@ inline CUDA_CALLABLE float16 atomic_add(float16* buf, float16 value)
|
|
|
1306
1329
|
#undef __PTR
|
|
1307
1330
|
|
|
1308
1331
|
#endif // CUDA compiled by NVRTC
|
|
1309
|
-
|
|
1310
|
-
}
|
|
1311
|
-
|
|
1312
|
-
// emulate atomic float max with atomicCAS()
|
|
1313
|
-
inline CUDA_CALLABLE float atomic_max(float* address, float val)
|
|
1314
|
-
{
|
|
1315
|
-
#if defined(__CUDA_ARCH__)
|
|
1316
|
-
int *address_as_int = (int*)address;
|
|
1317
|
-
int old = *address_as_int, assumed;
|
|
1318
|
-
|
|
1319
|
-
while (val > __int_as_float(old))
|
|
1320
|
-
{
|
|
1321
|
-
assumed = old;
|
|
1322
|
-
old = atomicCAS(address_as_int, assumed,
|
|
1323
|
-
__float_as_int(val));
|
|
1324
|
-
}
|
|
1325
|
-
|
|
1326
|
-
return __int_as_float(old);
|
|
1327
|
-
|
|
1328
|
-
#else
|
|
1329
|
-
float old = *address;
|
|
1330
|
-
*address = max(old, val);
|
|
1331
|
-
return old;
|
|
1332
|
-
#endif
|
|
1333
|
-
}
|
|
1334
|
-
|
|
1335
|
-
// emulate atomic float min with atomicCAS()
|
|
1336
|
-
inline CUDA_CALLABLE float atomic_min(float* address, float val)
|
|
1337
|
-
{
|
|
1338
|
-
#if defined(__CUDA_ARCH__)
|
|
1339
|
-
int *address_as_int = (int*)address;
|
|
1340
|
-
int old = *address_as_int, assumed;
|
|
1341
|
-
|
|
1342
|
-
while (val < __int_as_float(old))
|
|
1343
|
-
{
|
|
1344
|
-
assumed = old;
|
|
1345
|
-
old = atomicCAS(address_as_int, assumed,
|
|
1346
|
-
__float_as_int(val));
|
|
1347
|
-
}
|
|
1348
|
-
|
|
1349
|
-
return __int_as_float(old);
|
|
1350
|
-
|
|
1351
|
-
#else
|
|
1352
|
-
float old = *address;
|
|
1353
|
-
*address = min(old, val);
|
|
1354
|
-
return old;
|
|
1355
|
-
#endif
|
|
1356
1332
|
}
|
|
1357
1333
|
|
|
1358
1334
|
template<>
|
|
@@ -1388,33 +1364,47 @@ inline CUDA_CALLABLE float64 atomic_add(float64* buf, float64 value)
|
|
|
1388
1364
|
#undef __PTR
|
|
1389
1365
|
|
|
1390
1366
|
#endif // CUDA compiled by NVRTC
|
|
1367
|
+
}
|
|
1368
|
+
|
|
1369
|
+
template <typename T>
|
|
1370
|
+
inline CUDA_CALLABLE T atomic_min(T* address, T val)
|
|
1371
|
+
{
|
|
1372
|
+
#if defined(__CUDA_ARCH__)
|
|
1373
|
+
return atomicMin(address, val);
|
|
1391
1374
|
|
|
1375
|
+
#else
|
|
1376
|
+
T old = *address;
|
|
1377
|
+
*address = min(old, val);
|
|
1378
|
+
return old;
|
|
1379
|
+
#endif
|
|
1392
1380
|
}
|
|
1393
1381
|
|
|
1394
|
-
// emulate atomic
|
|
1395
|
-
|
|
1382
|
+
// emulate atomic float min with atomicCAS()
|
|
1383
|
+
template <>
|
|
1384
|
+
inline CUDA_CALLABLE float atomic_min(float* address, float val)
|
|
1396
1385
|
{
|
|
1397
1386
|
#if defined(__CUDA_ARCH__)
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1387
|
+
int *address_as_int = (int*)address;
|
|
1388
|
+
int old = *address_as_int, assumed;
|
|
1389
|
+
|
|
1390
|
+
while (val < __int_as_float(old))
|
|
1402
1391
|
{
|
|
1403
1392
|
assumed = old;
|
|
1404
|
-
old = atomicCAS(
|
|
1405
|
-
|
|
1393
|
+
old = atomicCAS(address_as_int, assumed,
|
|
1394
|
+
__float_as_int(val));
|
|
1406
1395
|
}
|
|
1407
1396
|
|
|
1408
|
-
return
|
|
1397
|
+
return __int_as_float(old);
|
|
1409
1398
|
|
|
1410
1399
|
#else
|
|
1411
|
-
|
|
1412
|
-
*address =
|
|
1400
|
+
float old = *address;
|
|
1401
|
+
*address = min(old, val);
|
|
1413
1402
|
return old;
|
|
1414
1403
|
#endif
|
|
1415
1404
|
}
|
|
1416
1405
|
|
|
1417
1406
|
// emulate atomic double min with atomicCAS()
|
|
1407
|
+
template <>
|
|
1418
1408
|
inline CUDA_CALLABLE double atomic_min(double* address, double val)
|
|
1419
1409
|
{
|
|
1420
1410
|
#if defined(__CUDA_ARCH__)
|
|
@@ -1437,27 +1427,63 @@ inline CUDA_CALLABLE double atomic_min(double* address, double val)
|
|
|
1437
1427
|
#endif
|
|
1438
1428
|
}
|
|
1439
1429
|
|
|
1440
|
-
|
|
1430
|
+
template <typename T>
|
|
1431
|
+
inline CUDA_CALLABLE T atomic_max(T* address, T val)
|
|
1441
1432
|
{
|
|
1442
1433
|
#if defined(__CUDA_ARCH__)
|
|
1443
1434
|
return atomicMax(address, val);
|
|
1444
1435
|
|
|
1445
1436
|
#else
|
|
1446
|
-
|
|
1437
|
+
T old = *address;
|
|
1447
1438
|
*address = max(old, val);
|
|
1448
1439
|
return old;
|
|
1449
1440
|
#endif
|
|
1450
1441
|
}
|
|
1451
1442
|
|
|
1452
|
-
// atomic
|
|
1453
|
-
|
|
1443
|
+
// emulate atomic float max with atomicCAS()
|
|
1444
|
+
template<>
|
|
1445
|
+
inline CUDA_CALLABLE float atomic_max(float* address, float val)
|
|
1454
1446
|
{
|
|
1455
1447
|
#if defined(__CUDA_ARCH__)
|
|
1456
|
-
|
|
1448
|
+
int *address_as_int = (int*)address;
|
|
1449
|
+
int old = *address_as_int, assumed;
|
|
1450
|
+
|
|
1451
|
+
while (val > __int_as_float(old))
|
|
1452
|
+
{
|
|
1453
|
+
assumed = old;
|
|
1454
|
+
old = atomicCAS(address_as_int, assumed,
|
|
1455
|
+
__float_as_int(val));
|
|
1456
|
+
}
|
|
1457
|
+
|
|
1458
|
+
return __int_as_float(old);
|
|
1457
1459
|
|
|
1458
1460
|
#else
|
|
1459
|
-
|
|
1460
|
-
*address =
|
|
1461
|
+
float old = *address;
|
|
1462
|
+
*address = max(old, val);
|
|
1463
|
+
return old;
|
|
1464
|
+
#endif
|
|
1465
|
+
}
|
|
1466
|
+
|
|
1467
|
+
// emulate atomic double max with atomicCAS()
|
|
1468
|
+
template<>
|
|
1469
|
+
inline CUDA_CALLABLE double atomic_max(double* address, double val)
|
|
1470
|
+
{
|
|
1471
|
+
#if defined(__CUDA_ARCH__)
|
|
1472
|
+
unsigned long long int *address_as_ull = (unsigned long long int*)address;
|
|
1473
|
+
unsigned long long int old = *address_as_ull, assumed;
|
|
1474
|
+
|
|
1475
|
+
while (val > __longlong_as_double(old))
|
|
1476
|
+
{
|
|
1477
|
+
assumed = old;
|
|
1478
|
+
old = atomicCAS(address_as_ull, assumed,
|
|
1479
|
+
__double_as_longlong(val));
|
|
1480
|
+
}
|
|
1481
|
+
|
|
1482
|
+
return __longlong_as_double(old);
|
|
1483
|
+
|
|
1484
|
+
#else
|
|
1485
|
+
double old = *address;
|
|
1486
|
+
*address = max(old, val);
|
|
1461
1487
|
return old;
|
|
1462
1488
|
#endif
|
|
1463
1489
|
}
|
warp/native/svd.h
CHANGED
|
@@ -60,17 +60,17 @@ struct _svd_config<double> {
|
|
|
60
60
|
static constexpr int JACOBI_ITERATIONS = 8;
|
|
61
61
|
};
|
|
62
62
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
// TODO: replace sqrt with rsqrt
|
|
66
|
-
|
|
67
|
-
template<typename Type>
|
|
68
|
-
inline CUDA_CALLABLE
|
|
69
|
-
Type accurateSqrt(Type x)
|
|
63
|
+
template <typename Type> inline CUDA_CALLABLE Type recipSqrt(Type x)
|
|
70
64
|
{
|
|
71
|
-
|
|
65
|
+
#if defined(__CUDA_ARCH__)
|
|
66
|
+
return ::rsqrt(x);
|
|
67
|
+
#else
|
|
68
|
+
return Type(1) / sqrt(x);
|
|
69
|
+
#endif
|
|
72
70
|
}
|
|
73
71
|
|
|
72
|
+
template <> inline CUDA_CALLABLE wp::half recipSqrt(wp::half x) { return wp::half(1) / sqrt(x); }
|
|
73
|
+
|
|
74
74
|
template<typename Type>
|
|
75
75
|
inline CUDA_CALLABLE
|
|
76
76
|
void condSwap(bool c, Type &X, Type &Y)
|
|
@@ -175,7 +175,7 @@ void approximateGivensQuaternion(Type a11, Type a12, Type a22, Type &ch, Type &s
|
|
|
175
175
|
ch = Type(2)*(a11-a22);
|
|
176
176
|
sh = a12;
|
|
177
177
|
bool b = Type(_gamma)*sh*sh < ch*ch;
|
|
178
|
-
Type w =
|
|
178
|
+
Type w = recipSqrt(ch*ch+sh*sh);
|
|
179
179
|
ch=b?w*ch:Type(_cstar);
|
|
180
180
|
sh=b?w*sh:Type(_sstar);
|
|
181
181
|
}
|
|
@@ -304,13 +304,13 @@ void QRGivensQuaternion(Type a1, Type a2, Type &ch, Type &sh)
|
|
|
304
304
|
// a1 = pivot point on diagonal
|
|
305
305
|
// a2 = lower triangular entry we want to annihilate
|
|
306
306
|
const Type epsilon = _svd_config<Type>::QR_GIVENS_EPSILON;
|
|
307
|
-
Type rho =
|
|
307
|
+
Type rho = sqrt(a1*a1 + a2*a2);
|
|
308
308
|
|
|
309
309
|
sh = rho > epsilon ? a2 : Type(0);
|
|
310
310
|
ch = abs(a1) + max(rho,epsilon);
|
|
311
311
|
bool b = a1 < Type(0);
|
|
312
312
|
condSwap(b,sh,ch);
|
|
313
|
-
Type w =
|
|
313
|
+
Type w = recipSqrt(ch*ch+sh*sh);
|
|
314
314
|
ch *= w;
|
|
315
315
|
sh *= w;
|
|
316
316
|
}
|
|
@@ -432,21 +432,15 @@ void _svd(// input A
|
|
|
432
432
|
);
|
|
433
433
|
}
|
|
434
434
|
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
// output S
|
|
445
|
-
Type &s11, Type &s12,
|
|
446
|
-
Type &s21, Type &s22,
|
|
447
|
-
// output V
|
|
448
|
-
Type &v11, Type &v12,
|
|
449
|
-
Type &v21, Type &v22)
|
|
435
|
+
template <typename Type>
|
|
436
|
+
inline CUDA_CALLABLE void _svd_2( // input A
|
|
437
|
+
Type a11, Type a12, Type a21, Type a22,
|
|
438
|
+
// output U
|
|
439
|
+
Type& u11, Type& u12, Type& u21, Type& u22,
|
|
440
|
+
// output S
|
|
441
|
+
Type& s1, Type& s2,
|
|
442
|
+
// output V
|
|
443
|
+
Type& v11, Type& v12, Type& v21, Type& v22)
|
|
450
444
|
{
|
|
451
445
|
// Step 1: Compute ATA
|
|
452
446
|
Type ATA11 = a11 * a11 + a21 * a21;
|
|
@@ -455,39 +449,56 @@ void _svd_2(// input A
|
|
|
455
449
|
|
|
456
450
|
// Step 2: Eigenanalysis
|
|
457
451
|
Type trace = ATA11 + ATA22;
|
|
458
|
-
Type
|
|
459
|
-
Type
|
|
460
|
-
Type lambda1 = (trace + sqrt_term) * Type(0.5);
|
|
461
|
-
Type lambda2 = (trace - sqrt_term) * Type(0.5);
|
|
452
|
+
Type diff = ATA11 - ATA22;
|
|
453
|
+
Type discriminant = diff * diff + Type(4) * ATA12 * ATA12;
|
|
462
454
|
|
|
463
455
|
// Step 3: Singular values
|
|
464
|
-
|
|
456
|
+
if (discriminant == Type(0))
|
|
457
|
+
{
|
|
458
|
+
// Duplicate eigenvalue, A ~ s Id
|
|
459
|
+
s1 = s2 = sqrt(Type(0.5) * trace);
|
|
460
|
+
u11 = v11 = Type(1);
|
|
461
|
+
u12 = v12 = Type(0);
|
|
462
|
+
u21 = v21 = Type(0);
|
|
463
|
+
u22 = v22 = Type(1);
|
|
464
|
+
return;
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
// General case
|
|
468
|
+
Type sqrt_term = sqrt(discriminant);
|
|
469
|
+
Type lambda1 = (trace + sqrt_term) * Type(0.5);
|
|
470
|
+
Type lambda2 = (trace - sqrt_term) * Type(0.5);
|
|
471
|
+
Type inv_sigma1 = recipSqrt(lambda1);
|
|
472
|
+
Type sigma1 = Type(1) / inv_sigma1;
|
|
465
473
|
Type sigma2 = sqrt(lambda2);
|
|
466
474
|
|
|
467
475
|
// Step 4: Eigenvectors (find V)
|
|
468
|
-
Type
|
|
469
|
-
Type
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
476
|
+
Type v1y = diff - sqrt_term + Type(2) * ATA12, v1x = diff + sqrt_term - Type(2) * ATA12;
|
|
477
|
+
Type len1_sq = v1x * v1x + v1y * v1y;
|
|
478
|
+
if (len1_sq == Type(0)) {
|
|
479
|
+
v11 = Type(0.707106781186547524401); // M_SQRT1_2
|
|
480
|
+
v21 = v11;
|
|
481
|
+
} else {
|
|
482
|
+
Type inv_len1 = recipSqrt(len1_sq);
|
|
483
|
+
v11 = v1x * inv_len1;
|
|
484
|
+
v21 = v1y * inv_len1;
|
|
485
|
+
}
|
|
486
|
+
v12 = -v21;
|
|
487
|
+
v22 = v11;
|
|
475
488
|
|
|
476
489
|
// Step 5: Compute U
|
|
477
|
-
Type inv_sigma1 = (sigma1 > Type(1e-6)) ? Type(1.0) / sigma1 : Type(0.0);
|
|
478
|
-
Type inv_sigma2 = (sigma2 > Type(1e-6)) ? Type(1.0) / sigma2 : Type(0.0);
|
|
479
|
-
|
|
480
490
|
u11 = (a11 * v11 + a12 * v21) * inv_sigma1;
|
|
481
|
-
u12 = (a11 * v12 + a12 * v22) * inv_sigma2;
|
|
482
491
|
u21 = (a21 * v11 + a22 * v21) * inv_sigma1;
|
|
483
|
-
|
|
492
|
+
// sigma2 may be zero, but we can complete U orthogonally up to determinant's sign
|
|
493
|
+
Type det_sign = wp::sign(a11 * a22 - a12 * a21);
|
|
494
|
+
u12 = -u21 * det_sign;
|
|
495
|
+
u22 = u11 * det_sign;
|
|
484
496
|
|
|
485
497
|
// Step 6: Set S
|
|
486
|
-
|
|
487
|
-
|
|
498
|
+
s1 = sigma1;
|
|
499
|
+
s2 = sigma2;
|
|
488
500
|
}
|
|
489
501
|
|
|
490
|
-
|
|
491
502
|
template<typename Type>
|
|
492
503
|
inline CUDA_CALLABLE void svd3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& U, vec_t<3,Type>& sigma, mat_t<3,3,Type>& V) {
|
|
493
504
|
Type s12, s13, s21, s23, s31, s32;
|
|
@@ -550,15 +561,14 @@ inline CUDA_CALLABLE void adj_svd3(const mat_t<3,3,Type>& A,
|
|
|
550
561
|
|
|
551
562
|
template<typename Type>
|
|
552
563
|
inline CUDA_CALLABLE void svd2(const mat_t<2,2,Type>& A, mat_t<2,2,Type>& U, vec_t<2,Type>& sigma, mat_t<2,2,Type>& V) {
|
|
553
|
-
Type s12, s21;
|
|
554
564
|
_svd_2(A.data[0][0], A.data[0][1],
|
|
555
565
|
A.data[1][0], A.data[1][1],
|
|
556
566
|
|
|
557
567
|
U.data[0][0], U.data[0][1],
|
|
558
568
|
U.data[1][0], U.data[1][1],
|
|
559
569
|
|
|
560
|
-
sigma[0],
|
|
561
|
-
|
|
570
|
+
sigma[0],
|
|
571
|
+
sigma[1],
|
|
562
572
|
|
|
563
573
|
V.data[0][0], V.data[0][1],
|
|
564
574
|
V.data[1][0], V.data[1][1]);
|