warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl → 1.7.2__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (60) hide show
  1. warp/autograd.py +12 -2
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +1 -1
  5. warp/builtins.py +103 -66
  6. warp/codegen.py +48 -27
  7. warp/config.py +1 -1
  8. warp/context.py +112 -49
  9. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  10. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  11. warp/fem/cache.py +1 -1
  12. warp/fem/field/field.py +11 -1
  13. warp/fem/field/nodal_field.py +36 -22
  14. warp/fem/geometry/adaptive_nanogrid.py +7 -3
  15. warp/fem/geometry/trimesh.py +4 -12
  16. warp/jax_experimental/custom_call.py +14 -2
  17. warp/jax_experimental/ffi.py +100 -67
  18. warp/native/builtin.h +91 -65
  19. warp/native/svd.h +59 -49
  20. warp/native/tile.h +55 -26
  21. warp/native/volume.cpp +2 -2
  22. warp/native/volume_builder.cu +33 -22
  23. warp/native/warp.cu +1 -1
  24. warp/render/render_opengl.py +41 -34
  25. warp/render/render_usd.py +96 -6
  26. warp/sim/collide.py +11 -9
  27. warp/sim/inertia.py +189 -156
  28. warp/sim/integrator_euler.py +3 -0
  29. warp/sim/integrator_xpbd.py +3 -0
  30. warp/sim/model.py +56 -31
  31. warp/sim/render.py +4 -0
  32. warp/sparse.py +1 -1
  33. warp/stubs.py +73 -25
  34. warp/tests/assets/torus.usda +1 -1
  35. warp/tests/cuda/test_streams.py +1 -1
  36. warp/tests/sim/test_collision.py +237 -206
  37. warp/tests/sim/test_inertia.py +161 -0
  38. warp/tests/sim/test_model.py +5 -3
  39. warp/tests/sim/{flaky_test_sim_grad.py → test_sim_grad.py} +1 -4
  40. warp/tests/sim/test_xpbd.py +399 -0
  41. warp/tests/test_array.py +8 -7
  42. warp/tests/test_atomic.py +181 -2
  43. warp/tests/test_builtins_resolution.py +38 -38
  44. warp/tests/test_codegen.py +24 -3
  45. warp/tests/test_examples.py +16 -6
  46. warp/tests/test_fem.py +93 -14
  47. warp/tests/test_func.py +1 -1
  48. warp/tests/test_mat.py +416 -119
  49. warp/tests/test_quat.py +321 -137
  50. warp/tests/test_struct.py +116 -0
  51. warp/tests/test_vec.py +320 -174
  52. warp/tests/tile/test_tile.py +27 -0
  53. warp/tests/tile/test_tile_load.py +124 -0
  54. warp/tests/unittest_suites.py +2 -5
  55. warp/types.py +107 -9
  56. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2.dist-info}/METADATA +41 -19
  57. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2.dist-info}/RECORD +60 -57
  58. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2.dist-info}/WHEEL +1 -1
  59. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2.dist-info}/licenses/LICENSE.md +0 -26
  60. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2.dist-info}/top_level.txt +0 -0
@@ -190,7 +190,7 @@ class Trimesh(Geometry):
190
190
  return args
191
191
 
192
192
  def _bvh_id(self, device):
193
- if self._tri_bvh is None or self._tri_bvh.device != device:
193
+ if self._tri_bvh is None or self._tri_bvh.device != wp.get_device(device):
194
194
  return _NULL_BVH
195
195
  return self._tri_bvh.id
196
196
 
@@ -519,7 +519,7 @@ class Trimesh(Geometry):
519
519
  @wp.kernel
520
520
  def _compute_tri_bounds(
521
521
  tri_vertex_indices: wp.array2d(dtype=int),
522
- positions: wp.array(dtype=wp.vec2),
522
+ positions: wp.array(dtype=Any),
523
523
  lowers: wp.array(dtype=wp.vec3),
524
524
  uppers: wp.array(dtype=wp.vec3),
525
525
  ):
@@ -528,16 +528,8 @@ class Trimesh(Geometry):
528
528
  p1 = _bvh_vec(positions[tri_vertex_indices[t, 1]])
529
529
  p2 = _bvh_vec(positions[tri_vertex_indices[t, 2]])
530
530
 
531
- lowers[t] = wp.vec3(
532
- wp.min(wp.min(p0[0], p1[0]), p2[0]),
533
- wp.min(wp.min(p0[1], p1[1]), p2[1]),
534
- wp.min(wp.min(p0[2], p1[2]), p2[2]),
535
- )
536
- uppers[t] = wp.vec3(
537
- wp.max(wp.max(p0[0], p1[0]), p2[0]),
538
- wp.max(wp.max(p0[1], p1[1]), p2[1]),
539
- wp.max(wp.max(p0[2], p1[2]), p2[2]),
540
- )
531
+ lowers[t] = wp.min(wp.min(p0, p1), p2)
532
+ uppers[t] = wp.max(wp.max(p0, p1), p2)
541
533
 
542
534
 
543
535
  @wp.struct
@@ -126,7 +126,14 @@ def _create_jax_warp_primitive():
126
126
 
127
127
  # Create and register the primitive.
128
128
  # TODO add default implementation that calls the kernel via warp.
129
- _jax_warp_p = jax.core.Primitive("jax_warp")
129
+ try:
130
+ # newer JAX versions
131
+ import jax.extend
132
+
133
+ _jax_warp_p = jax.extend.core.Primitive("jax_warp")
134
+ except (ImportError, AttributeError):
135
+ # older JAX versions
136
+ _jax_warp_p = jax.core.Primitive("jax_warp")
130
137
  _jax_warp_p.multiple_results = True
131
138
 
132
139
  # TODO Just launch the kernel directly, but make sure the argument
@@ -262,7 +269,12 @@ def _create_jax_warp_primitive():
262
269
  capsule = PyCapsule_New(ccall_address.value, b"xla._CUSTOM_CALL_TARGET", PyCapsule_Destructor(0))
263
270
 
264
271
  # Register the callback in XLA.
265
- jax.lib.xla_client.register_custom_call_target("warp_call", capsule, platform="gpu")
272
+ try:
273
+ # newer JAX versions
274
+ jax.ffi.register_ffi_target("warp_call", capsule, platform="gpu", api_version=0)
275
+ except AttributeError:
276
+ # older JAX versions
277
+ jax.lib.xla_client.register_custom_call_target("warp_call", capsule, platform="gpu")
266
278
 
267
279
  def default_layout(shape):
268
280
  return range(len(shape) - 1, -1, -1)
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import ctypes
17
+ import threading
17
18
  import traceback
18
19
  from typing import Callable
19
20
 
@@ -27,68 +28,6 @@ from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_war
27
28
  from .xla_ffi import *
28
29
 
29
30
 
30
- def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None):
31
- """Create a JAX callback from a Warp kernel.
32
-
33
- NOTE: This is an experimental feature under development.
34
-
35
- Args:
36
- kernel: The Warp kernel to launch.
37
- num_outputs: Optional. Specify the number of output arguments if greater than 1.
38
- vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
39
- This argument can also be specified for individual calls.
40
- launch_dims: Optional. Specify the default kernel launch dimensions. If None, launch
41
- dimensions are inferred from the shape of the first array argument.
42
- This argument can also be specified for individual calls.
43
- output_dims: Optional. Specify the default dimensions of output arrays. If None, output
44
- dimensions are inferred from the launch dimensions.
45
- This argument can also be specified for individual calls.
46
-
47
- Limitations:
48
- - All kernel arguments must be contiguous arrays or scalars.
49
- - Scalars must be static arguments in JAX.
50
- - Input arguments are followed by output arguments in the Warp kernel definition.
51
- - There must be at least one output argument.
52
- - Only the CUDA backend is supported.
53
- """
54
-
55
- return FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims)
56
-
57
-
58
- def jax_callable(
59
- func: Callable,
60
- num_outputs: int = 1,
61
- graph_compatible: bool = True,
62
- vmap_method: str = "broadcast_all",
63
- output_dims=None,
64
- ):
65
- """Create a JAX callback from an annotated Python function.
66
-
67
- The Python function arguments must have type annotations like Warp kernels.
68
-
69
- NOTE: This is an experimental feature under development.
70
-
71
- Args:
72
- func: The Python function to call.
73
- num_outputs: Optional. Specify the number of output arguments if greater than 1.
74
- graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
75
- vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
76
- This argument can also be specified for individual calls.
77
- output_dims: Optional. Specify the default dimensions of output arrays.
78
- If ``None``, output dimensions are inferred from the launch dimensions.
79
- This argument can also be specified for individual calls.
80
-
81
- Limitations:
82
- - All kernel arguments must be contiguous arrays or scalars.
83
- - Scalars must be static arguments in JAX.
84
- - Input arguments are followed by output arguments in the Warp kernel definition.
85
- - There must be at least one output argument.
86
- - Only the CUDA backend is supported.
87
- """
88
-
89
- return FfiCallable(func, num_outputs, graph_compatible, vmap_method, output_dims)
90
-
91
-
92
31
  class FfiArg:
93
32
  def __init__(self, name, type):
94
33
  self.name = name
@@ -560,7 +499,11 @@ class FfiCallable:
560
499
 
561
500
  # call the Python function with reconstructed arguments
562
501
  with wp.ScopedStream(stream, sync_enter=False):
563
- self.func(*arg_list)
502
+ if stream.is_capturing:
503
+ with wp.ScopedCapture(stream=stream, external=True):
504
+ self.func(*arg_list)
505
+ else:
506
+ self.func(*arg_list)
564
507
 
565
508
  except Exception as e:
566
509
  print(traceback.format_exc())
@@ -571,6 +514,98 @@ class FfiCallable:
571
514
  return None
572
515
 
573
516
 
517
+ # Holders for the custom callbacks to keep them alive.
518
+ _FFI_CALLABLE_REGISTRY: dict[str, FfiCallable] = {}
519
+ _FFI_KERNEL_REGISTRY: dict[str, FfiKernel] = {}
520
+ _FFI_REGISTRY_LOCK = threading.Lock()
521
+
522
+
523
+ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None):
524
+ """Create a JAX callback from a Warp kernel.
525
+
526
+ NOTE: This is an experimental feature under development.
527
+
528
+ Args:
529
+ kernel: The Warp kernel to launch.
530
+ num_outputs: Optional. Specify the number of output arguments if greater than 1.
531
+ vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
532
+ This argument can also be specified for individual calls.
533
+ launch_dims: Optional. Specify the default kernel launch dimensions. If None, launch
534
+ dimensions are inferred from the shape of the first array argument.
535
+ This argument can also be specified for individual calls.
536
+ output_dims: Optional. Specify the default dimensions of output arrays. If None, output
537
+ dimensions are inferred from the launch dimensions.
538
+ This argument can also be specified for individual calls.
539
+
540
+ Limitations:
541
+ - All kernel arguments must be contiguous arrays or scalars.
542
+ - Scalars must be static arguments in JAX.
543
+ - Input arguments are followed by output arguments in the Warp kernel definition.
544
+ - There must be at least one output argument.
545
+ - Only the CUDA backend is supported.
546
+ """
547
+ key = (
548
+ kernel.func,
549
+ num_outputs,
550
+ vmap_method,
551
+ tuple(launch_dims) if launch_dims else launch_dims,
552
+ tuple(sorted(output_dims.items())) if output_dims else output_dims,
553
+ )
554
+
555
+ with _FFI_REGISTRY_LOCK:
556
+ if key not in _FFI_KERNEL_REGISTRY:
557
+ new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims)
558
+ _FFI_KERNEL_REGISTRY[key] = new_kernel
559
+
560
+ return _FFI_KERNEL_REGISTRY[key]
561
+
562
+
563
+ def jax_callable(
564
+ func: Callable,
565
+ num_outputs: int = 1,
566
+ graph_compatible: bool = True,
567
+ vmap_method: str = "broadcast_all",
568
+ output_dims=None,
569
+ ):
570
+ """Create a JAX callback from an annotated Python function.
571
+
572
+ The Python function arguments must have type annotations like Warp kernels.
573
+
574
+ NOTE: This is an experimental feature under development.
575
+
576
+ Args:
577
+ func: The Python function to call.
578
+ num_outputs: Optional. Specify the number of output arguments if greater than 1.
579
+ graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
580
+ vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
581
+ This argument can also be specified for individual calls.
582
+ output_dims: Optional. Specify the default dimensions of output arrays.
583
+ If ``None``, output dimensions are inferred from the launch dimensions.
584
+ This argument can also be specified for individual calls.
585
+
586
+ Limitations:
587
+ - All kernel arguments must be contiguous arrays or scalars.
588
+ - Scalars must be static arguments in JAX.
589
+ - Input arguments are followed by output arguments in the Warp kernel definition.
590
+ - There must be at least one output argument.
591
+ - Only the CUDA backend is supported.
592
+ """
593
+ key = (
594
+ func,
595
+ num_outputs,
596
+ graph_compatible,
597
+ vmap_method,
598
+ tuple(sorted(output_dims.items())) if output_dims else output_dims,
599
+ )
600
+
601
+ with _FFI_REGISTRY_LOCK:
602
+ if key not in _FFI_CALLABLE_REGISTRY:
603
+ new_callable = FfiCallable(func, num_outputs, graph_compatible, vmap_method, output_dims)
604
+ _FFI_CALLABLE_REGISTRY[key] = new_callable
605
+
606
+ return _FFI_CALLABLE_REGISTRY[key]
607
+
608
+
574
609
  ###############################################################################
575
610
  #
576
611
  # Generic FFI callbacks for Python functions of the form
@@ -578,9 +613,6 @@ class FfiCallable:
578
613
  #
579
614
  ###############################################################################
580
615
 
581
- # Holder for the custom callbacks to keep them alive.
582
- ffi_callbacks = {}
583
-
584
616
 
585
617
  def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = True) -> None:
586
618
  """Create a JAX callback from a Python function.
@@ -640,7 +672,8 @@ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = Tr
640
672
 
641
673
  FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
642
674
  callback_func = FFI_CCALLFUNC(ffi_callback)
643
- ffi_callbacks[name] = callback_func
675
+ with _FFI_REGISTRY_LOCK:
676
+ _FFI_CALLABLE_REGISTRY[name] = callback_func
644
677
  ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p)
645
678
  ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
646
679
  jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA")
warp/native/builtin.h CHANGED
@@ -1271,6 +1271,29 @@ inline CUDA_CALLABLE T atomic_add(T* buf, T value)
1271
1271
  #endif
1272
1272
  }
1273
1273
 
1274
+ // emulate atomic int64 add with atomicCAS()
1275
+ template <>
1276
+ inline CUDA_CALLABLE int64 atomic_add(int64* address, int64 val)
1277
+ {
1278
+ #if defined(__CUDA_ARCH__)
1279
+ unsigned long long int *address_as_ull = (unsigned long long int*)address;
1280
+ unsigned long long int old = *address_as_ull, assumed;
1281
+
1282
+ while (val < (int64)old)
1283
+ {
1284
+ assumed = old;
1285
+ old = atomicCAS(address_as_ull, assumed, (int64)val);
1286
+ }
1287
+
1288
+ return (int64)old;
1289
+
1290
+ #else
1291
+ int64 old = *address;
1292
+ *address = min(old, val);
1293
+ return old;
1294
+ #endif
1295
+ }
1296
+
1274
1297
  template<>
1275
1298
  inline CUDA_CALLABLE float16 atomic_add(float16* buf, float16 value)
1276
1299
  {
@@ -1306,53 +1329,6 @@ inline CUDA_CALLABLE float16 atomic_add(float16* buf, float16 value)
1306
1329
  #undef __PTR
1307
1330
 
1308
1331
  #endif // CUDA compiled by NVRTC
1309
-
1310
- }
1311
-
1312
- // emulate atomic float max with atomicCAS()
1313
- inline CUDA_CALLABLE float atomic_max(float* address, float val)
1314
- {
1315
- #if defined(__CUDA_ARCH__)
1316
- int *address_as_int = (int*)address;
1317
- int old = *address_as_int, assumed;
1318
-
1319
- while (val > __int_as_float(old))
1320
- {
1321
- assumed = old;
1322
- old = atomicCAS(address_as_int, assumed,
1323
- __float_as_int(val));
1324
- }
1325
-
1326
- return __int_as_float(old);
1327
-
1328
- #else
1329
- float old = *address;
1330
- *address = max(old, val);
1331
- return old;
1332
- #endif
1333
- }
1334
-
1335
- // emulate atomic float min with atomicCAS()
1336
- inline CUDA_CALLABLE float atomic_min(float* address, float val)
1337
- {
1338
- #if defined(__CUDA_ARCH__)
1339
- int *address_as_int = (int*)address;
1340
- int old = *address_as_int, assumed;
1341
-
1342
- while (val < __int_as_float(old))
1343
- {
1344
- assumed = old;
1345
- old = atomicCAS(address_as_int, assumed,
1346
- __float_as_int(val));
1347
- }
1348
-
1349
- return __int_as_float(old);
1350
-
1351
- #else
1352
- float old = *address;
1353
- *address = min(old, val);
1354
- return old;
1355
- #endif
1356
1332
  }
1357
1333
 
1358
1334
  template<>
@@ -1388,33 +1364,47 @@ inline CUDA_CALLABLE float64 atomic_add(float64* buf, float64 value)
1388
1364
  #undef __PTR
1389
1365
 
1390
1366
  #endif // CUDA compiled by NVRTC
1367
+ }
1368
+
1369
+ template <typename T>
1370
+ inline CUDA_CALLABLE T atomic_min(T* address, T val)
1371
+ {
1372
+ #if defined(__CUDA_ARCH__)
1373
+ return atomicMin(address, val);
1391
1374
 
1375
+ #else
1376
+ T old = *address;
1377
+ *address = min(old, val);
1378
+ return old;
1379
+ #endif
1392
1380
  }
1393
1381
 
1394
- // emulate atomic double max with atomicCAS()
1395
- inline CUDA_CALLABLE double atomic_max(double* address, double val)
1382
+ // emulate atomic float min with atomicCAS()
1383
+ template <>
1384
+ inline CUDA_CALLABLE float atomic_min(float* address, float val)
1396
1385
  {
1397
1386
  #if defined(__CUDA_ARCH__)
1398
- unsigned long long int *address_as_ull = (unsigned long long int*)address;
1399
- unsigned long long int old = *address_as_ull, assumed;
1400
-
1401
- while (val > __longlong_as_double(old))
1387
+ int *address_as_int = (int*)address;
1388
+ int old = *address_as_int, assumed;
1389
+
1390
+ while (val < __int_as_float(old))
1402
1391
  {
1403
1392
  assumed = old;
1404
- old = atomicCAS(address_as_ull, assumed,
1405
- __double_as_longlong(val));
1393
+ old = atomicCAS(address_as_int, assumed,
1394
+ __float_as_int(val));
1406
1395
  }
1407
1396
 
1408
- return __longlong_as_double(old);
1397
+ return __int_as_float(old);
1409
1398
 
1410
1399
  #else
1411
- double old = *address;
1412
- *address = max(old, val);
1400
+ float old = *address;
1401
+ *address = min(old, val);
1413
1402
  return old;
1414
1403
  #endif
1415
1404
  }
1416
1405
 
1417
1406
  // emulate atomic double min with atomicCAS()
1407
+ template <>
1418
1408
  inline CUDA_CALLABLE double atomic_min(double* address, double val)
1419
1409
  {
1420
1410
  #if defined(__CUDA_ARCH__)
@@ -1437,27 +1427,63 @@ inline CUDA_CALLABLE double atomic_min(double* address, double val)
1437
1427
  #endif
1438
1428
  }
1439
1429
 
1440
- inline CUDA_CALLABLE int atomic_max(int* address, int val)
1430
+ template <typename T>
1431
+ inline CUDA_CALLABLE T atomic_max(T* address, T val)
1441
1432
  {
1442
1433
  #if defined(__CUDA_ARCH__)
1443
1434
  return atomicMax(address, val);
1444
1435
 
1445
1436
  #else
1446
- int old = *address;
1437
+ T old = *address;
1447
1438
  *address = max(old, val);
1448
1439
  return old;
1449
1440
  #endif
1450
1441
  }
1451
1442
 
1452
- // atomic int min
1453
- inline CUDA_CALLABLE int atomic_min(int* address, int val)
1443
+ // emulate atomic float max with atomicCAS()
1444
+ template<>
1445
+ inline CUDA_CALLABLE float atomic_max(float* address, float val)
1454
1446
  {
1455
1447
  #if defined(__CUDA_ARCH__)
1456
- return atomicMin(address, val);
1448
+ int *address_as_int = (int*)address;
1449
+ int old = *address_as_int, assumed;
1450
+
1451
+ while (val > __int_as_float(old))
1452
+ {
1453
+ assumed = old;
1454
+ old = atomicCAS(address_as_int, assumed,
1455
+ __float_as_int(val));
1456
+ }
1457
+
1458
+ return __int_as_float(old);
1457
1459
 
1458
1460
  #else
1459
- int old = *address;
1460
- *address = min(old, val);
1461
+ float old = *address;
1462
+ *address = max(old, val);
1463
+ return old;
1464
+ #endif
1465
+ }
1466
+
1467
+ // emulate atomic double max with atomicCAS()
1468
+ template<>
1469
+ inline CUDA_CALLABLE double atomic_max(double* address, double val)
1470
+ {
1471
+ #if defined(__CUDA_ARCH__)
1472
+ unsigned long long int *address_as_ull = (unsigned long long int*)address;
1473
+ unsigned long long int old = *address_as_ull, assumed;
1474
+
1475
+ while (val > __longlong_as_double(old))
1476
+ {
1477
+ assumed = old;
1478
+ old = atomicCAS(address_as_ull, assumed,
1479
+ __double_as_longlong(val));
1480
+ }
1481
+
1482
+ return __longlong_as_double(old);
1483
+
1484
+ #else
1485
+ double old = *address;
1486
+ *address = max(old, val);
1461
1487
  return old;
1462
1488
  #endif
1463
1489
  }
warp/native/svd.h CHANGED
@@ -60,17 +60,17 @@ struct _svd_config<double> {
60
60
  static constexpr int JACOBI_ITERATIONS = 8;
61
61
  };
62
62
 
63
-
64
-
65
- // TODO: replace sqrt with rsqrt
66
-
67
- template<typename Type>
68
- inline CUDA_CALLABLE
69
- Type accurateSqrt(Type x)
63
+ template <typename Type> inline CUDA_CALLABLE Type recipSqrt(Type x)
70
64
  {
71
- return x / sqrt(x);
65
+ #if defined(__CUDA_ARCH__)
66
+ return ::rsqrt(x);
67
+ #else
68
+ return Type(1) / sqrt(x);
69
+ #endif
72
70
  }
73
71
 
72
+ template <> inline CUDA_CALLABLE wp::half recipSqrt(wp::half x) { return wp::half(1) / sqrt(x); }
73
+
74
74
  template<typename Type>
75
75
  inline CUDA_CALLABLE
76
76
  void condSwap(bool c, Type &X, Type &Y)
@@ -175,7 +175,7 @@ void approximateGivensQuaternion(Type a11, Type a12, Type a22, Type &ch, Type &s
175
175
  ch = Type(2)*(a11-a22);
176
176
  sh = a12;
177
177
  bool b = Type(_gamma)*sh*sh < ch*ch;
178
- Type w = Type(1) / sqrt(ch*ch+sh*sh);
178
+ Type w = recipSqrt(ch*ch+sh*sh);
179
179
  ch=b?w*ch:Type(_cstar);
180
180
  sh=b?w*sh:Type(_sstar);
181
181
  }
@@ -304,13 +304,13 @@ void QRGivensQuaternion(Type a1, Type a2, Type &ch, Type &sh)
304
304
  // a1 = pivot point on diagonal
305
305
  // a2 = lower triangular entry we want to annihilate
306
306
  const Type epsilon = _svd_config<Type>::QR_GIVENS_EPSILON;
307
- Type rho = accurateSqrt(a1*a1 + a2*a2);
307
+ Type rho = sqrt(a1*a1 + a2*a2);
308
308
 
309
309
  sh = rho > epsilon ? a2 : Type(0);
310
310
  ch = abs(a1) + max(rho,epsilon);
311
311
  bool b = a1 < Type(0);
312
312
  condSwap(b,sh,ch);
313
- Type w = Type(1) / sqrt(ch*ch+sh*sh);
313
+ Type w = recipSqrt(ch*ch+sh*sh);
314
314
  ch *= w;
315
315
  sh *= w;
316
316
  }
@@ -432,21 +432,15 @@ void _svd(// input A
432
432
  );
433
433
  }
434
434
 
435
-
436
- template<typename Type>
437
- inline CUDA_CALLABLE
438
- void _svd_2(// input A
439
- Type a11, Type a12,
440
- Type a21, Type a22,
441
- // output U
442
- Type &u11, Type &u12,
443
- Type &u21, Type &u22,
444
- // output S
445
- Type &s11, Type &s12,
446
- Type &s21, Type &s22,
447
- // output V
448
- Type &v11, Type &v12,
449
- Type &v21, Type &v22)
435
+ template <typename Type>
436
+ inline CUDA_CALLABLE void _svd_2( // input A
437
+ Type a11, Type a12, Type a21, Type a22,
438
+ // output U
439
+ Type& u11, Type& u12, Type& u21, Type& u22,
440
+ // output S
441
+ Type& s1, Type& s2,
442
+ // output V
443
+ Type& v11, Type& v12, Type& v21, Type& v22)
450
444
  {
451
445
  // Step 1: Compute ATA
452
446
  Type ATA11 = a11 * a11 + a21 * a21;
@@ -455,39 +449,56 @@ void _svd_2(// input A
455
449
 
456
450
  // Step 2: Eigenanalysis
457
451
  Type trace = ATA11 + ATA22;
458
- Type det = ATA11 * ATA22 - ATA12 * ATA12;
459
- Type sqrt_term = sqrt(trace * trace - Type(4.0) * det);
460
- Type lambda1 = (trace + sqrt_term) * Type(0.5);
461
- Type lambda2 = (trace - sqrt_term) * Type(0.5);
452
+ Type diff = ATA11 - ATA22;
453
+ Type discriminant = diff * diff + Type(4) * ATA12 * ATA12;
462
454
 
463
455
  // Step 3: Singular values
464
- Type sigma1 = sqrt(lambda1);
456
+ if (discriminant == Type(0))
457
+ {
458
+ // Duplicate eigenvalue, A ~ s Id
459
+ s1 = s2 = sqrt(Type(0.5) * trace);
460
+ u11 = v11 = Type(1);
461
+ u12 = v12 = Type(0);
462
+ u21 = v21 = Type(0);
463
+ u22 = v22 = Type(1);
464
+ return;
465
+ }
466
+
467
+ // General case
468
+ Type sqrt_term = sqrt(discriminant);
469
+ Type lambda1 = (trace + sqrt_term) * Type(0.5);
470
+ Type lambda2 = (trace - sqrt_term) * Type(0.5);
471
+ Type inv_sigma1 = recipSqrt(lambda1);
472
+ Type sigma1 = Type(1) / inv_sigma1;
465
473
  Type sigma2 = sqrt(lambda2);
466
474
 
467
475
  // Step 4: Eigenvectors (find V)
468
- Type v1x = ATA12, v1y = lambda1 - ATA11; // For first eigenvector
469
- Type v2x = ATA12, v2y = lambda2 - ATA11; // For second eigenvector
470
- Type norm1 = sqrt(v1x * v1x + v1y * v1y);
471
- Type norm2 = sqrt(v2x * v2x + v2y * v2y);
472
-
473
- v11 = v1x / norm1; v12 = v2x / norm2;
474
- v21 = v1y / norm1; v22 = v2y / norm2;
476
+ Type v1y = diff - sqrt_term + Type(2) * ATA12, v1x = diff + sqrt_term - Type(2) * ATA12;
477
+ Type len1_sq = v1x * v1x + v1y * v1y;
478
+ if (len1_sq == Type(0)) {
479
+ v11 = Type(0.707106781186547524401); // M_SQRT1_2
480
+ v21 = v11;
481
+ } else {
482
+ Type inv_len1 = recipSqrt(len1_sq);
483
+ v11 = v1x * inv_len1;
484
+ v21 = v1y * inv_len1;
485
+ }
486
+ v12 = -v21;
487
+ v22 = v11;
475
488
 
476
489
  // Step 5: Compute U
477
- Type inv_sigma1 = (sigma1 > Type(1e-6)) ? Type(1.0) / sigma1 : Type(0.0);
478
- Type inv_sigma2 = (sigma2 > Type(1e-6)) ? Type(1.0) / sigma2 : Type(0.0);
479
-
480
490
  u11 = (a11 * v11 + a12 * v21) * inv_sigma1;
481
- u12 = (a11 * v12 + a12 * v22) * inv_sigma2;
482
491
  u21 = (a21 * v11 + a22 * v21) * inv_sigma1;
483
- u22 = (a21 * v12 + a22 * v22) * inv_sigma2;
492
+ // sigma2 may be zero, but we can complete U orthogonally up to determinant's sign
493
+ Type det_sign = wp::sign(a11 * a22 - a12 * a21);
494
+ u12 = -u21 * det_sign;
495
+ u22 = u11 * det_sign;
484
496
 
485
497
  // Step 6: Set S
486
- s11 = sigma1; s12 = Type(0.0);
487
- s21 = Type(0.0); s22 = sigma2;
498
+ s1 = sigma1;
499
+ s2 = sigma2;
488
500
  }
489
501
 
490
-
491
502
  template<typename Type>
492
503
  inline CUDA_CALLABLE void svd3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& U, vec_t<3,Type>& sigma, mat_t<3,3,Type>& V) {
493
504
  Type s12, s13, s21, s23, s31, s32;
@@ -550,15 +561,14 @@ inline CUDA_CALLABLE void adj_svd3(const mat_t<3,3,Type>& A,
550
561
 
551
562
  template<typename Type>
552
563
  inline CUDA_CALLABLE void svd2(const mat_t<2,2,Type>& A, mat_t<2,2,Type>& U, vec_t<2,Type>& sigma, mat_t<2,2,Type>& V) {
553
- Type s12, s21;
554
564
  _svd_2(A.data[0][0], A.data[0][1],
555
565
  A.data[1][0], A.data[1][1],
556
566
 
557
567
  U.data[0][0], U.data[0][1],
558
568
  U.data[1][0], U.data[1][1],
559
569
 
560
- sigma[0], s12,
561
- s21, sigma[1],
570
+ sigma[0],
571
+ sigma[1],
562
572
 
563
573
  V.data[0][0], V.data[0][1],
564
574
  V.data[1][0], V.data[1][1]);