warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,25 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any, Optional
2
17
 
3
18
  import warp as wp
4
19
  from warp.fem import cache
5
20
  from warp.fem.geometry import Geometry
6
21
  from warp.fem.linalg import basis_element, generalized_inner, generalized_outer
7
- from warp.fem.types import Coords, ElementIndex, make_free_sample
22
+ from warp.fem.types import NULL_QP_INDEX, Coords, ElementIndex, make_free_sample
8
23
 
9
24
  from .basis_space import BasisSpace
10
25
  from .dof_mapper import DofMapper, IdentityMapper
@@ -290,7 +305,9 @@ class VectorValuedFunctionSpace(FunctionSpace):
290
305
  space_value: self.dtype,
291
306
  ):
292
307
  coords = self.node_coords_in_element(elt_arg, space_arg, element_index, node_index_in_elt)
293
- weight = self.element_inner_weight(elt_arg, space_arg, element_index, coords, node_index_in_elt)
308
+ weight = self.element_inner_weight(
309
+ elt_arg, space_arg, element_index, coords, node_index_in_elt, NULL_QP_INDEX
310
+ )
294
311
  local_value_map = self.local_value_map_inner(elt_arg, element_index, coords)
295
312
 
296
313
  unit_value = local_value_map * weight
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Optional
2
17
 
3
18
  import numpy as np
@@ -6,7 +21,14 @@ import warp as wp
6
21
  from warp.fem import cache
7
22
  from warp.fem.geometry import Geometry
8
23
  from warp.fem.quadrature import Quadrature
9
- from warp.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, make_free_sample
24
+ from warp.fem.types import (
25
+ NULL_ELEMENT_INDEX,
26
+ NULL_QP_INDEX,
27
+ Coords,
28
+ ElementIndex,
29
+ QuadraturePointIndex,
30
+ make_free_sample,
31
+ )
10
32
 
11
33
  from .shape import ShapeFunction
12
34
  from .topology import RegularDiscontinuousSpaceTopology, SpaceTopology
@@ -220,6 +242,7 @@ class ShapeBasisSpace(BasisSpace):
220
242
  element_index: ElementIndex,
221
243
  coords: Coords,
222
244
  node_index_in_elt: int,
245
+ qp_index: QuadraturePointIndex,
223
246
  ):
224
247
  if wp.static(self.value == ShapeFunction.Value.Scalar):
225
248
  return shape_element_inner_weight(coords, node_index_in_elt)
@@ -239,6 +262,7 @@ class ShapeBasisSpace(BasisSpace):
239
262
  element_index: ElementIndex,
240
263
  coords: Coords,
241
264
  node_index_in_elt: int,
265
+ qp_index: QuadraturePointIndex,
242
266
  ):
243
267
  if wp.static(self.value == ShapeFunction.Value.Scalar):
244
268
  return shape_element_inner_weight_gradient(coords, node_index_in_elt)
@@ -358,6 +382,7 @@ class TraceBasisSpace(BasisSpace):
358
382
  element_index: ElementIndex,
359
383
  coords: Coords,
360
384
  node_index_in_elt: int,
385
+ qp_index: QuadraturePointIndex,
361
386
  ):
362
387
  cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt)
363
388
  if cell_index == NULL_ELEMENT_INDEX:
@@ -366,13 +391,7 @@ class TraceBasisSpace(BasisSpace):
366
391
  cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords)
367
392
 
368
393
  geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
369
- return cell_inner_weight(
370
- geo_cell_arg,
371
- basis_arg,
372
- cell_index,
373
- cell_coords,
374
- index_in_cell,
375
- )
394
+ return cell_inner_weight(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX)
376
395
 
377
396
  return trace_element_inner_weight
378
397
 
@@ -386,6 +405,7 @@ class TraceBasisSpace(BasisSpace):
386
405
  element_index: ElementIndex,
387
406
  coords: Coords,
388
407
  node_index_in_elt: int,
408
+ qp_index: QuadraturePointIndex,
389
409
  ):
390
410
  cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt)
391
411
  if cell_index == NULL_ELEMENT_INDEX:
@@ -394,13 +414,7 @@ class TraceBasisSpace(BasisSpace):
394
414
  cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords)
395
415
 
396
416
  geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
397
- return cell_outer_weight(
398
- geo_cell_arg,
399
- basis_arg,
400
- cell_index,
401
- cell_coords,
402
- index_in_cell,
403
- )
417
+ return cell_outer_weight(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX)
404
418
 
405
419
  return trace_element_outer_weight
406
420
 
@@ -414,6 +428,7 @@ class TraceBasisSpace(BasisSpace):
414
428
  element_index: ElementIndex,
415
429
  coords: Coords,
416
430
  node_index_in_elt: int,
431
+ qp_index: QuadraturePointIndex,
417
432
  ):
418
433
  cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt)
419
434
  if cell_index == NULL_ELEMENT_INDEX:
@@ -421,7 +436,9 @@ class TraceBasisSpace(BasisSpace):
421
436
 
422
437
  cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords)
423
438
  geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
424
- return cell_inner_weight_gradient(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell)
439
+ return cell_inner_weight_gradient(
440
+ geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX
441
+ )
425
442
 
426
443
  return trace_element_inner_weight_gradient
427
444
 
@@ -435,6 +452,7 @@ class TraceBasisSpace(BasisSpace):
435
452
  element_index: ElementIndex,
436
453
  coords: Coords,
437
454
  node_index_in_elt: int,
455
+ qp_index: QuadraturePointIndex,
438
456
  ):
439
457
  cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt)
440
458
  if cell_index == NULL_ELEMENT_INDEX:
@@ -442,7 +460,9 @@ class TraceBasisSpace(BasisSpace):
442
460
 
443
461
  cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords)
444
462
  geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
445
- return cell_outer_weight_gradient(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell)
463
+ return cell_outer_weight_gradient(
464
+ geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell, NULL_QP_INDEX
465
+ )
446
466
 
447
467
  return trace_element_outer_weight_gradient
448
468
 
@@ -609,11 +629,12 @@ class PointBasisSpace(BasisSpace):
609
629
  element_index: ElementIndex,
610
630
  coords: Coords,
611
631
  node_index_in_elt: int,
632
+ qp_index: QuadraturePointIndex,
612
633
  ):
613
634
  qp_coord = self._quadrature.point_coords(
614
635
  elt_arg, basis_arg, element_index, element_index, node_index_in_elt
615
636
  )
616
- return wp.select(wp.length_sq(coords - qp_coord) < _DIRAC_INTEGRATION_RADIUS, 0.0, 1.0)
637
+ return wp.where(wp.length_sq(coords - qp_coord) < _DIRAC_INTEGRATION_RADIUS, 1.0, 0.0)
617
638
 
618
639
  return element_inner_weight
619
640
 
@@ -627,6 +648,7 @@ class PointBasisSpace(BasisSpace):
627
648
  element_index: ElementIndex,
628
649
  coords: Coords,
629
650
  node_index_in_elt: int,
651
+ qp_index: QuadraturePointIndex,
630
652
  ):
631
653
  return gradient_vec(0.0)
632
654
 
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import math
2
17
  from enum import Enum
3
18
  from typing import Any
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
18
  import warp as wp
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import numpy as np
2
17
 
3
18
  import warp as wp
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import numpy as np
2
17
 
3
18
  import warp as wp
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import warp as wp
2
17
  from warp.fem import cache
3
18
  from warp.fem.geometry import Hexmesh
@@ -222,13 +237,13 @@ class HexmeshSpaceTopology(SpaceTopology):
222
237
  hex_edge = _CUBE_TO_HEX_EDGE[type_instance]
223
238
  v0 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 0]]
224
239
  v1 = geo_arg.hex_vertex_indices[element_index, EDGE_VERTEX_INDICES[hex_edge, 1]]
225
- return wp.select(v0 > v1, 1.0, -1.0)
240
+ return wp.where(v0 > v1, -1.0, 1.0)
226
241
 
227
242
  if wp.static(FACE_NODE_COUNT > 0):
228
243
  if node_type == CubeShapeFunction.FACE:
229
244
  face_index_and_ori = topo_arg.hex_face_indices[element_index, type_instance]
230
245
  flip = face_index_and_ori[1] & 1
231
- return wp.select(flip == 0, -1.0, 1.0)
246
+ return wp.where(flip == 0, 1.0, -1.0)
232
247
 
233
248
  return 1.0
234
249
 
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Union
2
17
 
3
18
  import warp as wp
@@ -1,10 +1,25 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any, Optional
2
17
 
3
18
  import warp as wp
4
19
  import warp.fem.cache as cache
5
20
  from warp.fem.geometry import GeometryPartition, WholeGeometryPartition
6
21
  from warp.fem.types import NULL_NODE_INDEX
7
- from warp.fem.utils import _iota_kernel, compress_node_indices
22
+ from warp.fem.utils import compress_node_indices
8
23
 
9
24
  from .function_space import FunctionSpace
10
25
  from .topology import SpaceTopology
@@ -72,7 +87,7 @@ class WholeSpacePartition(SpacePartition):
72
87
  """Return the global function space indices for nodes in this partition"""
73
88
  if self._node_indices is None:
74
89
  self._node_indices = cache.borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int)
75
- wp.launch(kernel=_iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array, 1])
90
+ wp.launch(kernel=self._iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array])
76
91
  return self._node_indices.array
77
92
 
78
93
  def partition_arg_value(self, device):
@@ -89,6 +104,10 @@ class WholeSpacePartition(SpacePartition):
89
104
  def name(self) -> str:
90
105
  return "Whole"
91
106
 
107
+ @wp.kernel
108
+ def _iota_kernel(indices: wp.array(dtype=int)):
109
+ indices[wp.tid()] = wp.tid()
110
+
92
111
 
93
112
  class NodeCategory:
94
113
  OWNED_INTERIOR = wp.constant(0)
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import warp as wp
2
17
  from warp.fem import cache
3
18
  from warp.fem.geometry import Quadmesh2D
@@ -151,10 +166,10 @@ class QuadmeshSpaceTopology(SpaceTopology):
151
166
 
152
167
  if wp.static(EDGE_NODE_COUNT > 0):
153
168
  # EDGE_X, EDGE_Y
154
- side_start = wp.select(
169
+ side_start = wp.where(
155
170
  node_type == SquareShapeFunction.EDGE_X,
156
- wp.select(type_instance == 0, 1, 3),
157
- wp.select(type_instance == 0, 2, 0),
171
+ wp.where(type_instance == 0, 0, 2),
172
+ wp.where(type_instance == 0, 3, 1),
158
173
  )
159
174
 
160
175
  side_index = topo_arg.quad_edge_indices[element_index, side_start]
@@ -163,7 +178,7 @@ class QuadmeshSpaceTopology(SpaceTopology):
163
178
 
164
179
  # Flip indexing direction
165
180
  flipped = int(side_start >= 2) ^ int(local_vs != global_vs)
166
- index_in_side = wp.select(flipped, type_index, EDGE_NODE_COUNT - 1 - type_index)
181
+ index_in_side = wp.where(flipped, EDGE_NODE_COUNT - 1 - type_index, type_index)
167
182
 
168
183
  return global_offset + EDGE_NODE_COUNT * side_index + index_in_side
169
184
 
@@ -182,10 +197,10 @@ class QuadmeshSpaceTopology(SpaceTopology):
182
197
  node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
183
198
 
184
199
  if node_type == SquareShapeFunction.EDGE_X or node_type == SquareShapeFunction.EDGE_Y:
185
- side_start = wp.select(
200
+ side_start = wp.where(
186
201
  node_type == SquareShapeFunction.EDGE_X,
187
- wp.select(type_instance == 0, 1, 3),
188
- wp.select(type_instance == 0, 2, 0),
202
+ wp.where(type_instance == 0, 0, 2),
203
+ wp.where(type_instance == 0, 3, 1),
189
204
  )
190
205
 
191
206
  side_index = topo_arg.quad_edge_indices[element_index, side_start]
@@ -194,7 +209,7 @@ class QuadmeshSpaceTopology(SpaceTopology):
194
209
 
195
210
  # Flip indexing direction
196
211
  flipped = int(side_start >= 2) ^ int(local_vs != global_vs)
197
- return wp.select(flipped, 1.0, -1.0)
212
+ return wp.where(flipped, -1.0, 1.0)
198
213
 
199
214
  return 1.0
200
215
 
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import warp as wp
2
17
  from warp.fem.cache import TemporaryStore, borrow_temporary, borrow_temporary_like, cached_arg_value
3
18
  from warp.fem.domain import GeometryDomain
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from enum import Enum
2
17
  from typing import Optional
3
18
 
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import math
2
17
 
3
18
  import numpy as np
@@ -48,7 +63,7 @@ class CubeTripolynomialShapeFunctions(CubeShapeFunction):
48
63
 
49
64
  self.ORDER = wp.constant(degree)
50
65
  self.NODES_PER_ELEMENT = wp.constant((degree + 1) ** 3)
51
- self.NODES_PER_EDGE = wp.constant(degree + 1)
66
+ self.NODES_PER_SIDE = wp.constant((degree + 1) ** 2)
52
67
 
53
68
  if is_closed(self.family):
54
69
  self.VERTEX_NODE_COUNT = wp.constant(1)
@@ -137,13 +152,13 @@ class CubeTripolynomialShapeFunctions(CubeShapeFunction):
137
152
  ):
138
153
  i, j, k = self._node_ijk(node_index_in_elt)
139
154
 
140
- zi = wp.select(i == 0, 0, 1)
141
- zj = wp.select(j == 0, 0, 1)
142
- zk = wp.select(k == 0, 0, 1)
155
+ zi = wp.where(i == 0, 1, 0)
156
+ zj = wp.where(j == 0, 1, 0)
157
+ zk = wp.where(k == 0, 1, 0)
143
158
 
144
- mi = wp.select(i == ORDER, 0, 1)
145
- mj = wp.select(j == ORDER, 0, 1)
146
- mk = wp.select(k == ORDER, 0, 1)
159
+ mi = wp.where(i == ORDER, 1, 0)
160
+ mj = wp.where(j == ORDER, 1, 0)
161
+ mk = wp.where(k == ORDER, 1, 0)
147
162
 
148
163
  if zi + mi == 1:
149
164
  if zj + mj == 1:
@@ -489,7 +504,7 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
489
504
 
490
505
  self.ORDER = wp.constant(degree)
491
506
  self.NODES_PER_ELEMENT = wp.constant(8 + 12 * (degree - 1))
492
- self.NODES_PER_EDGE = wp.constant(degree + 1)
507
+ self.NODES_PER_SIDE = wp.constant(4 * degree)
493
508
 
494
509
  self.VERTEX_NODE_COUNT = wp.constant(1)
495
510
  self.EDGE_NODE_COUNT = wp.constant(degree - 1)
@@ -619,9 +634,9 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
619
634
  if node_type == CubeSerendipityShapeFunctions.VERTEX:
620
635
  node_ijk = CubeSerendipityShapeFunctions._vertex_coords(type_instance)
621
636
 
622
- cx = wp.select(node_ijk[0] == 0, coords[0], 1.0 - coords[0])
623
- cy = wp.select(node_ijk[1] == 0, coords[1], 1.0 - coords[1])
624
- cz = wp.select(node_ijk[2] == 0, coords[2], 1.0 - coords[2])
637
+ cx = wp.where(node_ijk[0] == 0, 1.0 - coords[0], coords[0])
638
+ cy = wp.where(node_ijk[1] == 0, 1.0 - coords[1], coords[1])
639
+ cz = wp.where(node_ijk[2] == 0, 1.0 - coords[2], coords[2])
625
640
 
626
641
  w = cx * cy * cz
627
642
 
@@ -644,8 +659,8 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
644
659
  local_coords = Grid3D._world_to_local(axis, coords)
645
660
 
646
661
  w = float(1.0)
647
- w *= wp.select(node_all[1] == 0, local_coords[1], 1.0 - local_coords[1])
648
- w *= wp.select(node_all[2] == 0, local_coords[2], 1.0 - local_coords[2])
662
+ w *= wp.where(node_all[1] == 0, 1.0 - local_coords[1], local_coords[1])
663
+ w *= wp.where(node_all[2] == 0, 1.0 - local_coords[2], local_coords[2])
649
664
 
650
665
  for k in range(ORDER_PLUS_ONE):
651
666
  if k != node_all[0]:
@@ -675,13 +690,13 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
675
690
  if node_type == CubeSerendipityShapeFunctions.VERTEX:
676
691
  node_ijk = CubeSerendipityShapeFunctions._vertex_coords(type_instance)
677
692
 
678
- cx = wp.select(node_ijk[0] == 0, coords[0], 1.0 - coords[0])
679
- cy = wp.select(node_ijk[1] == 0, coords[1], 1.0 - coords[1])
680
- cz = wp.select(node_ijk[2] == 0, coords[2], 1.0 - coords[2])
693
+ cx = wp.where(node_ijk[0] == 0, 1.0 - coords[0], coords[0])
694
+ cy = wp.where(node_ijk[1] == 0, 1.0 - coords[1], coords[1])
695
+ cz = wp.where(node_ijk[2] == 0, 1.0 - coords[2], coords[2])
681
696
 
682
- gx = wp.select(node_ijk[0] == 0, 1.0, -1.0)
683
- gy = wp.select(node_ijk[1] == 0, 1.0, -1.0)
684
- gz = wp.select(node_ijk[2] == 0, 1.0, -1.0)
697
+ gx = wp.where(node_ijk[0] == 0, -1.0, 1.0)
698
+ gy = wp.where(node_ijk[1] == 0, -1.0, 1.0)
699
+ gz = wp.where(node_ijk[2] == 0, -1.0, 1.0)
685
700
 
686
701
  if wp.static(ORDER == 2):
687
702
  w = cx + cy + cz - 3.0 + LOBATTO_COORDS[1]
@@ -713,11 +728,11 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
713
728
 
714
729
  local_coords = Grid3D._world_to_local(axis, coords)
715
730
 
716
- w_long = wp.select(node_all[1] == 0, local_coords[1], 1.0 - local_coords[1])
717
- w_lat = wp.select(node_all[2] == 0, local_coords[2], 1.0 - local_coords[2])
731
+ w_long = wp.where(node_all[1] == 0, 1.0 - local_coords[1], local_coords[1])
732
+ w_lat = wp.where(node_all[2] == 0, 1.0 - local_coords[2], local_coords[2])
718
733
 
719
- g_long = wp.select(node_all[1] == 0, 1.0, -1.0)
720
- g_lat = wp.select(node_all[2] == 0, 1.0, -1.0)
734
+ g_long = wp.where(node_all[1] == 0, -1.0, 1.0)
735
+ g_lat = wp.where(node_all[2] == 0, -1.0, 1.0)
721
736
 
722
737
  w_alt = LAGRANGE_SCALE[node_all[0]]
723
738
  g_alt = float(0.0)
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from enum import Enum
2
17
 
3
18
  import numpy as np