warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/warp-clang.dll +0 -0
  3. warp/bin/warp.dll +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@ import warp as wp
19
19
  import warp.fem.cache as cache
20
20
  from warp.fem.geometry import GeometryPartition, WholeGeometryPartition
21
21
  from warp.fem.types import NULL_NODE_INDEX
22
- from warp.fem.utils import _iota_kernel, compress_node_indices
22
+ from warp.fem.utils import compress_node_indices
23
23
 
24
24
  from .function_space import FunctionSpace
25
25
  from .topology import SpaceTopology
@@ -87,7 +87,7 @@ class WholeSpacePartition(SpacePartition):
87
87
  """Return the global function space indices for nodes in this partition"""
88
88
  if self._node_indices is None:
89
89
  self._node_indices = cache.borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int)
90
- wp.launch(kernel=_iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array, 1])
90
+ wp.launch(kernel=self._iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array])
91
91
  return self._node_indices.array
92
92
 
93
93
  def partition_arg_value(self, device):
@@ -104,6 +104,10 @@ class WholeSpacePartition(SpacePartition):
104
104
  def name(self) -> str:
105
105
  return "Whole"
106
106
 
107
+ @wp.kernel
108
+ def _iota_kernel(indices: wp.array(dtype=int)):
109
+ indices[wp.tid()] = wp.tid()
110
+
107
111
 
108
112
  class NodeCategory:
109
113
  OWNED_INTERIOR = wp.constant(0)
@@ -166,10 +166,10 @@ class QuadmeshSpaceTopology(SpaceTopology):
166
166
 
167
167
  if wp.static(EDGE_NODE_COUNT > 0):
168
168
  # EDGE_X, EDGE_Y
169
- side_start = wp.select(
169
+ side_start = wp.where(
170
170
  node_type == SquareShapeFunction.EDGE_X,
171
- wp.select(type_instance == 0, 1, 3),
172
- wp.select(type_instance == 0, 2, 0),
171
+ wp.where(type_instance == 0, 0, 2),
172
+ wp.where(type_instance == 0, 3, 1),
173
173
  )
174
174
 
175
175
  side_index = topo_arg.quad_edge_indices[element_index, side_start]
@@ -178,7 +178,7 @@ class QuadmeshSpaceTopology(SpaceTopology):
178
178
 
179
179
  # Flip indexing direction
180
180
  flipped = int(side_start >= 2) ^ int(local_vs != global_vs)
181
- index_in_side = wp.select(flipped, type_index, EDGE_NODE_COUNT - 1 - type_index)
181
+ index_in_side = wp.where(flipped, EDGE_NODE_COUNT - 1 - type_index, type_index)
182
182
 
183
183
  return global_offset + EDGE_NODE_COUNT * side_index + index_in_side
184
184
 
@@ -197,10 +197,10 @@ class QuadmeshSpaceTopology(SpaceTopology):
197
197
  node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
198
198
 
199
199
  if node_type == SquareShapeFunction.EDGE_X or node_type == SquareShapeFunction.EDGE_Y:
200
- side_start = wp.select(
200
+ side_start = wp.where(
201
201
  node_type == SquareShapeFunction.EDGE_X,
202
- wp.select(type_instance == 0, 1, 3),
203
- wp.select(type_instance == 0, 2, 0),
202
+ wp.where(type_instance == 0, 0, 2),
203
+ wp.where(type_instance == 0, 3, 1),
204
204
  )
205
205
 
206
206
  side_index = topo_arg.quad_edge_indices[element_index, side_start]
@@ -209,7 +209,7 @@ class QuadmeshSpaceTopology(SpaceTopology):
209
209
 
210
210
  # Flip indexing direction
211
211
  flipped = int(side_start >= 2) ^ int(local_vs != global_vs)
212
- return wp.select(flipped, 1.0, -1.0)
212
+ return wp.where(flipped, -1.0, 1.0)
213
213
 
214
214
  return 1.0
215
215
 
@@ -63,7 +63,7 @@ class CubeTripolynomialShapeFunctions(CubeShapeFunction):
63
63
 
64
64
  self.ORDER = wp.constant(degree)
65
65
  self.NODES_PER_ELEMENT = wp.constant((degree + 1) ** 3)
66
- self.NODES_PER_EDGE = wp.constant(degree + 1)
66
+ self.NODES_PER_SIDE = wp.constant((degree + 1) ** 2)
67
67
 
68
68
  if is_closed(self.family):
69
69
  self.VERTEX_NODE_COUNT = wp.constant(1)
@@ -152,13 +152,13 @@ class CubeTripolynomialShapeFunctions(CubeShapeFunction):
152
152
  ):
153
153
  i, j, k = self._node_ijk(node_index_in_elt)
154
154
 
155
- zi = wp.select(i == 0, 0, 1)
156
- zj = wp.select(j == 0, 0, 1)
157
- zk = wp.select(k == 0, 0, 1)
155
+ zi = wp.where(i == 0, 1, 0)
156
+ zj = wp.where(j == 0, 1, 0)
157
+ zk = wp.where(k == 0, 1, 0)
158
158
 
159
- mi = wp.select(i == ORDER, 0, 1)
160
- mj = wp.select(j == ORDER, 0, 1)
161
- mk = wp.select(k == ORDER, 0, 1)
159
+ mi = wp.where(i == ORDER, 1, 0)
160
+ mj = wp.where(j == ORDER, 1, 0)
161
+ mk = wp.where(k == ORDER, 1, 0)
162
162
 
163
163
  if zi + mi == 1:
164
164
  if zj + mj == 1:
@@ -504,7 +504,7 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
504
504
 
505
505
  self.ORDER = wp.constant(degree)
506
506
  self.NODES_PER_ELEMENT = wp.constant(8 + 12 * (degree - 1))
507
- self.NODES_PER_EDGE = wp.constant(degree + 1)
507
+ self.NODES_PER_SIDE = wp.constant(4 * degree)
508
508
 
509
509
  self.VERTEX_NODE_COUNT = wp.constant(1)
510
510
  self.EDGE_NODE_COUNT = wp.constant(degree - 1)
@@ -634,9 +634,9 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
634
634
  if node_type == CubeSerendipityShapeFunctions.VERTEX:
635
635
  node_ijk = CubeSerendipityShapeFunctions._vertex_coords(type_instance)
636
636
 
637
- cx = wp.select(node_ijk[0] == 0, coords[0], 1.0 - coords[0])
638
- cy = wp.select(node_ijk[1] == 0, coords[1], 1.0 - coords[1])
639
- cz = wp.select(node_ijk[2] == 0, coords[2], 1.0 - coords[2])
637
+ cx = wp.where(node_ijk[0] == 0, 1.0 - coords[0], coords[0])
638
+ cy = wp.where(node_ijk[1] == 0, 1.0 - coords[1], coords[1])
639
+ cz = wp.where(node_ijk[2] == 0, 1.0 - coords[2], coords[2])
640
640
 
641
641
  w = cx * cy * cz
642
642
 
@@ -659,8 +659,8 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
659
659
  local_coords = Grid3D._world_to_local(axis, coords)
660
660
 
661
661
  w = float(1.0)
662
- w *= wp.select(node_all[1] == 0, local_coords[1], 1.0 - local_coords[1])
663
- w *= wp.select(node_all[2] == 0, local_coords[2], 1.0 - local_coords[2])
662
+ w *= wp.where(node_all[1] == 0, 1.0 - local_coords[1], local_coords[1])
663
+ w *= wp.where(node_all[2] == 0, 1.0 - local_coords[2], local_coords[2])
664
664
 
665
665
  for k in range(ORDER_PLUS_ONE):
666
666
  if k != node_all[0]:
@@ -690,13 +690,13 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
690
690
  if node_type == CubeSerendipityShapeFunctions.VERTEX:
691
691
  node_ijk = CubeSerendipityShapeFunctions._vertex_coords(type_instance)
692
692
 
693
- cx = wp.select(node_ijk[0] == 0, coords[0], 1.0 - coords[0])
694
- cy = wp.select(node_ijk[1] == 0, coords[1], 1.0 - coords[1])
695
- cz = wp.select(node_ijk[2] == 0, coords[2], 1.0 - coords[2])
693
+ cx = wp.where(node_ijk[0] == 0, 1.0 - coords[0], coords[0])
694
+ cy = wp.where(node_ijk[1] == 0, 1.0 - coords[1], coords[1])
695
+ cz = wp.where(node_ijk[2] == 0, 1.0 - coords[2], coords[2])
696
696
 
697
- gx = wp.select(node_ijk[0] == 0, 1.0, -1.0)
698
- gy = wp.select(node_ijk[1] == 0, 1.0, -1.0)
699
- gz = wp.select(node_ijk[2] == 0, 1.0, -1.0)
697
+ gx = wp.where(node_ijk[0] == 0, -1.0, 1.0)
698
+ gy = wp.where(node_ijk[1] == 0, -1.0, 1.0)
699
+ gz = wp.where(node_ijk[2] == 0, -1.0, 1.0)
700
700
 
701
701
  if wp.static(ORDER == 2):
702
702
  w = cx + cy + cz - 3.0 + LOBATTO_COORDS[1]
@@ -728,11 +728,11 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
728
728
 
729
729
  local_coords = Grid3D._world_to_local(axis, coords)
730
730
 
731
- w_long = wp.select(node_all[1] == 0, local_coords[1], 1.0 - local_coords[1])
732
- w_lat = wp.select(node_all[2] == 0, local_coords[2], 1.0 - local_coords[2])
731
+ w_long = wp.where(node_all[1] == 0, 1.0 - local_coords[1], local_coords[1])
732
+ w_lat = wp.where(node_all[2] == 0, 1.0 - local_coords[2], local_coords[2])
733
733
 
734
- g_long = wp.select(node_all[1] == 0, 1.0, -1.0)
735
- g_lat = wp.select(node_all[2] == 0, 1.0, -1.0)
734
+ g_long = wp.where(node_all[1] == 0, -1.0, 1.0)
735
+ g_lat = wp.where(node_all[2] == 0, -1.0, 1.0)
736
736
 
737
737
  w_alt = LAGRANGE_SCALE[node_all[0]]
738
738
  g_alt = float(0.0)
@@ -461,8 +461,8 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
461
461
  node_i, node_j = self._node_lobatto_indices(node_type, type_instance, type_index)
462
462
 
463
463
  if node_type == SquareSerendipityShapeFunctions.VERTEX:
464
- cx = wp.select(node_i == 0, coords[0], 1.0 - coords[0])
465
- cy = wp.select(node_j == 0, coords[1], 1.0 - coords[1])
464
+ cx = wp.where(node_i == 0, 1.0 - coords[0], coords[0])
465
+ cy = wp.where(node_j == 0, 1.0 - coords[1], coords[1])
466
466
 
467
467
  w = cx * cy
468
468
 
@@ -475,7 +475,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
475
475
 
476
476
  w = float(1.0)
477
477
  if node_type == SquareSerendipityShapeFunctions.EDGE_Y:
478
- w *= wp.select(node_i == 0, coords[0], 1.0 - coords[0])
478
+ w *= wp.where(node_i == 0, 1.0 - coords[0], coords[0])
479
479
  else:
480
480
  for k in range(ORDER_PLUS_ONE):
481
481
  if k != node_i:
@@ -484,7 +484,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
484
484
  w *= LAGRANGE_SCALE[node_i]
485
485
 
486
486
  if node_type == SquareSerendipityShapeFunctions.EDGE_X:
487
- w *= wp.select(node_j == 0, coords[1], 1.0 - coords[1])
487
+ w *= wp.where(node_j == 0, 1.0 - coords[1], coords[1])
488
488
  else:
489
489
  for k in range(ORDER_PLUS_ONE):
490
490
  if k != node_j:
@@ -513,11 +513,11 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
513
513
  node_i, node_j = self._node_lobatto_indices(node_type, type_instance, type_index)
514
514
 
515
515
  if node_type == SquareSerendipityShapeFunctions.VERTEX:
516
- cx = wp.select(node_i == 0, coords[0], 1.0 - coords[0])
517
- cy = wp.select(node_j == 0, coords[1], 1.0 - coords[1])
516
+ cx = wp.where(node_i == 0, 1.0 - coords[0], coords[0])
517
+ cy = wp.where(node_j == 0, 1.0 - coords[1], coords[1])
518
518
 
519
- gx = wp.select(node_i == 0, 1.0, -1.0)
520
- gy = wp.select(node_j == 0, 1.0, -1.0)
519
+ gx = wp.where(node_i == 0, -1.0, 1.0)
520
+ gy = wp.where(node_j == 0, -1.0, 1.0)
521
521
 
522
522
  if ORDER == 2:
523
523
  w = cx + cy - 2.0 + LOBATTO_COORDS[1]
@@ -537,7 +537,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
537
537
  return wp.vec2(grad_x, grad_y) * DEGREE_3_CIRCLE_SCALE
538
538
 
539
539
  if node_type == SquareSerendipityShapeFunctions.EDGE_X:
540
- prefix_x = wp.select(node_j == 0, coords[1], 1.0 - coords[1])
540
+ prefix_x = wp.where(node_j == 0, 1.0 - coords[1], coords[1])
541
541
  else:
542
542
  prefix_x = LAGRANGE_SCALE[node_j]
543
543
  for k in range(ORDER_PLUS_ONE):
@@ -545,7 +545,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
545
545
  prefix_x *= coords[1] - LOBATTO_COORDS[k]
546
546
 
547
547
  if node_type == SquareSerendipityShapeFunctions.EDGE_Y:
548
- prefix_y = wp.select(node_i == 0, coords[0], 1.0 - coords[0])
548
+ prefix_y = wp.where(node_i == 0, 1.0 - coords[0], coords[0])
549
549
  else:
550
550
  prefix_y = LAGRANGE_SCALE[node_i]
551
551
  for k in range(ORDER_PLUS_ONE):
@@ -553,7 +553,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
553
553
  prefix_y *= coords[0] - LOBATTO_COORDS[k]
554
554
 
555
555
  if node_type == SquareSerendipityShapeFunctions.EDGE_X:
556
- grad_y = wp.select(node_j == 0, 1.0, -1.0) * prefix_y
556
+ grad_y = wp.where(node_j == 0, -1.0, 1.0) * prefix_y
557
557
  else:
558
558
  prefix_y *= LAGRANGE_SCALE[node_j]
559
559
  grad_y = float(0.0)
@@ -564,7 +564,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
564
564
  prefix_y *= delta_y
565
565
 
566
566
  if node_type == SquareSerendipityShapeFunctions.EDGE_Y:
567
- grad_x = wp.select(node_i == 0, 1.0, -1.0) * prefix_x
567
+ grad_x = wp.where(node_i == 0, -1.0, 1.0) * prefix_x
568
568
  else:
569
569
  prefix_x *= LAGRANGE_SCALE[node_i]
570
570
  grad_x = float(0.0)
@@ -196,7 +196,7 @@ class TrianglePolynomialShapeFunctions(TriangleShapeFunction):
196
196
  def trace_node_quadrature_weight(node_index_in_element: int):
197
197
  node_type, type_index = self.node_type_and_type_index(node_index_in_element)
198
198
 
199
- return wp.select(node_type == TrianglePolynomialShapeFunctions.VERTEX, EDGE_WEIGHT, VERTEX_WEIGHT)
199
+ return wp.where(node_type == TrianglePolynomialShapeFunctions.VERTEX, VERTEX_WEIGHT, EDGE_WEIGHT)
200
200
 
201
201
  return trace_node_quadrature_weight
202
202
 
@@ -244,10 +244,10 @@ class TetmeshSpaceTopology(SpaceTopology):
244
244
  edge = type_index // INTERIOR_NODES_PER_EDGE
245
245
  c1, c2 = TetrahedronShapeFunction.edge_vidx(edge)
246
246
 
247
- return wp.select(
247
+ return wp.where(
248
248
  geo_arg.tet_vertex_indices[element_index][c1] > geo_arg.tet_vertex_indices[element_index][c2],
249
- 1.0,
250
249
  -1.0,
250
+ 1.0,
251
251
  )
252
252
 
253
253
  if wp.static(INTERIOR_NODES_PER_FACE > 0):
@@ -257,7 +257,7 @@ class TetmeshSpaceTopology(SpaceTopology):
257
257
  global_face_index = topo_arg.tet_face_indices[element_index][face]
258
258
  inner = topo_arg.face_tet_indices[global_face_index][0]
259
259
 
260
- return wp.select(inner == element_index, -1.0, 1.0)
260
+ return wp.where(inner == element_index, 1.0, -1.0)
261
261
 
262
262
  return 1.0
263
263
 
@@ -175,11 +175,11 @@ class TrimeshSpaceTopology(SpaceTopology):
175
175
  edge = type_index // INTERIOR_NODES_PER_SIDE
176
176
 
177
177
  global_edge_index = topo_arg.tri_edge_indices[element_index][edge]
178
- return wp.select(
178
+ return wp.where(
179
179
  topo_arg.edge_vertex_indices[global_edge_index][0]
180
180
  == geo_arg.topology.tri_vertex_indices[element_index][edge],
181
- -1.0,
182
181
  1.0,
182
+ -1.0,
183
183
  )
184
184
 
185
185
  return 1.0
warp/fem/utils.py CHANGED
@@ -56,13 +56,11 @@ def compress_node_indices(
56
56
  sorted_node_indices = sorted_node_indices_temp.array
57
57
  sorted_array_indices = sorted_array_indices_temp.array
58
58
 
59
- wp.copy(dest=sorted_node_indices, src=node_indices, count=index_count)
60
-
61
59
  indices_per_element = 1 if node_indices.ndim == 1 else node_indices.shape[-1]
62
60
  wp.launch(
63
- kernel=_iota_kernel,
61
+ kernel=_prepare_node_sort_kernel,
64
62
  dim=index_count,
65
- inputs=[sorted_array_indices, indices_per_element],
63
+ inputs=[node_indices.flatten(), sorted_node_indices, sorted_array_indices, indices_per_element],
66
64
  )
67
65
 
68
66
  # Sort indices
@@ -169,8 +167,16 @@ def masked_indices(
169
167
 
170
168
 
171
169
  @wp.kernel
172
- def _iota_kernel(indices: wp.array(dtype=int), divisor: int):
173
- indices[wp.tid()] = wp.tid() // divisor
170
+ def _prepare_node_sort_kernel(
171
+ node_indices: wp.array(dtype=int),
172
+ sort_keys: wp.array(dtype=int),
173
+ sort_values: wp.array(dtype=int),
174
+ divisor: int,
175
+ ):
176
+ i = wp.tid()
177
+ node = node_indices[i]
178
+ sort_keys[i] = wp.where(node >= 0, node, NULL_NODE_INDEX)
179
+ sort_values[i] = i // divisor
174
180
 
175
181
 
176
182
  @wp.kernel
warp/jax.py CHANGED
@@ -58,6 +58,19 @@ def device_from_jax(jax_device) -> warp.context.Device:
58
58
  raise RuntimeError(f"Unsupported Jax device platform '{jax_device.platform}'")
59
59
 
60
60
 
61
+ def get_jax_device():
62
+ """Get the current Jax device."""
63
+ import jax
64
+
65
+ # TODO: is there a simpler way of getting the Jax "current" device?
66
+ # check if jax.default_device() context manager is active
67
+ device = jax.config.jax_default_device
68
+ # if default device is not set, use first device
69
+ if device is None:
70
+ device = jax.local_devices()[0]
71
+ return device
72
+
73
+
61
74
  def dtype_to_jax(warp_dtype):
62
75
  """Return the Jax dtype corresponding to a Warp dtype.
63
76
 
@@ -156,7 +169,7 @@ def to_jax(warp_array):
156
169
  """
157
170
  import jax.dlpack
158
171
 
159
- return jax.dlpack.from_dlpack(warp.to_dlpack(warp_array))
172
+ return jax.dlpack.from_dlpack(warp_array)
160
173
 
161
174
 
162
175
  def from_jax(jax_array, dtype=None) -> warp.array:
@@ -0,0 +1,16 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .custom_call import jax_kernel
@@ -15,10 +15,9 @@
15
15
 
16
16
  import ctypes
17
17
 
18
- import jax
19
-
20
18
  import warp as wp
21
19
  from warp.context import type_str
20
+ from warp.jax import get_jax_device
22
21
  from warp.types import array_t, launch_bounds_t, strides_from_shape
23
22
 
24
23
  _jax_warp_p = None
@@ -29,35 +28,33 @@ _registered_kernels = [None]
29
28
  _registered_kernel_to_id = {}
30
29
 
31
30
 
32
- def jax_kernel(wp_kernel, launch_dims=None):
31
+ def jax_kernel(kernel, launch_dims=None):
33
32
  """Create a Jax primitive from a Warp kernel.
34
33
 
35
34
  NOTE: This is an experimental feature under development.
36
35
 
37
36
  Args:
38
- wp_kernel: The Warp kernel to be wrapped.
37
+ kernel: The Warp kernel to be wrapped.
39
38
  launch_dims: Optional. Specify the kernel launch dimensions. If None,
40
39
  dimensions are inferred from the shape of the first argument.
41
40
  This option when set will specify the output dimensions.
42
41
 
43
- Current limitations:
44
- - All kernel arguments must be arrays.
45
- - If launch_dims is not provided, kernel launch dimensions are inferred from the shape of the first argument.
46
- - Input arguments are followed by output arguments in the Warp kernel definition.
47
- - There must be at least one input argument and at least one output argument.
48
- - All arrays must be contiguous.
49
- - Only the CUDA backend is supported.
42
+ Limitations:
43
+ - All kernel arguments must be contiguous arrays.
44
+ - Input arguments are followed by output arguments in the Warp kernel definition.
45
+ - There must be at least one input argument and at least one output argument.
46
+ - Only the CUDA backend is supported.
50
47
  """
51
48
 
52
49
  if _jax_warp_p is None:
53
50
  # Create and register the primitive
54
51
  _create_jax_warp_primitive()
55
- if wp_kernel not in _registered_kernel_to_id:
52
+ if kernel not in _registered_kernel_to_id:
56
53
  id = len(_registered_kernels)
57
- _registered_kernels.append(wp_kernel)
58
- _registered_kernel_to_id[wp_kernel] = id
54
+ _registered_kernels.append(kernel)
55
+ _registered_kernel_to_id[kernel] = id
59
56
  else:
60
- id = _registered_kernel_to_id[wp_kernel]
57
+ id = _registered_kernel_to_id[kernel]
61
58
 
62
59
  def bind(*args):
63
60
  return _jax_warp_p.bind(*args, kernel=id, launch_dims=launch_dims)
@@ -102,7 +99,7 @@ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
102
99
  kernel_params[i + 1] = arg_ptr
103
100
 
104
101
  # Get current device.
105
- device = wp.device_from_jax(_get_jax_device())
102
+ device = wp.device_from_jax(get_jax_device())
106
103
 
107
104
  # Get kernel hooks.
108
105
  # Note: module was loaded during jit lowering.
@@ -115,16 +112,6 @@ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
115
112
  )
116
113
 
117
114
 
118
- # TODO: is there a simpler way of getting the Jax "current" device?
119
- def _get_jax_device():
120
- # check if jax.default_device() context manager is active
121
- device = jax.config.jax_default_device
122
- # if default device is not set, use first device
123
- if device is None:
124
- device = jax.local_devices()[0]
125
- return device
126
-
127
-
128
115
  def _create_jax_warp_primitive():
129
116
  from functools import reduce
130
117
 
@@ -288,7 +275,7 @@ def _create_jax_warp_primitive():
288
275
  # TODO This may not be necessary, but it is perhaps better not to be
289
276
  # mucking with kernel loading while already running the workload.
290
277
  module = wp_kernel.module
291
- device = wp.device_from_jax(_get_jax_device())
278
+ device = wp.device_from_jax(get_jax_device())
292
279
  if not module.load(device):
293
280
  raise Exception("Could not load kernel on device")
294
281