warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
@@ -173,7 +173,7 @@ class Grid2D(Geometry):
173
173
  return Grid2D.Side(axis, origin)
174
174
 
175
175
  axis_side_index = side_index - 2 * arg.cell_count
176
- axis = wp.select(axis_side_index < arg.axis_offsets[1], 1, 0)
176
+ axis = wp.where(axis_side_index < arg.axis_offsets[1], 0, 1)
177
177
 
178
178
  altitude = arg.cell_arg.res[Grid2D.ROTATION[axis, 0]]
179
179
  longitude = axis_side_index - arg.axis_offsets[axis]
@@ -273,7 +273,7 @@ class Grid2D(Geometry):
273
273
  def side_position(args: SideArg, s: Sample):
274
274
  side = Grid2D.get_side(args, s.element_index)
275
275
 
276
- coord = wp.select((side.origin[0] == 0) == (side.axis == 0), 1.0 - s.element_coords[0], s.element_coords[0])
276
+ coord = wp.where((side.origin[0] == 0) == (side.axis == 0), s.element_coords[0], 1.0 - s.element_coords[0])
277
277
 
278
278
  local_pos = wp.vec2(
279
279
  float(side.origin[0]),
@@ -288,7 +288,7 @@ class Grid2D(Geometry):
288
288
  def side_deformation_gradient(args: SideArg, s: Sample):
289
289
  side = Grid2D.get_side(args, s.element_index)
290
290
 
291
- sign = wp.select((side.origin[0] == 0) == (side.axis == 0), -1.0, 1.0)
291
+ sign = wp.where((side.origin[0] == 0) == (side.axis == 0), 1.0, -1.0)
292
292
 
293
293
  return wp.cw_mul(Grid2D._rotate(side.axis, wp.vec2(0.0, sign)), args.cell_arg.cell_size)
294
294
 
@@ -316,7 +316,7 @@ class Grid2D(Geometry):
316
316
  def side_normal(args: SideArg, s: Sample):
317
317
  side = Grid2D.get_side(args, s.element_index)
318
318
 
319
- sign = wp.select(side.origin[0] == 0, 1.0, -1.0)
319
+ sign = wp.where(side.origin[0] == 0, -1.0, 1.0)
320
320
 
321
321
  local_n = wp.vec2(sign, 0.0)
322
322
  return Grid2D._rotate(side.axis, local_n)
@@ -325,7 +325,7 @@ class Grid2D(Geometry):
325
325
  def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
326
326
  side = Grid2D.get_side(arg, side_index)
327
327
 
328
- inner_alt = wp.select(side.origin[0] == 0, side.origin[0] - 1, 0)
328
+ inner_alt = wp.where(side.origin[0] == 0, 0, side.origin[0] - 1)
329
329
 
330
330
  inner_origin = wp.vec2i(inner_alt, side.origin[1])
331
331
 
@@ -337,8 +337,8 @@ class Grid2D(Geometry):
337
337
  side = Grid2D.get_side(arg, side_index)
338
338
 
339
339
  alt_axis = Grid2D.ROTATION[side.axis, 0]
340
- outer_alt = wp.select(
341
- side.origin[0] == arg.cell_arg.res[alt_axis], side.origin[0], arg.cell_arg.res[alt_axis] - 1
340
+ outer_alt = wp.where(
341
+ side.origin[0] == arg.cell_arg.res[alt_axis], arg.cell_arg.res[alt_axis] - 1, side.origin[0]
342
342
  )
343
343
 
344
344
  outer_origin = wp.vec2i(outer_alt, side.origin[1])
@@ -350,9 +350,9 @@ class Grid2D(Geometry):
350
350
  def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
351
351
  side = Grid2D.get_side(args, side_index)
352
352
 
353
- inner_alt = wp.select(side.origin[0] == 0, 1.0, 0.0)
353
+ inner_alt = wp.where(side.origin[0] == 0, 0.0, 1.0)
354
354
 
355
- side_coord = wp.select((side.origin[0] == 0) == (side.axis == 0), 1.0 - side_coords[0], side_coords[0])
355
+ side_coord = wp.where((side.origin[0] == 0) == (side.axis == 0), side_coords[0], 1.0 - side_coords[0])
356
356
 
357
357
  coords = Grid2D._rotate(side.axis, wp.vec2(inner_alt, side_coord))
358
358
  return Coords(coords[0], coords[1], 0.0)
@@ -362,9 +362,9 @@ class Grid2D(Geometry):
362
362
  side = Grid2D.get_side(args, side_index)
363
363
 
364
364
  alt_axis = Grid2D.ROTATION[side.axis, 0]
365
- outer_alt = wp.select(side.origin[0] == args.cell_arg.res[alt_axis], 0.0, 1.0)
365
+ outer_alt = wp.where(side.origin[0] == args.cell_arg.res[alt_axis], 1.0, 0.0)
366
366
 
367
- side_coord = wp.select((side.origin[0] == 0) == (side.axis == 0), 1.0 - side_coords[0], side_coords[0])
367
+ side_coord = wp.where((side.origin[0] == 0) == (side.axis == 0), side_coords[0], 1.0 - side_coords[0])
368
368
 
369
369
  coords = Grid2D._rotate(side.axis, wp.vec2(outer_alt, side_coord))
370
370
  return Coords(coords[0], coords[1], 0.0)
@@ -382,7 +382,7 @@ class Grid2D(Geometry):
382
382
  if float(side.origin[0] - cell[side.axis]) == element_coords[side.axis]:
383
383
  long_axis = Grid2D.ROTATION[side.axis, 1]
384
384
  axis_coord = element_coords[long_axis]
385
- side_coord = wp.select((side.origin[0] == 0) == (side.axis == 0), 1.0 - axis_coord, axis_coord)
385
+ side_coord = wp.where((side.origin[0] == 0) == (side.axis == 0), axis_coord, 1.0 - axis_coord)
386
386
  return Coords(side_coord, 0.0, 0.0)
387
387
 
388
388
  return Coords(OUTSIDE)
@@ -30,9 +30,6 @@ class Grid3DCellArg:
30
30
  origin: wp.vec3
31
31
 
32
32
 
33
- _mat32 = wp.mat(shape=(3, 2), dtype=float)
34
-
35
-
36
33
  class Grid3D(Geometry):
37
34
  """Three-dimensional regular grid geometry"""
38
35
 
@@ -331,7 +328,7 @@ class Grid3D(Geometry):
331
328
  def side_position(args: SideArg, s: Sample):
332
329
  side = Grid3D.get_side(args, s.element_index)
333
330
 
334
- coord0 = wp.select(side.origin[0] == 0, s.element_coords[0], 1.0 - s.element_coords[0])
331
+ coord0 = wp.where(side.origin[0] == 0, 1.0 - s.element_coords[0], s.element_coords[0])
335
332
 
336
333
  local_pos = wp.vec3(
337
334
  float(side.origin[0]),
@@ -347,9 +344,9 @@ class Grid3D(Geometry):
347
344
  def side_deformation_gradient(args: SideArg, s: Sample):
348
345
  side = Grid3D.get_side(args, s.element_index)
349
346
 
350
- sign = wp.select(side.origin[0] == 0, 1.0, -1.0)
347
+ sign = wp.where(side.origin[0] == 0, -1.0, 1.0)
351
348
 
352
- return _mat32(
349
+ return wp.matrix_from_cols(
353
350
  wp.cw_mul(Grid3D._local_to_world(side.axis, wp.vec3(0.0, sign, 0.0)), args.cell_arg.cell_size),
354
351
  wp.cw_mul(Grid3D._local_to_world(side.axis, wp.vec3(0.0, 0.0, 1.0)), args.cell_arg.cell_size),
355
352
  )
@@ -379,7 +376,7 @@ class Grid3D(Geometry):
379
376
  def side_normal(args: SideArg, s: Sample):
380
377
  side = Grid3D.get_side(args, s.element_index)
381
378
 
382
- sign = wp.select(side.origin[0] == 0, 1.0, -1.0)
379
+ sign = wp.where(side.origin[0] == 0, -1.0, 1.0)
383
380
 
384
381
  local_n = wp.vec3(sign, 0.0, 0.0)
385
382
  return Grid3D._local_to_world(side.axis, local_n)
@@ -388,7 +385,7 @@ class Grid3D(Geometry):
388
385
  def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
389
386
  side = Grid3D.get_side(arg, side_index)
390
387
 
391
- inner_alt = wp.select(side.origin[0] == 0, side.origin[0] - 1, 0)
388
+ inner_alt = wp.where(side.origin[0] == 0, 0, side.origin[0] - 1)
392
389
 
393
390
  inner_origin = wp.vec3i(inner_alt, side.origin[1], side.origin[2])
394
391
 
@@ -401,8 +398,8 @@ class Grid3D(Geometry):
401
398
 
402
399
  alt_axis = Grid3D._local_to_world_axis(side.axis, 0)
403
400
 
404
- outer_alt = wp.select(
405
- side.origin[0] == arg.cell_arg.res[alt_axis], side.origin[0], arg.cell_arg.res[alt_axis] - 1
401
+ outer_alt = wp.where(
402
+ side.origin[0] == arg.cell_arg.res[alt_axis], arg.cell_arg.res[alt_axis] - 1, side.origin[0]
406
403
  )
407
404
 
408
405
  outer_origin = wp.vec3i(outer_alt, side.origin[1], side.origin[2])
@@ -414,9 +411,9 @@ class Grid3D(Geometry):
414
411
  def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
415
412
  side = Grid3D.get_side(args, side_index)
416
413
 
417
- inner_alt = wp.select(side.origin[0] == 0, 1.0, 0.0)
414
+ inner_alt = wp.where(side.origin[0] == 0, 0.0, 1.0)
418
415
 
419
- side_coord0 = wp.select(side.origin[0] == 0, side_coords[0], 1.0 - side_coords[0])
416
+ side_coord0 = wp.where(side.origin[0] == 0, 1.0 - side_coords[0], side_coords[0])
420
417
 
421
418
  return Grid3D._local_to_world(side.axis, wp.vec3(inner_alt, side_coord0, side_coords[1]))
422
419
 
@@ -425,9 +422,9 @@ class Grid3D(Geometry):
425
422
  side = Grid3D.get_side(args, side_index)
426
423
 
427
424
  alt_axis = Grid3D._local_to_world_axis(side.axis, 0)
428
- outer_alt = wp.select(side.origin[0] == args.cell_arg.res[alt_axis], 0.0, 1.0)
425
+ outer_alt = wp.where(side.origin[0] == args.cell_arg.res[alt_axis], 1.0, 0.0)
429
426
 
430
- side_coord0 = wp.select(side.origin[0] == 0, side_coords[0], 1.0 - side_coords[0])
427
+ side_coord0 = wp.where(side.origin[0] == 0, 1.0 - side_coords[0], side_coords[0])
431
428
 
432
429
  return Grid3D._local_to_world(side.axis, wp.vec3(outer_alt, side_coord0, side_coords[1]))
433
430
 
@@ -445,7 +442,7 @@ class Grid3D(Geometry):
445
442
  long_axis = Grid3D._local_to_world_axis(side.axis, 1)
446
443
  lat_axis = Grid3D._local_to_world_axis(side.axis, 2)
447
444
  long_coord = element_coords[long_axis]
448
- long_coord = wp.select(side.origin[0] == 0, long_coord, 1.0 - long_coord)
445
+ long_coord = wp.where(side.origin[0] == 0, 1.0 - long_coord, long_coord)
449
446
  return Coords(long_coord, element_coords[lat_axis], 0.0)
450
447
 
451
448
  return Coords(OUTSIDE)
@@ -46,8 +46,6 @@ class HexmeshSideArg:
46
46
  face_hex_face_orientation: wp.array(dtype=wp.vec4i)
47
47
 
48
48
 
49
- _mat32 = wp.mat(shape=(3, 2), dtype=float)
50
-
51
49
  FACE_VERTEX_INDICES = wp.constant(
52
50
  wp.mat(shape=(6, 4), dtype=int)(
53
51
  [
@@ -322,7 +320,7 @@ class Hexmesh(Geometry):
322
320
  def side_deformation_gradient(args: SideArg, s: Sample):
323
321
  """Transposed side deformation gradient at `coords`"""
324
322
  v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
325
- return _mat32(v1, v2)
323
+ return wp.matrix_from_cols(v1, v2)
326
324
 
327
325
  @wp.func
328
326
  def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
@@ -342,7 +340,7 @@ class Hexmesh(Geometry):
342
340
  )
343
341
 
344
342
  normal_coord = hex_coords[_FACE_COORD_INDICES[face_index, 2]]
345
- normal_coord = wp.select(_FACE_COORD_INDICES[face_index, 3] == 0, normal_coord - 1.0, -normal_coord)
343
+ normal_coord = wp.where(_FACE_COORD_INDICES[face_index, 3] == 0, -normal_coord, normal_coord - 1.0)
346
344
 
347
345
  return face_coords, normal_coord
348
346
 
@@ -353,7 +351,7 @@ class Hexmesh(Geometry):
353
351
  hex_coords = Coords()
354
352
  hex_coords[_FACE_COORD_INDICES[face_index, 0]] = face_coords[0]
355
353
  hex_coords[_FACE_COORD_INDICES[face_index, 1]] = face_coords[1]
356
- hex_coords[_FACE_COORD_INDICES[face_index, 2]] = wp.select(_FACE_COORD_INDICES[face_index, 3] == 0, 1.0, 0.0)
354
+ hex_coords[_FACE_COORD_INDICES[face_index, 2]] = wp.where(_FACE_COORD_INDICES[face_index, 3] == 0, 0.0, 1.0)
357
355
 
358
356
  return hex_coords
359
357
 
@@ -402,8 +400,8 @@ class Hexmesh(Geometry):
402
400
  face_orientation = args.face_hex_face_orientation[side_index][3]
403
401
 
404
402
  face_coords, normal_coord = Hexmesh._hex_local_face_coords(hex_coords, local_face_index)
405
- return wp.select(
406
- normal_coord == 0.0, Coords(OUTSIDE), Hexmesh._local_to_oriented_face_coords(face_orientation, face_coords)
403
+ return wp.where(
404
+ normal_coord == 0.0, Hexmesh._local_to_oriented_face_coords(face_orientation, face_coords), Coords(OUTSIDE)
407
405
  )
408
406
 
409
407
  @wp.func
@@ -33,13 +33,11 @@ FACE_AXIS_MASK = wp.constant(wp.uint8((1 << 2) - 1))
33
33
  FACE_INNER_OFFSET_BIT = wp.constant(wp.uint8(2))
34
34
  FACE_OUTER_OFFSET_BIT = wp.constant(wp.uint8(3))
35
35
 
36
- _mat32 = wp.mat(shape=(3, 2), dtype=float)
37
-
38
36
 
39
37
  @wp.func
40
38
  def _add_axis_flag(ijk: wp.vec3i, axis: int):
41
39
  coord = ijk[axis]
42
- ijk[axis] = wp.select(coord < 0, coord | GRID_AXIS_FLAG, coord & (~GRID_AXIS_FLAG))
40
+ ijk[axis] = wp.where(coord < 0, coord & (~GRID_AXIS_FLAG), coord | GRID_AXIS_FLAG)
43
41
  return ijk
44
42
 
45
43
 
@@ -191,9 +189,9 @@ class Nanogrid(Geometry):
191
189
  coords = uvw - wp.vec3(ijk)
192
190
  if cell_index == -1:
193
191
  if wp.min(coords) == 0.0 or wp.max(coords) == 1.0:
194
- il = wp.select(coords[0] > 0.5, -1, 0)
195
- jl = wp.select(coords[1] > 0.5, -1, 0)
196
- kl = wp.select(coords[2] > 0.5, -1, 0)
192
+ il = wp.where(coords[0] > 0.5, 0, -1)
193
+ jl = wp.where(coords[1] > 0.5, 0, -1)
194
+ kl = wp.where(coords[2] > 0.5, 0, -1)
197
195
 
198
196
  for n in range(8):
199
197
  ni = n >> 2
@@ -327,7 +325,7 @@ class Nanogrid(Geometry):
327
325
  axis = Nanogrid._get_face_axis(flags)
328
326
  flip = Nanogrid._get_face_inner_offset(flags)
329
327
  v1, v2 = Nanogrid._face_tangent_vecs(args.cell_arg.cell_grid, axis, flip)
330
- return _mat32(v1, v2)
328
+ return wp.matrix_from_cols(v1, v2)
331
329
 
332
330
  @wp.func
333
331
  def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
@@ -409,8 +407,8 @@ class Nanogrid(Geometry):
409
407
 
410
408
  on_side = float(side_ijk[axis] - cell_ijk[axis]) == element_coords[axis]
411
409
 
412
- return wp.select(
413
- on_side, Coords(OUTSIDE), Coords(element_coords[(axis + 1) % 3], element_coords[(axis + 2) % 3], 0.0)
410
+ return wp.where(
411
+ on_side, Coords(element_coords[(axis + 1) % 3], element_coords[(axis + 2) % 3], 0.0), Coords(OUTSIDE)
414
412
  )
415
413
 
416
414
  @wp.func
@@ -538,8 +536,8 @@ def _build_edge_grid(cell_ijk, grid: wp.Volume, temporary_store: cache.Temporary
538
536
 
539
537
  @wp.func
540
538
  def _make_face_flags(axis: int, plus_cell_index: int, minus_cell_index: int):
541
- plus_boundary = wp.uint8(wp.select(plus_cell_index == -1, 0, 1)) << FACE_OUTER_OFFSET_BIT
542
- minus_boundary = wp.uint8(wp.select(minus_cell_index == -1, 0, 1)) << FACE_INNER_OFFSET_BIT
539
+ plus_boundary = wp.uint8(wp.where(plus_cell_index == -1, 1, 0)) << FACE_OUTER_OFFSET_BIT
540
+ minus_boundary = wp.uint8(wp.where(minus_cell_index == -1, 1, 0)) << FACE_INNER_OFFSET_BIT
543
541
 
544
542
  return wp.uint8(axis) | plus_boundary | minus_boundary
545
543
 
@@ -165,13 +165,13 @@ class Quadmesh(Geometry):
165
165
  s = side_coords[0]
166
166
 
167
167
  if vs == quad_vidx[0]:
168
- return wp.select(ve == quad_vidx[1], Coords(0.0, s, 0.0), Coords(s, 0.0, 0.0))
168
+ return wp.where(ve == quad_vidx[1], Coords(s, 0.0, 0.0), Coords(0.0, s, 0.0))
169
169
  elif vs == quad_vidx[1]:
170
- return wp.select(ve == quad_vidx[2], Coords(1.0 - s, 0.0, 0.0), Coords(1.0, s, 0.0))
170
+ return wp.where(ve == quad_vidx[2], Coords(1.0, s, 0.0), Coords(1.0 - s, 0.0, 0.0))
171
171
  elif vs == quad_vidx[2]:
172
- return wp.select(ve == quad_vidx[3], Coords(1.0, 1.0 - s, 0.0), Coords(1.0 - s, 1.0, 0.0))
172
+ return wp.where(ve == quad_vidx[3], Coords(1.0 - s, 1.0, 0.0), Coords(1.0, 1.0 - s, 0.0))
173
173
 
174
- return wp.select(ve == quad_vidx[0], Coords(s, 1.0, 0.0), Coords(0.0, 1.0 - s, 0.0))
174
+ return wp.where(ve == quad_vidx[0], Coords(0.0, 1.0 - s, 0.0), Coords(s, 1.0, 0.0))
175
175
 
176
176
  @wp.func
177
177
  def _quad_to_edge_coords(
@@ -190,18 +190,18 @@ class Quadmesh(Geometry):
190
190
  cy = quad_coords[1]
191
191
 
192
192
  if vs == quad_vidx[0]:
193
- oc = wp.select(ve == quad_vidx[1], cx, cy)
194
- ec = wp.select(ve == quad_vidx[1], cy, cx)
193
+ oc = wp.where(ve == quad_vidx[1], cy, cx)
194
+ ec = wp.where(ve == quad_vidx[1], cx, cy)
195
195
  elif vs == quad_vidx[1]:
196
- oc = wp.select(ve == quad_vidx[2], cy, 1.0 - cx)
197
- ec = wp.select(ve == quad_vidx[2], 1.0 - cx, cy)
196
+ oc = wp.where(ve == quad_vidx[2], 1.0 - cx, cy)
197
+ ec = wp.where(ve == quad_vidx[2], cy, 1.0 - cx)
198
198
  elif vs == quad_vidx[2]:
199
- oc = wp.select(ve == quad_vidx[3], 1.0 - cx, 1.0 - cy)
200
- ec = wp.select(ve == quad_vidx[3], 1.0 - cy, 1.0 - cx)
199
+ oc = wp.where(ve == quad_vidx[3], 1.0 - cy, 1.0 - cx)
200
+ ec = wp.where(ve == quad_vidx[3], 1.0 - cx, 1.0 - cy)
201
201
  else:
202
- oc = wp.select(ve == quad_vidx[0], 1.0 - cy, cx)
203
- ec = wp.select(ve == quad_vidx[0], cx, 1.0 - cy)
204
- return wp.select(oc == 0.0, Coords(OUTSIDE), Coords(ec, 0.0, 0.0))
202
+ oc = wp.where(ve == quad_vidx[0], cx, 1.0 - cy)
203
+ ec = wp.where(ve == quad_vidx[0], 1.0 - cy, cx)
204
+ return wp.where(oc == 0.0, Coords(ec, 0.0, 0.0), Coords(OUTSIDE))
205
205
 
206
206
  @wp.func
207
207
  def boundary_side_index(args: SideIndexArg, boundary_side_index: int):
@@ -56,7 +56,6 @@ class TetmeshSideArg:
56
56
  face_tet_indices: wp.array(dtype=wp.vec2i)
57
57
 
58
58
 
59
- _mat32 = wp.mat(shape=(3, 2), dtype=float)
60
59
  _NULL_BVH = wp.constant(wp.uint64(-1))
61
60
 
62
61
 
@@ -203,7 +202,7 @@ class Tetmesh(Geometry):
203
202
  p1 = args.positions[args.tet_vertex_indices[s.element_index, 1]]
204
203
  p2 = args.positions[args.tet_vertex_indices[s.element_index, 2]]
205
204
  p3 = args.positions[args.tet_vertex_indices[s.element_index, 3]]
206
- return wp.mat33(p1 - p0, p2 - p0, p3 - p0)
205
+ return wp.matrix_from_cols(p1 - p0, p2 - p0, p3 - p0)
207
206
 
208
207
  @wp.func
209
208
  def cell_inverse_deformation_gradient(args: CellArg, s: Sample):
@@ -312,7 +311,7 @@ class Tetmesh(Geometry):
312
311
  @wp.func
313
312
  def side_deformation_gradient(args: SideArg, s: Sample):
314
313
  e1, e2 = Tetmesh._side_vecs(args, s.element_index)
315
- return _mat32(e1, e2)
314
+ return wp.matrix_from_cols(e1, e2)
316
315
 
317
316
  @wp.func
318
317
  def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
@@ -389,7 +388,7 @@ class Tetmesh(Geometry):
389
388
  else:
390
389
  c2 = 1.0 - tet_coords[0] - tet_coords[1] - tet_coords[2]
391
390
 
392
- return wp.select(c0 + c1 + c2 > 0.999, Coords(OUTSIDE), Coords(c0, c1, c2))
391
+ return wp.where(c0 + c1 + c2 > 0.999, Coords(c0, c1, c2), Coords(OUTSIDE))
393
392
 
394
393
  @wp.func
395
394
  def side_to_cell_arg(side_arg: SideArg):
@@ -190,7 +190,7 @@ class Trimesh(Geometry):
190
190
  return args
191
191
 
192
192
  def _bvh_id(self, device):
193
- if self._tri_bvh is None or self._tri_bvh.device != device:
193
+ if self._tri_bvh is None or self._tri_bvh.device != wp.get_device(device):
194
194
  return _NULL_BVH
195
195
  return self._tri_bvh.id
196
196
 
@@ -325,9 +325,7 @@ class Trimesh(Geometry):
325
325
  elif edge_vidx[0] == v:
326
326
  start = k
327
327
 
328
- return wp.select(
329
- tri_coords[start] + tri_coords[end] > 0.999, Coords(OUTSIDE), Coords(tri_coords[end], 0.0, 0.0)
330
- )
328
+ return wp.where(tri_coords[start] + tri_coords[end] > 0.999, Coords(tri_coords[end], 0.0, 0.0), Coords(OUTSIDE))
331
329
 
332
330
  def _build_topology(self, temporary_store: TemporaryStore):
333
331
  from warp.fem.utils import compress_node_indices, host_read_at_index, masked_indices
@@ -521,7 +519,7 @@ class Trimesh(Geometry):
521
519
  @wp.kernel
522
520
  def _compute_tri_bounds(
523
521
  tri_vertex_indices: wp.array2d(dtype=int),
524
- positions: wp.array(dtype=wp.vec2),
522
+ positions: wp.array(dtype=Any),
525
523
  lowers: wp.array(dtype=wp.vec3),
526
524
  uppers: wp.array(dtype=wp.vec3),
527
525
  ):
@@ -530,16 +528,8 @@ class Trimesh(Geometry):
530
528
  p1 = _bvh_vec(positions[tri_vertex_indices[t, 1]])
531
529
  p2 = _bvh_vec(positions[tri_vertex_indices[t, 2]])
532
530
 
533
- lowers[t] = wp.vec3(
534
- wp.min(wp.min(p0[0], p1[0]), p2[0]),
535
- wp.min(wp.min(p0[1], p1[1]), p2[1]),
536
- wp.min(wp.min(p0[2], p1[2]), p2[2]),
537
- )
538
- uppers[t] = wp.vec3(
539
- wp.max(wp.max(p0[0], p1[0]), p2[0]),
540
- wp.max(wp.max(p0[1], p1[1]), p2[1]),
541
- wp.max(wp.max(p0[2], p1[2]), p2[2]),
542
- )
531
+ lowers[t] = wp.min(wp.min(p0, p1), p2)
532
+ uppers[t] = wp.max(wp.max(p0, p1), p2)
543
533
 
544
534
 
545
535
  @wp.struct
@@ -576,7 +566,7 @@ class Trimesh2D(Trimesh):
576
566
  p0 = args.positions[tri_idx[0]]
577
567
  p1 = args.positions[tri_idx[1]]
578
568
  p2 = args.positions[tri_idx[2]]
579
- return wp.mat22(p1 - p0, p2 - p0)
569
+ return wp.matrix_from_cols(p1 - p0, p2 - p0)
580
570
 
581
571
  @wp.func
582
572
  def cell_lookup(args: CellArg, pos: wp.vec2):
@@ -689,9 +679,6 @@ class Trimesh3DSideArg:
689
679
  positions: wp.array(dtype=wp.vec3)
690
680
 
691
681
 
692
- _mat32 = wp.mat(shape=(3, 2), dtype=float)
693
-
694
-
695
682
  class Trimesh3D(Trimesh):
696
683
  """3D Triangular mesh geometry"""
697
684
 
@@ -714,7 +701,7 @@ class Trimesh3D(Trimesh):
714
701
  p0 = args.positions[tri_idx[0]]
715
702
  p1 = args.positions[tri_idx[1]]
716
703
  p2 = args.positions[tri_idx[2]]
717
- return _mat32(p1 - p0, p2 - p0)
704
+ return wp.matrix_from_cols(p1 - p0, p2 - p0)
718
705
 
719
706
  @wp.func
720
707
  def cell_lookup(args: CellArg, pos: wp.vec3):