warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -13,16 +13,38 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from enum import IntEnum
16
17
  from typing import List, Tuple
17
18
 
18
19
  import warp as wp
19
- from warp.fem.polynomial import Polynomial, quadrature_1d
20
- from warp.fem.types import Coords
20
+ from warp._src.fem.polynomial import Polynomial, quadrature_1d
21
+ from warp._src.fem.types import Coords
21
22
 
22
- _vec1 = wp.types.vector(length=1, dtype=float)
23
+ _vec1 = wp.vec(length=1, dtype=float)
23
24
 
24
25
 
25
- class Element:
26
+ class Element(IntEnum):
27
+ """Enumeration of reference element types"""
28
+
29
+ LINE_SEGMENT = 1
30
+ SQUARE = 2
31
+ CUBE = 3
32
+ TRIANGLE = 4
33
+ TETRAHEDRON = 5
34
+
35
+ @property
36
+ def prototype(self) -> "PrototypeElement":
37
+ """Prototype element for the given element type"""
38
+ return {
39
+ Element.LINE_SEGMENT: LinearEdge,
40
+ Element.SQUARE: Square,
41
+ Element.CUBE: Cube,
42
+ Element.TRIANGLE: Triangle,
43
+ Element.TETRAHEDRON: Tetrahedron,
44
+ }[self]
45
+
46
+
47
+ class PrototypeElement:
26
48
  dimension = 0
27
49
  """Intrinsic dimension of the element"""
28
50
 
@@ -76,7 +98,7 @@ def _point_count_from_order(order: int, family: Polynomial):
76
98
  return point_count
77
99
 
78
100
 
79
- class Cube(Element):
101
+ class Cube(PrototypeElement):
80
102
  dimension = 3
81
103
 
82
104
  @staticmethod
@@ -97,7 +119,7 @@ class Cube(Element):
97
119
  return coords, weights
98
120
 
99
121
 
100
- class Square(Element):
122
+ class Square(PrototypeElement):
101
123
  dimension = 2
102
124
 
103
125
  @staticmethod
@@ -118,7 +140,7 @@ class Square(Element):
118
140
  return coords, weights
119
141
 
120
142
 
121
- class LinearEdge(Element):
143
+ class LinearEdge(PrototypeElement):
122
144
  dimension = 1
123
145
 
124
146
  @staticmethod
@@ -137,7 +159,7 @@ class LinearEdge(Element):
137
159
  return coords, weights_1d
138
160
 
139
161
 
140
- class Triangle(Element):
162
+ class Triangle(PrototypeElement):
141
163
  dimension = 2
142
164
 
143
165
  @staticmethod
@@ -492,7 +514,7 @@ class Triangle(Element):
492
514
  return Coords(-ref_delta[0] - ref_delta[1], ref_delta[0], ref_delta[1])
493
515
 
494
516
 
495
- class Tetrahedron(Element):
517
+ class Tetrahedron(PrototypeElement):
496
518
  dimension = 3
497
519
 
498
520
  @staticmethod
@@ -827,4 +849,4 @@ class Tetrahedron(Element):
827
849
  c = c - Coords(n, 0.0, n)
828
850
 
829
851
  # project on cube
830
- return Element.project(c)
852
+ return PrototypeElement.project(c)
@@ -17,8 +17,9 @@ from functools import cached_property
17
17
  from typing import Any, ClassVar
18
18
 
19
19
  import warp as wp
20
- from warp.fem import cache
21
- from warp.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, ElementKind, Sample, make_free_sample
20
+ from warp._src.codegen import Struct
21
+ from warp._src.fem import cache
22
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, ElementKind, Sample, make_free_sample
22
23
 
23
24
  from .element import Element
24
25
 
@@ -66,7 +67,7 @@ class Geometry:
66
67
  @property
67
68
  def cell_dimension(self) -> int:
68
69
  """Manifold dimension of the geometry cells"""
69
- return self.reference_cell().dimension
70
+ return self.reference_cell().prototype.dimension
70
71
 
71
72
  @property
72
73
  def base(self) -> "Geometry":
@@ -80,22 +81,27 @@ class Geometry:
80
81
  def __str__(self) -> str:
81
82
  return self.name
82
83
 
83
- CellArg: wp.codegen.Struct
84
+ CellArg: Struct
84
85
  """Structure containing arguments to be passed to device functions evaluating cell-related quantities"""
85
86
 
86
- SideArg: wp.codegen.Struct
87
+ SideArg: Struct
87
88
  """Structure containing arguments to be passed to device functions evaluating side-related quantities"""
88
89
 
89
- SideIndexArg: wp.codegen.Struct
90
+ SideIndexArg: Struct
90
91
  """Structure containing arguments to be passed to device functions for indexing sides"""
91
92
 
93
+ @cache.cached_arg_value
92
94
  def cell_arg_value(self, device) -> "Geometry.CellArg":
93
95
  """Value of the arguments to be passed to cell-related device functions"""
94
- raise NotImplementedError
96
+ args = self.CellArg()
97
+ self.fill_cell_arg(args, device)
98
+ return args
95
99
 
96
100
  def fill_cell_arg(self, args: "Geometry.CellArg", device):
97
101
  """Fill the arguments to be passed to cell-related device functions"""
98
- raise NotImplementedError
102
+ if self.cell_arg_value is __class__.cell_arg_value:
103
+ raise NotImplementedError()
104
+ args.assign(self.cell_arg_value(device))
99
105
 
100
106
  @staticmethod
101
107
  def cell_position(args: "Geometry.CellArg", s: "Sample"):
@@ -130,13 +136,31 @@ class Geometry:
130
136
  For elements with the same dimension as the embedding space, this will be zero."""
131
137
  raise NotImplementedError
132
138
 
139
+ @cache.cached_arg_value
133
140
  def side_arg_value(self, device) -> "Geometry.SideArg":
134
141
  """Value of the arguments to be passed to side-related device functions"""
135
- raise NotImplementedError
142
+ args = self.SideArg()
143
+ self.fill_side_arg(args, device)
144
+ return args
136
145
 
137
146
  def fill_side_arg(self, args: "Geometry.SideArg", device):
138
147
  """Fill the arguments to be passed to side-related device functions"""
139
- raise NotImplementedError
148
+ if self.side_arg_value is __class__.side_arg_value:
149
+ raise NotImplementedError()
150
+ args.assign(self.side_arg_value(device))
151
+
152
+ @cache.cached_arg_value
153
+ def side_index_arg_value(self, device) -> "Geometry.SideIndexArg":
154
+ """Value of the arguments to be passed to side-related device functions"""
155
+ args = self.SideIndexArg()
156
+ self.fill_side_index_arg(args, device)
157
+ return args
158
+
159
+ def fill_side_index_arg(self, args: "Geometry.SideIndexArg", device):
160
+ """Fill the arguments to be passed to side-related device functions"""
161
+ if self.side_index_arg_value is __class__.side_index_arg_value:
162
+ raise NotImplementedError()
163
+ args.assign(self.side_index_arg_value(device))
140
164
 
141
165
  @staticmethod
142
166
  def boundary_side_index(args: "Geometry.SideIndexArg", boundary_side_index: int):
@@ -269,7 +293,7 @@ class Geometry:
269
293
  return wp.normalize(Fcross)
270
294
 
271
295
  def _make_cell_measure(self):
272
- REF_MEASURE = wp.constant(self.reference_cell().measure())
296
+ REF_MEASURE = wp.constant(self.reference_cell().prototype.measure())
273
297
 
274
298
  @cache.dynamic_func(suffix=self.name)
275
299
  def cell_measure(args: self.CellArg, s: Sample):
@@ -279,7 +303,7 @@ class Geometry:
279
303
  return cell_measure
280
304
 
281
305
  def _make_cell_normal(self):
282
- cell_dim = self.reference_cell().dimension
306
+ cell_dim = self.reference_cell().prototype.dimension
283
307
  geo_dim = self.dimension
284
308
  normal_vec = wp.vec(length=geo_dim, dtype=float)
285
309
 
@@ -300,7 +324,7 @@ class Geometry:
300
324
  return None
301
325
 
302
326
  def _make_cell_inverse_deformation_gradient(self):
303
- cell_dim = self.reference_cell().dimension
327
+ cell_dim = self.reference_cell().prototype.dimension
304
328
  geo_dim = self.dimension
305
329
 
306
330
  @cache.dynamic_func(suffix=self.name)
@@ -316,7 +340,7 @@ class Geometry:
316
340
  return cell_inverse_deformation_gradient if cell_dim == geo_dim else cell_pseudoinverse_deformation_gradient
317
341
 
318
342
  def _make_side_inverse_deformation_gradient(self):
319
- side_dim = self.reference_side().dimension
343
+ side_dim = self.reference_side().prototype.dimension
320
344
  geo_dim = self.dimension
321
345
 
322
346
  if side_dim == geo_dim:
@@ -345,7 +369,7 @@ class Geometry:
345
369
  return side_pseudoinverse_deformation_gradient
346
370
 
347
371
  def _make_side_measure(self):
348
- REF_MEASURE = wp.constant(self.reference_side().measure())
372
+ REF_MEASURE = wp.constant(self.reference_side().prototype.measure())
349
373
 
350
374
  @cache.dynamic_func(suffix=self.name)
351
375
  def side_measure(args: self.SideArg, s: Sample):
@@ -370,7 +394,7 @@ class Geometry:
370
394
  return side_measure_ratio
371
395
 
372
396
  def _make_side_normal(self):
373
- side_dim = self.reference_side().dimension
397
+ side_dim = self.reference_side().prototype.dimension
374
398
  geo_dim = self.dimension
375
399
 
376
400
  @cache.dynamic_func(suffix=self.name)
@@ -407,12 +431,12 @@ class Geometry:
407
431
  pos_type = cache.cached_vec_type(self.dimension, dtype=float)
408
432
 
409
433
  if element_kind == ElementKind.CELL:
410
- ref_elt = self.reference_cell()
434
+ ref_elt = self.reference_cell().prototype
411
435
  arg_type = self.CellArg
412
436
  elt_pos = self.cell_position
413
437
  elt_inv_grad = self.cell_inverse_deformation_gradient
414
438
  else:
415
- ref_elt = self.reference_side()
439
+ ref_elt = self.reference_side().prototype
416
440
  arg_type = self.SideArg
417
441
  elt_pos = self.side_position
418
442
  elt_inv_grad = self.side_inverse_deformation_gradient
@@ -452,12 +476,12 @@ class Geometry:
452
476
  element_coordinates = self._make_element_coordinates(element_kind=element_kind, assume_linear=assume_linear)
453
477
 
454
478
  if element_kind == ElementKind.CELL:
455
- ref_elt = self.reference_cell()
479
+ ref_elt = self.reference_cell().prototype
456
480
  arg_type = self.CellArg
457
481
  elt_pos = self.cell_position
458
482
  elt_def_grad = self.cell_deformation_gradient
459
483
  else:
460
- ref_elt = self.reference_side()
484
+ ref_elt = self.reference_side().prototype
461
485
  arg_type = self.SideArg
462
486
  elt_pos = self.side_position
463
487
  elt_def_grad = self.side_deformation_gradient
@@ -636,8 +660,12 @@ class Geometry:
636
660
 
637
661
  if self._bvhs is None:
638
662
  self._bvhs = {}
663
+
639
664
  self._bvhs[device.ordinal] = wp.Bvh(lowers, uppers)
640
665
 
666
+ Geometry.cell_arg_value.invalidate(self, device)
667
+ Geometry.side_arg_value.invalidate(self, device)
668
+
641
669
  def bvh_id(self, device):
642
670
  if self._bvhs is None:
643
671
  return _NULL_BVH_ID
@@ -13,14 +13,15 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from functools import cached_property
16
17
  from typing import Any, Optional
17
18
 
18
19
  import warp as wp
19
- from warp.fem.cache import cached_arg_value, dynamic_func
20
- from warp.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
20
+ from warp._src.fem.cache import cached_arg_value, dynamic_func
21
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
21
22
 
22
23
  from .closest_point import project_on_box_at_origin
23
- from .element import LinearEdge, Square
24
+ from .element import Element
24
25
  from .geometry import Geometry
25
26
 
26
27
 
@@ -59,7 +60,7 @@ class Grid2D(Geometry):
59
60
 
60
61
  self._res = res
61
62
 
62
- @property
63
+ @cached_property
63
64
  def extents(self) -> wp.vec3:
64
65
  # Avoid using native sub due to higher over of calling builtins from Python
65
66
  return wp.vec2(
@@ -67,7 +68,7 @@ class Grid2D(Geometry):
67
68
  self.bounds_hi[1] - self.bounds_lo[1],
68
69
  )
69
70
 
70
- @property
71
+ @cached_property
71
72
  def cell_size(self) -> wp.vec2:
72
73
  ex = self.extents
73
74
  return wp.vec2(
@@ -87,11 +88,11 @@ class Grid2D(Geometry):
87
88
  def boundary_side_count(self):
88
89
  return 2 * (self.res[0] + self.res[1])
89
90
 
90
- def reference_cell(self) -> Square:
91
- return Square()
91
+ def reference_cell(self) -> Element:
92
+ return Element.SQUARE
92
93
 
93
- def reference_side(self) -> LinearEdge:
94
- return LinearEdge()
94
+ def reference_side(self) -> Element:
95
+ return Element.LINE_SEGMENT
95
96
 
96
97
  @property
97
98
  def res(self):
@@ -101,7 +102,7 @@ class Grid2D(Geometry):
101
102
  def origin(self):
102
103
  return self.bounds_lo
103
104
 
104
- @property
105
+ @cached_property
105
106
  def strides(self):
106
107
  return wp.vec2i(self.res[1], 1)
107
108
 
@@ -183,12 +184,6 @@ class Grid2D(Geometry):
183
184
 
184
185
  # Geometry device interface
185
186
 
186
- @cached_arg_value
187
- def cell_arg_value(self, device) -> CellArg:
188
- args = self.CellArg()
189
- self.fill_cell_arg(args, device)
190
- return args
191
-
192
187
  def fill_cell_arg(self, args: CellArg, device):
193
188
  args.res = self.res
194
189
  args.cell_size = self.cell_size
@@ -303,23 +298,17 @@ class Grid2D(Geometry):
303
298
  @cached_arg_value
304
299
  def side_arg_value(self, device) -> SideArg:
305
300
  args = self.SideArg()
306
- self.fill_side_arg(args, device)
307
- return args
308
-
309
- def fill_side_arg(self, args: SideArg, device):
310
301
  args.axis_offsets = wp.vec2i(
311
302
  0,
312
303
  self.res[1],
313
304
  )
314
305
  args.cell_count = self.cell_count()
315
306
  args.cell_arg = self.cell_arg_value(device)
307
+ return args
316
308
 
317
309
  def side_index_arg_value(self, device) -> SideIndexArg:
318
310
  return self.side_arg_value(device)
319
311
 
320
- def fill_side_index_arg(self, args: SideIndexArg, device):
321
- self.fill_side_arg(args, device)
322
-
323
312
  @wp.func
324
313
  def boundary_side_index(args: SideArg, boundary_side_index: int):
325
314
  """Boundary side to side index"""
@@ -13,14 +13,15 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from functools import cached_property
16
17
  from typing import Any, Optional
17
18
 
18
19
  import warp as wp
19
- from warp.fem.cache import cached_arg_value, dynamic_func
20
- from warp.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
20
+ from warp._src.fem.cache import cached_arg_value, dynamic_func
21
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
21
22
 
22
23
  from .closest_point import project_on_box_at_origin
23
- from .element import Cube, Square
24
+ from .element import Element
24
25
  from .geometry import Geometry
25
26
 
26
27
 
@@ -56,7 +57,7 @@ class Grid3D(Geometry):
56
57
 
57
58
  self._res = res
58
59
 
59
- @property
60
+ @cached_property
60
61
  def extents(self) -> wp.vec3:
61
62
  # Avoid using native sub due to higher over of calling builtins from Python
62
63
  return wp.vec3(
@@ -65,7 +66,7 @@ class Grid3D(Geometry):
65
66
  self.bounds_hi[2] - self.bounds_lo[2],
66
67
  )
67
68
 
68
- @property
69
+ @cached_property
69
70
  def cell_size(self) -> wp.vec3:
70
71
  ex = self.extents
71
72
  return wp.vec3(
@@ -97,11 +98,11 @@ class Grid3D(Geometry):
97
98
  def boundary_side_count(self):
98
99
  return 2 * (self.res[1]) * (self.res[2]) + (self.res[0]) * 2 * (self.res[2]) + (self.res[0]) * (self.res[1]) * 2
99
100
 
100
- def reference_cell(self) -> Cube:
101
- return Cube()
101
+ def reference_cell(self) -> Element:
102
+ return Element.CUBE
102
103
 
103
- def reference_side(self) -> Square:
104
- return Square()
104
+ def reference_side(self) -> Element:
105
+ return Element.SQUARE
105
106
 
106
107
  @property
107
108
  def res(self):
@@ -111,7 +112,7 @@ class Grid3D(Geometry):
111
112
  def origin(self):
112
113
  return self.bounds_lo
113
114
 
114
- @property
115
+ @cached_property
115
116
  def strides(self):
116
117
  return wp.vec3i(self.res[1] * self.res[2], self.res[2], 1)
117
118
 
@@ -223,12 +224,6 @@ class Grid3D(Geometry):
223
224
 
224
225
  # Geometry device interface
225
226
 
226
- @cached_arg_value
227
- def cell_arg_value(self, device) -> CellArg:
228
- args = self.CellArg()
229
- self.fill_cell_arg(args, device)
230
- return args
231
-
232
227
  def fill_cell_arg(self, args: CellArg, device):
233
228
  args.res = self.res
234
229
  args.origin = self.bounds_lo
@@ -348,10 +343,6 @@ class Grid3D(Geometry):
348
343
  @cached_arg_value
349
344
  def side_arg_value(self, device) -> SideArg:
350
345
  args = self.SideArg()
351
- self.fill_side_arg(args, device)
352
- return args
353
-
354
- def fill_side_arg(self, args: SideArg, device):
355
346
  axis_dims = wp.vec3i(
356
347
  self.res[1] * self.res[2],
357
348
  self.res[2] * self.res[0],
@@ -364,13 +355,11 @@ class Grid3D(Geometry):
364
355
  )
365
356
  args.cell_count = self.cell_count()
366
357
  args.cell_arg = self.cell_arg_value(device)
358
+ return args
367
359
 
368
360
  def side_index_arg_value(self, device) -> SideIndexArg:
369
361
  return self.side_arg_value(device)
370
362
 
371
- def fill_side_index_arg(self, args: SideIndexArg, device):
372
- self.fill_side_arg(args, device)
373
-
374
363
  @wp.func
375
364
  def boundary_side_index(args: SideArg, boundary_side_index: int):
376
365
  """Boundary side to side index"""