warp-lang 1.7.2__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (180) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +125 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +257 -101
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +657 -223
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  35. warp/examples/optim/example_drone.py +1 -1
  36. warp/examples/sim/example_cloth.py +1 -1
  37. warp/examples/sim/example_cloth_self_contact.py +48 -54
  38. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  39. warp/examples/tile/example_tile_cholesky.py +2 -1
  40. warp/examples/tile/example_tile_convolution.py +1 -1
  41. warp/examples/tile/example_tile_filtering.py +1 -1
  42. warp/examples/tile/example_tile_matmul.py +1 -1
  43. warp/examples/tile/example_tile_mlp.py +2 -0
  44. warp/fabric.py +7 -7
  45. warp/fem/__init__.py +5 -0
  46. warp/fem/adaptivity.py +1 -1
  47. warp/fem/cache.py +152 -63
  48. warp/fem/dirichlet.py +2 -2
  49. warp/fem/domain.py +136 -6
  50. warp/fem/field/field.py +141 -99
  51. warp/fem/field/nodal_field.py +85 -39
  52. warp/fem/field/virtual.py +97 -52
  53. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  54. warp/fem/geometry/closest_point.py +13 -0
  55. warp/fem/geometry/deformed_geometry.py +102 -40
  56. warp/fem/geometry/element.py +56 -2
  57. warp/fem/geometry/geometry.py +323 -22
  58. warp/fem/geometry/grid_2d.py +157 -62
  59. warp/fem/geometry/grid_3d.py +116 -20
  60. warp/fem/geometry/hexmesh.py +86 -20
  61. warp/fem/geometry/nanogrid.py +166 -86
  62. warp/fem/geometry/partition.py +59 -25
  63. warp/fem/geometry/quadmesh.py +86 -135
  64. warp/fem/geometry/tetmesh.py +47 -119
  65. warp/fem/geometry/trimesh.py +77 -270
  66. warp/fem/integrate.py +107 -52
  67. warp/fem/linalg.py +25 -58
  68. warp/fem/operator.py +124 -27
  69. warp/fem/quadrature/pic_quadrature.py +36 -14
  70. warp/fem/quadrature/quadrature.py +40 -16
  71. warp/fem/space/__init__.py +1 -1
  72. warp/fem/space/basis_function_space.py +66 -46
  73. warp/fem/space/basis_space.py +17 -4
  74. warp/fem/space/dof_mapper.py +1 -1
  75. warp/fem/space/function_space.py +2 -2
  76. warp/fem/space/grid_2d_function_space.py +4 -1
  77. warp/fem/space/hexmesh_function_space.py +4 -2
  78. warp/fem/space/nanogrid_function_space.py +3 -1
  79. warp/fem/space/partition.py +11 -2
  80. warp/fem/space/quadmesh_function_space.py +4 -1
  81. warp/fem/space/restriction.py +5 -2
  82. warp/fem/space/shape/__init__.py +10 -8
  83. warp/fem/space/tetmesh_function_space.py +4 -1
  84. warp/fem/space/topology.py +52 -21
  85. warp/fem/space/trimesh_function_space.py +4 -1
  86. warp/fem/utils.py +53 -8
  87. warp/jax.py +1 -2
  88. warp/jax_experimental/ffi.py +12 -17
  89. warp/jax_experimental/xla_ffi.py +37 -24
  90. warp/math.py +171 -1
  91. warp/native/array.h +99 -0
  92. warp/native/builtin.h +174 -31
  93. warp/native/coloring.cpp +1 -1
  94. warp/native/exports.h +118 -63
  95. warp/native/intersect.h +3 -3
  96. warp/native/mat.h +5 -10
  97. warp/native/mathdx.cpp +11 -5
  98. warp/native/matnn.h +1 -123
  99. warp/native/quat.h +28 -4
  100. warp/native/sparse.cpp +121 -258
  101. warp/native/sparse.cu +181 -274
  102. warp/native/spatial.h +305 -17
  103. warp/native/tile.h +583 -72
  104. warp/native/tile_radix_sort.h +1108 -0
  105. warp/native/tile_reduce.h +237 -2
  106. warp/native/tile_scan.h +240 -0
  107. warp/native/tuple.h +189 -0
  108. warp/native/vec.h +6 -16
  109. warp/native/warp.cpp +36 -4
  110. warp/native/warp.cu +574 -51
  111. warp/native/warp.h +47 -74
  112. warp/optim/linear.py +5 -1
  113. warp/paddle.py +7 -8
  114. warp/py.typed +0 -0
  115. warp/render/render_opengl.py +58 -29
  116. warp/render/render_usd.py +124 -61
  117. warp/sim/__init__.py +9 -0
  118. warp/sim/collide.py +252 -78
  119. warp/sim/graph_coloring.py +8 -1
  120. warp/sim/import_mjcf.py +4 -3
  121. warp/sim/import_usd.py +11 -7
  122. warp/sim/integrator.py +5 -2
  123. warp/sim/integrator_euler.py +1 -1
  124. warp/sim/integrator_featherstone.py +1 -1
  125. warp/sim/integrator_vbd.py +751 -320
  126. warp/sim/integrator_xpbd.py +1 -1
  127. warp/sim/model.py +265 -260
  128. warp/sim/utils.py +10 -7
  129. warp/sparse.py +303 -166
  130. warp/tape.py +52 -51
  131. warp/tests/cuda/test_conditional_captures.py +1046 -0
  132. warp/tests/cuda/test_streams.py +1 -1
  133. warp/tests/geometry/test_volume.py +2 -2
  134. warp/tests/interop/test_dlpack.py +9 -9
  135. warp/tests/interop/test_jax.py +0 -1
  136. warp/tests/run_coverage_serial.py +1 -1
  137. warp/tests/sim/disabled_kinematics.py +2 -2
  138. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  139. warp/tests/sim/test_collision.py +159 -51
  140. warp/tests/sim/test_coloring.py +15 -1
  141. warp/tests/test_array.py +254 -2
  142. warp/tests/test_array_reduce.py +2 -2
  143. warp/tests/test_atomic_cas.py +299 -0
  144. warp/tests/test_codegen.py +142 -19
  145. warp/tests/test_conditional.py +47 -1
  146. warp/tests/test_ctypes.py +0 -20
  147. warp/tests/test_devices.py +8 -0
  148. warp/tests/test_fabricarray.py +4 -2
  149. warp/tests/test_fem.py +58 -25
  150. warp/tests/test_func.py +42 -1
  151. warp/tests/test_grad.py +1 -1
  152. warp/tests/test_lerp.py +1 -3
  153. warp/tests/test_map.py +481 -0
  154. warp/tests/test_mat.py +1 -24
  155. warp/tests/test_quat.py +6 -15
  156. warp/tests/test_rounding.py +10 -38
  157. warp/tests/test_runlength_encode.py +7 -7
  158. warp/tests/test_smoothstep.py +1 -1
  159. warp/tests/test_sparse.py +51 -2
  160. warp/tests/test_spatial.py +507 -1
  161. warp/tests/test_struct.py +2 -2
  162. warp/tests/test_tuple.py +265 -0
  163. warp/tests/test_types.py +2 -2
  164. warp/tests/test_utils.py +24 -18
  165. warp/tests/tile/test_tile.py +420 -1
  166. warp/tests/tile/test_tile_mathdx.py +518 -14
  167. warp/tests/tile/test_tile_reduce.py +213 -0
  168. warp/tests/tile/test_tile_shared_memory.py +130 -1
  169. warp/tests/tile/test_tile_sort.py +117 -0
  170. warp/tests/unittest_suites.py +4 -6
  171. warp/types.py +462 -308
  172. warp/utils.py +647 -86
  173. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
  175. warp/stubs.py +0 -3381
  176. warp/tests/sim/test_xpbd.py +0 -399
  177. warp/tests/test_mlp.py +0 -282
  178. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/math.py CHANGED
@@ -22,11 +22,13 @@ Vector norm functions
22
22
  """
23
23
 
24
24
  __all__ = [
25
+ "norm_huber",
25
26
  "norm_l1",
26
27
  "norm_l2",
27
- "norm_huber",
28
28
  "norm_pseudo_huber",
29
29
  "smooth_normalize",
30
+ "transform_compose",
31
+ "transform_decompose",
30
32
  "transform_from_matrix",
31
33
  "transform_to_matrix",
32
34
  ]
@@ -142,6 +144,19 @@ def create_transform_from_matrix_func(dtype):
142
144
  """
143
145
  Construct a transformation from a 4x4 matrix.
144
146
 
147
+ .. math::
148
+ M = \\begin{bmatrix}
149
+ R_{00} & R_{01} & R_{02} & p_x \\\\
150
+ R_{10} & R_{11} & R_{12} & p_y \\\\
151
+ R_{20} & R_{21} & R_{22} & p_z \\\\
152
+ 0 & 0 & 0 & 1
153
+ \\end{bmatrix}
154
+
155
+ Where:
156
+
157
+ * :math:`R` is the 3x3 rotation matrix created from the orientation quaternion of the input transform.
158
+ * :math:`p` is the 3D position vector :math:`[p_x, p_y, p_z]` of the input transform.
159
+
145
160
  Args:
146
161
  mat (Matrix[4, 4, Float]): Matrix to convert.
147
162
 
@@ -177,6 +192,19 @@ def create_transform_to_matrix_func(dtype):
177
192
  """
178
193
  Convert a transformation to a 4x4 matrix.
179
194
 
195
+ .. math::
196
+ M = \\begin{bmatrix}
197
+ R_{00} & R_{01} & R_{02} & p_x \\\\
198
+ R_{10} & R_{11} & R_{12} & p_y \\\\
199
+ R_{20} & R_{21} & R_{22} & p_z \\\\
200
+ 0 & 0 & 0 & 1
201
+ \\end{bmatrix}
202
+
203
+ Where:
204
+
205
+ * :math:`R` is the 3x3 rotation matrix created from the orientation quaternion of the input transform.
206
+ * :math:`p` is the 3D position vector :math:`[p_x, p_y, p_z]` of the input transform.
207
+
180
208
  Args:
181
209
  xform (Transformation[Float]): Transformation to convert.
182
210
 
@@ -212,6 +240,140 @@ wp.func(
212
240
  )
213
241
 
214
242
 
243
+ def create_transform_compose_func(dtype):
244
+ mat44 = wp.types.matrix((4, 4), dtype)
245
+ quat = wp.types.quaternion(dtype)
246
+ vec3 = wp.types.vector(3, dtype)
247
+
248
+ def transform_compose(position: vec3, rotation: quat, scale: vec3):
249
+ """
250
+ Compose a 4x4 transformation matrix from a 3D position, quaternion orientation, and 3D scale.
251
+
252
+ .. math::
253
+ M = \\begin{bmatrix}
254
+ s_x R_{00} & s_y R_{01} & s_z R_{02} & p_x \\\\
255
+ s_x R_{10} & s_y R_{11} & s_z R_{12} & p_y \\\\
256
+ s_x R_{20} & s_y R_{21} & s_z R_{22} & p_z \\\\
257
+ 0 & 0 & 0 & 1
258
+ \\end{bmatrix}
259
+
260
+ Where:
261
+
262
+ * :math:`R` is the 3x3 rotation matrix created from the orientation quaternion of the input transform.
263
+ * :math:`p` is the 3D position vector :math:`[p_x, p_y, p_z]` of the input transform.
264
+ * :math:`s` is the 3D scale vector :math:`[s_x, s_y, s_z]` of the input transform.
265
+
266
+ Args:
267
+ position (Vector[3, Float]): The 3D position vector.
268
+ rotation (Quaternion[Float]): The quaternion orientation.
269
+ scale (Vector[3, Float]): The 3D scale vector.
270
+
271
+ Returns:
272
+ Matrix[4, 4, Float]: The transformation matrix.
273
+ """
274
+ R = wp.quat_to_matrix(rotation)
275
+ # fmt: off
276
+ return mat44(
277
+ scale[0] * R[0,0], scale[1] * R[0,1], scale[2] * R[0,2], position[0],
278
+ scale[0] * R[1,0], scale[1] * R[1,1], scale[2] * R[1,2], position[1],
279
+ scale[0] * R[2,0], scale[1] * R[2,1], scale[2] * R[2,2], position[2],
280
+ dtype(0.0), dtype(0.0), dtype(0.0), dtype(1.0),
281
+ )
282
+ # fmt: on
283
+
284
+ return transform_compose
285
+
286
+
287
+ transform_compose = wp.func(
288
+ create_transform_compose_func(wp.float32),
289
+ name="transform_compose",
290
+ )
291
+ wp.func(
292
+ create_transform_compose_func(wp.float16),
293
+ name="transform_compose",
294
+ )
295
+ wp.func(
296
+ create_transform_compose_func(wp.float64),
297
+ name="transform_compose",
298
+ )
299
+
300
+
301
+ def create_transform_decompose_func(dtype):
302
+ mat44 = wp.types.matrix((4, 4), dtype)
303
+ vec3 = wp.types.vector(3, dtype)
304
+ mat33 = wp.types.matrix((3, 3), dtype)
305
+ zero = dtype(0.0)
306
+
307
+ def transform_decompose(m: mat44):
308
+ """
309
+ Decompose a 4x4 transformation matrix into 3D position, quaternion orientation, and 3D scale.
310
+
311
+ .. math::
312
+ M = \\begin{bmatrix}
313
+ s_x R_{00} & s_y R_{01} & s_z R_{02} & p_x \\\\
314
+ s_x R_{10} & s_y R_{11} & s_z R_{12} & p_y \\\\
315
+ s_x R_{20} & s_y R_{21} & s_z R_{22} & p_z \\\\
316
+ 0 & 0 & 0 & 1
317
+ \\end{bmatrix}
318
+
319
+ Where:
320
+
321
+ * :math:`R` is the 3x3 rotation matrix created from the orientation quaternion of the input transform.
322
+ * :math:`p` is the 3D position vector :math:`[p_x, p_y, p_z]` of the input transform.
323
+ * :math:`s` is the 3D scale vector :math:`[s_x, s_y, s_z]` of the input transform.
324
+
325
+ Args:
326
+ m (Matrix[4, 4, Float]): The matrix to decompose.
327
+
328
+ Returns:
329
+ Tuple[Vector[3, Float], Quaternion[Float], Vector[3, Float]]: A tuple containing the position vector, quaternion orientation, and scale vector.
330
+ """
331
+ # extract position
332
+ position = vec3(m[0, 3], m[1, 3], m[2, 3])
333
+ # extract rotation matrix components
334
+ r00, r01, r02 = m[0, 0], m[0, 1], m[0, 2]
335
+ r10, r11, r12 = m[1, 0], m[1, 1], m[1, 2]
336
+ r20, r21, r22 = m[2, 0], m[2, 1], m[2, 2]
337
+ # get scale magnitudes
338
+ sx = wp.sqrt(r00 * r00 + r10 * r10 + r20 * r20)
339
+ sy = wp.sqrt(r01 * r01 + r11 * r11 + r21 * r21)
340
+ sz = wp.sqrt(r02 * r02 + r12 * r12 + r22 * r22)
341
+ # normalize rotation matrix components
342
+ if sx != zero:
343
+ r00 /= sx
344
+ r10 /= sx
345
+ r20 /= sx
346
+ if sy != zero:
347
+ r01 /= sy
348
+ r11 /= sy
349
+ r21 /= sy
350
+ if sz != zero:
351
+ r02 /= sz
352
+ r12 /= sz
353
+ r22 /= sz
354
+ # extract rotation (quaternion)
355
+ rotation = wp.quat_from_matrix(mat33(r00, r01, r02, r10, r11, r12, r20, r21, r22))
356
+ # extract scale
357
+ scale = vec3(sx, sy, sz)
358
+ return position, rotation, scale
359
+
360
+ return transform_decompose
361
+
362
+
363
+ transform_decompose = wp.func(
364
+ create_transform_decompose_func(wp.float32),
365
+ name="transform_decompose",
366
+ )
367
+ wp.func(
368
+ create_transform_decompose_func(wp.float16),
369
+ name="transform_decompose",
370
+ )
371
+ wp.func(
372
+ create_transform_decompose_func(wp.float64),
373
+ name="transform_decompose",
374
+ )
375
+
376
+
215
377
  # register API functions so they appear in the documentation
216
378
 
217
379
  wp.context.register_api_function(
@@ -242,3 +404,11 @@ wp.context.register_api_function(
242
404
  transform_to_matrix,
243
405
  group="Transformations",
244
406
  )
407
+ wp.context.register_api_function(
408
+ transform_compose,
409
+ group="Transformations",
410
+ )
411
+ wp.context.register_api_function(
412
+ transform_decompose,
413
+ group="Transformations",
414
+ )
warp/native/array.h CHANGED
@@ -743,6 +743,24 @@ inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, T value)
743
743
  template<template<typename> class A, typename T>
744
744
  inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_max(&index(buf, i, j, k, l), value); }
745
745
 
746
+ template<template<typename> class A, typename T>
747
+ inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, T old_value, T new_value) { return atomic_cas(&index(buf, i), old_value, new_value); }
748
+ template<template<typename> class A, typename T>
749
+ inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, int j, T old_value, T new_value) { return atomic_cas(&index(buf, i, j), old_value, new_value); }
750
+ template<template<typename> class A, typename T>
751
+ inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, int j, int k, T old_value, T new_value) { return atomic_cas(&index(buf, i, j, k), old_value, new_value); }
752
+ template<template<typename> class A, typename T>
753
+ inline CUDA_CALLABLE T atomic_cas(const A<T>& buf, int i, int j, int k, int l, T old_value, T new_value) { return atomic_cas(&index(buf, i, j, k, l), old_value, new_value); }
754
+
755
+ template<template<typename> class A, typename T>
756
+ inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, T value) { return atomic_exch(&index(buf, i), value); }
757
+ template<template<typename> class A, typename T>
758
+ inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, T value) { return atomic_exch(&index(buf, i, j), value); }
759
+ template<template<typename> class A, typename T>
760
+ inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, int k, T value) { return atomic_exch(&index(buf, i, j, k), value); }
761
+ template<template<typename> class A, typename T>
762
+ inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_exch(&index(buf, i, j, k, l), value); }
763
+
746
764
  template<template<typename> class A, typename T>
747
765
  inline CUDA_CALLABLE T* address(const A<T>& buf, int i) { return &index(buf, i); }
748
766
  template<template<typename> class A, typename T>
@@ -1128,6 +1146,87 @@ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k,
1128
1146
  FP_VERIFY_ADJ_4(value, adj_value)
1129
1147
  }
1130
1148
 
1149
+ template<template<typename> class A1, template<typename> class A2, typename T>
1150
+ inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, T compare, T value, const A2<T>& adj_buf, int adj_i, T& adj_compare, T& adj_value, const T& adj_ret) {
1151
+ if (adj_buf.data)
1152
+ adj_atomic_cas(&index(buf, i), compare, value, &index(adj_buf, i), adj_compare, adj_value, adj_ret);
1153
+ else if (buf.grad)
1154
+ adj_atomic_cas(&index(buf, i), compare, value, &index_grad(buf, i), adj_compare, adj_value, adj_ret);
1155
+
1156
+ FP_VERIFY_ADJ_1(value, adj_value)
1157
+ }
1158
+
1159
+ template<template<typename> class A1, template<typename> class A2, typename T>
1160
+ inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, int j, T compare, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_compare, T& adj_value, const T& adj_ret) {
1161
+ if (adj_buf.data)
1162
+ adj_atomic_cas(&index(buf, i, j), compare, value, &index(adj_buf, i, j), adj_compare, adj_value, adj_ret);
1163
+ else if (buf.grad)
1164
+ adj_atomic_cas(&index(buf, i, j), compare, value, &index_grad(buf, i, j), adj_compare, adj_value, adj_ret);
1165
+
1166
+ FP_VERIFY_ADJ_2(value, adj_value)
1167
+ }
1168
+
1169
+ template<template<typename> class A1, template<typename> class A2, typename T>
1170
+ inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, int j, int k, T compare, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_compare, T& adj_value, const T& adj_ret) {
1171
+ if (adj_buf.data)
1172
+ adj_atomic_cas(&index(buf, i, j, k), compare, value, &index(adj_buf, i, j, k), adj_compare, adj_value, adj_ret);
1173
+ else if (buf.grad)
1174
+ adj_atomic_cas(&index(buf, i, j, k), compare, value, &index_grad(buf, i, j, k), adj_compare, adj_value, adj_ret);
1175
+
1176
+ FP_VERIFY_ADJ_3(value, adj_value)
1177
+ }
1178
+
1179
+ template<template<typename> class A1, template<typename> class A2, typename T>
1180
+ inline CUDA_CALLABLE void adj_atomic_cas(const A1<T>& buf, int i, int j, int k, int l, T compare, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_compare, T& adj_value, const T& adj_ret) {
1181
+ if (adj_buf.data)
1182
+ adj_atomic_cas(&index(buf, i, j, k, l), compare, value, &index(adj_buf, i, j, k, l), adj_compare, adj_value, adj_ret);
1183
+ else if (buf.grad)
1184
+ adj_atomic_cas(&index(buf, i, j, k, l), compare, value, &index_grad(buf, i, j, k, l), adj_compare, adj_value, adj_ret);
1185
+
1186
+ FP_VERIFY_ADJ_4(value, adj_value)
1187
+ }
1188
+
1189
+ template<template<typename> class A1, template<typename> class A2, typename T>
1190
+ inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {
1191
+ if (adj_buf.data)
1192
+ adj_atomic_exch(&index(buf, i), value, &index(adj_buf, i), adj_value, adj_ret);
1193
+ else if (buf.grad)
1194
+ adj_atomic_exch(&index(buf, i), value, &index_grad(buf, i), adj_value, adj_ret);
1195
+
1196
+ FP_VERIFY_ADJ_1(value, adj_value)
1197
+ }
1198
+
1199
+ template<template<typename> class A1, template<typename> class A2, typename T>
1200
+ inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {
1201
+ if (adj_buf.data)
1202
+ adj_atomic_exch(&index(buf, i, j), value, &index(adj_buf, i, j), adj_value, adj_ret);
1203
+ else if (buf.grad)
1204
+ adj_atomic_exch(&index(buf, i, j), value, &index_grad(buf, i, j), adj_value, adj_ret);
1205
+
1206
+ FP_VERIFY_ADJ_2(value, adj_value)
1207
+ }
1208
+
1209
+ template<template<typename> class A1, template<typename> class A2, typename T>
1210
+ inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {
1211
+ if (adj_buf.data)
1212
+ adj_atomic_exch(&index(buf, i, j, k), value, &index(adj_buf, i, j, k), adj_value, adj_ret);
1213
+ else if (buf.grad)
1214
+ adj_atomic_exch(&index(buf, i, j, k), value, &index_grad(buf, i, j, k), adj_value, adj_ret);
1215
+
1216
+ FP_VERIFY_ADJ_3(value, adj_value)
1217
+ }
1218
+
1219
+ template<template<typename> class A1, template<typename> class A2, typename T>
1220
+ inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {
1221
+ if (adj_buf.data)
1222
+ adj_atomic_exch(&index(buf, i, j, k, l), value, &index(adj_buf, i, j, k, l), adj_value, adj_ret);
1223
+ else if (buf.grad)
1224
+ adj_atomic_exch(&index(buf, i, j, k, l), value, &index_grad(buf, i, j, k, l), adj_value, adj_ret);
1225
+
1226
+ FP_VERIFY_ADJ_4(value, adj_value)
1227
+ }
1228
+
1229
+
1131
1230
  template<template<typename> class A, typename T>
1132
1231
  CUDA_CALLABLE inline int len(const A<T>& a)
1133
1232
  {
warp/native/builtin.h CHANGED
@@ -52,6 +52,11 @@
52
52
  __device__ void __debugbreak() {}
53
53
  #endif
54
54
 
55
+ #if defined(__clang__) && defined(__CUDA__) && defined(__CUDA_ARCH__)
56
+ // clang compiling CUDA code, device mode (NOTE: Used when building core library with Clang)
57
+ #include <cuda_fp16.h>
58
+ #endif
59
+
55
60
  namespace wp
56
61
  {
57
62
 
@@ -177,14 +182,14 @@ CUDA_CALLABLE inline float half_to_float(half x)
177
182
  #elif defined(__clang__)
178
183
 
179
184
  // _Float16 is Clang's native half-precision floating-point type
180
- inline half float_to_half(float x)
185
+ CUDA_CALLABLE inline half float_to_half(float x)
181
186
  {
182
187
 
183
188
  _Float16 f16 = static_cast<_Float16>(x);
184
189
  return *reinterpret_cast<half*>(&f16);
185
190
  }
186
191
 
187
- inline float half_to_float(half h)
192
+ CUDA_CALLABLE inline float half_to_float(half h)
188
193
  {
189
194
  _Float16 f16 = *reinterpret_cast<_Float16*>(&h);
190
195
  return static_cast<float>(f16);
@@ -1221,6 +1226,15 @@ inline CUDA_CALLABLE launch_coord_t launch_coord(size_t linear, const launch_bou
1221
1226
  return coord;
1222
1227
  }
1223
1228
 
1229
+ inline CUDA_CALLABLE int block_dim()
1230
+ {
1231
+ #if defined(__CUDA_ARCH__)
1232
+ return blockDim.x;
1233
+ #else
1234
+ return 1;
1235
+ #endif
1236
+ }
1237
+
1224
1238
  inline CUDA_CALLABLE int tid(size_t index, const launch_bounds_t& bounds)
1225
1239
  {
1226
1240
  // For the 1-D tid() we need to warn the user if we're about to provide a truncated index
@@ -1301,34 +1315,35 @@ inline CUDA_CALLABLE float16 atomic_add(float16* buf, float16 value)
1301
1315
  float16 old = buf[0];
1302
1316
  buf[0] += value;
1303
1317
  return old;
1304
- #elif defined(__clang__) // CUDA compiled by Clang
1305
- __half r = atomicAdd(reinterpret_cast<__half*>(buf), *reinterpret_cast<__half*>(&value));
1306
- return *reinterpret_cast<float16*>(&r);
1307
1318
  #else // CUDA compiled by NVRTC
1308
- //return atomicAdd(buf, value);
1309
-
1310
- /* Define __PTR for atomicAdd prototypes below, undef after done */
1311
- #if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
1312
- #define __PTR "l"
1313
- #else
1314
- #define __PTR "r"
1315
- #endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
1316
-
1317
- half r = 0.0;
1318
-
1319
1319
  #if __CUDA_ARCH__ >= 700
1320
-
1321
- asm volatile ("{ atom.add.noftz.f16 %0,[%1],%2; }\n"
1322
- : "=h"(r.u)
1323
- : __PTR(buf), "h"(value.u)
1324
- : "memory");
1320
+ #if defined(__clang__) // CUDA compiled by Clang
1321
+ __half r = atomicAdd(reinterpret_cast<__half*>(buf), *reinterpret_cast<__half*>(&value));
1322
+ return *reinterpret_cast<float16*>(&r);
1323
+ #else // CUDA compiled by NVRTC
1324
+ /* Define __PTR for atomicAdd prototypes below, undef after done */
1325
+ #if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
1326
+ #define __PTR "l"
1327
+ #else
1328
+ #define __PTR "r"
1329
+ #endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
1330
+
1331
+ half r = 0.0;
1332
+
1333
+ asm volatile ("{ atom.add.noftz.f16 %0,[%1],%2; }\n"
1334
+ : "=h"(r.u)
1335
+ : __PTR(buf), "h"(value.u)
1336
+ : "memory");
1337
+
1338
+ return r;
1339
+
1340
+ #undef __PTR
1341
+ #endif
1342
+ #else
1343
+ // No native __half atomic support on compute capability < 7.0
1344
+ return float16(0.0f);
1325
1345
  #endif
1326
-
1327
- return r;
1328
-
1329
- #undef __PTR
1330
-
1331
- #endif // CUDA compiled by NVRTC
1346
+ #endif
1332
1347
  }
1333
1348
 
1334
1349
  template<>
@@ -1508,6 +1523,129 @@ CUDA_CALLABLE inline void adj_atomic_minmax(uint64* buf, uint64* adj_buf, const
1508
1523
  CUDA_CALLABLE inline void adj_atomic_minmax(bool* buf, bool* adj_buf, const bool &value, bool &adj_value) { }
1509
1524
 
1510
1525
 
1526
+ template<typename T>
1527
+ inline CUDA_CALLABLE T atomic_cas(T* address, T compare, T val)
1528
+ {
1529
+ #if defined(__CUDA_ARCH__)
1530
+ return atomicCAS(address, compare, val);
1531
+ #else
1532
+ T old = *address;
1533
+ if (old == compare)
1534
+ {
1535
+ *address = val;
1536
+ }
1537
+ return old;
1538
+ #endif
1539
+ }
1540
+
1541
+ template<>
1542
+ inline CUDA_CALLABLE float atomic_cas(float* address, float compare, float val)
1543
+ {
1544
+ #if defined(__CUDA_ARCH__)
1545
+ auto result = atomicCAS(reinterpret_cast<unsigned int*>(address),
1546
+ reinterpret_cast<unsigned int&>(compare),
1547
+ reinterpret_cast<unsigned int&>(val));
1548
+ return reinterpret_cast<float&>(result);
1549
+ #else
1550
+ float old = *address;
1551
+ if (old == compare)
1552
+ {
1553
+ *address = val;
1554
+ }
1555
+ return old;
1556
+ #endif
1557
+ }
1558
+
1559
+ template<>
1560
+ inline CUDA_CALLABLE double atomic_cas(double* address, double compare, double val)
1561
+ {
1562
+ #if defined(__CUDA_ARCH__)
1563
+ auto result = atomicCAS(reinterpret_cast<unsigned long long int *>(address),
1564
+ reinterpret_cast<unsigned long long int &>(compare),
1565
+ reinterpret_cast<unsigned long long int &>(val));
1566
+ return reinterpret_cast<double&>(result);
1567
+ #else
1568
+ double old = *address;
1569
+ if (old == compare)
1570
+ {
1571
+ *address = val;
1572
+ }
1573
+ return old;
1574
+ #endif
1575
+ }
1576
+
1577
+ template<>
1578
+ inline CUDA_CALLABLE int64 atomic_cas(int64* address, int64 compare, int64 val)
1579
+ {
1580
+ #if defined(__CUDA_ARCH__)
1581
+ auto result = atomicCAS(reinterpret_cast<unsigned long long int *>(address),
1582
+ reinterpret_cast<unsigned long long int &>(compare),
1583
+ reinterpret_cast<unsigned long long int &>(val));
1584
+ return reinterpret_cast<int64&>(result);
1585
+ #else
1586
+ int64 old = *address;
1587
+ if (old == compare)
1588
+ {
1589
+ *address = val;
1590
+ }
1591
+ return old;
1592
+ #endif
1593
+ }
1594
+
1595
+ template<typename T>
1596
+ inline CUDA_CALLABLE T atomic_exch(T* address, T val)
1597
+ {
1598
+ #if defined(__CUDA_ARCH__)
1599
+ return atomicExch(address, val);
1600
+ #else
1601
+ T old = *address;
1602
+ *address = val;
1603
+ return old;
1604
+ #endif
1605
+ }
1606
+
1607
+ template<>
1608
+ inline CUDA_CALLABLE double atomic_exch(double* address, double val)
1609
+ {
1610
+ #if defined(__CUDA_ARCH__)
1611
+ auto result = atomicExch(reinterpret_cast<unsigned long long int*>(address),
1612
+ reinterpret_cast<unsigned long long int&>(val));
1613
+ return reinterpret_cast<double&>(result);
1614
+ #else
1615
+ double old = *address;
1616
+ *address = val;
1617
+ return old;
1618
+ #endif
1619
+ }
1620
+
1621
+ template<>
1622
+ inline CUDA_CALLABLE int64 atomic_exch(int64* address, int64 val)
1623
+ {
1624
+ #if defined(__CUDA_ARCH__)
1625
+ auto result = atomicExch(reinterpret_cast<unsigned long long int*>(address),
1626
+ reinterpret_cast<unsigned long long int&>(val));
1627
+ return reinterpret_cast<int64&>(result);
1628
+ #else
1629
+ int64 old = *address;
1630
+ *address = val;
1631
+ return old;
1632
+ #endif
1633
+ }
1634
+
1635
+
1636
+ template<typename T>
1637
+ CUDA_CALLABLE inline void adj_atomic_cas(T* address, T compare, T val, T* adj_address, T& adj_compare, T& adj_val, T adj_ret)
1638
+ {
1639
+ // Not implemented
1640
+ }
1641
+
1642
+ template<typename T>
1643
+ CUDA_CALLABLE inline void adj_atomic_exch(T* address, T val, T* adj_address, T& adj_val, T adj_ret)
1644
+ {
1645
+ // Not implemented
1646
+ }
1647
+
1648
+
1511
1649
  } // namespace wp
1512
1650
 
1513
1651
 
@@ -1778,8 +1916,9 @@ inline CUDA_CALLABLE void expect_near(const T& actual, const T& expected, const
1778
1916
  if (abs(actual - expected) > tolerance)
1779
1917
  {
1780
1918
  printf("Error, expect_near() failed with tolerance "); print(tolerance);
1781
- printf("\t Expected: "); print(expected);
1782
- printf("\t Actual: "); print(actual);
1919
+ printf(" Expected: "); print(expected);
1920
+ printf(" Actual: "); print(actual);
1921
+ printf(" Absolute difference: "); print(abs(actual - expected));
1783
1922
  }
1784
1923
  }
1785
1924
 
@@ -1789,8 +1928,9 @@ inline CUDA_CALLABLE void expect_near(const vec3& actual, const vec3& expected,
1789
1928
  if (diff > tolerance)
1790
1929
  {
1791
1930
  printf("Error, expect_near() failed with tolerance "); print(tolerance);
1792
- printf("\t Expected: "); print(expected);
1793
- printf("\t Actual: "); print(actual);
1931
+ printf(" Expected: "); print(expected);
1932
+ printf(" Actual: "); print(actual);
1933
+ printf(" Max absolute difference: "); print(diff);
1794
1934
  }
1795
1935
  }
1796
1936
 
@@ -1810,6 +1950,7 @@ inline CUDA_CALLABLE void adj_expect_near(const vec3& actual, const vec3& expect
1810
1950
 
1811
1951
  // include array.h so we have the print, isfinite functions for the inner array types defined
1812
1952
  #include "array.h"
1953
+ #include "tuple.h"
1813
1954
  #include "mesh.h"
1814
1955
  #include "bvh.h"
1815
1956
  #include "svd.h"
@@ -1823,4 +1964,6 @@ inline CUDA_CALLABLE void adj_expect_near(const vec3& actual, const vec3& expect
1823
1964
  #if !defined(WP_ENABLE_CUDA) // only include in kernels for now
1824
1965
  #include "tile.h"
1825
1966
  #include "tile_reduce.h"
1967
+ #include "tile_scan.h"
1968
+ #include "tile_radix_sort.h"
1826
1969
  #endif //!defined(WP_ENABLE_CUDA)
warp/native/coloring.cpp CHANGED
@@ -372,7 +372,7 @@ public:
372
372
  // we need to update max_weight because weight_buckets[max_weight] became empty
373
373
  {
374
374
  int new_max_weight = 0;
375
- for (size_t bucket_idx = max_weight - 1; bucket_idx >= 0; bucket_idx--)
375
+ for (int bucket_idx = max_weight - 1; bucket_idx >= 0; bucket_idx--)
376
376
  {
377
377
  if (weight_buckets[bucket_idx].size())
378
378
  {