warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
@@ -83,22 +83,15 @@ class NodalFieldBase(DiscreteField):
83
83
  @cache.dynamic_func(suffix=self.name)
84
84
  def eval_inner(args: self.ElementEvalArg, s: Sample):
85
85
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
86
- res = self.space.space_value(
87
- self._read_node_value(args, s.element_index, 0),
88
- self.space.element_inner_weight(
89
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
90
- ),
91
- local_value_map,
92
- )
93
-
94
86
  node_count = self.space.topology.element_node_count(
95
87
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
96
88
  )
97
- for k in range(1, node_count):
89
+ res = self.space.dtype(0.0)
90
+ for k in range(node_count):
98
91
  res += self.space.space_value(
99
92
  self._read_node_value(args, s.element_index, k),
100
93
  self.space.element_inner_weight(
101
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
94
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
102
95
  ),
103
96
  local_value_map,
104
97
  )
@@ -110,44 +103,43 @@ class NodalFieldBase(DiscreteField):
110
103
  if not self.space.gradient_valid():
111
104
  return None
112
105
 
113
- @cache.dynamic_func(suffix=self.name)
106
+ gradient_dtype = self.gradient_dtype if world_space else self.reference_gradient_dtype
107
+
108
+ @cache.dynamic_func(suffix=f"{self.name}{world_space}")
114
109
  def eval_grad_inner(args: self.ElementEvalArg, s: Sample, grad_transform: Any):
115
110
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
116
-
117
- res = self.space.space_gradient(
118
- self._read_node_value(args, s.element_index, 0),
119
- self.space.element_inner_weight_gradient(
120
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
121
- ),
122
- local_value_map,
123
- grad_transform,
124
- )
125
-
126
111
  node_count = self.space.topology.element_node_count(
127
112
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
128
113
  )
129
- for k in range(1, node_count):
114
+
115
+ res = gradient_dtype(0.0)
116
+ for k in range(node_count):
130
117
  res += self.space.space_gradient(
131
118
  self._read_node_value(args, s.element_index, k),
132
119
  self.space.element_inner_weight_gradient(
133
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
120
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
134
121
  ),
135
122
  local_value_map,
136
123
  grad_transform,
137
124
  )
138
125
  return res
139
126
 
140
- @cache.dynamic_func(suffix=self.name)
141
- def eval_grad_inner_ref_space(args: self.ElementEvalArg, s: Sample):
142
- grad_transform = 1.0
143
- return eval_grad_inner(args, s, grad_transform)
127
+ if world_space:
144
128
 
145
- @cache.dynamic_func(suffix=self.name)
146
- def eval_grad_inner_world_space(args: self.ElementEvalArg, s: Sample):
147
- grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
148
- return eval_grad_inner(args, s, grad_transform)
129
+ @cache.dynamic_func(suffix=self.name)
130
+ def eval_grad_inner_world_space(args: self.ElementEvalArg, s: Sample):
131
+ grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
132
+ return eval_grad_inner(args, s, grad_transform)
133
+
134
+ return eval_grad_inner_world_space
135
+ else:
136
+
137
+ @cache.dynamic_func(suffix=self.name)
138
+ def eval_grad_inner_ref_space(args: self.ElementEvalArg, s: Sample):
139
+ grad_transform = 1.0
140
+ return eval_grad_inner(args, s, grad_transform)
149
141
 
150
- return eval_grad_inner_world_space if world_space else eval_grad_inner_ref_space
142
+ return eval_grad_inner_ref_space
151
143
 
152
144
  def _make_eval_div_inner(self):
153
145
  if not self.divergence_valid():
@@ -157,24 +149,16 @@ class NodalFieldBase(DiscreteField):
157
149
  def eval_div_inner(args: self.ElementEvalArg, s: Sample):
158
150
  grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
159
151
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
160
-
161
- res = self.space.space_divergence(
162
- self._read_node_value(args, s.element_index, 0),
163
- self.space.element_inner_weight_gradient(
164
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
165
- ),
166
- local_value_map,
167
- grad_transform,
168
- )
169
-
170
152
  node_count = self.space.topology.element_node_count(
171
153
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
172
154
  )
173
- for k in range(1, node_count):
155
+
156
+ res = self.divergence_dtype(0.0)
157
+ for k in range(node_count):
174
158
  res += self.space.space_divergence(
175
159
  self._read_node_value(args, s.element_index, k),
176
160
  self.space.element_inner_weight_gradient(
177
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
161
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
178
162
  ),
179
163
  local_value_map,
180
164
  grad_transform,
@@ -187,23 +171,16 @@ class NodalFieldBase(DiscreteField):
187
171
  @cache.dynamic_func(suffix=self.name)
188
172
  def eval_outer(args: self.ElementEvalArg, s: Sample):
189
173
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
190
- res = self.space.space_value(
191
- self._read_node_value(args, s.element_index, 0),
192
- self.space.element_outer_weight(
193
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
194
- ),
195
- local_value_map,
196
- )
197
-
198
174
  node_count = self.space.topology.element_node_count(
199
175
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
200
176
  )
201
177
 
202
- for k in range(1, node_count):
178
+ res = self.dtype(0.0)
179
+ for k in range(node_count):
203
180
  res += self.space.space_value(
204
181
  self._read_node_value(args, s.element_index, k),
205
182
  self.space.element_outer_weight(
206
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
183
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
207
184
  ),
208
185
  local_value_map,
209
186
  )
@@ -215,44 +192,43 @@ class NodalFieldBase(DiscreteField):
215
192
  if not self.space.gradient_valid():
216
193
  return None
217
194
 
218
- @cache.dynamic_func(suffix=self.name)
195
+ gradient_dtype = self.gradient_dtype if world_space else self.reference_gradient_dtype
196
+
197
+ @cache.dynamic_func(suffix=f"{self.name}{world_space}")
219
198
  def eval_grad_outer(args: self.ElementEvalArg, s: Sample, grad_transform: Any):
220
199
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
221
-
222
- res = self.space.space_gradient(
223
- self._read_node_value(args, s.element_index, 0),
224
- self.space.element_outer_weight_gradient(
225
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
226
- ),
227
- local_value_map,
228
- grad_transform,
229
- )
230
-
231
200
  node_count = self.space.topology.element_node_count(
232
201
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
233
202
  )
234
- for k in range(1, node_count):
203
+
204
+ res = gradient_dtype(0.0)
205
+ for k in range(node_count):
235
206
  res += self.space.space_gradient(
236
207
  self._read_node_value(args, s.element_index, k),
237
208
  self.space.element_outer_weight_gradient(
238
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
209
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
239
210
  ),
240
211
  local_value_map,
241
212
  grad_transform,
242
213
  )
243
214
  return res
244
215
 
245
- @cache.dynamic_func(suffix=self.name)
246
- def eval_grad_outer_ref_space(args: self.ElementEvalArg, s: Sample):
247
- grad_transform = 1.0
248
- return eval_grad_outer_ref_space(args, s, grad_transform)
216
+ if world_space:
249
217
 
250
- @cache.dynamic_func(suffix=self.name)
251
- def eval_grad_outer_world_space(args: self.ElementEvalArg, s: Sample):
252
- grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
253
- return eval_grad_outer_ref_space(args, s, grad_transform)
218
+ @cache.dynamic_func(suffix=self.name)
219
+ def eval_grad_outer_world_space(args: self.ElementEvalArg, s: Sample):
220
+ grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
221
+ return eval_grad_outer_ref_space(args, s, grad_transform)
222
+
223
+ return eval_grad_outer_world_space
224
+ else:
225
+
226
+ @cache.dynamic_func(suffix=self.name)
227
+ def eval_grad_outer_ref_space(args: self.ElementEvalArg, s: Sample):
228
+ grad_transform = 1.0
229
+ return eval_grad_outer_ref_space(args, s, grad_transform)
254
230
 
255
- return eval_grad_outer_world_space if world_space else eval_grad_outer_ref_space
231
+ return eval_grad_outer_ref_space
256
232
 
257
233
  def _make_eval_div_outer(self):
258
234
  if not self.divergence_valid():
@@ -262,24 +238,16 @@ class NodalFieldBase(DiscreteField):
262
238
  def eval_div_outer(args: self.ElementEvalArg, s: Sample):
263
239
  grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
264
240
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
265
-
266
- res = self.space.space_divergence(
267
- self._read_node_value(args, s.element_index, 0),
268
- self.space.element_outer_weight_gradient(
269
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
270
- ),
271
- local_value_map,
272
- grad_transform,
273
- )
274
-
275
241
  node_count = self.space.topology.element_node_count(
276
242
  args.elt_arg, args.eval_arg.topology_arg, s.element_index
277
243
  )
278
- for k in range(1, node_count):
244
+
245
+ res = self.divergence_dtype(0.0)
246
+ for k in range(node_count):
279
247
  res += self.space.space_divergence(
280
248
  self._read_node_value(args, s.element_index, k),
281
249
  self.space.element_outer_weight_gradient(
282
- args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
250
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k, s.qp_index
283
251
  ),
284
252
  local_value_map,
285
253
  grad_transform,
warp/fem/field/virtual.py CHANGED
@@ -69,7 +69,12 @@ class AdjointField(SpaceField):
69
69
  def eval_test_inner(args: self.ElementEvalArg, s: Sample):
70
70
  dof = self._get_dof(s)
71
71
  node_weight = self.space.element_inner_weight(
72
- args.elt_arg, args.eval_arg, s.element_index, s.element_coords, get_node_index_in_element(dof)
72
+ args.elt_arg,
73
+ args.eval_arg,
74
+ s.element_index,
75
+ s.element_coords,
76
+ get_node_index_in_element(dof),
77
+ s.qp_index,
73
78
  )
74
79
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
75
80
  dof_value = self.space.node_basis_element(get_node_coord(dof))
@@ -90,6 +95,7 @@ class AdjointField(SpaceField):
90
95
  s.element_index,
91
96
  s.element_coords,
92
97
  get_node_index_in_element(dof),
98
+ s.qp_index,
93
99
  )
94
100
  grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
95
101
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
@@ -111,6 +117,7 @@ class AdjointField(SpaceField):
111
117
  s.element_index,
112
118
  s.element_coords,
113
119
  get_node_index_in_element(dof),
120
+ s.qp_index,
114
121
  )
115
122
  grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
116
123
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
@@ -124,7 +131,12 @@ class AdjointField(SpaceField):
124
131
  def eval_test_outer(args: self.ElementEvalArg, s: Sample):
125
132
  dof = self._get_dof(s)
126
133
  node_weight = self.space.element_outer_weight(
127
- args.elt_arg, args.eval_arg, s.element_index, s.element_coords, get_node_index_in_element(dof)
134
+ args.elt_arg,
135
+ args.eval_arg,
136
+ s.element_index,
137
+ s.element_coords,
138
+ get_node_index_in_element(dof),
139
+ s.qp_index,
128
140
  )
129
141
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
130
142
  dof_value = self.space.node_basis_element(get_node_coord(dof))
@@ -145,6 +157,7 @@ class AdjointField(SpaceField):
145
157
  s.element_index,
146
158
  s.element_coords,
147
159
  get_node_index_in_element(dof),
160
+ s.qp_index,
148
161
  )
149
162
  grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
150
163
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
@@ -166,6 +179,7 @@ class AdjointField(SpaceField):
166
179
  s.element_index,
167
180
  s.element_coords,
168
181
  get_node_index_in_element(dof),
182
+ s.qp_index,
169
183
  )
170
184
  grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
171
185
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
@@ -370,9 +384,8 @@ class LocalAdjointField(SpaceField):
370
384
 
371
385
  @cache.dynamic_func(suffix=str(TAYLOR_DOF_COUNT))
372
386
  def split_dof(dof_index: DofIndex, dof_begin: int):
373
- dof = get_node_coord(dof_index)
374
- value_dof = dof // TAYLOR_DOF_COUNT
375
- taylor_dof = dof - value_dof * TAYLOR_DOF_COUNT - dof_begin
387
+ taylor_dof = get_node_index_in_element(dof_index) - dof_begin
388
+ value_dof = get_node_coord(dof_index)
376
389
  return value_dof, taylor_dof
377
390
 
378
391
  return split_dof
@@ -386,7 +399,7 @@ class LocalAdjointField(SpaceField):
386
399
 
387
400
  local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
388
401
  dof_value = self.space.value_basis_element(value_dof, local_value_map)
389
- return wp.select(taylor_dof == 0, self.dtype(0.0), dof_value)
402
+ return wp.where(taylor_dof == 0, dof_value, self.dtype(0.0))
390
403
 
391
404
  return eval_test_inner
392
405
 
@@ -441,7 +454,7 @@ class LocalAdjointField(SpaceField):
441
454
 
442
455
  local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
443
456
  dof_value = self.space.value_basis_element(value_dof, local_value_map)
444
- return wp.select(taylor_dof == 0, self.dtype(0.0), dof_value)
457
+ return wp.where(taylor_dof == 0, dof_value, self.dtype(0.0))
445
458
 
446
459
  return eval_test_outer
447
460
 
@@ -551,17 +564,25 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
551
564
  qp_index = quadrature.point_index(
552
565
  domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
553
566
  )
567
+ qp_eval_index = quadrature.point_evaluation_index(
568
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
569
+ )
554
570
  coords = quadrature.point_coords(
555
571
  domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
556
572
  )
557
573
 
558
- qp_result = local_result[qp_index]
574
+ qp_result = local_result[qp_eval_index]
559
575
 
560
576
  qp_sum = float(0.0)
561
577
 
562
578
  if wp.static(0 != TEST_INNER_COUNT):
563
579
  w = test.space.element_inner_weight(
564
- domain_arg, test_space_arg, element_index, coords, test_element_index.node_index_in_element
580
+ domain_arg,
581
+ test_space_arg,
582
+ element_index,
583
+ coords,
584
+ test_element_index.node_index_in_element,
585
+ qp_index,
565
586
  )
566
587
  for val_dof in range(TEST_NODE_DOF_DIM):
567
588
  test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
@@ -569,7 +590,12 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
569
590
 
570
591
  if wp.static(0 != TEST_OUTER_COUNT):
571
592
  w = test.space.element_outer_weight(
572
- domain_arg, test_space_arg, element_index, coords, test_element_index.node_index_in_element
593
+ domain_arg,
594
+ test_space_arg,
595
+ element_index,
596
+ coords,
597
+ test_element_index.node_index_in_element,
598
+ qp_index,
573
599
  )
574
600
  for val_dof in range(TEST_NODE_DOF_DIM):
575
601
  test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
@@ -577,7 +603,12 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
577
603
 
578
604
  if wp.static(0 != TEST_INNER_GRAD_COUNT):
579
605
  w_grad = test.space.element_inner_weight_gradient(
580
- domain_arg, test_space_arg, element_index, coords, test_element_index.node_index_in_element
606
+ domain_arg,
607
+ test_space_arg,
608
+ element_index,
609
+ coords,
610
+ test_element_index.node_index_in_element,
611
+ qp_index,
581
612
  )
582
613
  for val_dof in range(TEST_NODE_DOF_DIM):
583
614
  test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
@@ -589,7 +620,12 @@ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, ac
589
620
 
590
621
  if wp.static(0 != TEST_OUTER_GRAD_COUNT):
591
622
  w_grad = test.space.element_outer_weight_gradient(
592
- domain_arg, test_space_arg, element_index, coords, test_element_index.node_index_in_element
623
+ domain_arg,
624
+ test_space_arg,
625
+ element_index,
626
+ coords,
627
+ test_element_index.node_index_in_element,
628
+ qp_index,
593
629
  )
594
630
  for val_dof in range(TEST_NODE_DOF_DIM):
595
631
  test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
@@ -669,10 +705,10 @@ def make_bilinear_dispatch_kernel(
669
705
  domain_arg, trial_topology_arg, element_index
670
706
  )
671
707
 
672
- qp_point_count = wp.select(
708
+ qp_point_count = wp.where(
673
709
  trial_node < element_trial_node_count,
674
- 0,
675
710
  quadrature.point_count(domain_arg, qp_arg, test_element_index.domain_element_index, element_index),
711
+ 0,
676
712
  )
677
713
 
678
714
  val_sum = accumulate_dtype(0.0)
@@ -681,51 +717,54 @@ def make_bilinear_dispatch_kernel(
681
717
  qp_index = quadrature.point_index(
682
718
  domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
683
719
  )
720
+ qp_eval_index = quadrature.point_evaluation_index(
721
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
722
+ )
684
723
  coords = quadrature.point_coords(
685
724
  domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
686
725
  )
687
726
 
688
- qp_result = local_result[qp_index]
727
+ qp_result = local_result[qp_eval_index]
689
728
  trial_result = float(0.0)
690
729
 
691
730
  if wp.static(0 != TEST_INNER_COUNT):
692
731
  w_test_inner = test.space.element_inner_weight(
693
- domain_arg, test_space_arg, element_index, coords, test_node
732
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
694
733
  )
695
734
 
696
735
  if wp.static(0 != TEST_OUTER_COUNT):
697
736
  w_test_outer = test.space.element_outer_weight(
698
- domain_arg, test_space_arg, element_index, coords, test_node
737
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
699
738
  )
700
739
 
701
740
  if wp.static(0 != TEST_INNER_GRAD_COUNT):
702
741
  w_test_grad_inner = test.space.element_inner_weight_gradient(
703
- domain_arg, test_space_arg, element_index, coords, test_node
742
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
704
743
  )
705
744
 
706
745
  if wp.static(0 != TEST_OUTER_GRAD_COUNT):
707
746
  w_test_grad_outer = test.space.element_outer_weight_gradient(
708
- domain_arg, test_space_arg, element_index, coords, test_node
747
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
709
748
  )
710
749
 
711
750
  if wp.static(0 != TRIAL_INNER_COUNT):
712
751
  w_trial_inner = trial.space.element_inner_weight(
713
- domain_arg, trial_space_arg, element_index, coords, trial_node
752
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
714
753
  )
715
754
 
716
755
  if wp.static(0 != TRIAL_OUTER_COUNT):
717
756
  w_trial_outer = trial.space.element_outer_weight(
718
- domain_arg, trial_space_arg, element_index, coords, trial_node
757
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
719
758
  )
720
759
 
721
760
  if wp.static(0 != TRIAL_INNER_GRAD_COUNT):
722
761
  w_trial_grad_inner = trial.space.element_inner_weight_gradient(
723
- domain_arg, trial_space_arg, element_index, coords, trial_node
762
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
724
763
  )
725
764
 
726
765
  if wp.static(0 != TRIAL_OUTER_GRAD_COUNT):
727
766
  w_trial_grad_outer = trial.space.element_outer_weight_gradient(
728
- domain_arg, trial_space_arg, element_index, coords, trial_node
767
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
729
768
  )
730
769
 
731
770
  for trial_val_dof in range(TRIAL_NODE_DOF_DIM):
@@ -30,7 +30,6 @@ from .nanogrid import (
30
30
  _extract_axis_flag,
31
31
  _get_boundary_mask,
32
32
  _make_face_flags,
33
- _mat32,
34
33
  )
35
34
 
36
35
  _FACE_LEVEL_BIT = wp.constant(wp.uint8(4)) # follows nanogrid.FACE_OUTER_OFFSET_BIT
@@ -185,9 +184,13 @@ class AdaptiveNanogrid(Geometry):
185
184
 
186
185
  @wp.func
187
186
  def cell_position(args: CellArg, s: Sample):
188
- scale = AdaptiveNanogrid._cell_scale(args, s.element_index)
189
- uvw = wp.vec3(args.cell_ijk[s.element_index]) + s.element_coords * scale
190
- return wp.volume_index_to_world(args.cell_grid, uvw - wp.vec3(0.5))
187
+ cell_idx = s.element_index
188
+ scale = AdaptiveNanogrid._cell_scale(args, cell_idx)
189
+ cell_coords = s.element_coords
190
+ cell_ijk = args.cell_ijk[cell_idx]
191
+ uvw = wp.vec3(cell_ijk) + cell_coords * scale
192
+ grid_id = args.cell_grid
193
+ return wp.volume_index_to_world(grid_id, uvw - wp.vec3(0.5))
191
194
 
192
195
  @wp.func
193
196
  def cell_deformation_gradient(args: CellArg, s: Sample):
@@ -214,9 +217,9 @@ class AdaptiveNanogrid(Geometry):
214
217
  coords = uvw - wp.vec3(ijk)
215
218
 
216
219
  if wp.min(coords) == 0.0 or wp.max(coords) == 1.0:
217
- il = wp.select(coords[0] > 0.5, -1, 0)
218
- jl = wp.select(coords[1] > 0.5, -1, 0)
219
- kl = wp.select(coords[2] > 0.5, -1, 0)
220
+ il = wp.where(coords[0] > 0.5, 0, -1)
221
+ jl = wp.where(coords[1] > 0.5, 0, -1)
222
+ kl = wp.where(coords[2] > 0.5, 0, -1)
220
223
 
221
224
  for n in range(8):
222
225
  ni = n >> 2
@@ -331,7 +334,7 @@ class AdaptiveNanogrid(Geometry):
331
334
  flip = Nanogrid._get_face_inner_offset(flags)
332
335
  scale = AdaptiveNanogrid._get_face_scale(flags)
333
336
  v1, v2 = Nanogrid._face_tangent_vecs(args.cell_arg.cell_grid, axis, flip)
334
- return _mat32(v1, v2) * scale
337
+ return wp.matrix_from_cols(v1, v2) * scale
335
338
 
336
339
  @wp.func
337
340
  def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
@@ -447,10 +450,10 @@ class AdaptiveNanogrid(Geometry):
447
450
  and wp.max(same_level_cell_coords) <= 1.0
448
451
  )
449
452
 
450
- return wp.select(
453
+ return wp.where(
451
454
  on_side,
452
- Coords(OUTSIDE),
453
455
  Coords(same_level_cell_coords[(axis + 1) % 3], same_level_cell_coords[(axis + 2) % 3], 0.0),
456
+ Coords(OUTSIDE),
454
457
  )
455
458
 
456
459
  def _build_face_grid(self, temporary_store: Optional[cache.TemporaryStore] = None):
@@ -527,7 +530,7 @@ class AdaptiveNanogrid(Geometry):
527
530
  for ax in range(3):
528
531
  coord = ijk[ax]
529
532
  level_flag = ((level >> ax) & 1) << _GRID_LEVEL_BIT
530
- ijk[ax] = wp.select(coord < 0, coord | level_flag, coord & ~level_flag)
533
+ ijk[ax] = wp.where(coord < 0, coord & ~level_flag, coord | level_flag)
531
534
 
532
535
  return _add_axis_flag(ijk, axis)
533
536
 
@@ -845,8 +848,8 @@ def _build_face_indices_and_flags(
845
848
 
846
849
  plus_cell_index, minus_cell_index = _find_face_neighbours(cell_grid, ijk, axis, level_count, cell_level)
847
850
 
848
- inner_cell = wp.select(minus_cell_index == -1, minus_cell_index, plus_cell_index)
849
- outer_cell = wp.select(plus_cell_index == -1, plus_cell_index, minus_cell_index)
851
+ inner_cell = wp.where(minus_cell_index == -1, plus_cell_index, minus_cell_index)
852
+ outer_cell = wp.where(plus_cell_index == -1, minus_cell_index, plus_cell_index)
850
853
 
851
854
  face_level = wp.min(cell_level[inner_cell], cell_level[outer_cell])
852
855
 
@@ -59,7 +59,7 @@ def project_on_tri_at_origin(q: Any, e1: Any, e2: Any):
59
59
 
60
60
  @wp.func
61
61
  def project_on_tet_at_origin(q: wp.vec3, e1: wp.vec3, e2: wp.vec3, e3: wp.vec3):
62
- mat = wp.inverse(wp.mat33(e1, e2, e3))
62
+ mat = wp.inverse(wp.matrix_from_cols(e1, e2, e3))
63
63
  coords = mat * q
64
64
 
65
65
  if wp.min(coords) >= 0.0 and coords[0] + coords[1] + coords[2] <= 1.0:
@@ -46,7 +46,6 @@ class DeformedGeometry(Geometry):
46
46
  self._relative = relative
47
47
 
48
48
  self.field: GeometryField = field
49
- self.base = self.field.geometry
50
49
  self.dimension = self.base.dimension
51
50
 
52
51
  self.CellArg = self.field.ElementEvalArg
@@ -81,9 +80,13 @@ class DeformedGeometry(Geometry):
81
80
  self._make_default_dependent_implementations()
82
81
 
83
82
  @property
84
- def name(self):
83
+ def name(self) -> str:
85
84
  return f"DefGeo_{self.field.name}_{'rel' if self._relative else 'abs'}"
86
85
 
86
+ @property
87
+ def base(self) -> Geometry:
88
+ return self.field.geometry.base
89
+
87
90
  # Geometry device interface
88
91
 
89
92
  @cache.cached_arg_value
@@ -58,6 +58,11 @@ class Geometry:
58
58
  """Manifold dimension of the geometry cells"""
59
59
  return self.reference_cell().dimension
60
60
 
61
+ @property
62
+ def base(self) -> "Geometry":
63
+ """Returns the base geometry from which this geometry derives its topology. Usually `self`"""
64
+ return self
65
+
61
66
  @property
62
67
  def name(self) -> str:
63
68
  return self.__class__.__name__