warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/warp-clang.dll +0 -0
  3. warp/bin/warp.dll +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -30,9 +30,6 @@ class Grid3DCellArg:
30
30
  origin: wp.vec3
31
31
 
32
32
 
33
- _mat32 = wp.mat(shape=(3, 2), dtype=float)
34
-
35
-
36
33
  class Grid3D(Geometry):
37
34
  """Three-dimensional regular grid geometry"""
38
35
 
@@ -331,7 +328,7 @@ class Grid3D(Geometry):
331
328
  def side_position(args: SideArg, s: Sample):
332
329
  side = Grid3D.get_side(args, s.element_index)
333
330
 
334
- coord0 = wp.select(side.origin[0] == 0, s.element_coords[0], 1.0 - s.element_coords[0])
331
+ coord0 = wp.where(side.origin[0] == 0, 1.0 - s.element_coords[0], s.element_coords[0])
335
332
 
336
333
  local_pos = wp.vec3(
337
334
  float(side.origin[0]),
@@ -347,9 +344,9 @@ class Grid3D(Geometry):
347
344
  def side_deformation_gradient(args: SideArg, s: Sample):
348
345
  side = Grid3D.get_side(args, s.element_index)
349
346
 
350
- sign = wp.select(side.origin[0] == 0, 1.0, -1.0)
347
+ sign = wp.where(side.origin[0] == 0, -1.0, 1.0)
351
348
 
352
- return _mat32(
349
+ return wp.matrix_from_cols(
353
350
  wp.cw_mul(Grid3D._local_to_world(side.axis, wp.vec3(0.0, sign, 0.0)), args.cell_arg.cell_size),
354
351
  wp.cw_mul(Grid3D._local_to_world(side.axis, wp.vec3(0.0, 0.0, 1.0)), args.cell_arg.cell_size),
355
352
  )
@@ -379,7 +376,7 @@ class Grid3D(Geometry):
379
376
  def side_normal(args: SideArg, s: Sample):
380
377
  side = Grid3D.get_side(args, s.element_index)
381
378
 
382
- sign = wp.select(side.origin[0] == 0, 1.0, -1.0)
379
+ sign = wp.where(side.origin[0] == 0, -1.0, 1.0)
383
380
 
384
381
  local_n = wp.vec3(sign, 0.0, 0.0)
385
382
  return Grid3D._local_to_world(side.axis, local_n)
@@ -388,7 +385,7 @@ class Grid3D(Geometry):
388
385
  def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
389
386
  side = Grid3D.get_side(arg, side_index)
390
387
 
391
- inner_alt = wp.select(side.origin[0] == 0, side.origin[0] - 1, 0)
388
+ inner_alt = wp.where(side.origin[0] == 0, 0, side.origin[0] - 1)
392
389
 
393
390
  inner_origin = wp.vec3i(inner_alt, side.origin[1], side.origin[2])
394
391
 
@@ -401,8 +398,8 @@ class Grid3D(Geometry):
401
398
 
402
399
  alt_axis = Grid3D._local_to_world_axis(side.axis, 0)
403
400
 
404
- outer_alt = wp.select(
405
- side.origin[0] == arg.cell_arg.res[alt_axis], side.origin[0], arg.cell_arg.res[alt_axis] - 1
401
+ outer_alt = wp.where(
402
+ side.origin[0] == arg.cell_arg.res[alt_axis], arg.cell_arg.res[alt_axis] - 1, side.origin[0]
406
403
  )
407
404
 
408
405
  outer_origin = wp.vec3i(outer_alt, side.origin[1], side.origin[2])
@@ -414,9 +411,9 @@ class Grid3D(Geometry):
414
411
  def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
415
412
  side = Grid3D.get_side(args, side_index)
416
413
 
417
- inner_alt = wp.select(side.origin[0] == 0, 1.0, 0.0)
414
+ inner_alt = wp.where(side.origin[0] == 0, 0.0, 1.0)
418
415
 
419
- side_coord0 = wp.select(side.origin[0] == 0, side_coords[0], 1.0 - side_coords[0])
416
+ side_coord0 = wp.where(side.origin[0] == 0, 1.0 - side_coords[0], side_coords[0])
420
417
 
421
418
  return Grid3D._local_to_world(side.axis, wp.vec3(inner_alt, side_coord0, side_coords[1]))
422
419
 
@@ -425,9 +422,9 @@ class Grid3D(Geometry):
425
422
  side = Grid3D.get_side(args, side_index)
426
423
 
427
424
  alt_axis = Grid3D._local_to_world_axis(side.axis, 0)
428
- outer_alt = wp.select(side.origin[0] == args.cell_arg.res[alt_axis], 0.0, 1.0)
425
+ outer_alt = wp.where(side.origin[0] == args.cell_arg.res[alt_axis], 1.0, 0.0)
429
426
 
430
- side_coord0 = wp.select(side.origin[0] == 0, side_coords[0], 1.0 - side_coords[0])
427
+ side_coord0 = wp.where(side.origin[0] == 0, 1.0 - side_coords[0], side_coords[0])
431
428
 
432
429
  return Grid3D._local_to_world(side.axis, wp.vec3(outer_alt, side_coord0, side_coords[1]))
433
430
 
@@ -445,7 +442,7 @@ class Grid3D(Geometry):
445
442
  long_axis = Grid3D._local_to_world_axis(side.axis, 1)
446
443
  lat_axis = Grid3D._local_to_world_axis(side.axis, 2)
447
444
  long_coord = element_coords[long_axis]
448
- long_coord = wp.select(side.origin[0] == 0, long_coord, 1.0 - long_coord)
445
+ long_coord = wp.where(side.origin[0] == 0, 1.0 - long_coord, long_coord)
449
446
  return Coords(long_coord, element_coords[lat_axis], 0.0)
450
447
 
451
448
  return Coords(OUTSIDE)
@@ -46,8 +46,6 @@ class HexmeshSideArg:
46
46
  face_hex_face_orientation: wp.array(dtype=wp.vec4i)
47
47
 
48
48
 
49
- _mat32 = wp.mat(shape=(3, 2), dtype=float)
50
-
51
49
  FACE_VERTEX_INDICES = wp.constant(
52
50
  wp.mat(shape=(6, 4), dtype=int)(
53
51
  [
@@ -322,7 +320,7 @@ class Hexmesh(Geometry):
322
320
  def side_deformation_gradient(args: SideArg, s: Sample):
323
321
  """Transposed side deformation gradient at `coords`"""
324
322
  v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
325
- return _mat32(v1, v2)
323
+ return wp.matrix_from_cols(v1, v2)
326
324
 
327
325
  @wp.func
328
326
  def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
@@ -342,7 +340,7 @@ class Hexmesh(Geometry):
342
340
  )
343
341
 
344
342
  normal_coord = hex_coords[_FACE_COORD_INDICES[face_index, 2]]
345
- normal_coord = wp.select(_FACE_COORD_INDICES[face_index, 3] == 0, normal_coord - 1.0, -normal_coord)
343
+ normal_coord = wp.where(_FACE_COORD_INDICES[face_index, 3] == 0, -normal_coord, normal_coord - 1.0)
346
344
 
347
345
  return face_coords, normal_coord
348
346
 
@@ -353,7 +351,7 @@ class Hexmesh(Geometry):
353
351
  hex_coords = Coords()
354
352
  hex_coords[_FACE_COORD_INDICES[face_index, 0]] = face_coords[0]
355
353
  hex_coords[_FACE_COORD_INDICES[face_index, 1]] = face_coords[1]
356
- hex_coords[_FACE_COORD_INDICES[face_index, 2]] = wp.select(_FACE_COORD_INDICES[face_index, 3] == 0, 1.0, 0.0)
354
+ hex_coords[_FACE_COORD_INDICES[face_index, 2]] = wp.where(_FACE_COORD_INDICES[face_index, 3] == 0, 0.0, 1.0)
357
355
 
358
356
  return hex_coords
359
357
 
@@ -402,8 +400,8 @@ class Hexmesh(Geometry):
402
400
  face_orientation = args.face_hex_face_orientation[side_index][3]
403
401
 
404
402
  face_coords, normal_coord = Hexmesh._hex_local_face_coords(hex_coords, local_face_index)
405
- return wp.select(
406
- normal_coord == 0.0, Coords(OUTSIDE), Hexmesh._local_to_oriented_face_coords(face_orientation, face_coords)
403
+ return wp.where(
404
+ normal_coord == 0.0, Hexmesh._local_to_oriented_face_coords(face_orientation, face_coords), Coords(OUTSIDE)
407
405
  )
408
406
 
409
407
  @wp.func
@@ -33,13 +33,11 @@ FACE_AXIS_MASK = wp.constant(wp.uint8((1 << 2) - 1))
33
33
  FACE_INNER_OFFSET_BIT = wp.constant(wp.uint8(2))
34
34
  FACE_OUTER_OFFSET_BIT = wp.constant(wp.uint8(3))
35
35
 
36
- _mat32 = wp.mat(shape=(3, 2), dtype=float)
37
-
38
36
 
39
37
  @wp.func
40
38
  def _add_axis_flag(ijk: wp.vec3i, axis: int):
41
39
  coord = ijk[axis]
42
- ijk[axis] = wp.select(coord < 0, coord | GRID_AXIS_FLAG, coord & (~GRID_AXIS_FLAG))
40
+ ijk[axis] = wp.where(coord < 0, coord & (~GRID_AXIS_FLAG), coord | GRID_AXIS_FLAG)
43
41
  return ijk
44
42
 
45
43
 
@@ -191,9 +189,9 @@ class Nanogrid(Geometry):
191
189
  coords = uvw - wp.vec3(ijk)
192
190
  if cell_index == -1:
193
191
  if wp.min(coords) == 0.0 or wp.max(coords) == 1.0:
194
- il = wp.select(coords[0] > 0.5, -1, 0)
195
- jl = wp.select(coords[1] > 0.5, -1, 0)
196
- kl = wp.select(coords[2] > 0.5, -1, 0)
192
+ il = wp.where(coords[0] > 0.5, 0, -1)
193
+ jl = wp.where(coords[1] > 0.5, 0, -1)
194
+ kl = wp.where(coords[2] > 0.5, 0, -1)
197
195
 
198
196
  for n in range(8):
199
197
  ni = n >> 2
@@ -327,7 +325,7 @@ class Nanogrid(Geometry):
327
325
  axis = Nanogrid._get_face_axis(flags)
328
326
  flip = Nanogrid._get_face_inner_offset(flags)
329
327
  v1, v2 = Nanogrid._face_tangent_vecs(args.cell_arg.cell_grid, axis, flip)
330
- return _mat32(v1, v2)
328
+ return wp.matrix_from_cols(v1, v2)
331
329
 
332
330
  @wp.func
333
331
  def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
@@ -409,8 +407,8 @@ class Nanogrid(Geometry):
409
407
 
410
408
  on_side = float(side_ijk[axis] - cell_ijk[axis]) == element_coords[axis]
411
409
 
412
- return wp.select(
413
- on_side, Coords(OUTSIDE), Coords(element_coords[(axis + 1) % 3], element_coords[(axis + 2) % 3], 0.0)
410
+ return wp.where(
411
+ on_side, Coords(element_coords[(axis + 1) % 3], element_coords[(axis + 2) % 3], 0.0), Coords(OUTSIDE)
414
412
  )
415
413
 
416
414
  @wp.func
@@ -538,8 +536,8 @@ def _build_edge_grid(cell_ijk, grid: wp.Volume, temporary_store: cache.Temporary
538
536
 
539
537
  @wp.func
540
538
  def _make_face_flags(axis: int, plus_cell_index: int, minus_cell_index: int):
541
- plus_boundary = wp.uint8(wp.select(plus_cell_index == -1, 0, 1)) << FACE_OUTER_OFFSET_BIT
542
- minus_boundary = wp.uint8(wp.select(minus_cell_index == -1, 0, 1)) << FACE_INNER_OFFSET_BIT
539
+ plus_boundary = wp.uint8(wp.where(plus_cell_index == -1, 1, 0)) << FACE_OUTER_OFFSET_BIT
540
+ minus_boundary = wp.uint8(wp.where(minus_cell_index == -1, 1, 0)) << FACE_INNER_OFFSET_BIT
543
541
 
544
542
  return wp.uint8(axis) | plus_boundary | minus_boundary
545
543
 
@@ -165,13 +165,13 @@ class Quadmesh(Geometry):
165
165
  s = side_coords[0]
166
166
 
167
167
  if vs == quad_vidx[0]:
168
- return wp.select(ve == quad_vidx[1], Coords(0.0, s, 0.0), Coords(s, 0.0, 0.0))
168
+ return wp.where(ve == quad_vidx[1], Coords(s, 0.0, 0.0), Coords(0.0, s, 0.0))
169
169
  elif vs == quad_vidx[1]:
170
- return wp.select(ve == quad_vidx[2], Coords(1.0 - s, 0.0, 0.0), Coords(1.0, s, 0.0))
170
+ return wp.where(ve == quad_vidx[2], Coords(1.0, s, 0.0), Coords(1.0 - s, 0.0, 0.0))
171
171
  elif vs == quad_vidx[2]:
172
- return wp.select(ve == quad_vidx[3], Coords(1.0, 1.0 - s, 0.0), Coords(1.0 - s, 1.0, 0.0))
172
+ return wp.where(ve == quad_vidx[3], Coords(1.0 - s, 1.0, 0.0), Coords(1.0, 1.0 - s, 0.0))
173
173
 
174
- return wp.select(ve == quad_vidx[0], Coords(s, 1.0, 0.0), Coords(0.0, 1.0 - s, 0.0))
174
+ return wp.where(ve == quad_vidx[0], Coords(0.0, 1.0 - s, 0.0), Coords(s, 1.0, 0.0))
175
175
 
176
176
  @wp.func
177
177
  def _quad_to_edge_coords(
@@ -190,18 +190,18 @@ class Quadmesh(Geometry):
190
190
  cy = quad_coords[1]
191
191
 
192
192
  if vs == quad_vidx[0]:
193
- oc = wp.select(ve == quad_vidx[1], cx, cy)
194
- ec = wp.select(ve == quad_vidx[1], cy, cx)
193
+ oc = wp.where(ve == quad_vidx[1], cy, cx)
194
+ ec = wp.where(ve == quad_vidx[1], cx, cy)
195
195
  elif vs == quad_vidx[1]:
196
- oc = wp.select(ve == quad_vidx[2], cy, 1.0 - cx)
197
- ec = wp.select(ve == quad_vidx[2], 1.0 - cx, cy)
196
+ oc = wp.where(ve == quad_vidx[2], 1.0 - cx, cy)
197
+ ec = wp.where(ve == quad_vidx[2], cy, 1.0 - cx)
198
198
  elif vs == quad_vidx[2]:
199
- oc = wp.select(ve == quad_vidx[3], 1.0 - cx, 1.0 - cy)
200
- ec = wp.select(ve == quad_vidx[3], 1.0 - cy, 1.0 - cx)
199
+ oc = wp.where(ve == quad_vidx[3], 1.0 - cy, 1.0 - cx)
200
+ ec = wp.where(ve == quad_vidx[3], 1.0 - cx, 1.0 - cy)
201
201
  else:
202
- oc = wp.select(ve == quad_vidx[0], 1.0 - cy, cx)
203
- ec = wp.select(ve == quad_vidx[0], cx, 1.0 - cy)
204
- return wp.select(oc == 0.0, Coords(OUTSIDE), Coords(ec, 0.0, 0.0))
202
+ oc = wp.where(ve == quad_vidx[0], cx, 1.0 - cy)
203
+ ec = wp.where(ve == quad_vidx[0], 1.0 - cy, cx)
204
+ return wp.where(oc == 0.0, Coords(ec, 0.0, 0.0), Coords(OUTSIDE))
205
205
 
206
206
  @wp.func
207
207
  def boundary_side_index(args: SideIndexArg, boundary_side_index: int):
@@ -56,7 +56,6 @@ class TetmeshSideArg:
56
56
  face_tet_indices: wp.array(dtype=wp.vec2i)
57
57
 
58
58
 
59
- _mat32 = wp.mat(shape=(3, 2), dtype=float)
60
59
  _NULL_BVH = wp.constant(wp.uint64(-1))
61
60
 
62
61
 
@@ -203,7 +202,7 @@ class Tetmesh(Geometry):
203
202
  p1 = args.positions[args.tet_vertex_indices[s.element_index, 1]]
204
203
  p2 = args.positions[args.tet_vertex_indices[s.element_index, 2]]
205
204
  p3 = args.positions[args.tet_vertex_indices[s.element_index, 3]]
206
- return wp.mat33(p1 - p0, p2 - p0, p3 - p0)
205
+ return wp.matrix_from_cols(p1 - p0, p2 - p0, p3 - p0)
207
206
 
208
207
  @wp.func
209
208
  def cell_inverse_deformation_gradient(args: CellArg, s: Sample):
@@ -312,7 +311,7 @@ class Tetmesh(Geometry):
312
311
  @wp.func
313
312
  def side_deformation_gradient(args: SideArg, s: Sample):
314
313
  e1, e2 = Tetmesh._side_vecs(args, s.element_index)
315
- return _mat32(e1, e2)
314
+ return wp.matrix_from_cols(e1, e2)
316
315
 
317
316
  @wp.func
318
317
  def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
@@ -389,7 +388,7 @@ class Tetmesh(Geometry):
389
388
  else:
390
389
  c2 = 1.0 - tet_coords[0] - tet_coords[1] - tet_coords[2]
391
390
 
392
- return wp.select(c0 + c1 + c2 > 0.999, Coords(OUTSIDE), Coords(c0, c1, c2))
391
+ return wp.where(c0 + c1 + c2 > 0.999, Coords(c0, c1, c2), Coords(OUTSIDE))
393
392
 
394
393
  @wp.func
395
394
  def side_to_cell_arg(side_arg: SideArg):
@@ -325,9 +325,7 @@ class Trimesh(Geometry):
325
325
  elif edge_vidx[0] == v:
326
326
  start = k
327
327
 
328
- return wp.select(
329
- tri_coords[start] + tri_coords[end] > 0.999, Coords(OUTSIDE), Coords(tri_coords[end], 0.0, 0.0)
330
- )
328
+ return wp.where(tri_coords[start] + tri_coords[end] > 0.999, Coords(tri_coords[end], 0.0, 0.0), Coords(OUTSIDE))
331
329
 
332
330
  def _build_topology(self, temporary_store: TemporaryStore):
333
331
  from warp.fem.utils import compress_node_indices, host_read_at_index, masked_indices
@@ -576,7 +574,7 @@ class Trimesh2D(Trimesh):
576
574
  p0 = args.positions[tri_idx[0]]
577
575
  p1 = args.positions[tri_idx[1]]
578
576
  p2 = args.positions[tri_idx[2]]
579
- return wp.mat22(p1 - p0, p2 - p0)
577
+ return wp.matrix_from_cols(p1 - p0, p2 - p0)
580
578
 
581
579
  @wp.func
582
580
  def cell_lookup(args: CellArg, pos: wp.vec2):
@@ -689,9 +687,6 @@ class Trimesh3DSideArg:
689
687
  positions: wp.array(dtype=wp.vec3)
690
688
 
691
689
 
692
- _mat32 = wp.mat(shape=(3, 2), dtype=float)
693
-
694
-
695
690
  class Trimesh3D(Trimesh):
696
691
  """3D Triangular mesh geometry"""
697
692
 
@@ -714,7 +709,7 @@ class Trimesh3D(Trimesh):
714
709
  p0 = args.positions[tri_idx[0]]
715
710
  p1 = args.positions[tri_idx[1]]
716
711
  p2 = args.positions[tri_idx[2]]
717
- return _mat32(p1 - p0, p2 - p0)
712
+ return wp.matrix_from_cols(p1 - p0, p2 - p0)
718
713
 
719
714
  @wp.func
720
715
  def cell_lookup(args: CellArg, pos: wp.vec3):