warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/libwarp-clang.dylib +0 -0
  3. warp/bin/libwarp.dylib +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -83,22 +83,15 @@ class NodalFieldBase(DiscreteField):
83
83
  @cache.dynamic_func(suffix=self.name)
84
84
  def eval_inner(args: self.ElementEvalArg, s: Sample):
85
85
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
86
- res = self.space.space_value(
87
- self._read_node_value(args, s.element_index, 0),
88
- self.space.element_inner_weight(
89
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
90
- ),
91
- local_value_map,
92
- )
93
-
94
86
  node_count = self.space.topology.element_node_count(
95
87
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
96
88
  )
97
- for k in range(1, node_count):
89
+ res = self.space.dtype(0.0)
90
+ for k in range(node_count):
98
91
  res += self.space.space_value(
99
92
  self._read_node_value(args, s.element_index, k),
100
93
  self.space.element_inner_weight(
101
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
94
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
102
95
  ),
103
96
  local_value_map,
104
97
  )
@@ -113,24 +106,16 @@ class NodalFieldBase(DiscreteField):
113
106
  @cache.dynamic_func(suffix=self.name)
114
107
  def eval_grad_inner(args: self.ElementEvalArg, s: Sample, grad_transform: Any):
115
108
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
116
-
117
- res = self.space.space_gradient(
118
- self._read_node_value(args, s.element_index, 0),
119
- self.space.element_inner_weight_gradient(
120
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
121
- ),
122
- local_value_map,
123
- grad_transform,
124
- )
125
-
126
109
  node_count = self.space.topology.element_node_count(
127
110
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
128
111
  )
129
- for k in range(1, node_count):
112
+
113
+ res = self.gradient_dtype(0.0)
114
+ for k in range(node_count):
130
115
  res += self.space.space_gradient(
131
116
  self._read_node_value(args, s.element_index, k),
132
117
  self.space.element_inner_weight_gradient(
133
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
118
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
134
119
  ),
135
120
  local_value_map,
136
121
  grad_transform,
@@ -157,24 +142,16 @@ class NodalFieldBase(DiscreteField):
157
142
  def eval_div_inner(args: self.ElementEvalArg, s: Sample):
158
143
  grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
159
144
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
160
-
161
- res = self.space.space_divergence(
162
- self._read_node_value(args, s.element_index, 0),
163
- self.space.element_inner_weight_gradient(
164
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
165
- ),
166
- local_value_map,
167
- grad_transform,
168
- )
169
-
170
145
  node_count = self.space.topology.element_node_count(
171
146
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
172
147
  )
173
- for k in range(1, node_count):
148
+
149
+ res = self.divergence_dtype(0.0)
150
+ for k in range(node_count):
174
151
  res += self.space.space_divergence(
175
152
  self._read_node_value(args, s.element_index, k),
176
153
  self.space.element_inner_weight_gradient(
177
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
154
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
178
155
  ),
179
156
  local_value_map,
180
157
  grad_transform,
@@ -187,23 +164,16 @@ class NodalFieldBase(DiscreteField):
187
164
  @cache.dynamic_func(suffix=self.name)
188
165
  def eval_outer(args: self.ElementEvalArg, s: Sample):
189
166
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
190
- res = self.space.space_value(
191
- self._read_node_value(args, s.element_index, 0),
192
- self.space.element_outer_weight(
193
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
194
- ),
195
- local_value_map,
196
- )
197
-
198
167
  node_count = self.space.topology.element_node_count(
199
168
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
200
169
  )
201
170
 
202
- for k in range(1, node_count):
171
+ res = self.dtype(0.0)
172
+ for k in range(node_count):
203
173
  res += self.space.space_value(
204
174
  self._read_node_value(args, s.element_index, k),
205
175
  self.space.element_outer_weight(
206
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
176
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
207
177
  ),
208
178
  local_value_map,
209
179
  )
@@ -218,24 +188,16 @@ class NodalFieldBase(DiscreteField):
218
188
  @cache.dynamic_func(suffix=self.name)
219
189
  def eval_grad_outer(args: self.ElementEvalArg, s: Sample, grad_transform: Any):
220
190
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
221
-
222
- res = self.space.space_gradient(
223
- self._read_node_value(args, s.element_index, 0),
224
- self.space.element_outer_weight_gradient(
225
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
226
- ),
227
- local_value_map,
228
- grad_transform,
229
- )
230
-
231
191
  node_count = self.space.topology.element_node_count(
232
192
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
233
193
  )
234
- for k in range(1, node_count):
194
+
195
+ res = self.gradient_dtype(0.0)
196
+ for k in range(node_count):
235
197
  res += self.space.space_gradient(
236
198
  self._read_node_value(args, s.element_index, k),
237
199
  self.space.element_outer_weight_gradient(
238
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
200
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
239
201
  ),
240
202
  local_value_map,
241
203
  grad_transform,
@@ -262,24 +224,16 @@ class NodalFieldBase(DiscreteField):
262
224
  def eval_div_outer(args: self.ElementEvalArg, s: Sample):
263
225
  grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
264
226
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
265
-
266
- res = self.space.space_divergence(
267
- self._read_node_value(args, s.element_index, 0),
268
- self.space.element_outer_weight_gradient(
269
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
270
- ),
271
- local_value_map,
272
- grad_transform,
273
- )
274
-
275
227
  node_count = self.space.topology.element_node_count(
276
228
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
277
229
  )
278
- for k in range(1, node_count):
230
+
231
+ res = self.divergence_dtype(0.0)
232
+ for k in range(node_count):
279
233
  res += self.space.space_divergence(
280
234
  self._read_node_value(args, s.element_index, k),
281
235
  self.space.element_outer_weight_gradient(
282
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
236
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
283
237
  ),
284
238
  local_value_map,
285
239
  grad_transform,
warp/fem/field/virtual.py CHANGED
@@ -69,7 +69,12 @@ class AdjointField(SpaceField):
69
69
  def eval_test_inner(args: self.ElementEvalArg, s: Sample):
70
70
  dof = self._get_dof(s)
71
71
  node_weight = self.space.element_inner_weight(
72
- args.elt_arg, args.eval_arg, s.element_index, s.element_coords, get_node_index_in_element(dof)
72
+ args.elt_arg,
73
+ args.eval_arg,
74
+ s.element_index,
75
+ s.element_coords,
76
+ get_node_index_in_element(dof),
77
+ s.qp_index,
73
78
  )
74
79
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
75
80
  dof_value = self.space.node_basis_element(get_node_coord(dof))
@@ -90,6 +95,7 @@ class AdjointField(SpaceField):
90
95
  s.element_index,
91
96
  s.element_coords,
92
97
  get_node_index_in_element(dof),
98
+ s.qp_index,
93
99
  )
94
100
  grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
95
101
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
@@ -111,6 +117,7 @@ class AdjointField(SpaceField):
111
117
  s.element_index,
112
118
  s.element_coords,
113
119
  get_node_index_in_element(dof),
120
+ s.qp_index,
114
121
  )
115
122
  grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
116
123
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
@@ -124,7 +131,12 @@ class AdjointField(SpaceField):
124
131
  def eval_test_outer(args: self.ElementEvalArg, s: Sample):
125
132
  dof = self._get_dof(s)
126
133
  node_weight = self.space.element_outer_weight(
127
- args.elt_arg, args.eval_arg, s.element_index, s.element_coords, get_node_index_in_element(dof)
134
+ args.elt_arg,
135
+ args.eval_arg,
136
+ s.element_index,
137
+ s.element_coords,
138
+ get_node_index_in_element(dof),
139
+ s.qp_index,
128
140
  )
129
141
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
130
142
  dof_value = self.space.node_basis_element(get_node_coord(dof))
@@ -145,6 +157,7 @@ class AdjointField(SpaceField):
145
157
  s.element_index,
146
158
  s.element_coords,
147
159
  get_node_index_in_element(dof),
160
+ s.qp_index,
148
161
  )
149
162
  grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
150
163
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
@@ -166,6 +179,7 @@ class AdjointField(SpaceField):
166
179
  s.element_index,
167
180
  s.element_coords,
168
181
  get_node_index_in_element(dof),
182
+ s.qp_index,
169
183
  )
170
184
  grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
171
185
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
@@ -370,9 +384,8 @@ class LocalAdjointField(SpaceField):
370
384
 
371
385
  @cache.dynamic_func(suffix=str(TAYLOR_DOF_COUNT))
372
386
  def split_dof(dof_index: DofIndex, dof_begin: int):
373
- dof = get_node_coord(dof_index)
374
- value_dof = dof // TAYLOR_DOF_COUNT
375
- taylor_dof = dof - value_dof * TAYLOR_DOF_COUNT - dof_begin
387
+ taylor_dof = get_node_index_in_element(dof_index) - dof_begin
388
+ value_dof = get_node_coord(dof_index)
376
389
  return value_dof, taylor_dof
377
390
 
378
391
  return split_dof
@@ -386,7 +399,7 @@ class LocalAdjointField(SpaceField):
386
399
 
387
400
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
388
401
  dof_value = self.space.value_basis_element(value_dof, local_value_map)
389
- return wp.select(taylor_dof == 0, self.dtype(0.0), dof_value)
402
+ return wp.where(taylor_dof == 0, dof_value, self.dtype(0.0))
390
403
 
391
404
  return eval_test_inner
392
405
 
@@ -441,7 +454,7 @@ class LocalAdjointField(SpaceField):
441
454
 
442
455
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
443
456
  dof_value = self.space.value_basis_element(value_dof, local_value_map)
444
- return wp.select(taylor_dof == 0, self.dtype(0.0), dof_value)
457
+ return wp.where(taylor_dof == 0, dof_value, self.dtype(0.0))
445
458
 
446
459
  return eval_test_outer
447
460
 
@@ -551,17 +564,25 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
551
564
  qp_index = quadrature.point_index(
552
565
  domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
553
566
  )
567
+ qp_eval_index = quadrature.point_evaluation_index(
568
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
569
+ )
554
570
  coords = quadrature.point_coords(
555
571
  domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
556
572
  )
557
573
 
558
- qp_result = local_result[qp_index]
574
+ qp_result = local_result[qp_eval_index]
559
575
 
560
576
  qp_sum = float(0.0)
561
577
 
562
578
  if wp.static(0 != TEST_INNER_COUNT):
563
579
  w = test.space.element_inner_weight(
564
- domain_arg, test_space_arg, element_index, coords, test_element_index.node_index_in_element
580
+ domain_arg,
581
+ test_space_arg,
582
+ element_index,
583
+ coords,
584
+ test_element_index.node_index_in_element,
585
+ qp_index,
565
586
  )
566
587
  for val_dof in range(TEST_NODE_DOF_DIM):
567
588
  test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
@@ -569,7 +590,12 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
569
590
 
570
591
  if wp.static(0 != TEST_OUTER_COUNT):
571
592
  w = test.space.element_outer_weight(
572
- domain_arg, test_space_arg, element_index, coords, test_element_index.node_index_in_element
593
+ domain_arg,
594
+ test_space_arg,
595
+ element_index,
596
+ coords,
597
+ test_element_index.node_index_in_element,
598
+ qp_index,
573
599
  )
574
600
  for val_dof in range(TEST_NODE_DOF_DIM):
575
601
  test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
@@ -577,7 +603,12 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
577
603
 
578
604
  if wp.static(0 != TEST_INNER_GRAD_COUNT):
579
605
  w_grad = test.space.element_inner_weight_gradient(
580
- domain_arg, test_space_arg, element_index, coords, test_element_index.node_index_in_element
606
+ domain_arg,
607
+ test_space_arg,
608
+ element_index,
609
+ coords,
610
+ test_element_index.node_index_in_element,
611
+ qp_index,
581
612
  )
582
613
  for val_dof in range(TEST_NODE_DOF_DIM):
583
614
  test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
@@ -589,7 +620,12 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
589
620
 
590
621
  if wp.static(0 != TEST_OUTER_GRAD_COUNT):
591
622
  w_grad = test.space.element_outer_weight_gradient(
592
- domain_arg, test_space_arg, element_index, coords, test_element_index.node_index_in_element
623
+ domain_arg,
624
+ test_space_arg,
625
+ element_index,
626
+ coords,
627
+ test_element_index.node_index_in_element,
628
+ qp_index,
593
629
  )
594
630
  for val_dof in range(TEST_NODE_DOF_DIM):
595
631
  test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
@@ -669,10 +705,10 @@ def make_bilinear_dispatch_kernel(
669
705
  domain_arg, trial_topology_arg, element_index
670
706
  )
671
707
 
672
- qp_point_count = wp.select(
708
+ qp_point_count = wp.where(
673
709
  trial_node < element_trial_node_count,
674
- 0,
675
710
  quadrature.point_count(domain_arg, qp_arg, test_element_index.domain_element_index, element_index),
711
+ 0,
676
712
  )
677
713
 
678
714
  val_sum = accumulate_dtype(0.0)
@@ -681,51 +717,54 @@ def make_bilinear_dispatch_kernel(
681
717
  qp_index = quadrature.point_index(
682
718
  domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
683
719
  )
720
+ qp_eval_index = quadrature.point_evaluation_index(
721
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
722
+ )
684
723
  coords = quadrature.point_coords(
685
724
  domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
686
725
  )
687
726
 
688
- qp_result = local_result[qp_index]
727
+ qp_result = local_result[qp_eval_index]
689
728
  trial_result = float(0.0)
690
729
 
691
730
  if wp.static(0 != TEST_INNER_COUNT):
692
731
  w_test_inner = test.space.element_inner_weight(
693
- domain_arg, test_space_arg, element_index, coords, test_node
732
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
694
733
  )
695
734
 
696
735
  if wp.static(0 != TEST_OUTER_COUNT):
697
736
  w_test_outer = test.space.element_outer_weight(
698
- domain_arg, test_space_arg, element_index, coords, test_node
737
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
699
738
  )
700
739
 
701
740
  if wp.static(0 != TEST_INNER_GRAD_COUNT):
702
741
  w_test_grad_inner = test.space.element_inner_weight_gradient(
703
- domain_arg, test_space_arg, element_index, coords, test_node
742
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
704
743
  )
705
744
 
706
745
  if wp.static(0 != TEST_OUTER_GRAD_COUNT):
707
746
  w_test_grad_outer = test.space.element_outer_weight_gradient(
708
- domain_arg, test_space_arg, element_index, coords, test_node
747
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
709
748
  )
710
749
 
711
750
  if wp.static(0 != TRIAL_INNER_COUNT):
712
751
  w_trial_inner = trial.space.element_inner_weight(
713
- domain_arg, trial_space_arg, element_index, coords, trial_node
752
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
714
753
  )
715
754
 
716
755
  if wp.static(0 != TRIAL_OUTER_COUNT):
717
756
  w_trial_outer = trial.space.element_outer_weight(
718
- domain_arg, trial_space_arg, element_index, coords, trial_node
757
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
719
758
  )
720
759
 
721
760
  if wp.static(0 != TRIAL_INNER_GRAD_COUNT):
722
761
  w_trial_grad_inner = trial.space.element_inner_weight_gradient(
723
- domain_arg, trial_space_arg, element_index, coords, trial_node
762
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
724
763
  )
725
764
 
726
765
  if wp.static(0 != TRIAL_OUTER_GRAD_COUNT):
727
766
  w_trial_grad_outer = trial.space.element_outer_weight_gradient(
728
- domain_arg, trial_space_arg, element_index, coords, trial_node
767
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
729
768
  )
730
769
 
731
770
  for trial_val_dof in range(TRIAL_NODE_DOF_DIM):
@@ -30,7 +30,6 @@ from .nanogrid import (
30
30
  _extract_axis_flag,
31
31
  _get_boundary_mask,
32
32
  _make_face_flags,
33
- _mat32,
34
33
  )
35
34
 
36
35
  _FACE_LEVEL_BIT = wp.constant(wp.uint8(4)) # follows nanogrid.FACE_OUTER_OFFSET_BIT
@@ -214,9 +213,9 @@ class AdaptiveNanogrid(Geometry):
214
213
  coords = uvw - wp.vec3(ijk)
215
214
 
216
215
  if wp.min(coords) == 0.0 or wp.max(coords) == 1.0:
217
- il = wp.select(coords[0] > 0.5, -1, 0)
218
- jl = wp.select(coords[1] > 0.5, -1, 0)
219
- kl = wp.select(coords[2] > 0.5, -1, 0)
216
+ il = wp.where(coords[0] > 0.5, 0, -1)
217
+ jl = wp.where(coords[1] > 0.5, 0, -1)
218
+ kl = wp.where(coords[2] > 0.5, 0, -1)
220
219
 
221
220
  for n in range(8):
222
221
  ni = n >> 2
@@ -331,7 +330,7 @@ class AdaptiveNanogrid(Geometry):
331
330
  flip = Nanogrid._get_face_inner_offset(flags)
332
331
  scale = AdaptiveNanogrid._get_face_scale(flags)
333
332
  v1, v2 = Nanogrid._face_tangent_vecs(args.cell_arg.cell_grid, axis, flip)
334
- return _mat32(v1, v2) * scale
333
+ return wp.matrix_from_cols(v1, v2) * scale
335
334
 
336
335
  @wp.func
337
336
  def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
@@ -447,10 +446,10 @@ class AdaptiveNanogrid(Geometry):
447
446
  and wp.max(same_level_cell_coords) <= 1.0
448
447
  )
449
448
 
450
- return wp.select(
449
+ return wp.where(
451
450
  on_side,
452
- Coords(OUTSIDE),
453
451
  Coords(same_level_cell_coords[(axis + 1) % 3], same_level_cell_coords[(axis + 2) % 3], 0.0),
452
+ Coords(OUTSIDE),
454
453
  )
455
454
 
456
455
  def _build_face_grid(self, temporary_store: Optional[cache.TemporaryStore] = None):
@@ -527,7 +526,7 @@ class AdaptiveNanogrid(Geometry):
527
526
  for ax in range(3):
528
527
  coord = ijk[ax]
529
528
  level_flag = ((level >> ax) & 1) << _GRID_LEVEL_BIT
530
- ijk[ax] = wp.select(coord < 0, coord | level_flag, coord & ~level_flag)
529
+ ijk[ax] = wp.where(coord < 0, coord & ~level_flag, coord | level_flag)
531
530
 
532
531
  return _add_axis_flag(ijk, axis)
533
532
 
@@ -845,8 +844,8 @@ def _build_face_indices_and_flags(
845
844
 
846
845
  plus_cell_index, minus_cell_index = _find_face_neighbours(cell_grid, ijk, axis, level_count, cell_level)
847
846
 
848
- inner_cell = wp.select(minus_cell_index == -1, minus_cell_index, plus_cell_index)
849
- outer_cell = wp.select(plus_cell_index == -1, plus_cell_index, minus_cell_index)
847
+ inner_cell = wp.where(minus_cell_index == -1, plus_cell_index, minus_cell_index)
848
+ outer_cell = wp.where(plus_cell_index == -1, minus_cell_index, plus_cell_index)
850
849
 
851
850
  face_level = wp.min(cell_level[inner_cell], cell_level[outer_cell])
852
851
 
@@ -59,7 +59,7 @@ def project_on_tri_at_origin(q: Any, e1: Any, e2: Any):
59
59
 
60
60
  @wp.func
61
61
  def project_on_tet_at_origin(q: wp.vec3, e1: wp.vec3, e2: wp.vec3, e3: wp.vec3):
62
- mat = wp.inverse(wp.mat33(e1, e2, e3))
62
+ mat = wp.inverse(wp.matrix_from_cols(e1, e2, e3))
63
63
  coords = mat * q
64
64
 
65
65
  if wp.min(coords) >= 0.0 and coords[0] + coords[1] + coords[2] <= 1.0:
@@ -46,7 +46,6 @@ class DeformedGeometry(Geometry):
46
46
  self._relative = relative
47
47
 
48
48
  self.field: GeometryField = field
49
- self.base = self.field.geometry
50
49
  self.dimension = self.base.dimension
51
50
 
52
51
  self.CellArg = self.field.ElementEvalArg
@@ -81,9 +80,13 @@ class DeformedGeometry(Geometry):
81
80
  self._make_default_dependent_implementations()
82
81
 
83
82
  @property
84
- def name(self):
83
+ def name(self) -> str:
85
84
  return f"DefGeo_{self.field.name}_{'rel' if self._relative else 'abs'}"
86
85
 
86
+ @property
87
+ def base(self) -> Geometry:
88
+ return self.field.geometry.base
89
+
87
90
  # Geometry device interface
88
91
 
89
92
  @cache.cached_arg_value
@@ -58,6 +58,11 @@ class Geometry:
58
58
  """Manifold dimension of the geometry cells"""
59
59
  return self.reference_cell().dimension
60
60
 
61
+ @property
62
+ def base(self) -> "Geometry":
63
+ """Returns the base geometry from which this geometry derives its topology. Usually `self`"""
64
+ return self
65
+
61
66
  @property
62
67
  def name(self) -> str:
63
68
  return self.__class__.__name__
@@ -173,7 +173,7 @@ class Grid2D(Geometry):
173
173
  return Grid2D.Side(axis, origin)
174
174
 
175
175
  axis_side_index = side_index - 2 * arg.cell_count
176
- axis = wp.select(axis_side_index < arg.axis_offsets[1], 1, 0)
176
+ axis = wp.where(axis_side_index < arg.axis_offsets[1], 0, 1)
177
177
 
178
178
  altitude = arg.cell_arg.res[Grid2D.ROTATION[axis, 0]]
179
179
  longitude = axis_side_index - arg.axis_offsets[axis]
@@ -273,7 +273,7 @@ class Grid2D(Geometry):
273
273
  def side_position(args: SideArg, s: Sample):
274
274
  side = Grid2D.get_side(args, s.element_index)
275
275
 
276
- coord = wp.select((side.origin[0] == 0) == (side.axis == 0), 1.0 - s.element_coords[0], s.element_coords[0])
276
+ coord = wp.where((side.origin[0] == 0) == (side.axis == 0), s.element_coords[0], 1.0 - s.element_coords[0])
277
277
 
278
278
  local_pos = wp.vec2(
279
279
  float(side.origin[0]),
@@ -288,7 +288,7 @@ class Grid2D(Geometry):
288
288
  def side_deformation_gradient(args: SideArg, s: Sample):
289
289
  side = Grid2D.get_side(args, s.element_index)
290
290
 
291
- sign = wp.select((side.origin[0] == 0) == (side.axis == 0), -1.0, 1.0)
291
+ sign = wp.where((side.origin[0] == 0) == (side.axis == 0), 1.0, -1.0)
292
292
 
293
293
  return wp.cw_mul(Grid2D._rotate(side.axis, wp.vec2(0.0, sign)), args.cell_arg.cell_size)
294
294
 
@@ -316,7 +316,7 @@ class Grid2D(Geometry):
316
316
  def side_normal(args: SideArg, s: Sample):
317
317
  side = Grid2D.get_side(args, s.element_index)
318
318
 
319
- sign = wp.select(side.origin[0] == 0, 1.0, -1.0)
319
+ sign = wp.where(side.origin[0] == 0, -1.0, 1.0)
320
320
 
321
321
  local_n = wp.vec2(sign, 0.0)
322
322
  return Grid2D._rotate(side.axis, local_n)
@@ -325,7 +325,7 @@ class Grid2D(Geometry):
325
325
  def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
326
326
  side = Grid2D.get_side(arg, side_index)
327
327
 
328
- inner_alt = wp.select(side.origin[0] == 0, side.origin[0] - 1, 0)
328
+ inner_alt = wp.where(side.origin[0] == 0, 0, side.origin[0] - 1)
329
329
 
330
330
  inner_origin = wp.vec2i(inner_alt, side.origin[1])
331
331
 
@@ -337,8 +337,8 @@ class Grid2D(Geometry):
337
337
  side = Grid2D.get_side(arg, side_index)
338
338
 
339
339
  alt_axis = Grid2D.ROTATION[side.axis, 0]
340
- outer_alt = wp.select(
341
- side.origin[0] == arg.cell_arg.res[alt_axis], side.origin[0], arg.cell_arg.res[alt_axis] - 1
340
+ outer_alt = wp.where(
341
+ side.origin[0] == arg.cell_arg.res[alt_axis], arg.cell_arg.res[alt_axis] - 1, side.origin[0]
342
342
  )
343
343
 
344
344
  outer_origin = wp.vec2i(outer_alt, side.origin[1])
@@ -350,9 +350,9 @@ class Grid2D(Geometry):
350
350
  def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
351
351
  side = Grid2D.get_side(args, side_index)
352
352
 
353
- inner_alt = wp.select(side.origin[0] == 0, 1.0, 0.0)
353
+ inner_alt = wp.where(side.origin[0] == 0, 0.0, 1.0)
354
354
 
355
- side_coord = wp.select((side.origin[0] == 0) == (side.axis == 0), 1.0 - side_coords[0], side_coords[0])
355
+ side_coord = wp.where((side.origin[0] == 0) == (side.axis == 0), side_coords[0], 1.0 - side_coords[0])
356
356
 
357
357
  coords = Grid2D._rotate(side.axis, wp.vec2(inner_alt, side_coord))
358
358
  return Coords(coords[0], coords[1], 0.0)
@@ -362,9 +362,9 @@ class Grid2D(Geometry):
362
362
  side = Grid2D.get_side(args, side_index)
363
363
 
364
364
  alt_axis = Grid2D.ROTATION[side.axis, 0]
365
- outer_alt = wp.select(side.origin[0] == args.cell_arg.res[alt_axis], 0.0, 1.0)
365
+ outer_alt = wp.where(side.origin[0] == args.cell_arg.res[alt_axis], 1.0, 0.0)
366
366
 
367
- side_coord = wp.select((side.origin[0] == 0) == (side.axis == 0), 1.0 - side_coords[0], side_coords[0])
367
+ side_coord = wp.where((side.origin[0] == 0) == (side.axis == 0), side_coords[0], 1.0 - side_coords[0])
368
368
 
369
369
  coords = Grid2D._rotate(side.axis, wp.vec2(outer_alt, side_coord))
370
370
  return Coords(coords[0], coords[1], 0.0)
@@ -382,7 +382,7 @@ class Grid2D(Geometry):
382
382
  if float(side.origin[0] - cell[side.axis]) == element_coords[side.axis]:
383
383
  long_axis = Grid2D.ROTATION[side.axis, 1]
384
384
  axis_coord = element_coords[long_axis]
385
- side_coord = wp.select((side.origin[0] == 0) == (side.axis == 0), 1.0 - axis_coord, axis_coord)
385
+ side_coord = wp.where((side.origin[0] == 0) == (side.axis == 0), axis_coord, 1.0 - axis_coord)
386
386
  return Coords(side_coord, 0.0, 0.0)
387
387
 
388
388
  return Coords(OUTSIDE)